royyy74 commited on
Commit
4bd4500
·
verified ·
1 Parent(s): 84d516b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -67
app.py CHANGED
@@ -1,6 +1,3 @@
1
- import spaces
2
- import gradio as gr
3
- from huggingface_hub import InferenceClient
4
  from torch import nn
5
  from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
6
  from pathlib import Path
@@ -9,11 +6,26 @@ import torch.amp.autocast_mode
9
  from PIL import Image
10
  import os
11
  import torchvision.transforms.functional as TVF
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  CLIP_PATH = "google/siglip-so400m-patch14-384"
15
  CHECKPOINT_PATH = Path("cgrkzexw-599808")
16
- TITLE = "<h1><center>JoyCaption Alpha Two (2024-09-26a)</center></h1>"
17
  CAPTION_TYPE_MAP = {
18
  "Descriptive": [
19
  "Write a descriptive caption for this image in a formal tone.",
@@ -62,7 +74,8 @@ CAPTION_TYPE_MAP = {
62
  ],
63
  }
64
 
65
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
66
 
67
 
68
  class ImageAdapter(nn.Module):
@@ -165,8 +178,8 @@ if device.type == 'cuda':
165
  image_adapter.to(device)
166
 
167
 
168
- @spaces.GPU() # We keep this decorator for now, assuming GPU is preferred if available
169
- @torch.no_grad()
170
  def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str, extra_options: list[str], name_input: str, custom_prompt: str) -> tuple[str, str]:
171
  if device.type == "cuda":
172
  torch.cuda.empty_cache()
@@ -306,66 +319,50 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str
306
  return prompt_str, caption.strip()
307
 
308
 
309
- with gr.Blocks() as demo:
310
- gr.HTML(TITLE)
311
-
312
- if device.type == 'cpu':
313
- gr.Markdown("**Warning: Running on CPU.** Captions may take a very long time to generate (potentially several minutes). For faster performance, please use a Space with GPU hardware.")
314
-
315
- with gr.Row():
316
- with gr.Column():
317
- input_image = gr.Image(type="pil", label="Input Image")
318
-
319
- caption_type = gr.Dropdown(
320
- choices=["Descriptive", "Descriptive (Informal)", "Training Prompt", "MidJourney", "Booru tag list", "Booru-like tag list", "Art Critic", "Product Listing", "Social Media Post"],
321
- label="Caption Type",
322
- value="Descriptive",
323
- )
324
-
325
- caption_length = gr.Dropdown(
326
- choices=["any", "very short", "short", "medium-length", "long", "very long"] +
327
- [str(i) for i in range(20, 261, 10)],
328
- label="Caption Length",
329
- value="long",
330
- )
331
-
332
- extra_options = gr.CheckboxGroup(
333
- choices=[
334
- "If there is a person/character in the image you must refer to them as {name}.",
335
- "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
336
- "Include information about lighting.",
337
- "Include information about camera angle.",
338
- "Include information about whether there is a watermark or not.",
339
- "Include information about whether there are JPEG artifacts or not.",
340
- "If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
341
- "Do NOT include anything sexual; keep it PG.",
342
- "Do NOT mention the image's resolution.",
343
- "You MUST include information about the subjective aesthetic quality of the image from low to very high.",
344
- "Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
345
- "Do NOT mention any text that is in the image.",
346
- "Specify the depth of field and whether the background is in focus or blurred.",
347
- "If applicable, mention the likely use of artificial or natural lighting sources.",
348
- "Do NOT use any ambiguous language.",
349
- "Include whether the image is sfw, suggestive, or nsfw.",
350
- "ONLY describe the most important elements of the image."
351
- ],
352
- label="Extra Options"
353
- )
354
-
355
- name_input = gr.Textbox(label="Person/Character Name (if applicable)")
356
- gr.Markdown("**Note:** Name input is only used if an Extra Option is selected that requires it.")
357
-
358
- custom_prompt = gr.Textbox(label="Custom Prompt (optional, will override all other settings)")
359
- gr.Markdown("**Note:** Alpha Two is not a general instruction follower and will not follow prompts outside its training data well. Use this feature with caution.")
360
-
361
- run_button = gr.Button("Caption")
362
-
363
- with gr.Column():
364
- output_prompt = gr.Textbox(label="Prompt that was used")
365
- output_caption = gr.Textbox(label="Caption")
366
-
367
- run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_length, extra_options, name_input, custom_prompt], outputs=[output_prompt, output_caption])
368
 
369
 
370
  if __name__ == "__main__":
371
- demo.launch(share=True)
 
 
 
 
 
1
  from torch import nn
2
  from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
3
  from pathlib import Path
 
6
  from PIL import Image
7
  import os
8
  import torchvision.transforms.functional as TVF
9
+ import io
10
+ import json # For parsing extra_options_json
11
+
12
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
13
+ from pydantic import BaseModel
14
+ from typing import List, Tuple # Tuple for stream_chat return type hint
15
+
16
+
17
+ # FastAPI App Initialization
18
+ app = FastAPI()
19
+
20
+ # Pydantic model for API response
21
+ class CaptionResponse(BaseModel):
22
+ prompt_that_was_used: str
23
+ caption: str
24
 
25
 
26
  CLIP_PATH = "google/siglip-so400m-patch14-384"
27
  CHECKPOINT_PATH = Path("cgrkzexw-599808")
28
+ # TITLE is not used for API
29
  CAPTION_TYPE_MAP = {
30
  "Descriptive": [
31
  "Write a descriptive caption for this image in a formal tone.",
 
74
  ],
75
  }
76
 
77
+ # HF_TOKEN is not used in the API version
78
+ # HF_TOKEN = os.environ.get("HF_TOKEN", None)
79
 
80
 
81
  class ImageAdapter(nn.Module):
 
178
  image_adapter.to(device)
179
 
180
 
181
+ # torch.no_grad() will be applied by FastAPI for endpoint or can be kept if function is called elsewhere
182
+ @torch.no_grad()
183
  def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str, extra_options: list[str], name_input: str, custom_prompt: str) -> tuple[str, str]:
184
  if device.type == "cuda":
185
  torch.cuda.empty_cache()
 
319
  return prompt_str, caption.strip()
320
 
321
 
322
+ @app.post("/caption_image/", response_model=CaptionResponse)
323
+ async def caption_image_endpoint(
324
+ image_file: UploadFile = File(...),
325
+ caption_type: str = Form(...),
326
+ caption_length: str = Form(...),
327
+ extra_options_json: str = Form("[]"), # Expect a JSON string for list of options
328
+ name_input: str = Form(""),
329
+ custom_prompt: str = Form("")
330
+ ):
331
+ try:
332
+ # Read image file
333
+ image_bytes = await image_file.read()
334
+ input_image = Image.open(io.BytesIO(image_bytes))
335
+ except Exception as e:
336
+ raise HTTPException(status_code=400, detail=f"Invalid image file: {e}")
337
+
338
+ try:
339
+ # Parse extra_options from JSON string
340
+ extra_options = json.loads(extra_options_json)
341
+ if not isinstance(extra_options, list):
342
+ raise ValueError("extra_options_json must be a JSON list")
343
+ except ValueError as e:
344
+ raise HTTPException(status_code=400, detail=f"Invalid extra_options_json: {e}")
345
+
346
+ # Call the existing stream_chat function
347
+ # Ensure stream_chat is compatible with these inputs
348
+ try:
349
+ prompt_used, generated_caption = stream_chat(
350
+ input_image=input_image,
351
+ caption_type=caption_type,
352
+ caption_length=caption_length,
353
+ extra_options=extra_options,
354
+ name_input=name_input,
355
+ custom_prompt=custom_prompt
356
+ )
357
+ return CaptionResponse(prompt_that_was_used=prompt_used, caption=generated_caption)
358
+ except ValueError as e: # Catch specific errors from stream_chat like invalid caption_length
359
+ raise HTTPException(status_code=400, detail=str(e))
360
+ except Exception as e:
361
+ # General error catch for unexpected issues during model processing
362
+ print(f"Error during caption generation: {e}") # Log for server visibility
363
+ raise HTTPException(status_code=500, detail="Internal server error during caption generation.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
 
366
  if __name__ == "__main__":
367
+ import uvicorn
368
+ uvicorn.run(app, host="0.0.0.0", port=8000)