codic commited on
Commit
24a8aeb
·
verified ·
1 Parent(s): f78f660

First update

Browse files
Files changed (1) hide show
  1. app.py +85 -38
app.py CHANGED
@@ -4,54 +4,103 @@ from PIL import Image
4
  import gradio as gr
5
  import numpy as np
6
  import cv2
 
 
 
 
 
 
 
7
 
8
- # 获取随机的颜色
9
  def get_random_color():
10
  c = tuple(np.random.randint(0, 256, 3).tolist())
11
  return c
12
 
13
- # 绘制ocr识别结果
14
  def draw_ocr_bbox(image, boxes, colors):
15
- print(colors)
16
- box_num = len(boxes)
17
- for i in range(box_num):
18
- box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
19
- image = cv2.polylines(np.array(image), [box], True, colors[i], 2)
 
 
 
 
 
20
  return image
21
 
22
- # torch.hub.download_url_to_file('https://i.imgur.com/aqMBT0i.jpg', 'example.jpg')
23
-
24
  def inference(img: Image.Image, lang, confidence):
 
25
  ocr = PaddleOCR(use_angle_cls=True, lang=lang, use_gpu=False,
26
  det_model_dir=f'./models/det/{lang}',
27
  cls_model_dir=f'./models/cls/{lang}',
28
  rec_model_dir=f'./models/rec/{lang}')
29
- # img_path = img.name
 
30
  img2np = np.array(img)
31
- result = ocr.ocr(img2np, cls=True)[0]
32
- # rgb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  image = img.convert('RGB')
34
- boxes = [line[0] for line in result]
35
- txts = [line[1][0] for line in result]
36
- scores = [line[1][1] for line in result]
37
 
38
- # 识别结果
39
- final_result = [dict(boxes=box, txt=txt, score=score, _c=get_random_color()) for box, txt, score in zip(boxes, txts, scores)]
40
- # 过滤 score < 0.5 的
41
- final_result = [item for item in final_result if item['score'] > confidence]
42
-
43
- im_show = draw_ocr_bbox(image, [item['boxes'] for item in final_result], [item['_c'] for item in final_result])
44
- im_show = Image.fromarray(im_show)
45
- data = [[json.dumps(item['boxes']), round(item['score'], 3), item['txt']] for item in final_result]
46
- return im_show, data
47
 
48
- title = 'PaddleOCR'
49
- description = 'Gradio demo for PaddleOCR.'
50
 
51
  examples = [
52
- ['example_imgs/example.jpg','en', 0.5],
53
- ['example_imgs/ch.jpg','ch', 0.7],
54
- ['example_imgs/demo003.jpeg','en', 0.7],
55
  ]
56
 
57
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
@@ -59,17 +108,15 @@ css = ".output_image, .input_image {height: 40rem !important; width: 100% !impor
59
  if __name__ == '__main__':
60
  demo = gr.Interface(
61
  inference,
62
- [gr.Image(type='pil', label='Input'),
63
- gr.Dropdown(choices=['ch', 'en', 'fr', 'german', 'korean', 'japan'], value='ch', label='language'),
64
- gr.Slider(0.1, 1, 0.5, step=0.1, label='confidence_threshold')
 
65
  ],
66
- # 输出
67
- [gr.Image(type='pil', label='Output'), gr.Dataframe(headers=[ 'bbox', 'score', 'text'], label='Result')],
68
  title=title,
69
  description=description,
70
  examples=examples,
71
- css=css,
72
- cache_examples=True # 添加缓存选项
73
  )
74
- demo.queue(max_size=10)
75
- demo.launch()
 
4
  import gradio as gr
5
  import numpy as np
6
  import cv2
7
+ from gliner import GLiNER
8
+
9
+ # Initialize GLiNER model
10
+ gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
11
+
12
+ # Entity labels including website
13
+ labels = ["person name", "company name", "job title", "phone", "email", "address", "website"]
14
 
 
15
  def get_random_color():
16
  c = tuple(np.random.randint(0, 256, 3).tolist())
17
  return c
18
 
 
19
  def draw_ocr_bbox(image, boxes, colors):
20
+ valid_boxes = []
21
+ valid_colors = []
22
+ for box, color in zip(boxes, colors):
23
+ if len(box) > 0: # Only draw valid boxes
24
+ valid_boxes.append(box)
25
+ valid_colors.append(color)
26
+
27
+ for box, color in zip(valid_boxes, valid_colors):
28
+ box = np.array(box).reshape(-1, 1, 2).astype(np.int64)
29
+ image = cv2.polylines(np.array(image), [box], True, color, 2)
30
  return image
31
 
 
 
32
  def inference(img: Image.Image, lang, confidence):
33
+ # Initialize PaddleOCR
34
  ocr = PaddleOCR(use_angle_cls=True, lang=lang, use_gpu=False,
35
  det_model_dir=f'./models/det/{lang}',
36
  cls_model_dir=f'./models/cls/{lang}',
37
  rec_model_dir=f'./models/rec/{lang}')
38
+
39
+ # Process image
40
  img2np = np.array(img)
41
+ ocr_result = ocr.ocr(img2np, cls=True)[0]
42
+
43
+ # Original OCR processing
44
+ ocr_items = []
45
+ if ocr_result:
46
+ boxes = [line[0] for line in ocr_result]
47
+ txts = [line[1][0] for line in ocr_result]
48
+ scores = [line[1][1] for line in ocr_result]
49
+
50
+ ocr_items = [
51
+ {'boxes': box, 'txt': txt, 'score': score, '_c': get_random_color()}
52
+ for box, txt, score in zip(boxes, txts, scores)
53
+ if score > confidence
54
+ ]
55
+
56
+ # GLiNER Entity Extraction
57
+ combined_text = " ".join([item['txt'] for item in ocr_items])
58
+ gliner_entities = gliner_model.predict_entities(combined_text, labels, threshold=0.3)
59
+
60
+ # Add GLiNER entities (without boxes)
61
+ gliner_items = [
62
+ {'boxes': [], 'txt': f"{ent['text']} ({ent['label']})", 'score': 1.0, '_c': get_random_color()}
63
+ for ent in gliner_entities
64
+ ]
65
+
66
+ # QR Code Detection
67
+ qr_items = []
68
+ qr_detector = cv2.QRCodeDetector()
69
+ retval, decoded_info, points, _ = qr_detector.detectAndDecodeMulti(img2np)
70
+
71
+ if retval:
72
+ for i, url in enumerate(decoded_info):
73
+ if url:
74
+ qr_box = points[i].reshape(-1, 2).tolist()
75
+ qr_items.append({
76
+ 'boxes': qr_box,
77
+ 'txt': url,
78
+ 'score': 1.0,
79
+ '_c': get_random_color()
80
+ })
81
+
82
+ # Combine all results
83
+ final_result = ocr_items + gliner_items + qr_items
84
+
85
+ # Prepare output
86
  image = img.convert('RGB')
87
+ image_with_boxes = draw_ocr_bbox(image,
88
+ [item['boxes'] for item in final_result],
89
+ [item['_c'] for item in final_result])
90
 
91
+ data = [
92
+ [json.dumps(item['boxes']), round(item['score'], 3), item['txt']]
93
+ for item in final_result
94
+ ]
95
+
96
+ return Image.fromarray(image_with_boxes), data
 
 
 
97
 
98
+ title = 'Enhanced Business Card Scanner'
99
+ description = 'Combines OCR, entity recognition, and QR scanning'
100
 
101
  examples = [
102
+ ['example_imgs/example.jpg', 'en', 0.5],
103
+ ['example_imgs/demo003.jpeg', 'en', 0.7],
 
104
  ]
105
 
106
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
 
108
  if __name__ == '__main__':
109
  demo = gr.Interface(
110
  inference,
111
+ [
112
+ gr.Image(type='pil', label='Input'),
113
+ gr.Dropdown(choices=['en', 'fr', 'german', 'korean', 'japan'], value='en', label='Language'),
114
+ gr.Slider(0.1, 1, 0.5, step=0.1, label='Confidence Threshold')
115
  ],
116
+ [gr.Image(type='pil', label='Output'), gr.Dataframe(headers=['bbox', 'score', 'text'], label='Results')],
 
117
  title=title,
118
  description=description,
119
  examples=examples,
120
+ css=css
 
121
  )
122
+ demo.launch()