Spaces:
Paused
Paused
| from torch import Tensor | |
| from jaxtyping import ( | |
| Float, | |
| Int, | |
| Bool | |
| ) | |
| # jaxtyping is a misnomer, works for pytorch | |
| class TorchTyping: | |
| def __init__(self, abstract_dtype): | |
| self.abstract_dtype = abstract_dtype | |
| def __getitem__(self, shapes: str): | |
| return self.abstract_dtype[Tensor, shapes] | |
| Float = TorchTyping(Float) | |
| Int = TorchTyping(Int) | |
| Bool = TorchTyping(Bool) | |
| __all__ = [ | |
| Float, | |
| Int, | |
| Bool | |
| ] | |