banao-tech commited on
Commit
b89e6d8
·
verified ·
1 Parent(s): 7ecde71

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +41 -14
main.py CHANGED
@@ -50,28 +50,54 @@ except Exception as e:
50
  caption_model_processor = {"processor": processor, "model": model}
51
  logger.info("Finished loading models!!!")
52
 
 
53
  app = FastAPI()
54
 
 
 
 
 
55
  class ProcessResponse(BaseModel):
56
  image: str # Base64 encoded image
57
  parsed_content_list: str
58
  label_coordinates: str
59
 
60
- # Create a queue for sequential processing
61
- request_queue = asyncio.Queue()
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
64
  """
65
  Asynchronously processes an image using YOLO and caption models.
66
  """
67
  try:
 
68
  image_save_path = "imgs/saved_image_demo.png"
69
  os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
70
-
71
- # Save the image asynchronously
72
- buffer = io.BytesIO()
73
- image_input.save(buffer, format="PNG")
74
- buffer.seek(0)
75
 
76
  # Perform YOLO and caption model inference
77
  box_overlay_ratio = image_input.size[0] / 3200
@@ -106,7 +132,7 @@ async def process(image_input: Image.Image, box_threshold: float, iou_threshold:
106
  iou_threshold=iou_threshold,
107
  )
108
 
109
- # Convert image to base64
110
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
111
  buffered = io.BytesIO()
112
  image.save(buffered, format="PNG")
@@ -124,6 +150,8 @@ async def process(image_input: Image.Image, box_threshold: float, iou_threshold:
124
  logger.error(f"Error in process function: {e}")
125
  raise
126
 
 
 
127
  @app.post("/process_image", response_model=ProcessResponse)
128
  async def process_image(
129
  image_file: UploadFile = File(...),
@@ -135,16 +163,15 @@ async def process_image(
135
  contents = await image_file.read()
136
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
137
 
 
 
 
138
  # Add the task to the queue
139
- task = asyncio.create_task(
140
- process(image_input, box_threshold, iou_threshold)
141
- )
142
  await request_queue.put(task)
 
143
 
144
- # Process the next task in the queue
145
- task = await request_queue.get()
146
  response = await task
147
- request_queue.task_done()
148
 
149
  return response
150
  except Exception as e:
 
50
  caption_model_processor = {"processor": processor, "model": model}
51
  logger.info("Finished loading models!!!")
52
 
53
+ # Initialize FastAPI app
54
  app = FastAPI()
55
 
56
+ # Define a queue for request processing
57
+ request_queue = asyncio.Queue()
58
+
59
+ # Define a response model for the processed image
60
  class ProcessResponse(BaseModel):
61
  image: str # Base64 encoded image
62
  parsed_content_list: str
63
  label_coordinates: str
64
 
 
 
65
 
66
+ # Define the async worker function
67
+ async def worker():
68
+ """
69
+ Background worker to process tasks from the request queue sequentially.
70
+ """
71
+ while True:
72
+ task = await request_queue.get() # Get the next task from the queue
73
+ try:
74
+ await task # Process the task
75
+ except Exception as e:
76
+ logger.error(f"Error while processing task: {e}")
77
+ finally:
78
+ request_queue.task_done() # Mark the task as done
79
+
80
+
81
+ # Start the worker when the application starts
82
+ @app.on_event("startup")
83
+ async def startup_event():
84
+ logger.info("Starting background worker...")
85
+ asyncio.create_task(worker()) # Start the worker in the background
86
+
87
+
88
+ # Define the process function
89
  async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
90
  """
91
  Asynchronously processes an image using YOLO and caption models.
92
  """
93
  try:
94
+ # Define the save path and ensure the directory exists
95
  image_save_path = "imgs/saved_image_demo.png"
96
  os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
97
+
98
+ # Save the image
99
+ image_input.save(image_save_path)
100
+ logger.debug(f"Image saved to: {image_save_path}")
 
101
 
102
  # Perform YOLO and caption model inference
103
  box_overlay_ratio = image_input.size[0] / 3200
 
132
  iou_threshold=iou_threshold,
133
  )
134
 
135
+ # Convert labeled image to base64
136
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
137
  buffered = io.BytesIO()
138
  image.save(buffered, format="PNG")
 
150
  logger.error(f"Error in process function: {e}")
151
  raise
152
 
153
+
154
+ # Define the process_image endpoint
155
  @app.post("/process_image", response_model=ProcessResponse)
156
  async def process_image(
157
  image_file: UploadFile = File(...),
 
163
  contents = await image_file.read()
164
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
165
 
166
+ # Create a task for processing
167
+ task = asyncio.create_task(process(image_input, box_threshold, iou_threshold))
168
+
169
  # Add the task to the queue
 
 
 
170
  await request_queue.put(task)
171
+ logger.info(f"Task added to queue. Current queue size: {request_queue.qsize()}")
172
 
173
+ # Wait for the task to complete
 
174
  response = await task
 
175
 
176
  return response
177
  except Exception as e: