|
from torch import nn |
|
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM |
|
from pathlib import Path |
|
import torch |
|
import torch.amp.autocast_mode |
|
from PIL import Image |
|
import os |
|
import torchvision.transforms.functional as TVF |
|
import io |
|
import json |
|
from tempfile import TemporaryDirectory |
|
|
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
|
from pydantic import BaseModel |
|
from typing import List, Tuple |
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
class CaptionResponse(BaseModel): |
|
prompt_that_was_used: str |
|
caption: str |
|
|
|
|
|
CLIP_PATH = "google/siglip-so400m-patch14-384" |
|
CHECKPOINT_PATH = Path("cgrkzexw-599808") |
|
|
|
CAPTION_TYPE_MAP = { |
|
"Descriptive": [ |
|
"Write a descriptive caption for this image in a formal tone.", |
|
"Write a descriptive caption for this image in a formal tone within {word_count} words.", |
|
"Write a {length} descriptive caption for this image in a formal tone.", |
|
], |
|
"Descriptive (Informal)": [ |
|
"Write a descriptive caption for this image in a casual tone.", |
|
"Write a descriptive caption for this image in a casual tone within {word_count} words.", |
|
"Write a {length} descriptive caption for this image in a casual tone.", |
|
], |
|
"Training Prompt": [ |
|
"Write a stable diffusion prompt for this image.", |
|
"Write a stable diffusion prompt for this image within {word_count} words.", |
|
"Write a {length} stable diffusion prompt for this image.", |
|
], |
|
"MidJourney": [ |
|
"Write a MidJourney prompt for this image.", |
|
"Write a MidJourney prompt for this image within {word_count} words.", |
|
"Write a {length} MidJourney prompt for this image.", |
|
], |
|
"Booru tag list": [ |
|
"Write a list of Booru tags for this image.", |
|
"Write a list of Booru tags for this image within {word_count} words.", |
|
"Write a {length} list of Booru tags for this image.", |
|
], |
|
"Booru-like tag list": [ |
|
"Write a list of Booru-like tags for this image.", |
|
"Write a list of Booru-like tags for this image within {word_count} words.", |
|
"Write a {length} list of Booru-like tags for this image.", |
|
], |
|
"Art Critic": [ |
|
"Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.", |
|
"Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.", |
|
"Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.", |
|
], |
|
"Product Listing": [ |
|
"Write a caption for this image as though it were a product listing.", |
|
"Write a caption for this image as though it were a product listing. Keep it under {word_count} words.", |
|
"Write a {length} caption for this image as though it were a product listing.", |
|
], |
|
"Social Media Post": [ |
|
"Write a caption for this image as if it were being used for a social media post.", |
|
"Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.", |
|
"Write a {length} caption for this image as if it were being used for a social media post.", |
|
], |
|
} |
|
|
|
|
|
|
|
|
|
|
|
class ImageAdapter(nn.Module): |
|
def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool): |
|
super().__init__() |
|
self.deep_extract = deep_extract |
|
|
|
if self.deep_extract: |
|
input_features = input_features * 5 |
|
|
|
self.linear1 = nn.Linear(input_features, output_features) |
|
self.activation = nn.GELU() |
|
self.linear2 = nn.Linear(output_features, output_features) |
|
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features) |
|
self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features)) |
|
|
|
|
|
self.other_tokens = nn.Embedding(3, output_features) |
|
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
def forward(self, vision_outputs: torch.Tensor): |
|
if self.deep_extract: |
|
x = torch.concat(( |
|
vision_outputs[-2], |
|
vision_outputs[3], |
|
vision_outputs[7], |
|
vision_outputs[13], |
|
vision_outputs[20], |
|
), dim=-1) |
|
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" |
|
assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}" |
|
else: |
|
x = vision_outputs[-2] |
|
|
|
x = self.ln1(x) |
|
|
|
if self.pos_emb is not None: |
|
assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}" |
|
x = x + self.pos_emb |
|
|
|
x = self.linear1(x) |
|
x = self.activation(x) |
|
x = self.linear2(x) |
|
|
|
|
|
other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1)) |
|
assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}" |
|
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1) |
|
|
|
return x |
|
|
|
def get_eot_embedding(self): |
|
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0) |
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
print("Loading CLIP") |
|
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH) |
|
clip_model = AutoModel.from_pretrained(CLIP_PATH) |
|
clip_model = clip_model.vision_model |
|
|
|
assert (CHECKPOINT_PATH / "clip_model.pt").exists() |
|
print("Loading VLM's custom vision model") |
|
checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu') |
|
checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()} |
|
clip_model.load_state_dict(checkpoint) |
|
del checkpoint |
|
|
|
clip_model.eval() |
|
clip_model.requires_grad_(False) |
|
clip_model.to(device) |
|
if device.type == 'cuda': |
|
clip_model = clip_model.to(dtype=torch.bfloat16) |
|
elif device.type == 'cpu': |
|
clip_model = clip_model.to(dtype=torch.float32) |
|
|
|
|
|
|
|
print("Loading tokenizer") |
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH / "text_model", use_fast=True) |
|
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}" |
|
|
|
|
|
print("Loading LLM") |
|
print("Loading VLM's custom text model") |
|
|
|
if device.type == 'cuda': |
|
|
|
try: |
|
print("Attempting to load LLM on CUDA with bfloat16...") |
|
text_model = AutoModelForCausalLM.from_pretrained( |
|
CHECKPOINT_PATH / "text_model", |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16 |
|
) |
|
except ValueError as ve: |
|
if "offload_dir" in str(ve): |
|
print(f"CUDA bfloat16 loading failed, needing offload_dir: {ve}") |
|
print("Attempting to load LLM on CUDA with disk offloading...") |
|
model_offload_dir = TemporaryDirectory().name |
|
text_model = AutoModelForCausalLM.from_pretrained( |
|
CHECKPOINT_PATH / "text_model", |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
offload_folder=model_offload_dir, |
|
offload_state_dict=True |
|
) |
|
print(f"LLM loaded on CUDA with offloading to {model_offload_dir}. WARNING: This may be slow.") |
|
else: |
|
raise |
|
except Exception as e: |
|
print(f"Failed to load LLM on CUDA: {e}") |
|
raise |
|
else: |
|
|
|
print("Attempting to load LLM on CPU directly with disk offloading (float32)...") |
|
try: |
|
model_offload_dir_cpu = TemporaryDirectory().name |
|
text_model = AutoModelForCausalLM.from_pretrained( |
|
CHECKPOINT_PATH / "text_model", |
|
device_map="auto", |
|
torch_dtype=torch.float32, |
|
offload_folder=model_offload_dir_cpu, |
|
offload_state_dict=True |
|
) |
|
print(f"LLM loaded on CPU with offloading to {model_offload_dir_cpu}. WARNING: This will be very slow.") |
|
except Exception as e_cpu_offload: |
|
print(f"CPU loading with disk offloading failed: {e_cpu_offload}") |
|
raise |
|
|
|
text_model.eval() |
|
|
|
|
|
print("Loading image adapter") |
|
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False) |
|
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu")) |
|
image_adapter.eval() |
|
image_adapter.to(device) |
|
if device.type == 'cuda': |
|
image_adapter = image_adapter.to(dtype=torch.bfloat16) |
|
elif device.type == 'cpu': |
|
image_adapter = image_adapter.to(dtype=torch.float32) |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str, extra_options: list[str], name_input: str, custom_prompt: str) -> tuple[str, str]: |
|
if device.type == "cuda": |
|
torch.cuda.empty_cache() |
|
|
|
|
|
length = None if caption_length == "any" else caption_length |
|
|
|
if isinstance(length, str): |
|
try: |
|
length = int(length) |
|
except ValueError: |
|
pass |
|
|
|
|
|
if length is None: |
|
map_idx = 0 |
|
elif isinstance(length, int): |
|
map_idx = 1 |
|
elif isinstance(length, str): |
|
map_idx = 2 |
|
else: |
|
raise ValueError(f"Invalid caption length: {length}") |
|
|
|
prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx] |
|
|
|
|
|
if len(extra_options) > 0: |
|
prompt_str += " " + " ".join(extra_options) |
|
|
|
|
|
prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length) |
|
|
|
if custom_prompt.strip() != "": |
|
prompt_str = custom_prompt.strip() |
|
|
|
|
|
print(f"Prompt: {prompt_str}") |
|
|
|
|
|
|
|
|
|
image = input_image.resize((384, 384), Image.LANCZOS) |
|
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0 |
|
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]) |
|
if device.type == 'cuda': |
|
pixel_values = pixel_values.to(device, dtype=torch.bfloat16) |
|
else: |
|
pixel_values = pixel_values.to(device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
autocast_enabled_on_cpu = False |
|
|
|
autocast_device_type = device.type |
|
autocast_kwargs = {'enabled': True} |
|
|
|
if autocast_device_type == 'cpu': |
|
autocast_kwargs['enabled'] = autocast_enabled_on_cpu |
|
if autocast_enabled_on_cpu: |
|
autocast_kwargs['dtype'] = torch.float32 |
|
|
|
|
|
with torch.amp.autocast_mode.autocast(autocast_device_type, **autocast_kwargs): |
|
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True) |
|
embedded_images = image_adapter(vision_outputs.hidden_states) |
|
|
|
|
|
|
|
embedded_images = embedded_images.to(device) |
|
|
|
|
|
convo = [ |
|
{ |
|
"role": "system", |
|
"content": "You are a helpful image captioner.", |
|
}, |
|
{ |
|
"role": "user", |
|
"content": prompt_str, |
|
}, |
|
] |
|
|
|
|
|
convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True) |
|
assert isinstance(convo_string, str) |
|
|
|
|
|
|
|
convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False).to(device) |
|
prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False).to(device) |
|
assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor) |
|
convo_tokens = convo_tokens.squeeze(0) |
|
prompt_tokens = prompt_tokens.squeeze(0) |
|
|
|
|
|
|
|
eot_id_indices = (convo_tokens.cpu() == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist() |
|
assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}" |
|
|
|
preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] |
|
|
|
|
|
convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(text_model.device)) |
|
|
|
|
|
|
|
input_embeds = torch.cat([ |
|
convo_embeds[:, :preamble_len], |
|
embedded_images.to(dtype=convo_embeds.dtype, device=convo_embeds.device), |
|
convo_embeds[:, preamble_len:], |
|
], dim=1) |
|
|
|
|
|
input_ids = torch.cat([ |
|
convo_tokens[:preamble_len].unsqueeze(0), |
|
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long, device=convo_tokens.device), |
|
convo_tokens[preamble_len:].unsqueeze(0), |
|
], dim=1) |
|
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}") |
|
|
|
|
|
|
|
generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) |
|
|
|
|
|
generate_ids = generate_ids[:, input_ids.shape[1]:] |
|
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"): |
|
generate_ids = generate_ids[:, :-1] |
|
|
|
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0] |
|
|
|
return prompt_str, caption.strip() |
|
|
|
|
|
@app.post("/caption_image/", response_model=CaptionResponse) |
|
async def caption_image_endpoint( |
|
image_file: UploadFile = File(...), |
|
caption_type: str = Form(...), |
|
caption_length: str = Form(...), |
|
extra_options_json: str = Form("[]"), |
|
name_input: str = Form(""), |
|
custom_prompt: str = Form("") |
|
): |
|
try: |
|
|
|
image_bytes = await image_file.read() |
|
input_image = Image.open(io.BytesIO(image_bytes)) |
|
except Exception as e: |
|
raise HTTPException(status_code=400, detail=f"Invalid image file: {e}") |
|
|
|
try: |
|
|
|
extra_options = json.loads(extra_options_json) |
|
if not isinstance(extra_options, list): |
|
raise ValueError("extra_options_json must be a JSON list") |
|
except ValueError as e: |
|
raise HTTPException(status_code=400, detail=f"Invalid extra_options_json: {e}") |
|
|
|
|
|
|
|
try: |
|
prompt_used, generated_caption = stream_chat( |
|
input_image=input_image, |
|
caption_type=caption_type, |
|
caption_length=caption_length, |
|
extra_options=extra_options, |
|
name_input=name_input, |
|
custom_prompt=custom_prompt |
|
) |
|
return CaptionResponse(prompt_that_was_used=prompt_used, caption=generated_caption) |
|
except ValueError as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except Exception as e: |
|
|
|
print(f"Error during caption generation: {e}") |
|
raise HTTPException(status_code=500, detail="Internal server error during caption generation.") |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|