怎么对网站上的游记做数据分析,百度搜索引擎竞价排名,asp网站如何发布,wordpress安装好之后怎么用强化学习DQN之俄罗斯方块强化学习DQN之俄罗斯方块算法流程文件目录结构模型结构游戏环境训练代码测试代码结果展示强化学习DQN之俄罗斯方块
算法流程
本项目目的是训练一个基于深度强化学习的俄罗斯方块。具体来说#xff0c;这个代码通过以下步骤实现训练#xff1a;
首先…
强化学习DQN之俄罗斯方块强化学习DQN之俄罗斯方块算法流程文件目录结构模型结构游戏环境训练代码测试代码结果展示强化学习DQN之俄罗斯方块
算法流程
本项目目的是训练一个基于深度强化学习的俄罗斯方块。具体来说这个代码通过以下步骤实现训练
首先设置一些随机数种子以便在后面的训练中能够重现结果。创建一个俄罗斯方块环境实例这个环境是一个俄罗斯方块游戏用于模拟AI与游戏的交互。创建一个DeepQNetwork模型这个模型是基于深度学习的强化学习模型用于预测下一步的最佳行动。创建一个优化器optimizer和一个损失函数criterion用于训练模型。在每个训练时期epoch中对于当前状态state计算所有可能的下一步状态next_steps根据一定的策略exploration or exploitation选择一个行动action并计算该行动带来的奖励reward和下一步是否为终止状态done。将当前状态、奖励、下一步状态和终止状态添加到回放内存replay memory中。如果当前状态为终止状态则重置环境并记录得分final_score、俄罗斯方块数量final_tetrominoes和消除的行数final_cleared_lines。从回放内存中随机选择一批样本batch并将其用于训练模型。具体来说将状态批次state_batch、奖励批次reward_batch、下一步状态批次next_state_batch和是否为终止状态批次done_batch分别取出并将其分别转换为张量tensor。然后计算每个样本的目标值targety_batch并用它来计算损失值loss并将损失值的梯度反向传播backpropagation。最后使用优化器来更新模型参数。输出当前训练时期的信息并记录得分、俄罗斯方块数量和消除的行数到TensorBoard中。如果当前训练时期为某个特定数的倍数将模型保存到硬盘中。重复上述步骤直到达到指定的训练时期数。
文件目录结构
├── output.mp4
├── src
│ ├── deep_q_network.py 模型结构
│ └── tetris.py 游戏环境
├── tensorboard
│ └── events.out.tfevents.1676879249.aifs3-worker-2
├── test.py 测试代码
├── trained_models 训练保存的模型
│ ├── tetris
│ ├── tetris_1000
│ ├── tetris_1500
│ ├── tetris_2000
│ └── tetris_500
└── train.py 训练代码
模型结构
import torch.nn as nnclass DeepQNetwork(nn.Module):def __init__(self):super(DeepQNetwork, self).__init__()self.conv1 nn.Sequential(nn.Linear(4, 64), nn.ReLU(inplaceTrue))self.conv2 nn.Sequential(nn.Linear(64, 64), nn.ReLU(inplaceTrue))self.conv3 nn.Sequential(nn.Linear(64, 1))self._create_weights()def _create_weights(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.constant_(m.bias, 0)def forward(self, x):x self.conv1(x)x self.conv2(x)x self.conv3(x)return x游戏环境
import numpy as np
from PIL import Image
import cv2
from matplotlib import style
import torch
import randomstyle.use(ggplot)class Tetris:piece_colors [(0, 0, 0),(255, 255, 0),(147, 88, 254),(54, 175, 144),(255, 0, 0),(102, 217, 238),(254, 151, 32),(0, 0, 255)]pieces [[[1, 1],[1, 1]],[[0, 2, 0],[2, 2, 2]],[[0, 3, 3],[3, 3, 0]],[[4, 4, 0],[0, 4, 4]],[[5, 5, 5, 5]],[[0, 0, 6],[6, 6, 6]],[[7, 0, 0],[7, 7, 7]]]def __init__(self, height20, width10, block_size20):self.height heightself.width widthself.block_size block_sizeself.extra_board np.ones((self.height * self.block_size, self.width * int(self.block_size / 2), 3),dtypenp.uint8) * np.array([204, 204, 255], dtypenp.uint8)self.text_color (200, 20, 220)self.reset()#---------------------------------------------------------------------------------------# 重置游戏#---------------------------------------------------------------------------------------def reset(self):self.board [[0] * self.width for _ in range(self.height)]self.score 0self.tetrominoes 0self.cleared_lines 0self.bag list(range(len(self.pieces)))random.shuffle(self.bag)self.ind self.bag.pop()self.piece [row[:] for row in self.pieces[self.ind]]self.current_pos {x: self.width // 2 - len(self.piece[0]) // 2, y: 0}self.gameover Falsereturn self.get_state_properties(self.board)#---------------------------------------------------------------------------------------# 旋转方块#---------------------------------------------------------------------------------------def rotate(self, piece):num_rows_orig num_cols_new len(piece)num_rows_new len(piece[0])rotated_array []for i in range(num_rows_new):new_row [0] * num_cols_newfor j in range(num_cols_new):new_row[j] piece[(num_rows_orig - 1) - j][i]rotated_array.append(new_row)return rotated_array#---------------------------------------------------------------------------------------# 获取当前游戏状态的一些属性#---------------------------------------------------------------------------------------def get_state_properties(self, board):lines_cleared, board self.check_cleared_rows(board)holes self.get_holes(board)bumpiness, height self.get_bumpiness_and_height(board)return torch.FloatTensor([lines_cleared, holes, bumpiness, height])#---------------------------------------------------------------------------------------# 面板中空洞数量#---------------------------------------------------------------------------------------def get_holes(self, board):num_holes 0for col in zip(*board):row 0while row self.height and col[row] 0:row 1num_holes len([x for x in col[row 1:] if x 0])return num_holes#---------------------------------------------------------------------------------------# 计算游戏面板的凹凸度和亮度#---------------------------------------------------------------------------------------def get_bumpiness_and_height(self, board):board np.array(board)mask board ! 0invert_heights np.where(mask.any(axis0), np.argmax(mask, axis0), self.height)heights self.height - invert_heightstotal_height np.sum(heights)currs heights[:-1]nexts heights[1:]diffs np.abs(currs - nexts)total_bumpiness np.sum(diffs)return total_bumpiness, total_height#---------------------------------------------------------------------------------------# 获取下一个可能的状态#---------------------------------------------------------------------------------------def get_next_states(self):states {}piece_id self.indcurr_piece [row[:] for row in self.piece]if piece_id 0: # O piecenum_rotations 1elif piece_id 2 or piece_id 3 or piece_id 4:num_rotations 2else:num_rotations 4for i in range(num_rotations):valid_xs self.width - len(curr_piece[0])for x in range(valid_xs 1):piece [row[:] for row in curr_piece]pos {x: x, y: 0}while not self.check_collision(piece, pos):pos[y] 1self.truncate(piece, pos)board self.store(piece, pos)states[(x, i)] self.get_state_properties(board)curr_piece self.rotate(curr_piece)return states#---------------------------------------------------------------------------------------# 获取当前面板状态#---------------------------------------------------------------------------------------def get_current_board_state(self):board [x[:] for x in self.board]for y in range(len(self.piece)):for x in range(len(self.piece[y])):board[y self.current_pos[y]][x self.current_pos[x]] self.piece[y][x]return board#---------------------------------------------------------------------------------------# 添加新的方块#---------------------------------------------------------------------------------------def new_piece(self):if not len(self.bag):self.bag list(range(len(self.pieces)))random.shuffle(self.bag)self.ind self.bag.pop()self.piece [row[:] for row in self.pieces[self.ind]]self.current_pos {x: self.width // 2 - len(self.piece[0]) // 2,y: 0}if self.check_collision(self.piece, self.current_pos):self.gameover True#---------------------------------------------------------------------------------------# 检查边界 输入形状、位置#---------------------------------------------------------------------------------------def check_collision(self, piece, pos):future_y pos[y] 1for y in range(len(piece)):for x in range(len(piece[y])):if future_y y self.height - 1 or self.board[future_y y][pos[x] x] and piece[y][x]:return Truereturn Falsedef truncate(self, piece, pos):gameover Falselast_collision_row -1for y in range(len(piece)):for x in range(len(piece[y])):if self.board[pos[y] y][pos[x] x] and piece[y][x]:if y last_collision_row:last_collision_row yif pos[y] - (len(piece) - last_collision_row) 0 and last_collision_row -1:while last_collision_row 0 and len(piece) 1:gameover Truelast_collision_row -1del piece[0]for y in range(len(piece)):for x in range(len(piece[y])):if self.board[pos[y] y][pos[x] x] and piece[y][x] and y last_collision_row:last_collision_row yreturn gameoverdef store(self, piece, pos):board [x[:] for x in self.board]for y in range(len(piece)):for x in range(len(piece[y])):if piece[y][x] and not board[y pos[y]][x pos[x]]:board[y pos[y]][x pos[x]] piece[y][x]return boarddef check_cleared_rows(self, board):to_delete []for i, row in enumerate(board[::-1]):if 0 not in row:to_delete.append(len(board) - 1 - i)if len(to_delete) 0:board self.remove_row(board, to_delete)return len(to_delete), boarddef remove_row(self, board, indices):for i in indices[::-1]:del board[i]board [[0 for _ in range(self.width)]] boardreturn boarddef step(self, action, renderTrue, videoNone):x, num_rotations actionself.current_pos {x: x, y: 0}for _ in range(num_rotations):self.piece self.rotate(self.piece)while not self.check_collision(self.piece, self.current_pos):self.current_pos[y] 1if render:self.render(video)overflow self.truncate(self.piece, self.current_pos)if overflow:self.gameover Trueself.board self.store(self.piece, self.current_pos)lines_cleared, self.board self.check_cleared_rows(self.board)score 1 (lines_cleared ** 2) * self.widthself.score scoreself.tetrominoes 1self.cleared_lines lines_clearedif not self.gameover:self.new_piece()if self.gameover:self.score - 2return score, self.gameoverdef render(self, videoNone):if not self.gameover:img [self.piece_colors[p] for row in self.get_current_board_state() for p in row]else:img [self.piece_colors[p] for row in self.board for p in row]img np.array(img).reshape((self.height, self.width, 3)).astype(np.uint8)img img[..., ::-1]img Image.fromarray(img, RGB)img img.resize((self.width * self.block_size, self.height * self.block_size))img np.array(img)img[[i * self.block_size for i in range(self.height)], :, :] 0img[:, [i * self.block_size for i in range(self.width)], :] 0img np.concatenate((img, self.extra_board), axis1)cv2.putText(img, Score:, (self.width * self.block_size int(self.block_size / 2), self.block_size),fontFacecv2.FONT_HERSHEY_DUPLEX, fontScale1.0, colorself.text_color)cv2.putText(img, str(self.score),(self.width * self.block_size int(self.block_size / 2), 2 * self.block_size),fontFacecv2.FONT_HERSHEY_DUPLEX, fontScale1.0, colorself.text_color)cv2.putText(img, Pieces:, (self.width * self.block_size int(self.block_size / 2), 4 * self.block_size),fontFacecv2.FONT_HERSHEY_DUPLEX, fontScale1.0, colorself.text_color)cv2.putText(img, str(self.tetrominoes),(self.width * self.block_size int(self.block_size / 2), 5 * self.block_size),fontFacecv2.FONT_HERSHEY_DUPLEX, fontScale1.0, colorself.text_color)cv2.putText(img, Lines:, (self.width * self.block_size int(self.block_size / 2), 7 * self.block_size),fontFacecv2.FONT_HERSHEY_DUPLEX, fontScale1.0, colorself.text_color)cv2.putText(img, str(self.cleared_lines),(self.width * self.block_size int(self.block_size / 2), 8 * self.block_size),fontFacecv2.FONT_HERSHEY_DUPLEX, fontScale1.0, colorself.text_color)if video:video.write(img)cv2.imshow(Deep Q-Learning Tetris, img)cv2.waitKey(1)
训练代码
import argparse
import os
import shutil
from random import random, randint, sampleimport numpy as np
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
import time
from src.deep_q_network import DeepQNetwork
from src.tetris import Tetris
from collections import dequedef get_args():parser argparse.ArgumentParser(Implementation of Deep Q Network to play Tetris)parser.add_argument(--width, typeint, default10, helpThe common width for all images)parser.add_argument(--height, typeint, default20, helpThe common height for all images)parser.add_argument(--block_size, typeint, default30, helpSize of a block)parser.add_argument(--batch_size, typeint, default512, helpThe number of images per batch)parser.add_argument(--lr, typefloat, default1e-3)parser.add_argument(--gamma, typefloat, default0.99)parser.add_argument(--initial_epsilon, typefloat, default1)parser.add_argument(--final_epsilon, typefloat, default1e-3)parser.add_argument(--num_decay_epochs, typefloat, default2000)parser.add_argument(--num_epochs, typeint, default3000)parser.add_argument(--save_interval, typeint, default500)parser.add_argument(--replay_memory_size, typeint, default30000,helpNumber of epoches between testing phases)parser.add_argument(--log_path, typestr, defaulttensorboard)parser.add_argument(--saved_path, typestr, defaulttrained_models)args parser.parse_args()return argsdef train(opt):if torch.cuda.is_available():torch.cuda.manual_seed(123)else:torch.manual_seed(123)if os.path.isdir(opt.log_path):shutil.rmtree(opt.log_path)os.makedirs(opt.log_path)writer SummaryWriter(opt.log_path)env Tetris(widthopt.width, heightopt.height, block_sizeopt.block_size)model DeepQNetwork()optimizer torch.optim.Adam(model.parameters(), lropt.lr)criterion nn.MSELoss()state env.reset()if torch.cuda.is_available():model.cuda()state state.cuda()replay_memory deque(maxlenopt.replay_memory_size)epoch 0t1 time.time()total_time 0best_score 1000while epoch opt.num_epochs:start_time time.time()next_steps env.get_next_states()# Exploration or exploitationepsilon opt.final_epsilon (max(opt.num_decay_epochs - epoch, 0) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_decay_epochs)u random()random_action u epsilonnext_actions, next_states zip(*next_steps.items())next_states torch.stack(next_states)if torch.cuda.is_available():next_states next_states.cuda()model.eval()with torch.no_grad():predictions model(next_states)[:, 0]model.train()if random_action:index randint(0, len(next_steps) - 1)else:index torch.argmax(predictions).item()next_state next_states[index, :]action next_actions[index]reward, done env.step(action, renderTrue)if torch.cuda.is_available():next_state next_state.cuda()replay_memory.append([state, reward, next_state, done])if done:final_score env.scorefinal_tetrominoes env.tetrominoesfinal_cleared_lines env.cleared_linesstate env.reset()if torch.cuda.is_available():state state.cuda()else:state next_statecontinueif len(replay_memory) opt.replay_memory_size / 10:continueepoch 1batch sample(replay_memory, min(len(replay_memory), opt.batch_size))state_batch, reward_batch, next_state_batch, done_batch zip(*batch)state_batch torch.stack(tuple(state for state in state_batch))reward_batch torch.from_numpy(np.array(reward_batch, dtypenp.float32)[:, None])next_state_batch torch.stack(tuple(state for state in next_state_batch))if torch.cuda.is_available():state_batch state_batch.cuda()reward_batch reward_batch.cuda()next_state_batch next_state_batch.cuda()print(state_batch,state_batch.shape)q_values model(state_batch)model.eval()with torch.no_grad():next_prediction_batch model(next_state_batch)model.train()y_batch torch.cat(tuple(reward if done else reward opt.gamma * prediction for reward, done, prediction inzip(reward_batch, done_batch, next_prediction_batch)))[:, None]optimizer.zero_grad()loss criterion(q_values, y_batch)loss.backward()optimizer.step()end_time time.time()use_time end_time-t1 -total_timetotal_time end_time-t1print(Epoch: {}/{}, Action: {}, Score: {}, Tetrominoes {}, Cleared lines: {}, Used time: {}, total used time: {}.format(epoch,opt.num_epochs,action,final_score,final_tetrominoes,final_cleared_lines,use_time,total_time))writer.add_scalar(Train/Score, final_score, epoch - 1)writer.add_scalar(Train/Tetrominoes, final_tetrominoes, epoch - 1)writer.add_scalar(Train/Cleared lines, final_cleared_lines, epoch - 1)if epoch 0 and epoch % opt.save_interval 0:print(save interval model: {}.format(epoch))torch.save(model, {}/tetris_{}.format(opt.saved_path, epoch))elif final_scorebest_score:best_score final_scoreprint(save best model: {}.format(best_score))torch.save(model, {}/tetris_{}.format(opt.saved_path, best_score))if __name__ __main__:opt get_args()train(opt)
测试代码
import argparse
import torch
import cv2
from src.tetris import Tetrisdef get_args():parser argparse.ArgumentParser(Implementation of Deep Q Network to play Tetris)parser.add_argument(--width, typeint, default10, helpThe common width for all images)parser.add_argument(--height, typeint, default20, helpThe common height for all images)parser.add_argument(--block_size, typeint, default30, helpSize of a block)parser.add_argument(--fps, typeint, default300, helpframes per second)parser.add_argument(--saved_path, typestr, defaulttrained_models)parser.add_argument(--output, typestr, defaultoutput.mp4)args parser.parse_args()return argsdef test(opt):if torch.cuda.is_available():torch.cuda.manual_seed(123)else:torch.manual_seed(123)if torch.cuda.is_available():model torch.load({}/tetris_2000.format(opt.saved_path))else:model torch.load({}/tetris_2000.format(opt.saved_path), map_locationlambda storage, loc: storage)model.eval()env Tetris(widthopt.width, heightopt.height, block_sizeopt.block_size)env.reset()if torch.cuda.is_available():model.cuda()out cv2.VideoWriter(opt.output, cv2.VideoWriter_fourcc(*MJPG), opt.fps,(int(1.5*opt.width*opt.block_size), opt.height*opt.block_size))while True:next_steps env.get_next_states()next_actions, next_states zip(*next_steps.items())next_states torch.stack(next_states)if torch.cuda.is_available():next_states next_states.cuda()predictions model(next_states)[:, 0]index torch.argmax(predictions).item()action next_actions[index]_, done env.step(action, renderTrue, videoout)if done:out.release()breakif __name__ __main__:opt get_args()test(opt)
结果展示