Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
2 |
-
from pydantic import BaseModel
|
3 |
import base64
|
4 |
import io
|
5 |
import os
|
6 |
import logging
|
7 |
from PIL import Image, UnidentifiedImageError
|
8 |
import torch
|
9 |
-
|
10 |
from utils import (
|
11 |
check_ocr_box,
|
12 |
get_yolo_model,
|
@@ -19,20 +19,15 @@ from transformers import AutoProcessor, AutoModelForCausalLM
|
|
19 |
logging.basicConfig(level=logging.DEBUG)
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
-
# Initialize FastAPI app
|
23 |
-
app = FastAPI()
|
24 |
-
|
25 |
-
# Initialize Celery
|
26 |
-
celery = Celery(
|
27 |
-
"tasks",
|
28 |
-
broker="redis://localhost:6379/0",
|
29 |
-
backend="redis://localhost:6379/0"
|
30 |
-
)
|
31 |
-
|
32 |
# Load YOLO model
|
33 |
yolo_model = get_yolo_model(model_path="weights/best.pt")
|
|
|
|
|
34 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
-
|
|
|
|
|
|
|
36 |
|
37 |
# Load caption model and processor
|
38 |
try:
|
@@ -43,7 +38,7 @@ try:
|
|
43 |
"weights/icon_caption_florence",
|
44 |
torch_dtype=torch.float16,
|
45 |
trust_remote_code=True,
|
46 |
-
).to(
|
47 |
except Exception as e:
|
48 |
logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
|
49 |
model = AutoModelForCausalLM.from_pretrained(
|
@@ -55,6 +50,12 @@ except Exception as e:
|
|
55 |
caption_model_processor = {"processor": processor, "model": model}
|
56 |
logger.info("Finished loading models!!!")
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
# Define a response model for the processed image
|
59 |
class ProcessResponse(BaseModel):
|
60 |
image: str # Base64 encoded image
|
@@ -62,14 +63,44 @@ class ProcessResponse(BaseModel):
|
|
62 |
label_coordinates: str
|
63 |
|
64 |
|
65 |
-
|
66 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
try:
|
68 |
-
|
69 |
image_save_path = "imgs/saved_image_demo.png"
|
70 |
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
|
|
|
|
|
71 |
image_input.save(image_save_path)
|
|
|
72 |
|
|
|
73 |
box_overlay_ratio = image_input.size[0] / 3200
|
74 |
draw_bbox_config = {
|
75 |
"text_scale": 0.8 * box_overlay_ratio,
|
@@ -78,7 +109,8 @@ def process_image_task(image_bytes: bytes, box_threshold: float, iou_threshold:
|
|
78 |
"thickness": max(int(3 * box_overlay_ratio), 1),
|
79 |
}
|
80 |
|
81 |
-
ocr_bbox_rslt, is_goal_filtered =
|
|
|
82 |
image_save_path,
|
83 |
display_img=False,
|
84 |
output_bb_format="xyxy",
|
@@ -88,7 +120,8 @@ def process_image_task(image_bytes: bytes, box_threshold: float, iou_threshold:
|
|
88 |
)
|
89 |
text, ocr_bbox = ocr_bbox_rslt
|
90 |
|
91 |
-
|
|
|
92 |
image_save_path,
|
93 |
yolo_model,
|
94 |
BOX_TRESHOLD=box_threshold,
|
@@ -100,48 +133,54 @@ def process_image_task(image_bytes: bytes, box_threshold: float, iou_threshold:
|
|
100 |
iou_threshold=iou_threshold,
|
101 |
)
|
102 |
|
103 |
-
image
|
|
|
104 |
buffered = io.BytesIO()
|
105 |
image.save(buffered, format="PNG")
|
106 |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
107 |
|
|
|
108 |
parsed_content_list_str = "\n".join([str(item) for item in parsed_content_list])
|
109 |
|
110 |
-
return
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
except Exception as e:
|
116 |
-
logger.error(f"Error in
|
117 |
-
|
118 |
|
119 |
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
122 |
try:
|
123 |
-
|
|
|
124 |
try:
|
125 |
-
Image.open(io.BytesIO(
|
126 |
except UnidentifiedImageError as e:
|
127 |
logger.error(f"Unsupported image format: {e}")
|
128 |
raise HTTPException(status_code=400, detail="Unsupported image format.")
|
129 |
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
except Exception as e:
|
133 |
logger.error(f"Error processing image: {e}")
|
134 |
-
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
|
135 |
-
|
136 |
-
|
137 |
-
@app.get("/task_status/{task_id}")
|
138 |
-
def get_task_status(task_id: str):
|
139 |
-
task_result = celery.AsyncResult(task_id)
|
140 |
-
if task_result.state == "PENDING":
|
141 |
-
return {"task_id": task_id, "status": "Processing"}
|
142 |
-
elif task_result.state == "SUCCESS":
|
143 |
-
return {"task_id": task_id, "status": "Completed", "result": task_result.result}
|
144 |
-
elif task_result.state == "FAILURE":
|
145 |
-
return {"task_id": task_id, "status": "Failed", "error": str(task_result.result)}
|
146 |
-
else:
|
147 |
-
return {"task_id": task_id, "status": task_result.state}
|
|
|
1 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
2 |
+
from pydantic import BaseModel#
|
3 |
import base64
|
4 |
import io
|
5 |
import os
|
6 |
import logging
|
7 |
from PIL import Image, UnidentifiedImageError
|
8 |
import torch
|
9 |
+
import asyncio
|
10 |
from utils import (
|
11 |
check_ocr_box,
|
12 |
get_yolo_model,
|
|
|
19 |
logging.basicConfig(level=logging.DEBUG)
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
# Load YOLO model
|
23 |
yolo_model = get_yolo_model(model_path="weights/best.pt")
|
24 |
+
|
25 |
+
# Handle device placement
|
26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
+
if str(device) == "cuda":
|
28 |
+
yolo_model = yolo_model.cuda()
|
29 |
+
else:
|
30 |
+
yolo_model = yolo_model.cpu()
|
31 |
|
32 |
# Load caption model and processor
|
33 |
try:
|
|
|
38 |
"weights/icon_caption_florence",
|
39 |
torch_dtype=torch.float16,
|
40 |
trust_remote_code=True,
|
41 |
+
).to("cuda")
|
42 |
except Exception as e:
|
43 |
logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
|
44 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
50 |
caption_model_processor = {"processor": processor, "model": model}
|
51 |
logger.info("Finished loading models!!!")
|
52 |
|
53 |
+
# Initialize FastAPI app
|
54 |
+
app = FastAPI()
|
55 |
+
|
56 |
+
MAX_QUEUE_SIZE = 10 # Set a reasonable limit based on your system capacity
|
57 |
+
request_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
|
58 |
+
|
59 |
# Define a response model for the processed image
|
60 |
class ProcessResponse(BaseModel):
|
61 |
image: str # Base64 encoded image
|
|
|
63 |
label_coordinates: str
|
64 |
|
65 |
|
66 |
+
# Define the async worker function
|
67 |
+
async def worker():
|
68 |
+
"""
|
69 |
+
Background worker to process tasks from the request queue sequentially.
|
70 |
+
"""
|
71 |
+
while True:
|
72 |
+
task = await request_queue.get() # Get the next task from the queue
|
73 |
+
try:
|
74 |
+
await task # Process the task
|
75 |
+
except Exception as e:
|
76 |
+
logger.error(f"Error while processing task: {e}")
|
77 |
+
finally:
|
78 |
+
request_queue.task_done() # Mark the task as done
|
79 |
+
|
80 |
+
|
81 |
+
# Start the worker when the application starts
|
82 |
+
@app.on_event("startup")
|
83 |
+
async def startup_event():
|
84 |
+
logger.info("Starting background worker...")
|
85 |
+
|
86 |
+
asyncio.create_task(worker()) # Start the worker in the background
|
87 |
+
|
88 |
+
|
89 |
+
# Define the process function
|
90 |
+
async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
|
91 |
+
"""
|
92 |
+
Asynchronously processes an image using YOLO and caption models.
|
93 |
+
"""
|
94 |
try:
|
95 |
+
# Define the save path and ensure the directory exists
|
96 |
image_save_path = "imgs/saved_image_demo.png"
|
97 |
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
|
98 |
+
|
99 |
+
# Save the image
|
100 |
image_input.save(image_save_path)
|
101 |
+
logger.debug(f"Image saved to: {image_save_path}")
|
102 |
|
103 |
+
# Perform YOLO and caption model inference
|
104 |
box_overlay_ratio = image_input.size[0] / 3200
|
105 |
draw_bbox_config = {
|
106 |
"text_scale": 0.8 * box_overlay_ratio,
|
|
|
109 |
"thickness": max(int(3 * box_overlay_ratio), 1),
|
110 |
}
|
111 |
|
112 |
+
ocr_bbox_rslt, is_goal_filtered = await asyncio.to_thread(
|
113 |
+
check_ocr_box,
|
114 |
image_save_path,
|
115 |
display_img=False,
|
116 |
output_bb_format="xyxy",
|
|
|
120 |
)
|
121 |
text, ocr_bbox = ocr_bbox_rslt
|
122 |
|
123 |
+
dino_labled_img, label_coordinates, parsed_content_list = await asyncio.to_thread(
|
124 |
+
get_som_labeled_img,
|
125 |
image_save_path,
|
126 |
yolo_model,
|
127 |
BOX_TRESHOLD=box_threshold,
|
|
|
133 |
iou_threshold=iou_threshold,
|
134 |
)
|
135 |
|
136 |
+
# Convert labeled image to base64
|
137 |
+
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
|
138 |
buffered = io.BytesIO()
|
139 |
image.save(buffered, format="PNG")
|
140 |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
141 |
|
142 |
+
# Join parsed content list
|
143 |
parsed_content_list_str = "\n".join([str(item) for item in parsed_content_list])
|
144 |
|
145 |
+
return ProcessResponse(
|
146 |
+
image=img_str,
|
147 |
+
parsed_content_list=parsed_content_list_str,
|
148 |
+
label_coordinates=str(label_coordinates),
|
149 |
+
)
|
150 |
except Exception as e:
|
151 |
+
logger.error(f"Error in process function: {e}")
|
152 |
+
raise HTTPException(status_code=500, detail=f"Failed to process the image: {e}")
|
153 |
|
154 |
|
155 |
+
# Define the process_image endpoint
|
156 |
+
@app.post("/process_image", response_model=ProcessResponse)
|
157 |
+
async def process_image(
|
158 |
+
image_file: UploadFile = File(...),
|
159 |
+
box_threshold: float = 0.05,
|
160 |
+
iou_threshold: float = 0.1,
|
161 |
+
):
|
162 |
try:
|
163 |
+
# Read the image file
|
164 |
+
contents = await image_file.read()
|
165 |
try:
|
166 |
+
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
|
167 |
except UnidentifiedImageError as e:
|
168 |
logger.error(f"Unsupported image format: {e}")
|
169 |
raise HTTPException(status_code=400, detail="Unsupported image format.")
|
170 |
|
171 |
+
# Create a task for processing
|
172 |
+
task = asyncio.create_task(process(image_input, box_threshold, iou_threshold))
|
173 |
+
|
174 |
+
# Add the task to the queue
|
175 |
+
await request_queue.put(task)
|
176 |
+
logger.info(f"Task added to queue. Current queue size: {request_queue.qsize()}")
|
177 |
+
|
178 |
+
# Wait for the task to complete
|
179 |
+
response = await task
|
180 |
+
|
181 |
+
return response
|
182 |
+
except HTTPException as he:
|
183 |
+
raise he
|
184 |
except Exception as e:
|
185 |
logger.error(f"Error processing image: {e}")
|
186 |
+
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|