Path: blob/master/examples/rl/actor_critic_cartpole.py
3507 views
"""1Title: Actor Critic Method2Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)3Date created: 2020/05/134Last modified: 2024/02/225Description: Implement Actor Critic Method in CartPole environment.6Accelerator: NONE7Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)8"""910"""11## Introduction1213This script shows an implementation of Actor Critic method on CartPole-V0 environment.1415### Actor Critic Method1617As an agent takes actions and moves through an environment, it learns to map18the observed state of the environment to two possible outputs:19201. Recommended action: A probability value for each action in the action space.21The part of the agent responsible for this output is called the **actor**.222. Estimated rewards in the future: Sum of all rewards it expects to receive in the23future. The part of the agent responsible for this output is the **critic**.2425Agent and Critic learn to perform their tasks, such that the recommended actions26from the actor maximize the rewards.2728### CartPole-V02930A pole is attached to a cart placed on a frictionless track. The agent has to apply31force to move the cart. It is rewarded for every time step the pole32remains upright. The agent, therefore, must learn to keep the pole from falling over.3334### References3536- [Environment documentation](https://gymnasium.farama.org/environments/classic_control/cart_pole/)37- [CartPole paper](http://www.derongliu.org/adp/adp-cdrom/Barto1983.pdf)38- [Actor Critic Method](https://hal.inria.fr/hal-00840470/document)39"""40"""41## Setup42"""4344import os4546os.environ["KERAS_BACKEND"] = "tensorflow"47import gym48import numpy as np49import keras50from keras import ops51from keras import layers52import tensorflow as tf5354# Configuration parameters for the whole setup55seed = 4256gamma = 0.99 # Discount factor for past rewards57max_steps_per_episode = 1000058# Adding `render_mode='human'` will show the attempts of the agent59env = gym.make("CartPole-v0") # Create the environment60env.reset(seed=seed)61eps = np.finfo(np.float32).eps.item() # Smallest number such that 1.0 + eps != 1.06263"""64## Implement Actor Critic network6566This network learns two functions:67681. Actor: This takes as input the state of our environment and returns a69probability value for each action in its action space.702. Critic: This takes as input the state of our environment and returns71an estimate of total rewards in the future.7273In our implementation, they share the initial layer.74"""7576num_inputs = 477num_actions = 278num_hidden = 1287980inputs = layers.Input(shape=(num_inputs,))81common = layers.Dense(num_hidden, activation="relu")(inputs)82action = layers.Dense(num_actions, activation="softmax")(common)83critic = layers.Dense(1)(common)8485model = keras.Model(inputs=inputs, outputs=[action, critic])8687"""88## Train89"""9091optimizer = keras.optimizers.Adam(learning_rate=0.01)92huber_loss = keras.losses.Huber()93action_probs_history = []94critic_value_history = []95rewards_history = []96running_reward = 097episode_count = 09899while True: # Run until solved100state = env.reset()[0]101episode_reward = 0102with tf.GradientTape() as tape:103for timestep in range(1, max_steps_per_episode):104105state = ops.convert_to_tensor(state)106state = ops.expand_dims(state, 0)107108# Predict action probabilities and estimated future rewards109# from environment state110action_probs, critic_value = model(state)111critic_value_history.append(critic_value[0, 0])112113# Sample action from action probability distribution114action = np.random.choice(num_actions, p=np.squeeze(action_probs))115action_probs_history.append(ops.log(action_probs[0, action]))116117# Apply the sampled action in our environment118state, reward, done, *_ = env.step(action)119rewards_history.append(reward)120episode_reward += reward121122if done:123break124125# Update running reward to check condition for solving126running_reward = 0.05 * episode_reward + (1 - 0.05) * running_reward127128# Calculate expected value from rewards129# - At each timestep what was the total reward received after that timestep130# - Rewards in the past are discounted by multiplying them with gamma131# - These are the labels for our critic132returns = []133discounted_sum = 0134for r in rewards_history[::-1]:135discounted_sum = r + gamma * discounted_sum136returns.insert(0, discounted_sum)137138# Normalize139returns = np.array(returns)140returns = (returns - np.mean(returns)) / (np.std(returns) + eps)141returns = returns.tolist()142143# Calculating loss values to update our network144history = zip(action_probs_history, critic_value_history, returns)145actor_losses = []146critic_losses = []147for log_prob, value, ret in history:148# At this point in history, the critic estimated that we would get a149# total reward = `value` in the future. We took an action with log probability150# of `log_prob` and ended up receiving a total reward = `ret`.151# The actor must be updated so that it predicts an action that leads to152# high rewards (compared to critic's estimate) with high probability.153diff = ret - value154actor_losses.append(-log_prob * diff) # actor loss155156# The critic must be updated so that it predicts a better estimate of157# the future rewards.158critic_losses.append(159huber_loss(ops.expand_dims(value, 0), ops.expand_dims(ret, 0))160)161162# Backpropagation163loss_value = sum(actor_losses) + sum(critic_losses)164grads = tape.gradient(loss_value, model.trainable_variables)165optimizer.apply_gradients(zip(grads, model.trainable_variables))166167# Clear the loss and reward history168action_probs_history.clear()169critic_value_history.clear()170rewards_history.clear()171172# Log details173episode_count += 1174if episode_count % 10 == 0:175template = "running reward: {:.2f} at episode {}"176print(template.format(running_reward, episode_count))177178if running_reward > 195: # Condition to consider the task solved179print("Solved at episode {}!".format(episode_count))180break181"""182## Visualizations183In early stages of training:184185186In later stages of training:187188"""189190191