|
|
|
import inspect |
|
import json |
|
import shutil |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
from fuse_clip.fuse_clip_arch import FuseCLIP |
|
from open_clip import get_input_dtype, SimpleTokenizer |
|
|
|
|
|
class FuseLIP(FuseCLIP, PyTorchModelHubMixin): |
|
"""FuseLIP with save_pretrained / from_pretrained / push_to_hub.""" |
|
|
|
|
|
def _save_pretrained(self, save_directory: Path, **kwargs): |
|
save_directory = Path(save_directory) |
|
save_directory.mkdir(parents=True, exist_ok=True) |
|
|
|
torch.save(self.state_dict(), save_directory / "pytorch_model.bin") |
|
(save_directory / "config.json").write_text( |
|
json.dumps(self.get_config(), indent=2) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
source_path = Path(inspect.getfile(FuseLIP)) |
|
shutil.copy(source_path, save_directory / "fuse_clip_hub.py") |
|
|
|
|
|
@classmethod |
|
def _from_pretrained(cls, save_directory: Path, **kwargs): |
|
|
|
cfg = json.loads(Path(save_directory, "config.json").read_text()) |
|
|
|
tokenizer = SimpleTokenizer(context_length=cfg["context_length"]) |
|
tokenizer.pad_token_id = 0 |
|
|
|
if cfg["mlm_probability"] > 0: |
|
MASK_TOKEN = "[MASK]" |
|
if MASK_TOKEN not in tokenizer.encoder: |
|
|
|
mask_token_id = max(tokenizer.encoder.values()) + 1 |
|
|
|
|
|
tokenizer.encoder[MASK_TOKEN] = mask_token_id |
|
tokenizer.decoder[mask_token_id] = MASK_TOKEN |
|
|
|
tokenizer.all_special_ids.append(mask_token_id) |
|
tokenizer.mask_token = mask_token_id |
|
tokenizer.vocab_size += 1 |
|
|
|
print(f"Added `[MASK]` token with ID {mask_token_id}") |
|
else: |
|
mask_token_id = tokenizer.encoder[MASK_TOKEN] |
|
print(f"`[MASK]` token already exists with ID {mask_token_id}") |
|
|
|
|
|
cfg["image_tokenizer_path"] = cfg["image_tokenizer"] |
|
cfg["init_logit_scale"] = np.log(10) |
|
cfg["init_logit_bias"] = -10 |
|
cfg["input_dtype"] = get_input_dtype("fp32") |
|
del cfg["text_config"] |
|
del cfg["image_tokenizer"] |
|
del cfg["context_length"] |
|
|
|
model = cls(**cfg, **kwargs) |
|
state = torch.load( |
|
Path(save_directory, "pytorch_model.bin"), |
|
map_location="cpu" |
|
) |
|
model.load_state_dict(state, strict=True) |
|
return model |
|
|