banao-tech commited on
Commit
d0b9c8a
·
verified ·
1 Parent(s): 6294868

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +86 -47
main.py CHANGED
@@ -1,12 +1,12 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from pydantic import BaseModel
3
  import base64
4
  import io
5
  import os
6
  import logging
7
  from PIL import Image, UnidentifiedImageError
8
  import torch
9
- from celery import Celery
10
  from utils import (
11
  check_ocr_box,
12
  get_yolo_model,
@@ -19,20 +19,15 @@ from transformers import AutoProcessor, AutoModelForCausalLM
19
  logging.basicConfig(level=logging.DEBUG)
20
  logger = logging.getLogger(__name__)
21
 
22
- # Initialize FastAPI app
23
- app = FastAPI()
24
-
25
- # Initialize Celery
26
- celery = Celery(
27
- "tasks",
28
- broker="redis://localhost:6379/0",
29
- backend="redis://localhost:6379/0"
30
- )
31
-
32
  # Load YOLO model
33
  yolo_model = get_yolo_model(model_path="weights/best.pt")
 
 
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
- yolo_model = yolo_model.to(device)
 
 
 
36
 
37
  # Load caption model and processor
38
  try:
@@ -43,7 +38,7 @@ try:
43
  "weights/icon_caption_florence",
44
  torch_dtype=torch.float16,
45
  trust_remote_code=True,
46
- ).to(device)
47
  except Exception as e:
48
  logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
49
  model = AutoModelForCausalLM.from_pretrained(
@@ -55,6 +50,12 @@ except Exception as e:
55
  caption_model_processor = {"processor": processor, "model": model}
56
  logger.info("Finished loading models!!!")
57
 
 
 
 
 
 
 
58
  # Define a response model for the processed image
59
  class ProcessResponse(BaseModel):
60
  image: str # Base64 encoded image
@@ -62,14 +63,44 @@ class ProcessResponse(BaseModel):
62
  label_coordinates: str
63
 
64
 
65
- @celery.task
66
- def process_image_task(image_bytes: bytes, box_threshold: float, iou_threshold: float):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  try:
68
- image_input = Image.open(io.BytesIO(image_bytes)).convert("RGB")
69
  image_save_path = "imgs/saved_image_demo.png"
70
  os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
 
 
71
  image_input.save(image_save_path)
 
72
 
 
73
  box_overlay_ratio = image_input.size[0] / 3200
74
  draw_bbox_config = {
75
  "text_scale": 0.8 * box_overlay_ratio,
@@ -78,7 +109,8 @@ def process_image_task(image_bytes: bytes, box_threshold: float, iou_threshold:
78
  "thickness": max(int(3 * box_overlay_ratio), 1),
79
  }
80
 
81
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
 
82
  image_save_path,
83
  display_img=False,
84
  output_bb_format="xyxy",
@@ -88,7 +120,8 @@ def process_image_task(image_bytes: bytes, box_threshold: float, iou_threshold:
88
  )
89
  text, ocr_bbox = ocr_bbox_rslt
90
 
91
- dino_labeled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
 
92
  image_save_path,
93
  yolo_model,
94
  BOX_TRESHOLD=box_threshold,
@@ -100,48 +133,54 @@ def process_image_task(image_bytes: bytes, box_threshold: float, iou_threshold:
100
  iou_threshold=iou_threshold,
101
  )
102
 
103
- image = Image.open(io.BytesIO(base64.b64decode(dino_labeled_img)))
 
104
  buffered = io.BytesIO()
105
  image.save(buffered, format="PNG")
106
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
107
 
 
108
  parsed_content_list_str = "\n".join([str(item) for item in parsed_content_list])
109
 
110
- return {
111
- "image": img_str,
112
- "parsed_content_list": parsed_content_list_str,
113
- "label_coordinates": str(label_coordinates),
114
- }
115
  except Exception as e:
116
- logger.error(f"Error in process_image_task: {e}")
117
- return {"error": str(e)}
118
 
119
 
120
- @app.post("/process_image")
121
- async def process_image(image_file: UploadFile = File(...), box_threshold: float = 0.05, iou_threshold: float = 0.1):
 
 
 
 
 
122
  try:
123
- image_bytes = await image_file.read()
 
124
  try:
125
- Image.open(io.BytesIO(image_bytes)).convert("RGB")
126
  except UnidentifiedImageError as e:
127
  logger.error(f"Unsupported image format: {e}")
128
  raise HTTPException(status_code=400, detail="Unsupported image format.")
129
 
130
- task = process_image_task.delay(image_bytes, box_threshold, iou_threshold)
131
- return {"task_id": task.id, "status": "Processing"}
 
 
 
 
 
 
 
 
 
 
 
132
  except Exception as e:
133
  logger.error(f"Error processing image: {e}")
134
- raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
135
-
136
-
137
- @app.get("/task_status/{task_id}")
138
- def get_task_status(task_id: str):
139
- task_result = celery.AsyncResult(task_id)
140
- if task_result.state == "PENDING":
141
- return {"task_id": task_id, "status": "Processing"}
142
- elif task_result.state == "SUCCESS":
143
- return {"task_id": task_id, "status": "Completed", "result": task_result.result}
144
- elif task_result.state == "FAILURE":
145
- return {"task_id": task_id, "status": "Failed", "error": str(task_result.result)}
146
- else:
147
- return {"task_id": task_id, "status": task_result.state}
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from pydantic import BaseModel#
3
  import base64
4
  import io
5
  import os
6
  import logging
7
  from PIL import Image, UnidentifiedImageError
8
  import torch
9
+ import asyncio
10
  from utils import (
11
  check_ocr_box,
12
  get_yolo_model,
 
19
  logging.basicConfig(level=logging.DEBUG)
20
  logger = logging.getLogger(__name__)
21
 
 
 
 
 
 
 
 
 
 
 
22
  # Load YOLO model
23
  yolo_model = get_yolo_model(model_path="weights/best.pt")
24
+
25
+ # Handle device placement
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ if str(device) == "cuda":
28
+ yolo_model = yolo_model.cuda()
29
+ else:
30
+ yolo_model = yolo_model.cpu()
31
 
32
  # Load caption model and processor
33
  try:
 
38
  "weights/icon_caption_florence",
39
  torch_dtype=torch.float16,
40
  trust_remote_code=True,
41
+ ).to("cuda")
42
  except Exception as e:
43
  logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
44
  model = AutoModelForCausalLM.from_pretrained(
 
50
  caption_model_processor = {"processor": processor, "model": model}
51
  logger.info("Finished loading models!!!")
52
 
53
+ # Initialize FastAPI app
54
+ app = FastAPI()
55
+
56
+ MAX_QUEUE_SIZE = 10 # Set a reasonable limit based on your system capacity
57
+ request_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
58
+
59
  # Define a response model for the processed image
60
  class ProcessResponse(BaseModel):
61
  image: str # Base64 encoded image
 
63
  label_coordinates: str
64
 
65
 
66
+ # Define the async worker function
67
+ async def worker():
68
+ """
69
+ Background worker to process tasks from the request queue sequentially.
70
+ """
71
+ while True:
72
+ task = await request_queue.get() # Get the next task from the queue
73
+ try:
74
+ await task # Process the task
75
+ except Exception as e:
76
+ logger.error(f"Error while processing task: {e}")
77
+ finally:
78
+ request_queue.task_done() # Mark the task as done
79
+
80
+
81
+ # Start the worker when the application starts
82
+ @app.on_event("startup")
83
+ async def startup_event():
84
+ logger.info("Starting background worker...")
85
+
86
+ asyncio.create_task(worker()) # Start the worker in the background
87
+
88
+
89
+ # Define the process function
90
+ async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
91
+ """
92
+ Asynchronously processes an image using YOLO and caption models.
93
+ """
94
  try:
95
+ # Define the save path and ensure the directory exists
96
  image_save_path = "imgs/saved_image_demo.png"
97
  os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
98
+
99
+ # Save the image
100
  image_input.save(image_save_path)
101
+ logger.debug(f"Image saved to: {image_save_path}")
102
 
103
+ # Perform YOLO and caption model inference
104
  box_overlay_ratio = image_input.size[0] / 3200
105
  draw_bbox_config = {
106
  "text_scale": 0.8 * box_overlay_ratio,
 
109
  "thickness": max(int(3 * box_overlay_ratio), 1),
110
  }
111
 
112
+ ocr_bbox_rslt, is_goal_filtered = await asyncio.to_thread(
113
+ check_ocr_box,
114
  image_save_path,
115
  display_img=False,
116
  output_bb_format="xyxy",
 
120
  )
121
  text, ocr_bbox = ocr_bbox_rslt
122
 
123
+ dino_labled_img, label_coordinates, parsed_content_list = await asyncio.to_thread(
124
+ get_som_labeled_img,
125
  image_save_path,
126
  yolo_model,
127
  BOX_TRESHOLD=box_threshold,
 
133
  iou_threshold=iou_threshold,
134
  )
135
 
136
+ # Convert labeled image to base64
137
+ image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
138
  buffered = io.BytesIO()
139
  image.save(buffered, format="PNG")
140
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
141
 
142
+ # Join parsed content list
143
  parsed_content_list_str = "\n".join([str(item) for item in parsed_content_list])
144
 
145
+ return ProcessResponse(
146
+ image=img_str,
147
+ parsed_content_list=parsed_content_list_str,
148
+ label_coordinates=str(label_coordinates),
149
+ )
150
  except Exception as e:
151
+ logger.error(f"Error in process function: {e}")
152
+ raise HTTPException(status_code=500, detail=f"Failed to process the image: {e}")
153
 
154
 
155
+ # Define the process_image endpoint
156
+ @app.post("/process_image", response_model=ProcessResponse)
157
+ async def process_image(
158
+ image_file: UploadFile = File(...),
159
+ box_threshold: float = 0.05,
160
+ iou_threshold: float = 0.1,
161
+ ):
162
  try:
163
+ # Read the image file
164
+ contents = await image_file.read()
165
  try:
166
+ image_input = Image.open(io.BytesIO(contents)).convert("RGB")
167
  except UnidentifiedImageError as e:
168
  logger.error(f"Unsupported image format: {e}")
169
  raise HTTPException(status_code=400, detail="Unsupported image format.")
170
 
171
+ # Create a task for processing
172
+ task = asyncio.create_task(process(image_input, box_threshold, iou_threshold))
173
+
174
+ # Add the task to the queue
175
+ await request_queue.put(task)
176
+ logger.info(f"Task added to queue. Current queue size: {request_queue.qsize()}")
177
+
178
+ # Wait for the task to complete
179
+ response = await task
180
+
181
+ return response
182
+ except HTTPException as he:
183
+ raise he
184
  except Exception as e:
185
  logger.error(f"Error processing image: {e}")
186
+ raise HTTPException(status_code=500, detail=f"Internal server error: {e}")#