FuseLIP-S-CC3M-MM / fuse_clip_hub.py
chs20's picture
Push model using huggingface_hub.
5e8482a verified
# fuseclip_hub.py (keep the rest of your code unchanged)
import inspect
import json
import shutil
from pathlib import Path
import numpy as np
import torch
from huggingface_hub import PyTorchModelHubMixin
from fuse_clip.fuse_clip_arch import FuseCLIP
from open_clip import get_input_dtype, SimpleTokenizer
class FuseLIP(FuseCLIP, PyTorchModelHubMixin):
"""FuseLIP with save_pretrained / from_pretrained / push_to_hub."""
# ---------- save ----------
def _save_pretrained(self, save_directory: Path, **kwargs):
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
(save_directory / "config.json").write_text(
json.dumps(self.get_config(), indent=2)
)
# copy TiTok VQ-VAE weights so offline loading works
# shutil.copy(
# self.image_tokenizer.tokenizer_path,
# save_directory / "titok_image_tokenizer.pt"
# )
# publish fuse_clip_hub.py
source_path = Path(inspect.getfile(FuseLIP)) # absolute path of this file
shutil.copy(source_path, save_directory / "fuse_clip_hub.py")
# ---------- load ----------
@classmethod
def _from_pretrained(cls, save_directory: Path, **kwargs):
cfg = json.loads(Path(save_directory, "config.json").read_text())
tokenizer = SimpleTokenizer(context_length=cfg["context_length"])
tokenizer.pad_token_id = 0
if cfg["mlm_probability"] > 0:
MASK_TOKEN = "[MASK]"
if MASK_TOKEN not in tokenizer.encoder:
# Assign a new token ID
mask_token_id = max(tokenizer.encoder.values()) + 1 # Get a new unique ID
# Add to tokenizer's vocabulary
tokenizer.encoder[MASK_TOKEN] = mask_token_id
tokenizer.decoder[mask_token_id] = MASK_TOKEN
tokenizer.all_special_ids.append(mask_token_id)
tokenizer.mask_token = mask_token_id
tokenizer.vocab_size += 1
print(f"Added `[MASK]` token with ID {mask_token_id}")
else:
mask_token_id = tokenizer.encoder[MASK_TOKEN]
print(f"`[MASK]` token already exists with ID {mask_token_id}")
cfg["image_tokenizer_path"] = cfg["image_tokenizer"]
cfg["init_logit_scale"] = np.log(10)
cfg["init_logit_bias"] = -10
cfg["input_dtype"] = get_input_dtype("fp32")
del cfg["text_config"]
del cfg["image_tokenizer"]
del cfg["context_length"]
model = cls(**cfg, **kwargs) # device / dtype can be injected via kwargs
state = torch.load(
Path(save_directory, "pytorch_model.bin"),
map_location="cpu"
)
model.load_state_dict(state, strict=True)
return model