jax>=0.3.0 jaxlib>=0.3.0 flax>=0.5.0 optax>=0.1.0 tensorflow-datasets>=4.0.0 numpy>=1.19.0 tqdm>=4.0.0 matplotlib>=3.0.0 einops>=0.3.0