Spaces:
Running
on
Zero
Running
on
Zero
| from environs import Env | |
| from torch import Tensor | |
| from beartype import beartype | |
| from beartype.door import is_bearable | |
| from jaxtyping import ( | |
| Float, | |
| Int, | |
| Bool, | |
| jaxtyped | |
| ) | |
| # environment | |
| env = Env() | |
| env.read_env() | |
| # function | |
| def always(value): | |
| def inner(*args, **kwargs): | |
| return value | |
| return inner | |
| def identity(t): | |
| return t | |
| # jaxtyping is a misnomer, works for pytorch | |
| class TorchTyping: | |
| def __init__(self, abstract_dtype): | |
| self.abstract_dtype = abstract_dtype | |
| def __getitem__(self, shapes: str): | |
| return self.abstract_dtype[Tensor, shapes] | |
| Float = TorchTyping(Float) | |
| Int = TorchTyping(Int) | |
| Bool = TorchTyping(Bool) | |
| # use env variable TYPECHECK to control whether to use beartype + jaxtyping | |
| should_typecheck = env.bool('TYPECHECK', False) | |
| typecheck = jaxtyped(typechecker = beartype) if should_typecheck else identity | |
| beartype_isinstance = is_bearable if should_typecheck else always(True) | |
| __all__ = [ | |
| Float, | |
| Int, | |
| Bool, | |
| typecheck, | |
| beartype_isinstance | |
| ] | |