|
import os |
|
from pathlib import Path |
|
from typing import Dict, Optional, Union |
|
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
from huggingface_hub.constants import (PYTORCH_WEIGHTS_NAME, |
|
SAFETENSORS_SINGLE_FILE) |
|
from huggingface_hub.file_download import hf_hub_download |
|
from huggingface_hub.utils import EntryNotFoundError, is_torch_available |
|
|
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
|
|
class CompatiblePyTorchModelHubMixin(PyTorchModelHubMixin): |
|
"""Mixin class to load Pytorch models from the Hub.""" |
|
def _save_pretrained(self, save_directory: Path) -> None: |
|
"""Save weights from a Pytorch model to a local directory.""" |
|
|
|
model_to_save = self.module if hasattr(self, "module") else self |
|
torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME) |
|
|
|
@classmethod |
|
def _from_pretrained( |
|
cls, |
|
*, |
|
model_id: str, |
|
revision: Optional[str], |
|
cache_dir: Optional[Union[str, Path]], |
|
force_download: bool, |
|
proxies: Optional[Dict], |
|
resume_download: Optional[bool], |
|
local_files_only: bool, |
|
token: Union[str, bool, None], |
|
map_location: str = "cpu", |
|
strict: bool = False, |
|
**model_kwargs, |
|
): |
|
"""Load Pytorch pretrained weights and return the loaded model.""" |
|
model = cls(**model_kwargs) |
|
if os.path.isdir(model_id): |
|
print("Loading weights from local directory") |
|
try: |
|
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) |
|
return cls._load_as_safetensor(model, model_file, map_location, strict) |
|
except FileNotFoundError: |
|
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) |
|
return cls._load_as_pickle(model, model_file, map_location, strict) |
|
else: |
|
try: |
|
model_file = hf_hub_download( |
|
repo_id=model_id, |
|
filename=SAFETENSORS_SINGLE_FILE, |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
token=token, |
|
local_files_only=local_files_only, |
|
) |
|
return cls._load_as_safetensor(model, model_file, map_location, strict) |
|
except EntryNotFoundError: |
|
model_file = hf_hub_download( |
|
repo_id=model_id, |
|
filename=PYTORCH_WEIGHTS_NAME, |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
token=token, |
|
local_files_only=local_files_only, |
|
) |
|
return cls._load_as_pickle(model, model_file, map_location, strict) |