ljpzzz / machinelearning

My blogs and code for machine learning. http://cnblogs.com/pinard

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

您好,可以帮我看看这个Nature-DQN的代码吗?运行特别慢

miss-fang opened this issue · comments

`import tensorflow as tf
import numpy as np
import gym
import random
from collections import deque
from keras.utils.np_utils import to_categorical

import tensorflow.keras.backend as K

class QNetwork(tf.keras.Model):

def __init__(self):
    super().__init__()
    self.dense1=tf.keras.layers.Dense(24,activation='relu')
    self.dense2=tf.keras.layers.Dense(2)
    self.dense3=tf.keras.layers.Dense(24,activation='relu')
    self.dense4=tf.keras.layers.Dense(2)

def call(self,inputs):
    x=self.dense1(inputs)
    x=self.dense2(x)
    return x

def tarNet_Q(self,inputs):
    x=self.dense3(inputs)
    x=self.dense4(x)
    return x

def get_action(self,inputs):
    q_values=self(inputs)
    return K.eval(tf.argmax(q_values,axis=-1))[0]

env=gym.make('CartPole-v0')

num_episodes=300
num_exploration=200
max_len=400
batch_size=32
lr=1e-3
gamma=0.9
initial_epsilon=0.5
final_epsilon=0.01
replay_buffer=deque(maxlen=10000)
tarNet_update_frequence=10
optimizer=tf.train.AdamOptimizer(learning_rate=lr)
qNet=QNetwork()
for i in range(1,num_episodes+1):
state=env.reset()
epsilon=max(initial_epsilon*(num_exploration-i)/num_exploration,final_epsilon)
for t in range(max_len):#设置最大得分1000
if random.random()<epsilon:
action=env.action_space.sample()
else:
action=qNet.get_action(tf.constant(np.expand_dims(state,axis=0),dtype=tf.float32))
next_state,reward,done,info=env.step(action)
reward=-1.if done else reward
replay_buffer.append((state,action,reward,next_state,done))
state=next_state
if done:
print('episode %d,epsilon %f,score %d'%(i,epsilon,t))
break
if len(replay_buffer)>=batch_size:
batch_state,batch_action,batch_reward,batch_next_state,batch_done=
[np.array(a,dtype=np.float32) for a in zip(random.sample(replay_buffer,batch_size))]
q_value=qNet.tarNet_Q(tf.constant(batch_next_state,dtype=tf.float32))
y=batch_reward+(gamma
tf.reduce_max(q_value,axis=1))*(1-batch_done)
with tf.GradientTape() as tape:
loss=tf.losses.mean_squared_error(y,tf.reduce_max(
qNet(tf.constant(batch_state))*to_categorical(batch_action,num_classes=2),axis=1))
grads=tape.gradient(loss,qNet.variables[:4])
optimizer.apply_gradients(grads_and_vars=zip(grads,qNet.variables[:4]))
if i%tarNet_update_frequence==0:
for j in range(2):
tf.assign(qNet.variables[4+j],qNet.dense1.get_weights()[j])
tf.assign(qNet.variables[6+j],qNet.dense2.get_weights()[j])
env.close()
`
我觉得运行慢是因为复制网络参数的方式不对,请看到的兄弟姐妹给个建议。