Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
import spaces
|
2 |
-
import gradio as gr
|
3 |
-
from huggingface_hub import InferenceClient
|
4 |
from torch import nn
|
5 |
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
|
6 |
from pathlib import Path
|
@@ -9,11 +6,26 @@ import torch.amp.autocast_mode
|
|
9 |
from PIL import Image
|
10 |
import os
|
11 |
import torchvision.transforms.functional as TVF
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
15 |
CHECKPOINT_PATH = Path("cgrkzexw-599808")
|
16 |
-
TITLE
|
17 |
CAPTION_TYPE_MAP = {
|
18 |
"Descriptive": [
|
19 |
"Write a descriptive caption for this image in a formal tone.",
|
@@ -62,7 +74,8 @@ CAPTION_TYPE_MAP = {
|
|
62 |
],
|
63 |
}
|
64 |
|
65 |
-
HF_TOKEN
|
|
|
66 |
|
67 |
|
68 |
class ImageAdapter(nn.Module):
|
@@ -165,8 +178,8 @@ if device.type == 'cuda':
|
|
165 |
image_adapter.to(device)
|
166 |
|
167 |
|
168 |
-
|
169 |
-
@torch.no_grad()
|
170 |
def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str, extra_options: list[str], name_input: str, custom_prompt: str) -> tuple[str, str]:
|
171 |
if device.type == "cuda":
|
172 |
torch.cuda.empty_cache()
|
@@ -306,66 +319,50 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str
|
|
306 |
return prompt_str, caption.strip()
|
307 |
|
308 |
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
],
|
352 |
-
label="Extra Options"
|
353 |
-
)
|
354 |
-
|
355 |
-
name_input = gr.Textbox(label="Person/Character Name (if applicable)")
|
356 |
-
gr.Markdown("**Note:** Name input is only used if an Extra Option is selected that requires it.")
|
357 |
-
|
358 |
-
custom_prompt = gr.Textbox(label="Custom Prompt (optional, will override all other settings)")
|
359 |
-
gr.Markdown("**Note:** Alpha Two is not a general instruction follower and will not follow prompts outside its training data well. Use this feature with caution.")
|
360 |
-
|
361 |
-
run_button = gr.Button("Caption")
|
362 |
-
|
363 |
-
with gr.Column():
|
364 |
-
output_prompt = gr.Textbox(label="Prompt that was used")
|
365 |
-
output_caption = gr.Textbox(label="Caption")
|
366 |
-
|
367 |
-
run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_length, extra_options, name_input, custom_prompt], outputs=[output_prompt, output_caption])
|
368 |
|
369 |
|
370 |
if __name__ == "__main__":
|
371 |
-
|
|
|
|
|
|
|
|
|
|
1 |
from torch import nn
|
2 |
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
|
3 |
from pathlib import Path
|
|
|
6 |
from PIL import Image
|
7 |
import os
|
8 |
import torchvision.transforms.functional as TVF
|
9 |
+
import io
|
10 |
+
import json # For parsing extra_options_json
|
11 |
+
|
12 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
13 |
+
from pydantic import BaseModel
|
14 |
+
from typing import List, Tuple # Tuple for stream_chat return type hint
|
15 |
+
|
16 |
+
|
17 |
+
# FastAPI App Initialization
|
18 |
+
app = FastAPI()
|
19 |
+
|
20 |
+
# Pydantic model for API response
|
21 |
+
class CaptionResponse(BaseModel):
|
22 |
+
prompt_that_was_used: str
|
23 |
+
caption: str
|
24 |
|
25 |
|
26 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
27 |
CHECKPOINT_PATH = Path("cgrkzexw-599808")
|
28 |
+
# TITLE is not used for API
|
29 |
CAPTION_TYPE_MAP = {
|
30 |
"Descriptive": [
|
31 |
"Write a descriptive caption for this image in a formal tone.",
|
|
|
74 |
],
|
75 |
}
|
76 |
|
77 |
+
# HF_TOKEN is not used in the API version
|
78 |
+
# HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
79 |
|
80 |
|
81 |
class ImageAdapter(nn.Module):
|
|
|
178 |
image_adapter.to(device)
|
179 |
|
180 |
|
181 |
+
# torch.no_grad() will be applied by FastAPI for endpoint or can be kept if function is called elsewhere
|
182 |
+
@torch.no_grad()
|
183 |
def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str, extra_options: list[str], name_input: str, custom_prompt: str) -> tuple[str, str]:
|
184 |
if device.type == "cuda":
|
185 |
torch.cuda.empty_cache()
|
|
|
319 |
return prompt_str, caption.strip()
|
320 |
|
321 |
|
322 |
+
@app.post("/caption_image/", response_model=CaptionResponse)
|
323 |
+
async def caption_image_endpoint(
|
324 |
+
image_file: UploadFile = File(...),
|
325 |
+
caption_type: str = Form(...),
|
326 |
+
caption_length: str = Form(...),
|
327 |
+
extra_options_json: str = Form("[]"), # Expect a JSON string for list of options
|
328 |
+
name_input: str = Form(""),
|
329 |
+
custom_prompt: str = Form("")
|
330 |
+
):
|
331 |
+
try:
|
332 |
+
# Read image file
|
333 |
+
image_bytes = await image_file.read()
|
334 |
+
input_image = Image.open(io.BytesIO(image_bytes))
|
335 |
+
except Exception as e:
|
336 |
+
raise HTTPException(status_code=400, detail=f"Invalid image file: {e}")
|
337 |
+
|
338 |
+
try:
|
339 |
+
# Parse extra_options from JSON string
|
340 |
+
extra_options = json.loads(extra_options_json)
|
341 |
+
if not isinstance(extra_options, list):
|
342 |
+
raise ValueError("extra_options_json must be a JSON list")
|
343 |
+
except ValueError as e:
|
344 |
+
raise HTTPException(status_code=400, detail=f"Invalid extra_options_json: {e}")
|
345 |
+
|
346 |
+
# Call the existing stream_chat function
|
347 |
+
# Ensure stream_chat is compatible with these inputs
|
348 |
+
try:
|
349 |
+
prompt_used, generated_caption = stream_chat(
|
350 |
+
input_image=input_image,
|
351 |
+
caption_type=caption_type,
|
352 |
+
caption_length=caption_length,
|
353 |
+
extra_options=extra_options,
|
354 |
+
name_input=name_input,
|
355 |
+
custom_prompt=custom_prompt
|
356 |
+
)
|
357 |
+
return CaptionResponse(prompt_that_was_used=prompt_used, caption=generated_caption)
|
358 |
+
except ValueError as e: # Catch specific errors from stream_chat like invalid caption_length
|
359 |
+
raise HTTPException(status_code=400, detail=str(e))
|
360 |
+
except Exception as e:
|
361 |
+
# General error catch for unexpected issues during model processing
|
362 |
+
print(f"Error during caption generation: {e}") # Log for server visibility
|
363 |
+
raise HTTPException(status_code=500, detail="Internal server error during caption generation.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
|
365 |
|
366 |
if __name__ == "__main__":
|
367 |
+
import uvicorn
|
368 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|