A2C playing AntBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
2e370ab
| import copy | |
| import random | |
| from collections import deque | |
| from typing import Any, Deque, Dict, List, Optional | |
| import numpy as np | |
| from rl_algo_impls.runner.config import Config | |
| from rl_algo_impls.shared.policy.policy import Policy | |
| from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker | |
| from rl_algo_impls.wrappers.vectorable_wrapper import ( | |
| VecEnvObs, | |
| VecEnvStepReturn, | |
| VecotarableWrapper, | |
| ) | |
| class SelfPlayWrapper(VecotarableWrapper): | |
| next_obs: VecEnvObs | |
| next_action_masks: Optional[np.ndarray] | |
| def __init__( | |
| self, | |
| env, | |
| config: Config, | |
| num_old_policies: int = 0, | |
| save_steps: int = 20_000, | |
| swap_steps: int = 10_000, | |
| window: int = 10, | |
| swap_window_size: int = 2, | |
| selfplay_bots: Optional[Dict[str, Any]] = None, | |
| bot_always_player_2: bool = False, | |
| ) -> None: | |
| super().__init__(env) | |
| assert num_old_policies % 2 == 0, f"num_old_policies must be even" | |
| assert ( | |
| num_old_policies % swap_window_size == 0 | |
| ), f"num_old_policies must be a multiple of swap_window_size" | |
| self.config = config | |
| self.num_old_policies = num_old_policies | |
| self.save_steps = save_steps | |
| self.swap_steps = swap_steps | |
| self.swap_window_size = swap_window_size | |
| self.selfplay_bots = selfplay_bots | |
| self.bot_always_player_2 = bot_always_player_2 | |
| self.policies: Deque[Policy] = deque(maxlen=window) | |
| self.policy_assignments: List[Optional[Policy]] = [None] * env.num_envs | |
| self.steps_since_swap = np.zeros(env.num_envs) | |
| self.selfplay_policies: Dict[str, Policy] = {} | |
| self.num_envs = env.num_envs - num_old_policies | |
| if self.selfplay_bots: | |
| self.num_envs -= sum(self.selfplay_bots.values()) | |
| self.initialize_selfplay_bots() | |
| def get_action_mask(self) -> Optional[np.ndarray]: | |
| return self.env.get_action_mask()[self.learner_indexes()] | |
| def learner_indexes(self) -> List[int]: | |
| return [p is None for p in self.policy_assignments] | |
| def checkpoint_policy(self, copied_policy: Policy) -> None: | |
| copied_policy.train(False) | |
| self.policies.append(copied_policy) | |
| if all(p is None for p in self.policy_assignments[: 2 * self.num_old_policies]): | |
| for i in range(self.num_old_policies): | |
| # Switch between player 1 and 2 | |
| self.policy_assignments[ | |
| 2 * i + (i % 2 if not self.bot_always_player_2 else 1) | |
| ] = copied_policy | |
| def swap_policy(self, idx: int, swap_window_size: int = 1) -> None: | |
| policy = random.choice(self.policies) | |
| idx = idx // 2 * 2 | |
| for j in range(swap_window_size * 2): | |
| if self.policy_assignments[idx + j]: | |
| self.policy_assignments[idx + j] = policy | |
| self.steps_since_swap[idx : idx + swap_window_size * 2] = np.zeros( | |
| swap_window_size * 2 | |
| ) | |
| def initialize_selfplay_bots(self) -> None: | |
| if not self.selfplay_bots: | |
| return | |
| from rl_algo_impls.runner.running_utils import get_device, make_policy | |
| env = self.env # Type: ignore | |
| device = get_device(self.config, env) | |
| start_idx = 2 * self.num_old_policies | |
| for model_path, n in self.selfplay_bots.items(): | |
| policy = make_policy( | |
| self.config.algo, | |
| env, | |
| device, | |
| load_path=model_path, | |
| **self.config.policy_hyperparams, | |
| ).eval() | |
| self.selfplay_policies["model_path"] = policy | |
| for idx in range(start_idx, start_idx + 2 * n, 2): | |
| bot_idx = ( | |
| (idx + 1) if self.bot_always_player_2 else (idx + idx // 2 % 2) | |
| ) | |
| self.policy_assignments[bot_idx] = policy | |
| start_idx += 2 * n | |
| def step(self, actions: np.ndarray) -> VecEnvStepReturn: | |
| env = self.env # type: ignore | |
| all_actions = np.zeros((env.num_envs,) + actions.shape[1:], dtype=actions.dtype) | |
| orig_learner_indexes = self.learner_indexes() | |
| all_actions[orig_learner_indexes] = actions | |
| for policy in set(p for p in self.policy_assignments if p): | |
| policy_indexes = [policy == p for p in self.policy_assignments] | |
| if any(policy_indexes): | |
| all_actions[policy_indexes] = policy.act( | |
| self.next_obs[policy_indexes], | |
| deterministic=False, | |
| action_masks=self.next_action_masks[policy_indexes] | |
| if self.next_action_masks is not None | |
| else None, | |
| ) | |
| self.next_obs, rew, done, info = env.step(all_actions) | |
| self.next_action_masks = self.env.get_action_mask() | |
| rew = rew[orig_learner_indexes] | |
| info = [i for i, b in zip(info, orig_learner_indexes) if b] | |
| self.steps_since_swap += 1 | |
| for idx in range(0, self.num_old_policies * 2, 2 * self.swap_window_size): | |
| if self.steps_since_swap[idx] > self.swap_steps: | |
| self.swap_policy(idx, self.swap_window_size) | |
| new_learner_indexes = self.learner_indexes() | |
| return self.next_obs[new_learner_indexes], rew, done[new_learner_indexes], info | |
| def reset(self) -> VecEnvObs: | |
| self.next_obs = super().reset() | |
| self.next_action_masks = self.env.get_action_mask() | |
| return self.next_obs[self.learner_indexes()] | |