Atari
環境設定
1 2 3 4 | !pip install keras-rl !pip install easydict !pip install pyvirtualdisplay !apt-get install xvfb |
Virtual Displayの設定
1 2 3 4 5 6 | # Start virtual display from pyvirtualdisplay import Display display = Display(visible=0, size=(1024, 768)) display.start() import os os.environ["DISPLAY"] = ":" + str(display.display) + "." + str(display.screen) |
Atariのサンプル
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | from __future__ import division import argparse from PIL import Image import numpy as np import easydict # added import gym from gym import wrappers # 追加 from keras.models import Sequential from keras.layers import Dense, Activation, Flatten, Convolution2D, Permute from keras.optimizers import Adam import keras.backend as K from rl.agents.dqn import DQNAgent from rl.policy import LinearAnnealedPolicy, BoltzmannQPolicy, EpsGreedyQPolicy from rl.memory import SequentialMemory from rl.core import Processor from rl.callbacks import FileLogger, ModelIntervalCheckpoint INPUT_SHAPE = (84, 84) WINDOW_LENGTH = 4 class AtariProcessor(Processor): def process_observation(self, observation): assert observation.ndim == 3 # (height, width, channel) img = Image.fromarray(observation) img = img.resize(INPUT_SHAPE).convert('L') # resize and convert to grayscale processed_observation = np.array(img) assert processed_observation.shape == INPUT_SHAPE return processed_observation.astype('uint8') # saves storage in experience memory def process_state_batch(self, batch): # We could perform this processing step in `process_observation`. In this case, however, # we would need to store a `float32` array instead, which is 4x more memory intensive than # an `uint8` array. This matters if we store 1M observations. processed_batch = batch.astype('float32') / 255. return processed_batch def process_reward(self, reward): return np.clip(reward, -1., 1.) ## argparse failes. Use easydict instead ## see https://qiita.com/LittleWat/items/6e56857e1f97c842b261 #parser = argparse.ArgumentParser() #parser.add_argument('--mode', choices=['train', 'test'], default='train') #parser.add_argument('--env-name', type=str, default='BreakoutDeterministic-v4') #parser.add_argument('--weights', type=str, default=None) #args = parser.parse_args() args = easydict.EasyDict({ # use this instead of argparse "mode": "train", "env_name": "BreakoutDeterministic-v4", "weight": None, }) # Get the environment and extract the number of actions. env = gym.make(args.env_name) env = wrappers.Monitor(env, './', force=True) # save animations np.random.seed(123) env.seed(123) nb_actions = env.action_space.n # Next, we build our model. We use the same model that was described by Mnih et al. (2015). input_shape = (WINDOW_LENGTH,) + INPUT_SHAPE model = Sequential() if K.image_dim_ordering() == 'tf': # (width, height, channels) model.add(Permute((2, 3, 1), input_shape=input_shape)) elif K.image_dim_ordering() == 'th': # (channels, width, height) model.add(Permute((1, 2, 3), input_shape=input_shape)) else: raise RuntimeError('Unknown image_dim_ordering.') model.add(Convolution2D(32, (8, 8), strides=(4, 4))) model.add(Activation('relu')) model.add(Convolution2D(64, (4, 4), strides=(2, 2))) model.add(Activation('relu')) model.add(Convolution2D(64, (3, 3), strides=(1, 1))) model.add(Activation('relu')) model.add(Flatten()) model.add(Dense(512)) model.add(Activation('relu')) model.add(Dense(nb_actions)) model.add(Activation('linear')) print(model.summary()) # Finally, we configure and compile our agent. You can use every built-in Keras optimizer and # even the metrics! memory = SequentialMemory(limit=1000000, window_length=WINDOW_LENGTH) processor = AtariProcessor() # Select a policy. We use eps-greedy action selection, which means that a random action is selected # with probability eps. We anneal eps from 1.0 to 0.1 over the course of 1M steps. This is done so that # the agent initially explores the environment (high eps) and then gradually sticks to what it knows # (low eps). We also set a dedicated eps value that is used during testing. Note that we set it to 0.05 # so that the agent still performs some random actions. This ensures that the agent cannot get stuck. policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.05, nb_steps=1000000) # The trade-off between exploration and exploitation is difficult and an on-going research topic. # If you want, you can experiment with the parameters or use a different policy. Another popular one # is Boltzmann-style exploration: # policy = BoltzmannQPolicy(tau=1.) # Feel free to give it a try! dqn = DQNAgent(model=model, nb_actions=nb_actions, policy=policy, memory=memory, processor=processor, nb_steps_warmup=50000, gamma=.99, target_model_update=10000, train_interval=4, delta_clip=1.) dqn.compile(Adam(lr=.00025), metrics=['mae']) if args.mode == 'train': # Okay, now it's time to learn something! We capture the interrupt exception so that training # can be prematurely aborted. Notice that you can the built-in Keras callbacks! weights_filename = 'dqn_{}_weights.h5f'.format(args.env_name) checkpoint_weights_filename = 'dqn_' + args.env_name + '_weights_{step}.h5f' log_filename = 'dqn_{}_log.json'.format(args.env_name) callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=250000)] callbacks += [FileLogger(log_filename, interval=100)] dqn.fit(env, callbacks=callbacks, nb_steps=1750000, log_interval=10000) # After training is done, we save the final weights one more time. dqn.save_weights(weights_filename, overwrite=True) # Finally, evaluate our algorithm for 10 episodes. dqn.test(env, nb_episodes=10, visualize=False) elif args.mode == 'test': weights_filename = 'dqn_{}_weights.h5f'.format(args.env_name) if args.weights: weights_filename = args.weights dqn.load_weights(weights_filename) dqn.test(env, nb_episodes=10, visualize=True) |
動画の作成
1 2 3 4 5 6 | # download videos from google.colab import files import glob for file in glob.glob("*.mp4"): files.download(file) |