codic commited on
Commit
6bba885
·
verified ·
1 Parent(s): 24a8aeb

trying the other update

Browse files
Files changed (1) hide show
  1. app.py +106 -91
app.py CHANGED
@@ -1,106 +1,119 @@
1
  from paddleocr import PaddleOCR
 
2
  import json
3
  from PIL import Image
4
  import gradio as gr
5
  import numpy as np
6
  import cv2
7
- from gliner import GLiNER
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Initialize GLiNER model
10
- gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
11
 
12
- # Entity labels including website
13
- labels = ["person name", "company name", "job title", "phone", "email", "address", "website"]
 
 
 
 
 
14
 
 
15
  def get_random_color():
16
  c = tuple(np.random.randint(0, 256, 3).tolist())
17
  return c
18
 
 
19
  def draw_ocr_bbox(image, boxes, colors):
20
- valid_boxes = []
21
- valid_colors = []
22
- for box, color in zip(boxes, colors):
23
- if len(box) > 0: # Only draw valid boxes
24
- valid_boxes.append(box)
25
- valid_colors.append(color)
26
-
27
- for box, color in zip(valid_boxes, valid_colors):
28
- box = np.array(box).reshape(-1, 1, 2).astype(np.int64)
29
- image = cv2.polylines(np.array(image), [box], True, color, 2)
30
  return image
31
 
32
- def inference(img: Image.Image, lang, confidence):
33
- # Initialize PaddleOCR
34
- ocr = PaddleOCR(use_angle_cls=True, lang=lang, use_gpu=False,
35
- det_model_dir=f'./models/det/{lang}',
36
- cls_model_dir=f'./models/cls/{lang}',
37
- rec_model_dir=f'./models/rec/{lang}')
38
-
39
- # Process image
40
- img2np = np.array(img)
41
- ocr_result = ocr.ocr(img2np, cls=True)[0]
42
-
43
- # Original OCR processing
44
- ocr_items = []
45
- if ocr_result:
46
- boxes = [line[0] for line in ocr_result]
47
- txts = [line[1][0] for line in ocr_result]
48
- scores = [line[1][1] for line in ocr_result]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- ocr_items = [
51
- {'boxes': box, 'txt': txt, 'score': score, '_c': get_random_color()}
52
- for box, txt, score in zip(boxes, txts, scores)
53
- if score > confidence
54
- ]
55
-
56
- # GLiNER Entity Extraction
57
- combined_text = " ".join([item['txt'] for item in ocr_items])
58
- gliner_entities = gliner_model.predict_entities(combined_text, labels, threshold=0.3)
59
-
60
- # Add GLiNER entities (without boxes)
61
- gliner_items = [
62
- {'boxes': [], 'txt': f"{ent['text']} ({ent['label']})", 'score': 1.0, '_c': get_random_color()}
63
- for ent in gliner_entities
64
- ]
65
-
66
- # QR Code Detection
67
- qr_items = []
68
- qr_detector = cv2.QRCodeDetector()
69
- retval, decoded_info, points, _ = qr_detector.detectAndDecodeMulti(img2np)
70
-
71
- if retval:
72
- for i, url in enumerate(decoded_info):
73
- if url:
74
- qr_box = points[i].reshape(-1, 2).tolist()
75
- qr_items.append({
76
- 'boxes': qr_box,
77
- 'txt': url,
78
- 'score': 1.0,
79
- '_c': get_random_color()
80
- })
81
-
82
- # Combine all results
83
- final_result = ocr_items + gliner_items + qr_items
84
-
85
- # Prepare output
86
- image = img.convert('RGB')
87
- image_with_boxes = draw_ocr_bbox(image,
88
- [item['boxes'] for item in final_result],
89
- [item['_c'] for item in final_result])
90
-
91
- data = [
92
- [json.dumps(item['boxes']), round(item['score'], 3), item['txt']]
93
- for item in final_result
94
- ]
95
-
96
- return Image.fromarray(image_with_boxes), data
97
 
98
- title = 'Enhanced Business Card Scanner'
99
- description = 'Combines OCR, entity recognition, and QR scanning'
 
100
 
 
101
  examples = [
102
- ['example_imgs/example.jpg', 'en', 0.5],
103
- ['example_imgs/demo003.jpeg', 'en', 0.7],
104
  ]
105
 
106
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
@@ -108,15 +121,17 @@ css = ".output_image, .input_image {height: 40rem !important; width: 100% !impor
108
  if __name__ == '__main__':
109
  demo = gr.Interface(
110
  inference,
111
- [
112
- gr.Image(type='pil', label='Input'),
113
- gr.Dropdown(choices=['en', 'fr', 'german', 'korean', 'japan'], value='en', label='Language'),
114
- gr.Slider(0.1, 1, 0.5, step=0.1, label='Confidence Threshold')
115
- ],
116
- [gr.Image(type='pil', label='Output'), gr.Dataframe(headers=['bbox', 'score', 'text'], label='Results')],
117
  title=title,
118
  description=description,
119
  examples=examples,
120
- css=css
 
121
  )
122
- demo.launch()
 
 
1
  from paddleocr import PaddleOCR
2
+ from gliner import GLiNER
3
  import json
4
  from PIL import Image
5
  import gradio as gr
6
  import numpy as np
7
  import cv2
8
+ import logging
9
+ import os
10
+ from pathlib import Path
11
+ import tempfile
12
+ import pandas as pd
13
+ import io
14
+ import re
15
+ import traceback
16
+
17
+ # Configure logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
 
21
+ # Set up GLiNER environment variables (adjust if needed)
22
+ os.environ['GLINER_HOME'] = './gliner_models'
23
 
24
+ # Load GLiNER model (do not change the model)
25
+ try:
26
+ logger.info("Loading GLiNER model...")
27
+ gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
28
+ except Exception as e:
29
+ logger.error("Failed to load GLiNER model")
30
+ raise e
31
 
32
+ # Get a random color (used for drawing bounding boxes, if needed)
33
  def get_random_color():
34
  c = tuple(np.random.randint(0, 256, 3).tolist())
35
  return c
36
 
37
+ # Draw OCR bounding boxes (this function is kept for debugging/visualization purposes)
38
  def draw_ocr_bbox(image, boxes, colors):
39
+ for i in range(len(boxes)):
40
+ box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
41
+ image = cv2.polylines(np.array(image), [box], True, colors[i], 2)
 
 
 
 
 
 
 
42
  return image
43
 
44
+ # Scan for a QR code using OpenCV's QRCodeDetector
45
+ def scan_qr_code(image):
46
+ try:
47
+ # Ensure the image is in numpy array format
48
+ image_np = np.array(image) if not isinstance(image, np.ndarray) else image
49
+ qr_detector = cv2.QRCodeDetector()
50
+ data, points, _ = qr_detector.detectAndDecode(image_np)
51
+ if data:
52
+ return data.strip()
53
+ return None
54
+ except Exception as e:
55
+ logger.error("QR code scanning failed: " + str(e))
56
+ return None
57
+
58
+ # Main inference function
59
+ def inference(img: Image.Image, confidence):
60
+ try:
61
+ # Initialize PaddleOCR for English only (removed other languages)
62
+ ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False,
63
+ det_model_dir=f'./models/det/en',
64
+ cls_model_dir=f'./models/cls/en',
65
+ rec_model_dir=f'./models/rec/en')
66
+ img_np = np.array(img)
67
+ result = ocr.ocr(img_np, cls=True)[0]
68
+
69
+ # Concatenate all recognized texts
70
+ ocr_texts = [line[1][0] for line in result]
71
+ ocr_text = " ".join(ocr_texts)
72
+
73
+ # (Optional) Draw bounding boxes on the image if needed for debugging
74
+ image_rgb = img.convert('RGB')
75
+ boxes = [line[0] for line in result]
76
+ colors = [get_random_color() for _ in boxes]
77
+ # Uncomment next two lines if you want to visualize OCR results:
78
+ # im_show = draw_ocr_bbox(image_rgb, boxes, colors)
79
+ # im_show = Image.fromarray(im_show)
80
+
81
+ # Extract entities using GLiNER with updated labels (adding 'website')
82
+ labels = ["person name", "company name", "job title", "phone", "email", "address", "website"]
83
+ entities = gliner_model.predict_entities(ocr_text, labels, threshold=confidence, flat_ner=True)
84
+ results = {label.title(): [] for label in labels}
85
+ for entity in entities:
86
+ lab = entity["label"].title()
87
+ if lab in results:
88
+ results[lab].append(entity["text"])
89
+
90
+ # Scan the original image for a QR code and add it if found
91
+ qr_data = scan_qr_code(img)
92
+ if qr_data:
93
+ results["QR"] = [qr_data]
94
+
95
+ # Generate CSV content in memory using BytesIO
96
+ csv_io = io.BytesIO()
97
+ pd.DataFrame([{k: "; ".join(v) for k, v in results.items()}]).to_csv(csv_io, index=False)
98
+ csv_io.seek(0)
99
+ with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="wb") as tmp_file:
100
+ tmp_file.write(csv_io.getvalue())
101
+ csv_path = tmp_file.name
102
 
103
+ # Return tuple: (OCR text, JSON entities, CSV file path, error message)
104
+ return ocr_text, {k: "; ".join(v) for k, v in results.items()}, csv_path, ""
105
+ except Exception as e:
106
+ logger.error("Processing failed: " + traceback.format_exc())
107
+ return "", {}, None, f"Error: {str(e)}\n{traceback.format_exc()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ # Gradio Interface setup (output structure remains unchanged)
110
+ title = 'Business Card Information Extractor'
111
+ description = 'Extracts text using PaddleOCR and entities using GLiNER (with added website label) along with QR code scanning.'
112
 
113
+ # Examples can be updated accordingly
114
  examples = [
115
+ ['example_imgs/example.jpg', 0.5],
116
+ ['example_imgs/demo003.jpeg', 0.7],
117
  ]
118
 
119
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
 
121
  if __name__ == '__main__':
122
  demo = gr.Interface(
123
  inference,
124
+ [gr.Image(type='pil', label='Upload Business Card'),
125
+ gr.Slider(0.1, 1, 0.5, step=0.1, label='Confidence Threshold')],
126
+ [gr.Textbox(label="Extracted Text"),
127
+ gr.JSON(label="Entities"),
128
+ gr.File(label="Download CSV"),
129
+ gr.Textbox(label="Error Details")],
130
  title=title,
131
  description=description,
132
  examples=examples,
133
+ css=css,
134
+ cache_examples=True
135
  )
136
+ demo.queue(max_size=10)
137
+ demo.launch()