Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,181 +1,215 @@
|
|
|
|
1 |
from paddleocr import PaddleOCR
|
2 |
from gliner import GLiNER
|
3 |
from PIL import Image
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
|
|
6 |
import logging
|
|
|
7 |
import tempfile
|
8 |
import pandas as pd
|
9 |
import re
|
10 |
import traceback
|
11 |
-
import zxingcpp
|
12 |
|
13 |
-
#
|
14 |
-
# Configuration & Constants
|
15 |
-
# --------------------------
|
16 |
logging.basicConfig(level=logging.INFO)
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
#
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
return None
|
46 |
|
47 |
-
def
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
)
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
if '.' in domain:
|
69 |
-
contacts['websites'].add(f"www.{domain.split('/')[0]}")
|
70 |
-
|
71 |
-
return {k: list(v) for k, v in contacts.items() if v}
|
72 |
-
|
73 |
-
def process_entities(entities: list, ocr_text: list) -> dict:
|
74 |
-
"""Process GLiNER entities with validation and fallbacks"""
|
75 |
-
result = {
|
76 |
-
'name': None,
|
77 |
-
'company': None,
|
78 |
-
'title': None,
|
79 |
-
'address': None
|
80 |
-
}
|
81 |
-
|
82 |
-
# Entity extraction
|
83 |
-
for entity in entities:
|
84 |
-
label = entity['label'].lower()
|
85 |
-
text = entity['text'].strip()
|
86 |
-
|
87 |
-
if label == 'person name' and VALIDATION_PATTERNS['name'].match(text):
|
88 |
-
result['name'] = text.title()
|
89 |
-
elif label == 'company name':
|
90 |
-
result['company'] = text
|
91 |
-
elif label == 'job title':
|
92 |
-
result['title'] = text.title()
|
93 |
-
elif label == 'address':
|
94 |
-
result['address'] = text
|
95 |
-
|
96 |
-
# Name fallback from OCR text
|
97 |
-
if not result['name']:
|
98 |
-
for text in ocr_text:
|
99 |
-
if VALIDATION_PATTERNS['name'].match(text):
|
100 |
-
result['name'] = text.title()
|
101 |
-
break
|
102 |
-
|
103 |
-
return result
|
104 |
-
|
105 |
-
# --------------------------
|
106 |
-
# Main Processing Pipeline
|
107 |
-
# --------------------------
|
108 |
-
|
109 |
-
def process_business_card(img: Image.Image, confidence: float) -> tuple:
|
110 |
-
"""Full processing pipeline for business card images"""
|
111 |
try:
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f:
|
151 |
-
pd.DataFrame([
|
152 |
csv_path = f.name
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
return "", {}, None, f"Error: {str(e)}"
|
159 |
-
|
160 |
-
# --------------------------
|
161 |
-
# Gradio Interface
|
162 |
-
# --------------------------
|
163 |
-
|
164 |
-
interface = gr.Interface(
|
165 |
-
fn=process_business_card,
|
166 |
-
inputs=[
|
167 |
-
gr.Image(type='pil', label='Upload Business Card'),
|
168 |
-
gr.Slider(0.1, 1.0, value=0.4, label='Confidence Threshold')
|
169 |
-
],
|
170 |
-
outputs=[
|
171 |
-
gr.Textbox(label='OCR Result'),
|
172 |
-
gr.JSON(label='Structured Data'),
|
173 |
-
gr.File(label='Download CSV'),
|
174 |
-
gr.Textbox(label='Error Log')
|
175 |
-
],
|
176 |
-
title='Enterprise Business Card Parser',
|
177 |
-
description='Multi-country support with comprehensive validation'
|
178 |
-
)
|
179 |
|
|
|
180 |
if __name__ == '__main__':
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
```python
|
2 |
from paddleocr import PaddleOCR
|
3 |
from gliner import GLiNER
|
4 |
from PIL import Image
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
7 |
+
import cv2
|
8 |
import logging
|
9 |
+
import os
|
10 |
import tempfile
|
11 |
import pandas as pd
|
12 |
import re
|
13 |
import traceback
|
14 |
+
import zxingcpp # QR decoding
|
15 |
|
16 |
+
# Configure logging
|
|
|
|
|
17 |
logging.basicConfig(level=logging.INFO)
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
20 |
+
# Environment setup
|
21 |
+
os.environ['GLINER_HOME'] = './gliner_models'
|
22 |
+
|
23 |
+
# Load GLiNER model
|
24 |
+
try:
|
25 |
+
logger.info("Loading GLiNER model...")
|
26 |
+
gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
|
27 |
+
except Exception:
|
28 |
+
logger.exception("Failed to load GLiNER model")
|
29 |
+
raise
|
30 |
+
|
31 |
+
# Regex patterns
|
32 |
+
EMAIL_REGEX = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b")
|
33 |
+
WEBSITE_REGEX = re.compile(r"^(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})(?:/\S*)?$")
|
34 |
+
# Phone number constants and regex
|
35 |
+
SAUDI_CODE = '+966'
|
36 |
+
UAE_CODE = '+971'
|
37 |
+
PHONE_REGEX = re.compile(r'^(?:\+9665\d{8}|\+9715\d{8}|05\d{8}|5\d{8})$')
|
38 |
+
|
39 |
+
# Utility functions
|
40 |
+
def extract_emails(text: str) -> list[str]:
|
41 |
+
return [e.lower() for e in EMAIL_REGEX.findall(text)]
|
42 |
+
|
43 |
+
def extract_websites(text: str) -> list[str]:
|
44 |
+
return [m.lower() for m in WEBSITE_REGEX.findall(text)]
|
45 |
+
|
46 |
+
def clean_phone_number(phone: str) -> str | None:
|
47 |
+
cleaned = re.sub(r"[^\d+]", "", phone)
|
48 |
+
# International formats
|
49 |
+
if cleaned.startswith(SAUDI_CODE + '5') and len(cleaned) == 12:
|
50 |
+
return cleaned
|
51 |
+
if cleaned.startswith(UAE_CODE + '5') and len(cleaned) == 12:
|
52 |
+
return cleaned
|
53 |
+
# Local to international
|
54 |
+
if cleaned.startswith('05') and len(cleaned) == 10:
|
55 |
+
return f"{UAE_CODE}{cleaned[1:]}"
|
56 |
+
if cleaned.startswith('5') and len(cleaned) == 9:
|
57 |
+
return f"{UAE_CODE}{cleaned}"
|
58 |
+
if cleaned.startswith('9665') and len(cleaned) == 12:
|
59 |
+
return f"+{cleaned}"
|
60 |
return None
|
61 |
|
62 |
+
def process_phone_numbers(text: str) -> list[str]:
|
63 |
+
found = []
|
64 |
+
for match in re.finditer(r'(?:\+?\d{8,13}|05\d{8})', text):
|
65 |
+
raw = match.group().strip()
|
66 |
+
if (c := clean_phone_number(raw)):
|
67 |
+
found.append(c)
|
68 |
+
return list(set(found))
|
69 |
+
|
70 |
+
def normalize_website(url: str) -> str | None:
|
71 |
+
u = url.lower().replace('www.', '').split('/')[0]
|
72 |
+
if re.match(r"^[a-z0-9-]+\.[a-z]{2,}$", u):
|
73 |
+
return f"www.{u}"
|
74 |
+
return None
|
75 |
+
|
76 |
+
def extract_address(ocr_texts: list[str]) -> str | None:
|
77 |
+
keywords = ["block","street","ave","area","industrial","road"]
|
78 |
+
parts = [t for t in ocr_texts if any(kw in t.lower() for kw in keywords)]
|
79 |
+
return " ".join(parts) if parts else None
|
80 |
+
|
81 |
+
# QR scanning
|
82 |
+
def scan_qr_code(image: Image.Image) -> str | None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
try:
|
84 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
85 |
+
image.save(tmp, format="PNG")
|
86 |
+
path = tmp.name
|
87 |
+
img_cv = cv2.imread(path)
|
88 |
+
# Direct decode
|
89 |
+
try:
|
90 |
+
res = zxingcpp.read_barcodes(img_cv)
|
91 |
+
if res and res[0].text:
|
92 |
+
return res[0].text.strip()
|
93 |
+
except:
|
94 |
+
logger.warning("Direct ZXing decode failed")
|
95 |
+
# Fallback recolor
|
96 |
+
default_color = (0, 0, 0)
|
97 |
+
tol = 50
|
98 |
+
pix = list(image.convert('RGB').getdata())
|
99 |
+
new_pix = [default_color if all(abs(p[i]-default_color[i])<=tol for i in range(3)) else (255,255,255) for p in pix]
|
100 |
+
img_conv = Image.new('RGB', image.size)
|
101 |
+
img_conv.putdata(new_pix)
|
102 |
+
cv2.imwrite(path + '_conv.png', cv2.cvtColor(np.array(img_conv), cv2.COLOR_RGB2BGR))
|
103 |
+
res = zxingcpp.read_barcodes(cv2.imread(path + '_conv.png'))
|
104 |
+
if res and res[0].text:
|
105 |
+
return res[0].text.strip()
|
106 |
+
except Exception:
|
107 |
+
logger.exception("QR scan error")
|
108 |
+
return None
|
109 |
+
|
110 |
+
# Deduplication
|
111 |
+
def deduplicate_data(results: dict[str, list[str]]) -> None:
|
112 |
+
def clean_list(items, normalizer=lambda x: x):
|
113 |
+
seen = set(); out = []
|
114 |
+
for raw in items:
|
115 |
+
for part in re.split(r'[;,]\s*', raw):
|
116 |
+
p = part.strip()
|
117 |
+
if not p: continue
|
118 |
+
norm = normalizer(p)
|
119 |
+
if norm and norm not in seen:
|
120 |
+
seen.add(norm); out.append(norm)
|
121 |
+
return out
|
122 |
+
# Normalize lists
|
123 |
+
results['Email Address'] = clean_list(results['Email Address'], lambda e: e.lower())
|
124 |
+
results['Website'] = clean_list(results['Website'], normalize_website)
|
125 |
+
results['Phone Number'] = clean_list(results['Phone Number'], clean_phone_number)
|
126 |
+
# Others: simple dedupe
|
127 |
+
for key in ['Person Name','Company Name','Job Title','Address','QR Code']:
|
128 |
+
seen = set(); out = []
|
129 |
+
for v in results.get(key, []):
|
130 |
+
vv = v.strip()
|
131 |
+
if vv and vv not in seen:
|
132 |
+
seen.add(vv); out.append(vv)
|
133 |
+
results[key] = out
|
134 |
+
|
135 |
+
# Inference pipeline
|
136 |
+
def inference(img: Image.Image, confidence: float):
|
137 |
+
try:
|
138 |
+
ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False,
|
139 |
+
det_model_dir='./models/det/en',
|
140 |
+
cls_model_dir='./models/cls/en',
|
141 |
+
rec_model_dir='./models/rec/en')
|
142 |
+
arr = np.array(img)
|
143 |
+
raw = ocr.ocr(arr, cls=True)[0]
|
144 |
+
ocr_texts = [ln[1][0] for ln in raw]
|
145 |
+
full_text = ' '.join(ocr_texts)
|
146 |
+
|
147 |
+
labels = ['person name','company name','job title','phone number','email address','address','website']
|
148 |
+
entities = gliner_model.predict_entities(full_text, labels, threshold=confidence, flat_ner=True)
|
149 |
+
|
150 |
+
results = {k: [] for k in ['Person Name','Company Name','Job Title','Phone Number','Email Address','Address','Website','QR Code']}
|
151 |
+
# Entity processing
|
152 |
+
for ent in entities:
|
153 |
+
txt, lbl = ent['text'].strip(), ent['label'].lower()
|
154 |
+
if lbl == 'person name': results['Person Name'].append(txt)
|
155 |
+
elif lbl == 'company name': results['Company Name'].append(txt)
|
156 |
+
elif lbl == 'job title': results['Job Title'].append(txt.title())
|
157 |
+
elif lbl == 'phone number':
|
158 |
+
if (c:=clean_phone_number(txt)): results['Phone Number'].append(c)
|
159 |
+
elif lbl == 'email address' and EMAIL_REGEX.fullmatch(txt):
|
160 |
+
results['Email Address'].append(txt.lower())
|
161 |
+
elif lbl == 'website' and WEBSITE_REGEX.fullmatch(txt):
|
162 |
+
if (n:=normalize_website(txt)): results['Website'].append(n)
|
163 |
+
elif lbl == 'address': results['Address'].append(txt)
|
164 |
+
# Regex fallbacks
|
165 |
+
results['Email Address'] += extract_emails(full_text)
|
166 |
+
results['Website'] += extract_websites(full_text)
|
167 |
+
# Phone regex fallback
|
168 |
+
results['Phone Number'] += process_phone_numbers(full_text)
|
169 |
+
# QR
|
170 |
+
if qr := scan_qr_code(img): results['QR Code'].append(qr)
|
171 |
+
# Address fallback
|
172 |
+
if not results['Address']:
|
173 |
+
if addr := extract_address(ocr_texts): results['Address'].append(addr)
|
174 |
+
# Dedupe
|
175 |
+
deduplicate_data(results)
|
176 |
+
# Company fallback
|
177 |
+
if not results['Company Name']:
|
178 |
+
if results['Email Address']:
|
179 |
+
dom = results['Email Address'][0].split('@')[-1].split('.')[0]
|
180 |
+
results['Company Name'].append(dom.title())
|
181 |
+
elif results['Website']:
|
182 |
+
dom = results['Website'][0].split('.')[1]
|
183 |
+
results['Company Name'].append(dom.title())
|
184 |
+
# Name fallback
|
185 |
+
if not results['Person Name']:
|
186 |
+
for t in ocr_texts:
|
187 |
+
if re.match(r'^(?:[A-Z][a-z]+\s?){2,}$', t):
|
188 |
+
results['Person Name'].append(t); break
|
189 |
+
# CSV
|
190 |
+
csv_map = {k: '; '.join(v) for k,v in results.items() if v}
|
191 |
with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f:
|
192 |
+
pd.DataFrame([csv_map]).to_csv(f, index=False)
|
193 |
csv_path = f.name
|
194 |
+
return full_text, csv_map, csv_path, ''
|
195 |
+
except Exception:
|
196 |
+
err = traceback.format_exc()
|
197 |
+
logger.error(f"Processing failed: {err}")
|
198 |
+
return '', {}, None, f"Error:\n{err}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
+
# Gradio Interface
|
201 |
if __name__ == '__main__':
|
202 |
+
demo = gr.Interface(
|
203 |
+
inference,
|
204 |
+
[gr.Image(type='pil', label='Upload Business Card'),
|
205 |
+
gr.Slider(0.1, 1, 0.4, step=0.1, label='Confidence Threshold')],
|
206 |
+
[gr.Textbox(label="OCR Result"),
|
207 |
+
gr.JSON(label="Structured Data"),
|
208 |
+
gr.File(label="Download CSV"),
|
209 |
+
gr.Textbox(label="Error Log")],
|
210 |
+
title='Enhanced Business Card Parser',
|
211 |
+
description='Accurate entity extraction with combined AI and regex validation (with Saudi/UAE support)',
|
212 |
+
css=".gr-interface {max-width: 800px !important;}"
|
213 |
+
)
|
214 |
+
demo.launch()
|
215 |
+
```
|