Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/mario_rl_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2Train a Mario-playing RL Agent3===============================45**Authors:** `Yuansong Feng <https://github.com/YuansongFeng>`__, `Suraj Subramanian <https://github.com/suraj813>`__, `Howard Wang <https://github.com/hw26>`__, `Steven Guo <https://github.com/GuoYuzhang>`__.678This tutorial walks you through the fundamentals of Deep Reinforcement9Learning. At the end, you will implement an AI-powered Mario (using10`Double Deep Q-Networks <https://arxiv.org/pdf/1509.06461.pdf>`__) that11can play the game by itself.1213Although no prior knowledge of RL is necessary for this tutorial, you14can familiarize yourself with these RL15`concepts <https://spinningup.openai.com/en/latest/spinningup/rl_intro.html>`__,16and have this handy17`cheatsheet <https://colab.research.google.com/drive/1eN33dPVtdPViiS1njTW_-r-IYCDTFU7N>`__18as your companion. The full code is available19`here <https://github.com/yuansongFeng/MadMario/>`__.2021.. figure:: /_static/img/mario.gif22:alt: mario2324"""252627######################################################################28#29#30# .. code-block:: bash31#32# %%bash33# pip install gym-super-mario-bros==7.4.034# pip install tensordict==0.3.035# pip install torchrl==0.3.036#3738import torch39from torch import nn40from torchvision import transforms as T41from PIL import Image42import numpy as np43from pathlib import Path44from collections import deque45import random, datetime, os4647# Gym is an OpenAI toolkit for RL48import gym49from gym.spaces import Box50from gym.wrappers import FrameStack5152# NES Emulator for OpenAI Gym53from nes_py.wrappers import JoypadSpace5455# Super Mario environment for OpenAI Gym56import gym_super_mario_bros5758from tensordict import TensorDict59from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage6061######################################################################62# RL Definitions63# """"""""""""""""""64#65# **Environment** The world that an agent interacts with and learns from.66#67# **Action** :math:`a` : How the Agent responds to the Environment. The68# set of all possible Actions is called *action-space*.69#70# **State** :math:`s` : The current characteristic of the Environment. The71# set of all possible States the Environment can be in is called72# *state-space*.73#74# **Reward** :math:`r` : Reward is the key feedback from Environment to75# Agent. It is what drives the Agent to learn and to change its future76# action. An aggregation of rewards over multiple time steps is called77# **Return**.78#79# **Optimal Action-Value function** :math:`Q^*(s,a)` : Gives the expected80# return if you start in state :math:`s`, take an arbitrary action81# :math:`a`, and then for each future time step take the action that82# maximizes returns. :math:`Q` can be said to stand for the “quality” of83# the action in a state. We try to approximate this function.84#858687######################################################################88# Environment89# """"""""""""""""90#91# Initialize Environment92# ------------------------93#94# In Mario, the environment consists of tubes, mushrooms and other95# components.96#97# When Mario makes an action, the environment responds with the changed98# (next) state, reward and other info.99#100101# Initialize Super Mario environment (in v0.26 change render mode to 'human' to see results on the screen)102if gym.__version__ < '0.26':103env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", new_step_api=True)104else:105env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode='rgb', apply_api_compatibility=True)106107# Limit the action-space to108# 0. walk right109# 1. jump right110env = JoypadSpace(env, [["right"], ["right", "A"]])111112env.reset()113next_state, reward, done, trunc, info = env.step(action=0)114print(f"{next_state.shape},\n {reward},\n {done},\n {info}")115116117######################################################################118# Preprocess Environment119# ------------------------120#121# Environment data is returned to the agent in ``next_state``. As you saw122# above, each state is represented by a ``[3, 240, 256]`` size array.123# Often that is more information than our agent needs; for instance,124# Mario’s actions do not depend on the color of the pipes or the sky!125#126# We use **Wrappers** to preprocess environment data before sending it to127# the agent.128#129# ``GrayScaleObservation`` is a common wrapper to transform an RGB image130# to grayscale; doing so reduces the size of the state representation131# without losing useful information. Now the size of each state:132# ``[1, 240, 256]``133#134# ``ResizeObservation`` downsamples each observation into a square image.135# New size: ``[1, 84, 84]``136#137# ``SkipFrame`` is a custom wrapper that inherits from ``gym.Wrapper`` and138# implements the ``step()`` function. Because consecutive frames don’t139# vary much, we can skip n-intermediate frames without losing much140# information. The n-th frame aggregates rewards accumulated over each141# skipped frame.142#143# ``FrameStack`` is a wrapper that allows us to squash consecutive frames144# of the environment into a single observation point to feed to our145# learning model. This way, we can identify if Mario was landing or146# jumping based on the direction of his movement in the previous several147# frames.148#149150151class SkipFrame(gym.Wrapper):152def __init__(self, env, skip):153"""Return only every `skip`-th frame"""154super().__init__(env)155self._skip = skip156157def step(self, action):158"""Repeat action, and sum reward"""159total_reward = 0.0160for i in range(self._skip):161# Accumulate reward and repeat the same action162obs, reward, done, trunk, info = self.env.step(action)163total_reward += reward164if done:165break166return obs, total_reward, done, trunk, info167168169class GrayScaleObservation(gym.ObservationWrapper):170def __init__(self, env):171super().__init__(env)172obs_shape = self.observation_space.shape[:2]173self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)174175def permute_orientation(self, observation):176# permute [H, W, C] array to [C, H, W] tensor177observation = np.transpose(observation, (2, 0, 1))178observation = torch.tensor(observation.copy(), dtype=torch.float)179return observation180181def observation(self, observation):182observation = self.permute_orientation(observation)183transform = T.Grayscale()184observation = transform(observation)185return observation186187188class ResizeObservation(gym.ObservationWrapper):189def __init__(self, env, shape):190super().__init__(env)191if isinstance(shape, int):192self.shape = (shape, shape)193else:194self.shape = tuple(shape)195196obs_shape = self.shape + self.observation_space.shape[2:]197self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)198199def observation(self, observation):200transforms = T.Compose(201[T.Resize(self.shape, antialias=True), T.Normalize(0, 255)]202)203observation = transforms(observation).squeeze(0)204return observation205206207# Apply Wrappers to environment208env = SkipFrame(env, skip=4)209env = GrayScaleObservation(env)210env = ResizeObservation(env, shape=84)211if gym.__version__ < '0.26':212env = FrameStack(env, num_stack=4, new_step_api=True)213else:214env = FrameStack(env, num_stack=4)215216217######################################################################218# After applying the above wrappers to the environment, the final wrapped219# state consists of 4 gray-scaled consecutive frames stacked together, as220# shown above in the image on the left. Each time Mario makes an action,221# the environment responds with a state of this structure. The structure222# is represented by a 3-D array of size ``[4, 84, 84]``.223#224# .. figure:: /_static/img/mario_env.png225# :alt: picture226#227#228229230######################################################################231# Agent232# """""""""233#234# We create a class ``Mario`` to represent our agent in the game. Mario235# should be able to:236#237# - **Act** according to the optimal action policy based on the current238# state (of the environment).239#240# - **Remember** experiences. Experience = (current state, current241# action, reward, next state). Mario *caches* and later *recalls* his242# experiences to update his action policy.243#244# - **Learn** a better action policy over time245#246247248class Mario:249def __init__():250pass251252def act(self, state):253"""Given a state, choose an epsilon-greedy action"""254pass255256def cache(self, experience):257"""Add the experience to memory"""258pass259260def recall(self):261"""Sample experiences from memory"""262pass263264def learn(self):265"""Update online action value (Q) function with a batch of experiences"""266pass267268269######################################################################270# In the following sections, we will populate Mario’s parameters and271# define his functions.272#273274275######################################################################276# Act277# --------------278#279# For any given state, an agent can choose to do the most optimal action280# (**exploit**) or a random action (**explore**).281#282# Mario randomly explores with a chance of ``self.exploration_rate``; when283# he chooses to exploit, he relies on ``MarioNet`` (implemented in284# ``Learn`` section) to provide the most optimal action.285#286287288class Mario:289def __init__(self, state_dim, action_dim, save_dir):290self.state_dim = state_dim291self.action_dim = action_dim292self.save_dir = save_dir293294self.device = "cuda" if torch.cuda.is_available() else "cpu"295296# Mario's DNN to predict the most optimal action - we implement this in the Learn section297self.net = MarioNet(self.state_dim, self.action_dim).float()298self.net = self.net.to(device=self.device)299300self.exploration_rate = 1301self.exploration_rate_decay = 0.99999975302self.exploration_rate_min = 0.1303self.curr_step = 0304305self.save_every = 5e5 # no. of experiences between saving Mario Net306307def act(self, state):308"""309Given a state, choose an epsilon-greedy action and update value of step.310311Inputs:312state(``LazyFrame``): A single observation of the current state, dimension is (state_dim)313Outputs:314``action_idx`` (``int``): An integer representing which action Mario will perform315"""316# EXPLORE317if np.random.rand() < self.exploration_rate:318action_idx = np.random.randint(self.action_dim)319320# EXPLOIT321else:322state = state[0].__array__() if isinstance(state, tuple) else state.__array__()323state = torch.tensor(state, device=self.device).unsqueeze(0)324action_values = self.net(state, model="online")325action_idx = torch.argmax(action_values, axis=1).item()326327# decrease exploration_rate328self.exploration_rate *= self.exploration_rate_decay329self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)330331# increment step332self.curr_step += 1333return action_idx334335336######################################################################337# Cache and Recall338# ----------------------339#340# These two functions serve as Mario’s “memory” process.341#342# ``cache()``: Each time Mario performs an action, he stores the343# ``experience`` to his memory. His experience includes the current344# *state*, *action* performed, *reward* from the action, the *next state*,345# and whether the game is *done*.346#347# ``recall()``: Mario randomly samples a batch of experiences from his348# memory, and uses that to learn the game.349#350351352class Mario(Mario): # subclassing for continuity353def __init__(self, state_dim, action_dim, save_dir):354super().__init__(state_dim, action_dim, save_dir)355self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000, device=torch.device("cpu")))356self.batch_size = 32357358def cache(self, state, next_state, action, reward, done):359"""360Store the experience to self.memory (replay buffer)361362Inputs:363state (``LazyFrame``),364next_state (``LazyFrame``),365action (``int``),366reward (``float``),367done(``bool``))368"""369def first_if_tuple(x):370return x[0] if isinstance(x, tuple) else x371state = first_if_tuple(state).__array__()372next_state = first_if_tuple(next_state).__array__()373374state = torch.tensor(state)375next_state = torch.tensor(next_state)376action = torch.tensor([action])377reward = torch.tensor([reward])378done = torch.tensor([done])379380# self.memory.append((state, next_state, action, reward, done,))381self.memory.add(TensorDict({"state": state, "next_state": next_state, "action": action, "reward": reward, "done": done}, batch_size=[]))382383def recall(self):384"""385Retrieve a batch of experiences from memory386"""387batch = self.memory.sample(self.batch_size).to(self.device)388state, next_state, action, reward, done = (batch.get(key) for key in ("state", "next_state", "action", "reward", "done"))389return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()390391392######################################################################393# Learn394# --------------395#396# Mario uses the `DDQN algorithm <https://arxiv.org/pdf/1509.06461>`__397# under the hood. DDQN uses two ConvNets - :math:`Q_{online}` and398# :math:`Q_{target}` - that independently approximate the optimal399# action-value function.400#401# In our implementation, we share feature generator ``features`` across402# :math:`Q_{online}` and :math:`Q_{target}`, but maintain separate FC403# classifiers for each. :math:`\theta_{target}` (the parameters of404# :math:`Q_{target}`) is frozen to prevent updating by backprop. Instead,405# it is periodically synced with :math:`\theta_{online}` (more on this406# later).407#408# Neural Network409# ~~~~~~~~~~~~~~~~~~410411412class MarioNet(nn.Module):413"""mini CNN structure414input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output415"""416417def __init__(self, input_dim, output_dim):418super().__init__()419c, h, w = input_dim420421if h != 84:422raise ValueError(f"Expecting input height: 84, got: {h}")423if w != 84:424raise ValueError(f"Expecting input width: 84, got: {w}")425426self.online = self.__build_cnn(c, output_dim)427428self.target = self.__build_cnn(c, output_dim)429self.target.load_state_dict(self.online.state_dict())430431# Q_target parameters are frozen.432for p in self.target.parameters():433p.requires_grad = False434435def forward(self, input, model):436if model == "online":437return self.online(input)438elif model == "target":439return self.target(input)440441def __build_cnn(self, c, output_dim):442return nn.Sequential(443nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),444nn.ReLU(),445nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),446nn.ReLU(),447nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),448nn.ReLU(),449nn.Flatten(),450nn.Linear(3136, 512),451nn.ReLU(),452nn.Linear(512, output_dim),453)454455456######################################################################457# TD Estimate & TD Target458# ~~~~~~~~~~~~~~~~~~~~~~~~~~459#460# Two values are involved in learning:461#462# **TD Estimate** - the predicted optimal :math:`Q^*` for a given state463# :math:`s`464#465# .. math::466#467#468# {TD}_e = Q_{online}^*(s,a)469#470# **TD Target** - aggregation of current reward and the estimated471# :math:`Q^*` in the next state :math:`s'`472#473# .. math::474#475#476# a' = argmax_{a} Q_{online}(s', a)477#478# .. math::479#480#481# {TD}_t = r + \gamma Q_{target}^*(s',a')482#483# Because we don’t know what next action :math:`a'` will be, we use the484# action :math:`a'` maximizes :math:`Q_{online}` in the next state485# :math:`s'`.486#487# Notice we use the488# `@torch.no_grad() <https://pytorch.org/docs/stable/generated/torch.no_grad.html#no-grad>`__489# decorator on ``td_target()`` to disable gradient calculations here490# (because we don’t need to backpropagate on :math:`\theta_{target}`).491#492493494class Mario(Mario):495def __init__(self, state_dim, action_dim, save_dir):496super().__init__(state_dim, action_dim, save_dir)497self.gamma = 0.9498499def td_estimate(self, state, action):500current_Q = self.net(state, model="online")[501np.arange(0, self.batch_size), action502] # Q_online(s,a)503return current_Q504505@torch.no_grad()506def td_target(self, reward, next_state, done):507next_state_Q = self.net(next_state, model="online")508best_action = torch.argmax(next_state_Q, axis=1)509next_Q = self.net(next_state, model="target")[510np.arange(0, self.batch_size), best_action511]512return (reward + (1 - done.float()) * self.gamma * next_Q).float()513514515######################################################################516# Updating the model517# ~~~~~~~~~~~~~~~~~~~~~~518#519# As Mario samples inputs from his replay buffer, we compute :math:`TD_t`520# and :math:`TD_e` and backpropagate this loss down :math:`Q_{online}` to521# update its parameters :math:`\theta_{online}` (:math:`\alpha` is the522# learning rate ``lr`` passed to the ``optimizer``)523#524# .. math::525#526#527# \theta_{online} \leftarrow \theta_{online} + \alpha \nabla(TD_e - TD_t)528#529# :math:`\theta_{target}` does not update through backpropagation.530# Instead, we periodically copy :math:`\theta_{online}` to531# :math:`\theta_{target}`532#533# .. math::534#535#536# \theta_{target} \leftarrow \theta_{online}537#538#539540541class Mario(Mario):542def __init__(self, state_dim, action_dim, save_dir):543super().__init__(state_dim, action_dim, save_dir)544self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)545self.loss_fn = torch.nn.SmoothL1Loss()546547def update_Q_online(self, td_estimate, td_target):548loss = self.loss_fn(td_estimate, td_target)549self.optimizer.zero_grad()550loss.backward()551self.optimizer.step()552return loss.item()553554def sync_Q_target(self):555self.net.target.load_state_dict(self.net.online.state_dict())556557558######################################################################559# Save checkpoint560# ~~~~~~~~~~~~~~~~~~561#562563564class Mario(Mario):565def save(self):566save_path = (567self.save_dir / f"mario_net_{int(self.curr_step // self.save_every)}.chkpt"568)569torch.save(570dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),571save_path,572)573print(f"MarioNet saved to {save_path} at step {self.curr_step}")574575576######################################################################577# Putting it all together578# ~~~~~~~~~~~~~~~~~~~~~~~~~~579#580581582class Mario(Mario):583def __init__(self, state_dim, action_dim, save_dir):584super().__init__(state_dim, action_dim, save_dir)585self.burnin = 1e4 # min. experiences before training586self.learn_every = 3 # no. of experiences between updates to Q_online587self.sync_every = 1e4 # no. of experiences between Q_target & Q_online sync588589def learn(self):590if self.curr_step % self.sync_every == 0:591self.sync_Q_target()592593if self.curr_step % self.save_every == 0:594self.save()595596if self.curr_step < self.burnin:597return None, None598599if self.curr_step % self.learn_every != 0:600return None, None601602# Sample from memory603state, next_state, action, reward, done = self.recall()604605# Get TD Estimate606td_est = self.td_estimate(state, action)607608# Get TD Target609td_tgt = self.td_target(reward, next_state, done)610611# Backpropagate loss through Q_online612loss = self.update_Q_online(td_est, td_tgt)613614return (td_est.mean().item(), loss)615616617######################################################################618# Logging619# --------------620#621622import numpy as np623import time, datetime624import matplotlib.pyplot as plt625626627class MetricLogger:628def __init__(self, save_dir):629self.save_log = save_dir / "log"630with open(self.save_log, "w") as f:631f.write(632f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"633f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"634f"{'TimeDelta':>15}{'Time':>20}\n"635)636self.ep_rewards_plot = save_dir / "reward_plot.jpg"637self.ep_lengths_plot = save_dir / "length_plot.jpg"638self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"639self.ep_avg_qs_plot = save_dir / "q_plot.jpg"640641# History metrics642self.ep_rewards = []643self.ep_lengths = []644self.ep_avg_losses = []645self.ep_avg_qs = []646647# Moving averages, added for every call to record()648self.moving_avg_ep_rewards = []649self.moving_avg_ep_lengths = []650self.moving_avg_ep_avg_losses = []651self.moving_avg_ep_avg_qs = []652653# Current episode metric654self.init_episode()655656# Timing657self.record_time = time.time()658659def log_step(self, reward, loss, q):660self.curr_ep_reward += reward661self.curr_ep_length += 1662if loss:663self.curr_ep_loss += loss664self.curr_ep_q += q665self.curr_ep_loss_length += 1666667def log_episode(self):668"Mark end of episode"669self.ep_rewards.append(self.curr_ep_reward)670self.ep_lengths.append(self.curr_ep_length)671if self.curr_ep_loss_length == 0:672ep_avg_loss = 0673ep_avg_q = 0674else:675ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)676ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)677self.ep_avg_losses.append(ep_avg_loss)678self.ep_avg_qs.append(ep_avg_q)679680self.init_episode()681682def init_episode(self):683self.curr_ep_reward = 0.0684self.curr_ep_length = 0685self.curr_ep_loss = 0.0686self.curr_ep_q = 0.0687self.curr_ep_loss_length = 0688689def record(self, episode, epsilon, step):690mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)691mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)692mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)693mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)694self.moving_avg_ep_rewards.append(mean_ep_reward)695self.moving_avg_ep_lengths.append(mean_ep_length)696self.moving_avg_ep_avg_losses.append(mean_ep_loss)697self.moving_avg_ep_avg_qs.append(mean_ep_q)698699last_record_time = self.record_time700self.record_time = time.time()701time_since_last_record = np.round(self.record_time - last_record_time, 3)702703print(704f"Episode {episode} - "705f"Step {step} - "706f"Epsilon {epsilon} - "707f"Mean Reward {mean_ep_reward} - "708f"Mean Length {mean_ep_length} - "709f"Mean Loss {mean_ep_loss} - "710f"Mean Q Value {mean_ep_q} - "711f"Time Delta {time_since_last_record} - "712f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"713)714715with open(self.save_log, "a") as f:716f.write(717f"{episode:8d}{step:8d}{epsilon:10.3f}"718f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"719f"{time_since_last_record:15.3f}"720f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"721)722723for metric in ["ep_lengths", "ep_avg_losses", "ep_avg_qs", "ep_rewards"]:724plt.clf()725plt.plot(getattr(self, f"moving_avg_{metric}"), label=f"moving_avg_{metric}")726plt.legend()727plt.savefig(getattr(self, f"{metric}_plot"))728729730######################################################################731# Let’s play!732# """""""""""""""733#734# In this example we run the training loop for 40 episodes, but for Mario to truly learn the ways of735# his world, we suggest running the loop for at least 40,000 episodes!736#737use_cuda = torch.cuda.is_available()738print(f"Using CUDA: {use_cuda}")739print()740741save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")742save_dir.mkdir(parents=True)743744mario = Mario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir)745746logger = MetricLogger(save_dir)747748episodes = 40749for e in range(episodes):750751state = env.reset()752753# Play the game!754while True:755756# Run agent on the state757action = mario.act(state)758759# Agent performs action760next_state, reward, done, trunc, info = env.step(action)761762# Remember763mario.cache(state, next_state, action, reward, done)764765# Learn766q, loss = mario.learn()767768# Logging769logger.log_step(reward, loss, q)770771# Update state772state = next_state773774# Check if end of game775if done or info["flag_get"]:776break777778logger.log_episode()779780if (e % 20 == 0) or (e == episodes - 1):781logger.record(episode=e, epsilon=mario.exploration_rate, step=mario.curr_step)782783784######################################################################785# Conclusion786# """""""""""""""787#788# In this tutorial, we saw how we can use PyTorch to train a game-playing AI. You can use the same methods789# to train an AI to play any of the games at the `OpenAI gym <https://gym.openai.com/>`__. Hope you enjoyed this tutorial, feel free to reach us at790# `our github <https://github.com/yuansongFeng/MadMario/>`__!791792793