banao-tech commited on
Commit
0e6d684
·
verified ·
1 Parent(s): fec0378

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +331 -242
utils.py CHANGED
@@ -1,49 +1,58 @@
1
- # from ultralytics import YOLO
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import io
4
  import base64
5
  import time
6
- from PIL import Image, ImageDraw, ImageFont
7
- import json
8
- import requests
9
- # utility function
10
- import os
11
-
12
-
13
  import json
14
  import sys
15
- import os
16
- import cv2
 
 
17
  import numpy as np
18
- # %matplotlib inline
 
19
  from matplotlib import pyplot as plt
 
20
  import easyocr
21
  from paddleocr import PaddleOCR
 
 
 
 
 
 
 
 
 
22
  reader = easyocr.Reader(['en'])
23
  paddle_ocr = PaddleOCR(
24
- lang='en', # other lang also available
25
  use_angle_cls=False,
26
- use_gpu=False, # using cuda will conflict with pytorch in the same process
27
  show_log=False,
28
  max_batch_size=1024,
29
  use_dilation=True, # improves accuracy
30
  det_db_score_mode='slow', # improves accuracy
31
- rec_batch_num=1024)
32
- import time
33
- import base64
34
-
35
- import os
36
- import ast
37
- import torch
38
- from typing import Tuple, List
39
- from torchvision.ops import box_convert
40
- import re
41
- from torchvision.transforms import ToPILImage
42
- import supervision as sv
43
- import torchvision.transforms as T
44
 
45
 
46
  def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
 
 
 
 
47
  if not device:
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
  if model_name == "blip2":
@@ -51,45 +60,53 @@ def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2
51
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
52
  if device == 'cpu':
53
  model = Blip2ForConditionalGeneration.from_pretrained(
54
- model_name_or_path, device_map=None, torch_dtype=torch.float32
55
- )
56
  else:
57
  model = Blip2ForConditionalGeneration.from_pretrained(
58
- model_name_or_path, device_map=None, torch_dtype=torch.float16
59
- ).to(device)
60
  elif model_name == "florence2":
61
- from transformers import AutoProcessor, AutoModelForCausalLM
62
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
63
  if device == 'cpu':
64
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
 
 
65
  else:
66
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
 
 
67
  return {'model': model.to(device), 'processor': processor}
68
 
69
 
70
  def get_yolo_model(model_path):
 
 
 
71
  from ultralytics import YOLO
72
- # Load the model.
73
  model = YOLO(model_path)
74
  return model
75
 
76
 
77
  @torch.inference_mode()
78
  def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=32):
79
- # Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
 
 
80
 
81
  to_pil = ToPILImage()
82
  if starting_idx:
83
  non_ocr_boxes = filtered_boxes[starting_idx:]
84
  else:
85
  non_ocr_boxes = filtered_boxes
86
- croped_pil_image = []
87
- for i, coord in enumerate(non_ocr_boxes):
88
- xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
89
- ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
90
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
91
- croped_pil_image.append(to_pil(cropped_image))
92
-
93
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
94
  if not prompt:
95
  if 'florence' in model.config.name_or_path:
@@ -99,17 +116,29 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
99
 
100
  generated_texts = []
101
  device = model.device
102
- for i in range(0, len(croped_pil_image), batch_size):
103
- start = time.time()
104
- batch = croped_pil_image[i:i+batch_size]
105
  if model.device.type == 'cuda':
106
- inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device, dtype=torch.float16)
107
  else:
108
- inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
109
  if 'florence' in model.config.name_or_path:
110
- generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=100,num_beams=3, do_sample=False)
 
 
 
 
 
 
111
  else:
112
- generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
 
 
 
 
 
 
 
113
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
114
  generated_text = [gen.strip() for gen in generated_text]
115
  generated_texts.extend(generated_text)
@@ -118,51 +147,57 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
118
 
119
 
120
 
 
121
  def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
 
 
 
122
  to_pil = ToPILImage()
123
  if ocr_bbox:
124
  non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
125
  else:
126
  non_ocr_boxes = filtered_boxes
127
- croped_pil_image = []
128
- for i, coord in enumerate(non_ocr_boxes):
129
- xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
130
- ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
131
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
132
- croped_pil_image.append(to_pil(cropped_image))
133
 
134
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
135
  device = model.device
136
- messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
137
  prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
138
 
139
  batch_size = 5 # Number of samples per batch
140
  generated_texts = []
141
 
142
- for i in range(0, len(croped_pil_image), batch_size):
143
- images = croped_pil_image[i:i+batch_size]
144
  image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
145
- inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
146
  texts = [prompt] * len(images)
147
- for i, txt in enumerate(texts):
148
- input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
149
- inputs['input_ids'].append(input['input_ids'])
150
- inputs['attention_mask'].append(input['attention_mask'])
151
- inputs['pixel_values'].append(input['pixel_values'])
152
- inputs['image_sizes'].append(input['image_sizes'])
153
- max_len = max([x.shape[1] for x in inputs['input_ids']])
154
- for i, v in enumerate(inputs['input_ids']):
155
- inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
156
- inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
 
 
157
  inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
158
 
159
- generation_args = {
160
- "max_new_tokens": 25,
161
- "temperature": 0.01,
162
- "do_sample": False,
163
- }
164
- generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
165
- # # remove input tokens
166
  generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
167
  response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
168
  response = [res.strip('\n').strip() for res in response]
@@ -170,7 +205,19 @@ def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, captio
170
 
171
  return generated_texts
172
 
 
173
  def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
 
 
 
 
 
 
 
 
 
 
 
174
  assert ocr_bbox is None or isinstance(ocr_bbox, List)
175
 
176
  def box_area(box):
@@ -184,39 +231,30 @@ def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
184
  return max(0, x2 - x1) * max(0, y2 - y1)
185
 
186
  def IoU(box1, box2):
187
- intersection = intersection_area(box1, box2)
188
- union = box_area(box1) + box_area(box2) - intersection + 1e-6
189
- if box_area(box1) > 0 and box_area(box2) > 0:
190
- ratio1 = intersection / box_area(box1)
191
- ratio2 = intersection / box_area(box2)
192
- else:
193
- ratio1, ratio2 = 0, 0
194
- return max(intersection / union, ratio1, ratio2)
195
 
196
  def is_inside(box1, box2):
197
- # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
198
- intersection = intersection_area(box1, box2)
199
- ratio1 = intersection / box_area(box1)
200
- return ratio1 > 0.95
201
 
202
  boxes = boxes.tolist()
203
  filtered_boxes = []
204
  if ocr_bbox:
205
  filtered_boxes.extend(ocr_bbox)
206
- # print('ocr_bbox!!!', ocr_bbox)
207
  for i, box1 in enumerate(boxes):
208
- # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
209
  is_valid_box = True
210
  for j, box2 in enumerate(boxes):
211
- # keep the smaller box
212
  if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
213
  is_valid_box = False
214
  break
215
  if is_valid_box:
216
- # add the following 2 lines to include ocr bbox
217
  if ocr_bbox:
218
- # only add the box if it does not overlap with any ocr bbox
219
- if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
220
  filtered_boxes.append(box1)
221
  else:
222
  filtered_boxes.append(box1)
@@ -224,11 +262,17 @@ def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
224
 
225
 
226
  def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
227
- '''
228
- ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
229
- boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
230
-
231
- '''
 
 
 
 
 
 
232
  assert ocr_bbox is None or isinstance(ocr_bbox, List)
233
 
234
  def box_area(box):
@@ -242,132 +286,130 @@ def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
242
  return max(0, x2 - x1) * max(0, y2 - y1)
243
 
244
  def IoU(box1, box2):
245
- intersection = intersection_area(box1, box2)
246
- union = box_area(box1) + box_area(box2) - intersection + 1e-6
247
- if box_area(box1) > 0 and box_area(box2) > 0:
248
- ratio1 = intersection / box_area(box1)
249
- ratio2 = intersection / box_area(box2)
250
- else:
251
- ratio1, ratio2 = 0, 0
252
- return max(intersection / union, ratio1, ratio2)
253
 
254
  def is_inside(box1, box2):
255
- # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
256
- intersection = intersection_area(box1, box2)
257
- ratio1 = intersection / box_area(box1)
258
- return ratio1 > 0.80
259
 
260
- # boxes = boxes.tolist()
261
  filtered_boxes = []
262
  if ocr_bbox:
263
  filtered_boxes.extend(ocr_bbox)
264
- # print('ocr_bbox!!!', ocr_bbox)
265
  for i, box1_elem in enumerate(boxes):
266
  box1 = box1_elem['bbox']
267
  is_valid_box = True
268
  for j, box2_elem in enumerate(boxes):
269
- # keep the smaller box
270
  box2 = box2_elem['bbox']
271
  if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
272
  is_valid_box = False
273
  break
274
  if is_valid_box:
275
- # add the following 2 lines to include ocr bbox
276
  if ocr_bbox:
277
- # keep yolo boxes + prioritize ocr label
278
  box_added = False
279
  for box3_elem in ocr_bbox:
280
- if not box_added:
281
- box3 = box3_elem['bbox']
282
- if is_inside(box3, box1): # ocr inside icon
283
- # box_added = True
284
- # delete the box3_elem from ocr_bbox
285
- try:
286
- filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content']})
287
- filtered_boxes.remove(box3_elem)
288
- # print('remove ocr bbox:', box3_elem)
289
- except:
290
- continue
291
- # break
292
- elif is_inside(box1, box3): # icon inside ocr
293
- box_added = True
294
- # try:
295
- # filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
296
- # filtered_boxes.remove(box3_elem)
297
- # except:
298
- # continue
299
- break
300
- else:
301
  continue
 
 
 
302
  if not box_added:
303
- filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
304
-
 
 
 
 
305
  else:
306
  filtered_boxes.append(box1)
307
- return filtered_boxes # torch.tensor(filtered_boxes)
308
 
309
 
310
  def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
311
- transform = T.Compose(
312
- [
313
- T.RandomResize([800], max_size=1333),
314
- T.ToTensor(),
315
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
316
- ]
317
- )
 
 
 
 
 
318
  image_source = Image.open(image_path).convert("RGB")
319
  image = np.asarray(image_source)
320
  image_transformed, _ = transform(image_source, None)
321
  return image, image_transformed
322
 
323
 
324
- def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
325
- text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
326
- """
327
- This function annotates an image with bounding boxes and labels.
328
-
329
- Parameters:
330
- image_source (np.ndarray): The source image to be annotated.
331
- boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
332
- logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
333
- phrases (List[str]): A list of labels for each bounding box.
334
- text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
335
-
336
- Returns:
337
- np.ndarray: The annotated image.
338
  """
 
 
 
339
  h, w, _ = image_source.shape
340
  boxes = boxes * torch.Tensor([w, h, w, h])
341
  xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
342
  xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
343
  detections = sv.Detections(xyxy=xyxy)
344
 
345
- labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
346
 
347
- from util.box_annotator import BoxAnnotator
348
- box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
 
349
  annotated_frame = image_source.copy()
350
- annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
351
 
352
  label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
353
  return annotated_frame, label_coordinates
354
 
355
 
 
356
  def predict(model, image, caption, box_threshold, text_threshold):
357
- """ Use huggingface model to replace the original model
358
  """
359
- model, processor = model['model'], model['processor']
360
- device = model.device
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
  inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
363
  with torch.no_grad():
364
- outputs = model(**inputs)
365
 
366
  results = processor.post_process_grounded_object_detection(
367
  outputs,
368
  inputs.input_ids,
369
- box_threshold=box_threshold, # 0.4,
370
- text_threshold=text_threshold, # 0.3,
371
  target_sizes=[image.size[::-1]]
372
  )[0]
373
  boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
@@ -375,78 +417,109 @@ def predict(model, image, caption, box_threshold, text_threshold):
375
 
376
 
377
  def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
378
- """ Use huggingface model to replace the original model
379
  """
380
- # model = model['model']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  if scale_img:
382
- result = model.predict(
383
- source=image_path,
384
- conf=box_threshold,
385
- imgsz=imgsz,
386
- iou=iou_threshold, # default 0.7
387
- )
388
- else:
389
- result = model.predict(
390
- source=image_path,
391
- conf=box_threshold,
392
- iou=iou_threshold, # default 0.7
393
- )
394
- boxes = result[0].boxes.xyxy#.tolist() # in pixel space
395
- conf = result[0].boxes.conf
396
- phrases = [str(i) for i in range(len(boxes))]
397
 
398
- return boxes, conf, phrases
 
 
 
399
 
400
 
401
- def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=None):
402
- """ ocr_bbox: list of xyxy format bbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  """
404
  image_source = Image.open(img_path).convert("RGB")
405
  w, h = image_source.size
406
  if not imgsz:
407
  imgsz = (h, w)
408
- # print('image size:', w, h)
409
- xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
 
 
 
410
  xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
411
- image_source = np.asarray(image_source)
412
  phrases = [str(i) for i in range(len(phrases))]
413
 
414
- # annotate the image with labels
415
- h, w, _ = image_source.shape
416
  if ocr_bbox:
417
  ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
418
- ocr_bbox=ocr_bbox.tolist()
419
  else:
420
  print('no ocr bbox!!!')
421
  ocr_bbox = None
422
- # filtered_boxes = remove_overlap(boxes=xyxy, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox)
423
- # starting_idx = len(ocr_bbox)
424
- # print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
425
 
426
- ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt} for box, txt in zip(ocr_bbox, ocr_text)]
427
- xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist()]
 
 
428
  filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
429
 
430
- # sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
431
  filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
432
- # get the index of the first 'content': None
433
  starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
434
- filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
435
 
436
-
437
- # get parsed icon local semantics
 
 
438
  if use_local_semantics:
439
  caption_model = caption_model_processor['model']
440
- if 'phi3_v' in caption_model.config.model_type:
441
- parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
442
  else:
443
- parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
444
  ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
445
  icon_start = len(ocr_text)
446
  parsed_content_icon_ls = []
447
- # fill the filtered_boxes_elem None content with parsed_content_icon in order
448
- for i, box in enumerate(filtered_boxes_elem):
449
- if box['content'] is None:
450
  box['content'] = parsed_content_icon.pop(0)
451
  for i, txt in enumerate(parsed_content_icon):
452
  parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
@@ -455,51 +528,72 @@ def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_
455
  ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
456
  parsed_content_merged = ocr_text
457
 
458
- filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
459
-
460
- phrases = [i for i in range(len(filtered_boxes))]
461
 
462
- # draw boxes
463
  if draw_bbox_config:
464
- annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
 
 
465
  else:
466
- annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
 
 
 
467
 
468
  pil_img = Image.fromarray(annotated_frame)
469
  buffered = io.BytesIO()
470
  pil_img.save(buffered, format="PNG")
471
  encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
 
472
  if output_coord_in_ratio:
473
- # h, w, _ = image_source.shape
474
- label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
475
  assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
476
 
477
  return encoded_image, label_coordinates, filtered_boxes_elem
478
 
479
 
480
  def get_xywh(input):
481
- x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
482
- x, y, w, h = int(x), int(y), int(w), int(h)
483
- return x, y, w, h
 
 
 
 
 
484
 
485
  def get_xyxy(input):
486
- x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
487
- x, y, xp, yp = int(x), int(y), int(xp), int(yp)
488
- return x, y, xp, yp
 
 
 
 
489
 
490
  def get_xywh_yolo(input):
491
- x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
492
- x, y, w, h = int(x), int(y), int(w), int(h)
493
- return x, y, w, h
494
-
 
 
 
495
 
496
 
497
- def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
 
 
 
 
 
 
 
 
498
  if use_paddleocr:
499
- if easyocr_args is None:
500
- text_threshold = 0.5
501
- else:
502
- text_threshold = easyocr_args['text_threshold']
503
  result = paddle_ocr.ocr(image_path, cls=False)[0]
504
  conf = [item[1] for item in result]
505
  coord = [item[0] for item in result if item[1][1] > text_threshold]
@@ -508,26 +602,21 @@ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_
508
  if easyocr_args is None:
509
  easyocr_args = {}
510
  result = reader.readtext(image_path, **easyocr_args)
511
- # print('goal filtering pred:', result[-5:])
512
  coord = [item[0] for item in result]
513
  text = [item[1] for item in result]
514
- # read the image using cv2
515
  if display_img:
516
  opencv_img = cv2.imread(image_path)
517
  opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
518
  bb = []
519
  for item in coord:
520
  x, y, a, b = get_xywh(item)
521
- # print(x, y, a, b)
522
  bb.append((x, y, a, b))
523
- cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
524
-
525
- # Display the image
526
  plt.imshow(opencv_img)
527
  else:
528
  if output_bb_format == 'xywh':
529
  bb = [get_xywh(item) for item in coord]
530
  elif output_bb_format == 'xyxy':
531
  bb = [get_xyxy(item) for item in coord]
532
- # print('bounding box!!!', bb)
533
- return (text, bb), goal_filtering
 
1
+ """
2
+ utils.py
3
+
4
+ This module contains utility functions for:
5
+ - Loading and processing images
6
+ - Object detection with YOLO
7
+ - OCR with EasyOCR / PaddleOCR
8
+ - Image annotation and bounding box manipulation
9
+ - Captioning / semantic parsing of detected icons
10
+ """
11
+
12
  import os
13
  import io
14
  import base64
15
  import time
 
 
 
 
 
 
 
16
  import json
17
  import sys
18
+ import re
19
+ from typing import Tuple, List
20
+
21
+ import torch
22
  import numpy as np
23
+ import cv2
24
+ from PIL import Image, ImageDraw, ImageFont
25
  from matplotlib import pyplot as plt
26
+
27
  import easyocr
28
  from paddleocr import PaddleOCR
29
+ import supervision as sv
30
+ import torchvision.transforms as T
31
+ from torchvision.transforms import ToPILImage
32
+ from torchvision.ops import box_convert
33
+
34
+ # Optional: import AzureOpenAI if used
35
+ from openai import AzureOpenAI
36
+
37
+ # Initialize OCR readers
38
  reader = easyocr.Reader(['en'])
39
  paddle_ocr = PaddleOCR(
40
+ lang='en', # other languages available
41
  use_angle_cls=False,
42
+ use_gpu=False, # using cuda might conflict with PyTorch in the same process
43
  show_log=False,
44
  max_batch_size=1024,
45
  use_dilation=True, # improves accuracy
46
  det_db_score_mode='slow', # improves accuracy
47
+ rec_batch_num=1024
48
+ )
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
52
+ """
53
+ Loads the captioning model and processor.
54
+ Supports either BLIP2 or Florence-2 models.
55
+ """
56
  if not device:
57
  device = "cuda" if torch.cuda.is_available() else "cpu"
58
  if model_name == "blip2":
 
60
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
61
  if device == 'cpu':
62
  model = Blip2ForConditionalGeneration.from_pretrained(
63
+ model_name_or_path, device_map=None, torch_dtype=torch.float32
64
+ )
65
  else:
66
  model = Blip2ForConditionalGeneration.from_pretrained(
67
+ model_name_or_path, device_map=None, torch_dtype=torch.float16
68
+ ).to(device)
69
  elif model_name == "florence2":
70
+ from transformers import AutoProcessor, AutoModelForCausalLM
71
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
72
  if device == 'cpu':
73
+ model = AutoModelForCausalLM.from_pretrained(
74
+ model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True
75
+ )
76
  else:
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True
79
+ ).to(device)
80
  return {'model': model.to(device), 'processor': processor}
81
 
82
 
83
  def get_yolo_model(model_path):
84
+ """
85
+ Loads a YOLO model from a given model_path using ultralytics.
86
+ """
87
  from ultralytics import YOLO
 
88
  model = YOLO(model_path)
89
  return model
90
 
91
 
92
  @torch.inference_mode()
93
  def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=32):
94
+ # Ensure batch_size is an integer
95
+ if batch_size is None:
96
+ batch_size = 32
97
 
98
  to_pil = ToPILImage()
99
  if starting_idx:
100
  non_ocr_boxes = filtered_boxes[starting_idx:]
101
  else:
102
  non_ocr_boxes = filtered_boxes
103
+ cropped_pil_images = []
104
+ for coord in non_ocr_boxes:
105
+ xmin, xmax = int(coord[0] * image_source.shape[1]), int(coord[2] * image_source.shape[1])
106
+ ymin, ymax = int(coord[1] * image_source.shape[0]), int(coord[3] * image_source.shape[0])
107
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
108
+ cropped_pil_images.append(to_pil(cropped_image))
109
+
110
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
111
  if not prompt:
112
  if 'florence' in model.config.name_or_path:
 
116
 
117
  generated_texts = []
118
  device = model.device
119
+ for i in range(0, len(cropped_pil_images), batch_size):
120
+ batch = cropped_pil_images[i:i + batch_size]
 
121
  if model.device.type == 'cuda':
122
+ inputs = processor(images=batch, text=[prompt] * len(batch), return_tensors="pt").to(device=device, dtype=torch.float16)
123
  else:
124
+ inputs = processor(images=batch, text=[prompt] * len(batch), return_tensors="pt").to(device=device)
125
  if 'florence' in model.config.name_or_path:
126
+ generated_ids = model.generate(
127
+ input_ids=inputs["input_ids"],
128
+ pixel_values=inputs["pixel_values"],
129
+ max_new_tokens=100,
130
+ num_beams=3,
131
+ do_sample=False
132
+ )
133
  else:
134
+ generated_ids = model.generate(
135
+ **inputs,
136
+ max_length=100,
137
+ num_beams=5,
138
+ no_repeat_ngram_size=2,
139
+ early_stopping=True,
140
+ num_return_sequences=1
141
+ )
142
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
143
  generated_text = [gen.strip() for gen in generated_text]
144
  generated_texts.extend(generated_text)
 
147
 
148
 
149
 
150
+
151
  def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
152
+ """
153
+ Generates parsed textual content for detected icons using the phi3_v model variant.
154
+ """
155
  to_pil = ToPILImage()
156
  if ocr_bbox:
157
  non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
158
  else:
159
  non_ocr_boxes = filtered_boxes
160
+ cropped_pil_images = []
161
+ for coord in non_ocr_boxes:
162
+ xmin, xmax = int(coord[0] * image_source.shape[1]), int(coord[2] * image_source.shape[1])
163
+ ymin, ymax = int(coord[1] * image_source.shape[0]), int(coord[3] * image_source.shape[0])
164
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
165
+ cropped_pil_images.append(to_pil(cropped_image))
166
 
167
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
168
  device = model.device
169
+ messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
170
  prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
171
 
172
  batch_size = 5 # Number of samples per batch
173
  generated_texts = []
174
 
175
+ for i in range(0, len(cropped_pil_images), batch_size):
176
+ images = cropped_pil_images[i:i+batch_size]
177
  image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
178
+ inputs = {'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
179
  texts = [prompt] * len(images)
180
+ for idx, txt in enumerate(texts):
181
+ inp = processor._convert_images_texts_to_inputs(image_inputs[idx], txt, return_tensors="pt")
182
+ inputs['input_ids'].append(inp['input_ids'])
183
+ inputs['attention_mask'].append(inp['attention_mask'])
184
+ inputs['pixel_values'].append(inp['pixel_values'])
185
+ inputs['image_sizes'].append(inp['image_sizes'])
186
+ max_len = max(x.shape[1] for x in inputs['input_ids'])
187
+ for idx, v in enumerate(inputs['input_ids']):
188
+ pad_tensor = processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long)
189
+ inputs['input_ids'][idx] = torch.cat([pad_tensor, v], dim=1)
190
+ pad_att = torch.zeros(1, max_len - v.shape[1], dtype=torch.long)
191
+ inputs['attention_mask'][idx] = torch.cat([pad_att, inputs['attention_mask'][idx]], dim=1)
192
  inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
193
 
194
+ generation_args = {
195
+ "max_new_tokens": 25,
196
+ "temperature": 0.01,
197
+ "do_sample": False,
198
+ }
199
+ generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
200
+ # Remove input tokens from the generated sequence
201
  generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
202
  response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
203
  response = [res.strip('\n').strip() for res in response]
 
205
 
206
  return generated_texts
207
 
208
+
209
  def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
210
+ """
211
+ Removes overlapping bounding boxes based on IoU and optionally considers OCR boxes.
212
+
213
+ Args:
214
+ boxes: Tensor of bounding boxes (in xyxy format).
215
+ iou_threshold: IoU threshold to determine overlaps.
216
+ ocr_bbox: Optional list of OCR bounding boxes.
217
+
218
+ Returns:
219
+ Filtered boxes as a torch.Tensor.
220
+ """
221
  assert ocr_bbox is None or isinstance(ocr_bbox, List)
222
 
223
  def box_area(box):
 
231
  return max(0, x2 - x1) * max(0, y2 - y1)
232
 
233
  def IoU(box1, box2):
234
+ inter = intersection_area(box1, box2)
235
+ union = box_area(box1) + box_area(box2) - inter + 1e-6
236
+ ratio1 = inter / box_area(box1) if box_area(box1) > 0 else 0
237
+ ratio2 = inter / box_area(box2) if box_area(box2) > 0 else 0
238
+ return max(inter / union, ratio1, ratio2)
 
 
 
239
 
240
  def is_inside(box1, box2):
241
+ inter = intersection_area(box1, box2)
242
+ return (inter / box_area(box1)) > 0.95
 
 
243
 
244
  boxes = boxes.tolist()
245
  filtered_boxes = []
246
  if ocr_bbox:
247
  filtered_boxes.extend(ocr_bbox)
 
248
  for i, box1 in enumerate(boxes):
 
249
  is_valid_box = True
250
  for j, box2 in enumerate(boxes):
 
251
  if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
252
  is_valid_box = False
253
  break
254
  if is_valid_box:
 
255
  if ocr_bbox:
256
+ # Only add the box if it does not overlap with any OCR box
257
+ if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for box3 in ocr_bbox):
258
  filtered_boxes.append(box1)
259
  else:
260
  filtered_boxes.append(box1)
 
262
 
263
 
264
  def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
265
+ """
266
+ Removes overlapping boxes with OCR priority.
267
+
268
+ Args:
269
+ boxes: List of dictionaries, each with keys: 'type', 'bbox', 'interactivity', 'content'.
270
+ iou_threshold: IoU threshold for removal.
271
+ ocr_bbox: List of OCR box dictionaries.
272
+
273
+ Returns:
274
+ A list of filtered box dictionaries.
275
+ """
276
  assert ocr_bbox is None or isinstance(ocr_bbox, List)
277
 
278
  def box_area(box):
 
286
  return max(0, x2 - x1) * max(0, y2 - y1)
287
 
288
  def IoU(box1, box2):
289
+ inter = intersection_area(box1, box2)
290
+ union = box_area(box1) + box_area(box2) - inter + 1e-6
291
+ ratio1 = inter / box_area(box1) if box_area(box1) > 0 else 0
292
+ ratio2 = inter / box_area(box2) if box_area(box2) > 0 else 0
293
+ return max(inter / union, ratio1, ratio2)
 
 
 
294
 
295
  def is_inside(box1, box2):
296
+ inter = intersection_area(box1, box2)
297
+ return (inter / box_area(box1)) > 0.80
 
 
298
 
 
299
  filtered_boxes = []
300
  if ocr_bbox:
301
  filtered_boxes.extend(ocr_bbox)
 
302
  for i, box1_elem in enumerate(boxes):
303
  box1 = box1_elem['bbox']
304
  is_valid_box = True
305
  for j, box2_elem in enumerate(boxes):
 
306
  box2 = box2_elem['bbox']
307
  if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
308
  is_valid_box = False
309
  break
310
  if is_valid_box:
 
311
  if ocr_bbox:
 
312
  box_added = False
313
  for box3_elem in ocr_bbox:
314
+ box3 = box3_elem['bbox']
315
+ if is_inside(box3, box1):
316
+ try:
317
+ filtered_boxes.append({
318
+ 'type': 'text',
319
+ 'bbox': box1_elem['bbox'],
320
+ 'interactivity': True,
321
+ 'content': box3_elem['content']
322
+ })
323
+ filtered_boxes.remove(box3_elem)
324
+ except Exception:
 
 
 
 
 
 
 
 
 
 
325
  continue
326
+ elif is_inside(box1, box3):
327
+ box_added = True
328
+ break
329
  if not box_added:
330
+ filtered_boxes.append({
331
+ 'type': 'icon',
332
+ 'bbox': box1_elem['bbox'],
333
+ 'interactivity': True,
334
+ 'content': None
335
+ })
336
  else:
337
  filtered_boxes.append(box1)
338
+ return filtered_boxes # Optionally, you could return torch.tensor(filtered_boxes) if needed
339
 
340
 
341
  def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
342
+ """
343
+ Loads an image and applies transformations.
344
+
345
+ Returns:
346
+ image: Original image as a NumPy array.
347
+ image_transformed: Transformed tensor.
348
+ """
349
+ transform = T.Compose([
350
+ T.RandomResize([800], max_size=1333),
351
+ T.ToTensor(),
352
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
353
+ ])
354
  image_source = Image.open(image_path).convert("RGB")
355
  image = np.asarray(image_source)
356
  image_transformed, _ = transform(image_source, None)
357
  return image, image_transformed
358
 
359
 
360
+ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str],
361
+ text_scale: float, text_padding=5, text_thickness=2, thickness=3) -> Tuple[np.ndarray, dict]:
362
+ """
363
+ Annotates an image with bounding boxes and labels.
 
 
 
 
 
 
 
 
 
 
364
  """
365
+ # Validate phrases input
366
+ phrases = [str(phrase) if not isinstance(phrase, str) else phrase for phrase in phrases]
367
+
368
  h, w, _ = image_source.shape
369
  boxes = boxes * torch.Tensor([w, h, w, h])
370
  xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
371
  xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
372
  detections = sv.Detections(xyxy=xyxy)
373
 
374
+ labels = [f"{phrase}" for phrase in phrases]
375
 
376
+ from util.box_annotator import BoxAnnotator
377
+ box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,
378
+ text_thickness=text_thickness, thickness=thickness)
379
  annotated_frame = image_source.copy()
380
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w, h))
381
 
382
  label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
383
  return annotated_frame, label_coordinates
384
 
385
 
386
+
387
  def predict(model, image, caption, box_threshold, text_threshold):
 
388
  """
389
+ Uses a Hugging Face model to perform grounded object detection.
390
+
391
+ Args:
392
+ model: Dictionary with 'model' and 'processor'.
393
+ image: Input PIL image.
394
+ caption: Caption text.
395
+ box_threshold: Confidence threshold for boxes.
396
+ text_threshold: Threshold for text detection.
397
+
398
+ Returns:
399
+ boxes, logits, phrases from the detection.
400
+ """
401
+ model_obj, processor = model['model'], model['processor']
402
+ device = model_obj.device
403
 
404
  inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
405
  with torch.no_grad():
406
+ outputs = model_obj(**inputs)
407
 
408
  results = processor.post_process_grounded_object_detection(
409
  outputs,
410
  inputs.input_ids,
411
+ box_threshold=box_threshold,
412
+ text_threshold=text_threshold,
413
  target_sizes=[image.size[::-1]]
414
  )[0]
415
  boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
 
417
 
418
 
419
  def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
 
420
  """
421
+ Uses a YOLO model for object detection.
422
+
423
+ Args:
424
+ model: YOLO model instance.
425
+ image_path: Path to the image.
426
+ box_threshold: Confidence threshold.
427
+ imgsz: Image size for scaling (if scale_img is True).
428
+ scale_img: Boolean flag to scale the image.
429
+ iou_threshold: IoU threshold for non-max suppression.
430
+
431
+ Returns:
432
+ Bounding boxes, confidence scores, and placeholder phrases.
433
+ """
434
+ kwargs = {
435
+ 'conf': box_threshold, # Confidence threshold
436
+ 'iou': iou_threshold, # IoU threshold
437
+ 'verbose': False
438
+ }
439
  if scale_img:
440
+ kwargs['imgsz'] = imgsz
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
+ results = model.predict(image_path, **kwargs)
443
+ boxes = results[0].boxes.xyxy
444
+ conf = results[0].boxes.conf
445
+ return boxes, conf, [str(i) for i in range(len(boxes))]
446
 
447
 
448
+ def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None,
449
+ text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None,
450
+ ocr_text=[], use_local_semantics=True, iou_threshold=0.9, prompt=None, scale_img=False,
451
+ imgsz=None, batch_size=None):
452
+ """
453
+ Processes an image to generate semantic (SOM) labels.
454
+
455
+ Args:
456
+ img_path: Path to the image.
457
+ model: YOLO model for detection.
458
+ BOX_TRESHOLD: Confidence threshold for box prediction.
459
+ output_coord_in_ratio: If True, output coordinates in ratio.
460
+ ocr_bbox: OCR bounding boxes.
461
+ text_scale, text_padding: Parameters for drawing annotations.
462
+ draw_bbox_config: Custom configuration for bounding box drawing.
463
+ caption_model_processor: Dictionary with caption model and processor.
464
+ ocr_text: List of OCR-detected texts.
465
+ use_local_semantics: Whether to use local semantic processing.
466
+ iou_threshold: IoU threshold for filtering overlaps.
467
+ prompt: Optional caption prompt.
468
+ scale_img: Whether to scale the image.
469
+ imgsz: Image size for YOLO.
470
+ batch_size: Batch size for captioning.
471
+
472
+ Returns:
473
+ Encoded annotated image, label coordinates, and filtered boxes.
474
  """
475
  image_source = Image.open(img_path).convert("RGB")
476
  w, h = image_source.size
477
  if not imgsz:
478
  imgsz = (h, w)
479
+ # Run YOLO detection
480
+ xyxy, logits, phrases = predict_yolo(
481
+ model=model, image_path=img_path, box_threshold=BOX_TRESHOLD,
482
+ imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1
483
+ )
484
  xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
485
+ image_source_np = np.asarray(image_source)
486
  phrases = [str(i) for i in range(len(phrases))]
487
 
488
+ # Process OCR bounding boxes (if any)
 
489
  if ocr_bbox:
490
  ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
491
+ ocr_bbox = ocr_bbox.tolist()
492
  else:
493
  print('no ocr bbox!!!')
494
  ocr_bbox = None
 
 
 
495
 
496
+ ocr_bbox_elem = [{'type': 'text', 'bbox': box, 'interactivity': False, 'content': txt}
497
+ for box, txt in zip(ocr_bbox, ocr_text)]
498
+ xyxy_elem = [{'type': 'icon', 'bbox': box, 'interactivity': True, 'content': None}
499
+ for box in xyxy.tolist()]
500
  filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
501
 
502
+ # Sort filtered boxes so that boxes with 'content' == None are at the end
503
  filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
 
504
  starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
505
+ filtered_boxes_tensor = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
506
 
507
+ if batch_size is None:
508
+ batch_size = 32
509
+
510
+ # Generate parsed icon semantics if required
511
  if use_local_semantics:
512
  caption_model = caption_model_processor['model']
513
+ if 'phi3_v' in caption_model.config.model_type:
514
+ parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes_tensor, ocr_bbox, image_source_np, caption_model_processor)
515
  else:
516
+ parsed_content_icon = get_parsed_content_icon(filtered_boxes_tensor, starting_idx, image_source_np, caption_model_processor, prompt=prompt, batch_size=batch_size)
517
  ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
518
  icon_start = len(ocr_text)
519
  parsed_content_icon_ls = []
520
+ # Fill boxes with no OCR content with parsed icon content
521
+ for box in filtered_boxes_elem:
522
+ if box['content'] is None and parsed_content_icon:
523
  box['content'] = parsed_content_icon.pop(0)
524
  for i, txt in enumerate(parsed_content_icon):
525
  parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
 
528
  ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
529
  parsed_content_merged = ocr_text
530
 
531
+ filtered_boxes_cxcywh = box_convert(boxes=filtered_boxes_tensor, in_fmt="xyxy", out_fmt="cxcywh")
532
+ phrases = [i for i in range(len(filtered_boxes_cxcywh))]
 
533
 
534
+ # Annotate image with bounding boxes and labels
535
  if draw_bbox_config:
536
+ annotated_frame, label_coordinates = annotate(
537
+ image_source=image_source_np, boxes=filtered_boxes_cxcywh, logits=logits, phrases=phrases, **draw_bbox_config
538
+ )
539
  else:
540
+ annotated_frame, label_coordinates = annotate(
541
+ image_source=image_source_np, boxes=filtered_boxes_cxcywh, logits=logits, phrases=phrases,
542
+ text_scale=text_scale, text_padding=text_padding
543
+ )
544
 
545
  pil_img = Image.fromarray(annotated_frame)
546
  buffered = io.BytesIO()
547
  pil_img.save(buffered, format="PNG")
548
  encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
549
+
550
  if output_coord_in_ratio:
551
+ label_coordinates = {k: [v[0] / w, v[1] / h, v[2] / w, v[3] / h] for k, v in label_coordinates.items()}
 
552
  assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
553
 
554
  return encoded_image, label_coordinates, filtered_boxes_elem
555
 
556
 
557
  def get_xywh(input):
558
+ """
559
+ Converts a bounding box from a list of two points into (x, y, width, height).
560
+ """
561
+ x, y = input[0][0], input[0][1]
562
+ w = input[2][0] - input[0][0]
563
+ h = input[2][1] - input[0][1]
564
+ return int(x), int(y), int(w), int(h)
565
+
566
 
567
  def get_xyxy(input):
568
+ """
569
+ Converts a bounding box from a list of two points into (x, y, x2, y2).
570
+ """
571
+ x, y = input[0][0], input[0][1]
572
+ x2, y2 = input[2][0], input[2][1]
573
+ return int(x), int(y), int(x2), int(y2)
574
+
575
 
576
  def get_xywh_yolo(input):
577
+ """
578
+ Converts a YOLO-style bounding box (x1, y1, x2, y2) into (x, y, width, height).
579
+ """
580
+ x, y = input[0], input[1]
581
+ w = input[2] - input[0]
582
+ h = input[3] - input[1]
583
+ return int(x), int(y), int(w), int(h)
584
 
585
 
586
+ def check_ocr_box(image_path, display_img=True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
587
+ """
588
+ Runs OCR on the given image using PaddleOCR or EasyOCR and optionally displays annotated results.
589
+
590
+ Returns:
591
+ A tuple containing:
592
+ - A tuple (text, bounding boxes)
593
+ - The goal_filtering parameter (unchanged)
594
+ """
595
  if use_paddleocr:
596
+ text_threshold = 0.5 if easyocr_args is None else easyocr_args.get('text_threshold', 0.5)
 
 
 
597
  result = paddle_ocr.ocr(image_path, cls=False)[0]
598
  conf = [item[1] for item in result]
599
  coord = [item[0] for item in result if item[1][1] > text_threshold]
 
602
  if easyocr_args is None:
603
  easyocr_args = {}
604
  result = reader.readtext(image_path, **easyocr_args)
 
605
  coord = [item[0] for item in result]
606
  text = [item[1] for item in result]
607
+
608
  if display_img:
609
  opencv_img = cv2.imread(image_path)
610
  opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
611
  bb = []
612
  for item in coord:
613
  x, y, a, b = get_xywh(item)
 
614
  bb.append((x, y, a, b))
615
+ cv2.rectangle(opencv_img, (x, y), (x + a, y + b), (0, 255, 0), 2)
 
 
616
  plt.imshow(opencv_img)
617
  else:
618
  if output_bb_format == 'xywh':
619
  bb = [get_xywh(item) for item in coord]
620
  elif output_bb_format == 'xyxy':
621
  bb = [get_xyxy(item) for item in coord]
622
+ return (text, bb), goal_filtering