A2C playing AntBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0760ef7d52b17f30219a27c18ba52c8895025ae3
0126ac9
| import gym | |
| import torch | |
| from abc import ABC, abstractmethod | |
| from torch.utils.tensorboard.writer import SummaryWriter | |
| from typing import List, Optional, TypeVar | |
| from shared.callbacks.callback import Callback | |
| from shared.policy.policy import Policy | |
| from wrappers.vectorable_wrapper import VecEnv | |
| AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm") | |
| class Algorithm(ABC): | |
| def __init__( | |
| self, | |
| policy: Policy, | |
| env: VecEnv, | |
| device: torch.device, | |
| tb_writer: SummaryWriter, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__() | |
| self.policy = policy | |
| self.env = env | |
| self.device = device | |
| self.tb_writer = tb_writer | |
| def learn( | |
| self: AlgorithmSelf, total_timesteps: int, callback: Optional[Callback] = None | |
| ) -> AlgorithmSelf: | |
| ... | |