PPO playing CarRacing-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
07ba300
| import logging | |
| from dataclasses import asdict, dataclass | |
| from time import perf_counter | |
| from typing import List, NamedTuple, Optional, TypeVar | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim import Adam | |
| from torch.utils.tensorboard.writer import SummaryWriter | |
| from rl_algo_impls.shared.algorithm import Algorithm | |
| from rl_algo_impls.shared.callbacks import Callback | |
| from rl_algo_impls.shared.gae import compute_advantages | |
| from rl_algo_impls.shared.policy.actor_critic import ActorCritic | |
| from rl_algo_impls.shared.schedule import ( | |
| constant_schedule, | |
| linear_schedule, | |
| schedule, | |
| update_learning_rate, | |
| ) | |
| from rl_algo_impls.shared.stats import log_scalars | |
| from rl_algo_impls.wrappers.vectorable_wrapper import ( | |
| VecEnv, | |
| single_action_space, | |
| single_observation_space, | |
| ) | |
| class TrainStepStats(NamedTuple): | |
| loss: float | |
| pi_loss: float | |
| v_loss: float | |
| entropy_loss: float | |
| approx_kl: float | |
| clipped_frac: float | |
| val_clipped_frac: float | |
| class TrainStats: | |
| loss: float | |
| pi_loss: float | |
| v_loss: float | |
| entropy_loss: float | |
| approx_kl: float | |
| clipped_frac: float | |
| val_clipped_frac: float | |
| explained_var: float | |
| def __init__(self, step_stats: List[TrainStepStats], explained_var: float) -> None: | |
| self.loss = np.mean([s.loss for s in step_stats]).item() | |
| self.pi_loss = np.mean([s.pi_loss for s in step_stats]).item() | |
| self.v_loss = np.mean([s.v_loss for s in step_stats]).item() | |
| self.entropy_loss = np.mean([s.entropy_loss for s in step_stats]).item() | |
| self.approx_kl = np.mean([s.approx_kl for s in step_stats]).item() | |
| self.clipped_frac = np.mean([s.clipped_frac for s in step_stats]).item() | |
| self.val_clipped_frac = np.mean([s.val_clipped_frac for s in step_stats]).item() | |
| self.explained_var = explained_var | |
| def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None: | |
| for name, value in asdict(self).items(): | |
| tb_writer.add_scalar(f"losses/{name}", value, global_step=global_step) | |
| def __repr__(self) -> str: | |
| return " | ".join( | |
| [ | |
| f"Loss: {round(self.loss, 2)}", | |
| f"Pi L: {round(self.pi_loss, 2)}", | |
| f"V L: {round(self.v_loss, 2)}", | |
| f"E L: {round(self.entropy_loss, 2)}", | |
| f"Apx KL Div: {round(self.approx_kl, 2)}", | |
| f"Clip Frac: {round(self.clipped_frac, 2)}", | |
| f"Val Clip Frac: {round(self.val_clipped_frac, 2)}", | |
| ] | |
| ) | |
| PPOSelf = TypeVar("PPOSelf", bound="PPO") | |
| class PPO(Algorithm): | |
| def __init__( | |
| self, | |
| policy: ActorCritic, | |
| env: VecEnv, | |
| device: torch.device, | |
| tb_writer: SummaryWriter, | |
| learning_rate: float = 3e-4, | |
| learning_rate_decay: str = "none", | |
| n_steps: int = 2048, | |
| batch_size: int = 64, | |
| n_epochs: int = 10, | |
| gamma: float = 0.99, | |
| gae_lambda: float = 0.95, | |
| clip_range: float = 0.2, | |
| clip_range_decay: str = "none", | |
| clip_range_vf: Optional[float] = None, | |
| clip_range_vf_decay: str = "none", | |
| normalize_advantage: bool = True, | |
| ent_coef: float = 0.0, | |
| ent_coef_decay: str = "none", | |
| vf_coef: float = 0.5, | |
| ppo2_vf_coef_halving: bool = False, | |
| max_grad_norm: float = 0.5, | |
| sde_sample_freq: int = -1, | |
| update_advantage_between_epochs: bool = True, | |
| update_returns_between_epochs: bool = False, | |
| gamma_end: Optional[float] = None, | |
| ) -> None: | |
| super().__init__(policy, env, device, tb_writer) | |
| self.policy = policy | |
| self.get_action_mask = getattr(env, "get_action_mask", None) | |
| self.gamma_schedule = ( | |
| linear_schedule(gamma, gamma_end) | |
| if gamma_end is not None | |
| else constant_schedule(gamma) | |
| ) | |
| self.gae_lambda = gae_lambda | |
| self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7) | |
| self.lr_schedule = schedule(learning_rate_decay, learning_rate) | |
| self.max_grad_norm = max_grad_norm | |
| self.clip_range_schedule = schedule(clip_range_decay, clip_range) | |
| self.clip_range_vf_schedule = None | |
| if clip_range_vf: | |
| self.clip_range_vf_schedule = schedule(clip_range_vf_decay, clip_range_vf) | |
| if normalize_advantage: | |
| assert ( | |
| env.num_envs * n_steps > 1 and batch_size > 1 | |
| ), f"Each minibatch must be larger than 1 to support normalization" | |
| self.normalize_advantage = normalize_advantage | |
| self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef) | |
| self.vf_coef = vf_coef | |
| self.ppo2_vf_coef_halving = ppo2_vf_coef_halving | |
| self.n_steps = n_steps | |
| self.batch_size = batch_size | |
| self.n_epochs = n_epochs | |
| self.sde_sample_freq = sde_sample_freq | |
| self.update_advantage_between_epochs = update_advantage_between_epochs | |
| self.update_returns_between_epochs = update_returns_between_epochs | |
| def learn( | |
| self: PPOSelf, | |
| train_timesteps: int, | |
| callbacks: Optional[List[Callback]] = None, | |
| total_timesteps: Optional[int] = None, | |
| start_timesteps: int = 0, | |
| ) -> PPOSelf: | |
| if total_timesteps is None: | |
| total_timesteps = train_timesteps | |
| assert start_timesteps + train_timesteps <= total_timesteps | |
| epoch_dim = (self.n_steps, self.env.num_envs) | |
| step_dim = (self.env.num_envs,) | |
| obs_space = single_observation_space(self.env) | |
| act_space = single_action_space(self.env) | |
| act_shape = self.policy.action_shape | |
| next_obs = self.env.reset() | |
| next_action_masks = self.get_action_mask() if self.get_action_mask else None | |
| next_episode_starts = np.full(step_dim, True, dtype=np.bool_) | |
| obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore | |
| actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype) # type: ignore | |
| rewards = np.zeros(epoch_dim, dtype=np.float32) | |
| episode_starts = np.zeros(epoch_dim, dtype=np.bool_) | |
| values = np.zeros(epoch_dim, dtype=np.float32) | |
| logprobs = np.zeros(epoch_dim, dtype=np.float32) | |
| action_masks = ( | |
| np.zeros( | |
| (self.n_steps,) + next_action_masks.shape, dtype=next_action_masks.dtype | |
| ) | |
| if next_action_masks is not None | |
| else None | |
| ) | |
| timesteps_elapsed = start_timesteps | |
| while timesteps_elapsed < start_timesteps + train_timesteps: | |
| start_time = perf_counter() | |
| progress = timesteps_elapsed / total_timesteps | |
| ent_coef = self.ent_coef_schedule(progress) | |
| learning_rate = self.lr_schedule(progress) | |
| update_learning_rate(self.optimizer, learning_rate) | |
| pi_clip = self.clip_range_schedule(progress) | |
| gamma = self.gamma_schedule(progress) | |
| chart_scalars = { | |
| "learning_rate": self.optimizer.param_groups[0]["lr"], | |
| "ent_coef": ent_coef, | |
| "pi_clip": pi_clip, | |
| "gamma": gamma, | |
| } | |
| if self.clip_range_vf_schedule: | |
| v_clip = self.clip_range_vf_schedule(progress) | |
| chart_scalars["v_clip"] = v_clip | |
| else: | |
| v_clip = None | |
| log_scalars(self.tb_writer, "charts", chart_scalars, timesteps_elapsed) | |
| self.policy.eval() | |
| self.policy.reset_noise() | |
| for s in range(self.n_steps): | |
| timesteps_elapsed += self.env.num_envs | |
| if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0: | |
| self.policy.reset_noise() | |
| obs[s] = next_obs | |
| episode_starts[s] = next_episode_starts | |
| if action_masks is not None: | |
| action_masks[s] = next_action_masks | |
| ( | |
| actions[s], | |
| values[s], | |
| logprobs[s], | |
| clamped_action, | |
| ) = self.policy.step(next_obs, action_masks=next_action_masks) | |
| next_obs, rewards[s], next_episode_starts, _ = self.env.step( | |
| clamped_action | |
| ) | |
| next_action_masks = ( | |
| self.get_action_mask() if self.get_action_mask else None | |
| ) | |
| self.policy.train() | |
| b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device) # type: ignore | |
| b_actions = torch.tensor(actions.reshape((-1,) + act_shape)).to( # type: ignore | |
| self.device | |
| ) | |
| b_logprobs = torch.tensor(logprobs.reshape(-1)).to(self.device) | |
| b_action_masks = ( | |
| torch.tensor(action_masks.reshape((-1,) + next_action_masks.shape[1:])).to( # type: ignore | |
| self.device | |
| ) | |
| if action_masks is not None | |
| else None | |
| ) | |
| y_pred = values.reshape(-1) | |
| b_values = torch.tensor(y_pred).to(self.device) | |
| step_stats = [] | |
| # Define variables that will definitely be set through the first epoch | |
| advantages: np.ndarray = None # type: ignore | |
| b_advantages: torch.Tensor = None # type: ignore | |
| y_true: np.ndarray = None # type: ignore | |
| b_returns: torch.Tensor = None # type: ignore | |
| for e in range(self.n_epochs): | |
| if e == 0 or self.update_advantage_between_epochs: | |
| advantages = compute_advantages( | |
| rewards, | |
| values, | |
| episode_starts, | |
| next_episode_starts, | |
| next_obs, | |
| self.policy, | |
| gamma, | |
| self.gae_lambda, | |
| ) | |
| b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device) | |
| if e == 0 or self.update_returns_between_epochs: | |
| returns = advantages + values | |
| y_true = returns.reshape(-1) | |
| b_returns = torch.tensor(y_true).to(self.device) | |
| b_idxs = torch.randperm(len(b_obs)) | |
| # Only record last epoch's stats | |
| step_stats.clear() | |
| for i in range(0, len(b_obs), self.batch_size): | |
| self.policy.reset_noise(self.batch_size) | |
| mb_idxs = b_idxs[i : i + self.batch_size] | |
| mb_obs = b_obs[mb_idxs] | |
| mb_actions = b_actions[mb_idxs] | |
| mb_values = b_values[mb_idxs] | |
| mb_logprobs = b_logprobs[mb_idxs] | |
| mb_action_masks = ( | |
| b_action_masks[mb_idxs] if b_action_masks is not None else None | |
| ) | |
| mb_adv = b_advantages[mb_idxs] | |
| if self.normalize_advantage: | |
| mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8) | |
| mb_returns = b_returns[mb_idxs] | |
| new_logprobs, entropy, new_values = self.policy( | |
| mb_obs, mb_actions, action_masks=mb_action_masks | |
| ) | |
| logratio = new_logprobs - mb_logprobs | |
| ratio = torch.exp(logratio) | |
| clipped_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip) | |
| pi_loss = torch.max(-ratio * mb_adv, -clipped_ratio * mb_adv).mean() | |
| v_loss_unclipped = (new_values - mb_returns) ** 2 | |
| if v_clip: | |
| v_loss_clipped = ( | |
| mb_values | |
| + torch.clamp(new_values - mb_values, -v_clip, v_clip) | |
| - mb_returns | |
| ) ** 2 | |
| v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean() | |
| else: | |
| v_loss = v_loss_unclipped.mean() | |
| if self.ppo2_vf_coef_halving: | |
| v_loss *= 0.5 | |
| entropy_loss = -entropy.mean() | |
| loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| nn.utils.clip_grad_norm_( | |
| self.policy.parameters(), self.max_grad_norm | |
| ) | |
| self.optimizer.step() | |
| with torch.no_grad(): | |
| approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item() | |
| clipped_frac = ( | |
| ((ratio - 1).abs() > pi_clip) | |
| .float() | |
| .mean() | |
| .cpu() | |
| .numpy() | |
| .item() | |
| ) | |
| val_clipped_frac = ( | |
| ((new_values - mb_values).abs() > v_clip) | |
| .float() | |
| .mean() | |
| .cpu() | |
| .numpy() | |
| .item() | |
| if v_clip | |
| else 0 | |
| ) | |
| step_stats.append( | |
| TrainStepStats( | |
| loss.item(), | |
| pi_loss.item(), | |
| v_loss.item(), | |
| entropy_loss.item(), | |
| approx_kl, | |
| clipped_frac, | |
| val_clipped_frac, | |
| ) | |
| ) | |
| var_y = np.var(y_true).item() | |
| explained_var = ( | |
| np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y | |
| ) | |
| TrainStats(step_stats, explained_var).write_to_tensorboard( | |
| self.tb_writer, timesteps_elapsed | |
| ) | |
| end_time = perf_counter() | |
| rollout_steps = self.n_steps * self.env.num_envs | |
| self.tb_writer.add_scalar( | |
| "train/steps_per_second", | |
| rollout_steps / (end_time - start_time), | |
| timesteps_elapsed, | |
| ) | |
| if callbacks: | |
| if not all( | |
| c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks | |
| ): | |
| logging.info( | |
| f"Callback terminated training at {timesteps_elapsed} timesteps" | |
| ) | |
| break | |
| return self | |