banao-tech commited on
Commit
056fb25
·
verified ·
1 Parent(s): 0e6d684

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +72 -51
main.py CHANGED
@@ -1,76 +1,71 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel
4
- from typing import Optional
5
  import base64
6
  import io
 
 
7
  from PIL import Image
8
  import torch
9
- import numpy as np
10
- import os
11
 
12
  # Existing imports
13
- import numpy as np
14
- import torch
15
- from PIL import Image
16
- import io
17
-
18
  from utils import (
19
  check_ocr_box,
20
  get_yolo_model,
21
  get_caption_model_processor,
22
  get_som_labeled_img,
23
  )
24
- import torch
25
-
26
- # yolo_model = get_yolo_model(model_path='/data/icon_detect/best.pt')
27
- # caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="/data/icon_caption_florence")
28
-
29
- from ultralytics import YOLO
30
 
31
- # if not os.path.exists("/data/icon_detect"):
32
- # os.makedirs("/data/icon_detect")
 
33
 
34
- try:
35
- yolo_model = YOLO("weights/best.pt").to("cuda")
36
- except:
37
- yolo_model = YOLO("weights/best.pt")
38
 
39
- from transformers import AutoProcessor, AutoModelForCausalLM
40
-
41
- processor = AutoProcessor.from_pretrained(
42
- "microsoft/Florence-2-base", trust_remote_code=True
43
- )
 
44
 
 
45
  try:
 
 
 
46
  model = AutoModelForCausalLM.from_pretrained(
47
  "weights/icon_caption_florence",
48
  torch_dtype=torch.float16,
49
  trust_remote_code=True,
50
  ).to("cuda")
51
- except:
 
52
  model = AutoModelForCausalLM.from_pretrained(
53
  "weights/icon_caption_florence",
54
  torch_dtype=torch.float16,
55
  trust_remote_code=True,
56
  )
 
57
  caption_model_processor = {"processor": processor, "model": model}
58
- print("finish loading model!!!")
59
 
60
  app = FastAPI()
61
 
62
-
63
  class ProcessResponse(BaseModel):
64
  image: str # Base64 encoded image
65
  parsed_content_list: str
66
  label_coordinates: str
67
 
68
-
69
- def process(
70
- image_input: Image.Image, box_threshold: float, iou_threshold: float
71
- ) -> ProcessResponse:
72
  image_save_path = "imgs/saved_image_demo.png"
 
73
  image_input.save(image_save_path)
 
 
 
 
74
  image = Image.open(image_save_path)
75
  box_overlay_ratio = image.size[0] / 3200
76
  draw_bbox_config = {
@@ -80,6 +75,7 @@ def process(
80
  "thickness": max(int(3 * box_overlay_ratio), 1),
81
  }
82
 
 
83
  ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
84
  image_save_path,
85
  display_img=False,
@@ -89,33 +85,40 @@ def process(
89
  use_paddleocr=True,
90
  )
91
  text, ocr_bbox = ocr_bbox_rslt
92
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
93
- image_save_path,
94
- yolo_model,
95
- BOX_TRESHOLD=box_threshold,
96
- output_coord_in_ratio=True,
97
- ocr_bbox=ocr_bbox,
98
- draw_bbox_config=draw_bbox_config,
99
- caption_model_processor=caption_model_processor,
100
- ocr_text=text,
101
- iou_threshold=iou_threshold,
102
- )
 
 
 
 
 
 
 
 
 
 
103
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
104
- print("finish processing")
105
  parsed_content_list_str = "\n".join(parsed_content_list)
106
 
107
- # Encode image to base64
108
  buffered = io.BytesIO()
109
  image.save(buffered, format="PNG")
110
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
111
 
112
  return ProcessResponse(
113
  image=img_str,
114
- parsed_content_list=str(parsed_content_list_str),
115
  label_coordinates=str(label_coordinates),
116
  )
117
 
118
-
119
  @app.post("/process_image", response_model=ProcessResponse)
120
  async def process_image(
121
  image_file: UploadFile = File(...),
@@ -125,8 +128,26 @@ async def process_image(
125
  try:
126
  contents = await image_file.read()
127
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  except Exception as e:
129
- raise HTTPException(status_code=400, detail="Invalid image file")
 
 
 
130
 
131
- response = process(image_input, box_threshold, iou_threshold)
132
- return response
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
 
2
  from pydantic import BaseModel
 
3
  import base64
4
  import io
5
+ import os
6
+ import logging
7
  from PIL import Image
8
  import torch
 
 
9
 
10
  # Existing imports
 
 
 
 
 
11
  from utils import (
12
  check_ocr_box,
13
  get_yolo_model,
14
  get_caption_model_processor,
15
  get_som_labeled_img,
16
  )
17
+ from transformers import AutoProcessor, AutoModelForCausalLM
 
 
 
 
 
18
 
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.DEBUG) # Changed to DEBUG for more verbosity
21
+ logger = logging.getLogger(__name__)
22
 
23
+ # Load YOLO model
24
+ yolo_model = get_yolo_model(model_path="weights/best.pt")
 
 
25
 
26
+ # Handle device placement
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ if str(device) == "cuda":
29
+ yolo_model = yolo_model.cuda()
30
+ else:
31
+ yolo_model = yolo_model.cpu()
32
 
33
+ # Load caption model and processor
34
  try:
35
+ processor = AutoProcessor.from_pretrained(
36
+ "microsoft/Florence-2-base", trust_remote_code=True
37
+ )
38
  model = AutoModelForCausalLM.from_pretrained(
39
  "weights/icon_caption_florence",
40
  torch_dtype=torch.float16,
41
  trust_remote_code=True,
42
  ).to("cuda")
43
+ except Exception as e:
44
+ logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
45
  model = AutoModelForCausalLM.from_pretrained(
46
  "weights/icon_caption_florence",
47
  torch_dtype=torch.float16,
48
  trust_remote_code=True,
49
  )
50
+
51
  caption_model_processor = {"processor": processor, "model": model}
52
+ logger.info("Finished loading models!!!")
53
 
54
  app = FastAPI()
55
 
 
56
  class ProcessResponse(BaseModel):
57
  image: str # Base64 encoded image
58
  parsed_content_list: str
59
  label_coordinates: str
60
 
61
+ def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
 
 
 
62
  image_save_path = "imgs/saved_image_demo.png"
63
+ os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
64
  image_input.save(image_save_path)
65
+
66
+ logger.info(f"Saved image for processing: {image_save_path}")
67
+
68
+ # Open image and prepare it for further processing
69
  image = Image.open(image_save_path)
70
  box_overlay_ratio = image.size[0] / 3200
71
  draw_bbox_config = {
 
75
  "thickness": max(int(3 * box_overlay_ratio), 1),
76
  }
77
 
78
+ # OCR and YOLO box processing
79
  ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
80
  image_save_path,
81
  display_img=False,
 
85
  use_paddleocr=True,
86
  )
87
  text, ocr_bbox = ocr_bbox_rslt
88
+
89
+ # Process image and get result
90
+ try:
91
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
92
+ image_save_path,
93
+ yolo_model,
94
+ BOX_TRESHOLD=box_threshold,
95
+ output_coord_in_ratio=True,
96
+ ocr_bbox=ocr_bbox,
97
+ draw_bbox_config=draw_bbox_config,
98
+ caption_model_processor=caption_model_processor,
99
+ ocr_text=text,
100
+ iou_threshold=iou_threshold,
101
+ )
102
+ except Exception as e:
103
+ logger.error(f"Error during labeling and captioning: {e}")
104
+ raise
105
+
106
+ logger.info("Finished processing image with YOLO and captioning.")
107
+
108
+ # Convert the image to base64 string
109
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
 
110
  parsed_content_list_str = "\n".join(parsed_content_list)
111
 
 
112
  buffered = io.BytesIO()
113
  image.save(buffered, format="PNG")
114
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
115
 
116
  return ProcessResponse(
117
  image=img_str,
118
+ parsed_content_list=parsed_content_list_str,
119
  label_coordinates=str(label_coordinates),
120
  )
121
 
 
122
  @app.post("/process_image", response_model=ProcessResponse)
123
  async def process_image(
124
  image_file: UploadFile = File(...),
 
128
  try:
129
  contents = await image_file.read()
130
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
131
+
132
+ logger.info(f"Processing image: {image_file.filename}")
133
+ logger.info(f"Image size: {image_input.size}")
134
+
135
+ # Debugging the input image
136
+ if not image_input:
137
+ raise ValueError("Image input is empty or invalid.")
138
+
139
+ response = process(image_input, box_threshold, iou_threshold)
140
+
141
+ # Ensure the response contains an image
142
+ if not response.image:
143
+ raise ValueError("Empty image in response")
144
+
145
+ logger.info("Processing complete, returning response.")
146
+ return response
147
+
148
  except Exception as e:
149
+ logger.error(f"Error processing image: {e}")
150
+ import traceback
151
+ traceback.print_exc()
152
+ raise HTTPException(status_code=500, detail=str(e))
153