import os from pathlib import Path from typing import Dict, Optional, Union from huggingface_hub import PyTorchModelHubMixin from huggingface_hub.constants import (PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE) from huggingface_hub.file_download import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, is_torch_available if is_torch_available(): import torch # type: ignore class CompatiblePyTorchModelHubMixin(PyTorchModelHubMixin): """Mixin class to load Pytorch models from the Hub.""" def _save_pretrained(self, save_directory: Path) -> None: """Save weights from a Pytorch model to a local directory.""" # To bypass saving into safetensor by default model_to_save = self.module if hasattr(self, "module") else self # type: ignore torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME) @classmethod def _from_pretrained( cls, *, model_id: str, revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, proxies: Optional[Dict], resume_download: Optional[bool], local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", strict: bool = False, **model_kwargs, ): """Load Pytorch pretrained weights and return the loaded model.""" model = cls(**model_kwargs) if os.path.isdir(model_id): print("Loading weights from local directory") try: model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) return cls._load_as_safetensor(model, model_file, map_location, strict) except FileNotFoundError: model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) return cls._load_as_pickle(model, model_file, map_location, strict) else: try: model_file = hf_hub_download( repo_id=model_id, filename=SAFETENSORS_SINGLE_FILE, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) return cls._load_as_safetensor(model, model_file, map_location, strict) except EntryNotFoundError: model_file = hf_hub_download( repo_id=model_id, filename=PYTORCH_WEIGHTS_NAME, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) return cls._load_as_pickle(model, model_file, map_location, strict)