PPO playing CarRacing-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
76ee962
| import gym | |
| import numpy as np | |
| from typing import Any, Dict, Tuple, Union | |
| from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper | |
| ObsType = Union[np.ndarray, dict] | |
| ActType = Union[int, float, np.ndarray, dict] | |
| class InitialStepTruncateWrapper(VecotarableWrapper): | |
| def __init__(self, env: gym.Env, initial_steps_to_truncate: int) -> None: | |
| super().__init__(env) | |
| self.initial_steps_to_truncate = initial_steps_to_truncate | |
| self.initialized = initial_steps_to_truncate == 0 | |
| self.steps = 0 | |
| def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]: | |
| obs, rew, done, info = self.env.step(action) | |
| if not self.initialized: | |
| self.steps += 1 | |
| if self.steps >= self.initial_steps_to_truncate: | |
| print(f"Truncation at {self.steps} steps") | |
| done = True | |
| self.initialized = True | |
| return obs, rew, done, info | |