codic commited on
Commit
fd45fbc
·
verified ·
1 Parent(s): c66181c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -160
app.py CHANGED
@@ -1,181 +1,215 @@
 
1
  from paddleocr import PaddleOCR
2
  from gliner import GLiNER
3
  from PIL import Image
4
  import gradio as gr
5
  import numpy as np
 
6
  import logging
 
7
  import tempfile
8
  import pandas as pd
9
  import re
10
  import traceback
11
- import zxingcpp
12
 
13
- # --------------------------
14
- # Configuration & Constants
15
- # --------------------------
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
- COUNTRY_CODES = {
20
- 'SAUDI': {'code': '+966', 'pattern': r'^(\+9665\d{8}|05\d{8})$'},
21
- 'UAE': {'code': '+971', 'pattern': r'^(\+9715\d{8}|05\d{8})$'}
22
- }
23
-
24
- VALIDATION_PATTERNS = {
25
- 'email': re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', re.IGNORECASE),
26
- 'website': re.compile(r'(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})'),
27
- 'name': re.compile(r'^[A-Z][a-z]+(?:\s+[A-Z][a-z]+){1,2}$')
28
- }
29
-
30
- # --------------------------
31
- # Core Processing Functions
32
- # --------------------------
33
-
34
- def process_phone_number(raw_number: str) -> str:
35
- """Validate and standardize phone numbers for supported countries"""
36
- cleaned = re.sub(r'[^\d+]', '', raw_number)
37
-
38
- for country, config in COUNTRY_CODES.items():
39
- if re.match(config['pattern'], cleaned):
40
- if cleaned.startswith('0'):
41
- return f"{config['code']}{cleaned[1:]}"
42
- if cleaned.startswith('5'):
43
- return f"{config['code']}{cleaned}"
44
- return cleaned
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  return None
46
 
47
- def extract_contact_info(text: str) -> dict:
48
- """Extract and validate all contact information from text"""
49
- contacts = {
50
- 'phones': set(),
51
- 'emails': set(),
52
- 'websites': set()
53
- }
54
-
55
- # Phone number extraction
56
- for match in re.finditer(r'(\+?\d{10,13}|05\d{8})', text):
57
- if processed := process_phone_number(match.group()):
58
- contacts['phones'].add(processed)
59
-
60
- # Email validation
61
- contacts['emails'].update(
62
- email.lower() for email in VALIDATION_PATTERNS['email'].findall(text)
63
- )
64
-
65
- # Website normalization
66
- for match in VALIDATION_PATTERNS['website'].finditer(text):
67
- domain = match.group(1).lower()
68
- if '.' in domain:
69
- contacts['websites'].add(f"www.{domain.split('/')[0]}")
70
-
71
- return {k: list(v) for k, v in contacts.items() if v}
72
-
73
- def process_entities(entities: list, ocr_text: list) -> dict:
74
- """Process GLiNER entities with validation and fallbacks"""
75
- result = {
76
- 'name': None,
77
- 'company': None,
78
- 'title': None,
79
- 'address': None
80
- }
81
-
82
- # Entity extraction
83
- for entity in entities:
84
- label = entity['label'].lower()
85
- text = entity['text'].strip()
86
-
87
- if label == 'person name' and VALIDATION_PATTERNS['name'].match(text):
88
- result['name'] = text.title()
89
- elif label == 'company name':
90
- result['company'] = text
91
- elif label == 'job title':
92
- result['title'] = text.title()
93
- elif label == 'address':
94
- result['address'] = text
95
-
96
- # Name fallback from OCR text
97
- if not result['name']:
98
- for text in ocr_text:
99
- if VALIDATION_PATTERNS['name'].match(text):
100
- result['name'] = text.title()
101
- break
102
-
103
- return result
104
-
105
- # --------------------------
106
- # Main Processing Pipeline
107
- # --------------------------
108
-
109
- def process_business_card(img: Image.Image, confidence: float) -> tuple:
110
- """Full processing pipeline for business card images"""
111
  try:
112
- # Initialize OCR
113
- ocr_engine = PaddleOCR(lang='en', use_gpu=False)
114
-
115
- # OCR Processing
116
- ocr_result = ocr_engine.ocr(np.array(img), cls=True)
117
- ocr_text = [line[1][0] for line in ocr_result[0]]
118
- full_text = " ".join(ocr_text)
119
-
120
- # Entity Recognition
121
- labels = ["person name", "company name", "job title",
122
- "phone number", "email address", "address",
123
- "website"]
124
- entities = gliner_model.predict_entities(full_text, labels, threshold=confidence)
125
-
126
- # Data Extraction
127
- contacts = extract_contact_info(full_text)
128
- entity_data = process_entities(entities, ocr_text)
129
- qr_data = zxingcpp.read_barcodes(np.array(img.convert('RGB')))
130
-
131
- # Compile Final Results
132
- results = {
133
- 'Person Name': entity_data['name'],
134
- 'Company Name': entity_data['company'] or (
135
- contacts['emails'][0].split('@')[1].split('.')[0].title()
136
- if contacts['emails'] else None
137
- ),
138
- 'Job Title': entity_data['title'],
139
- 'Phone Numbers': contacts['phones'],
140
- 'Email Addresses': contacts['emails'],
141
- 'Address': entity_data['address'] or next(
142
- (t for t in ocr_text if any(kw in t.lower()
143
- for kw in {'street', 'ave', 'road'})), None
144
- ),
145
- 'Website': contacts['websites'][0] if contacts['websites'] else None,
146
- 'QR Code': qr_data[0].text if qr_data else None
147
- }
148
-
149
- # Generate CSV Output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f:
151
- pd.DataFrame([results]).to_csv(f)
152
  csv_path = f.name
153
-
154
- return full_text, results, csv_path, ""
155
-
156
- except Exception as e:
157
- logger.error(f"Processing Error: {traceback.format_exc()}")
158
- return "", {}, None, f"Error: {str(e)}"
159
-
160
- # --------------------------
161
- # Gradio Interface
162
- # --------------------------
163
-
164
- interface = gr.Interface(
165
- fn=process_business_card,
166
- inputs=[
167
- gr.Image(type='pil', label='Upload Business Card'),
168
- gr.Slider(0.1, 1.0, value=0.4, label='Confidence Threshold')
169
- ],
170
- outputs=[
171
- gr.Textbox(label='OCR Result'),
172
- gr.JSON(label='Structured Data'),
173
- gr.File(label='Download CSV'),
174
- gr.Textbox(label='Error Log')
175
- ],
176
- title='Enterprise Business Card Parser',
177
- description='Multi-country support with comprehensive validation'
178
- )
179
 
 
180
  if __name__ == '__main__':
181
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
  from paddleocr import PaddleOCR
3
  from gliner import GLiNER
4
  from PIL import Image
5
  import gradio as gr
6
  import numpy as np
7
+ import cv2
8
  import logging
9
+ import os
10
  import tempfile
11
  import pandas as pd
12
  import re
13
  import traceback
14
+ import zxingcpp # QR decoding
15
 
16
+ # Configure logging
 
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
+ # Environment setup
21
+ os.environ['GLINER_HOME'] = './gliner_models'
22
+
23
+ # Load GLiNER model
24
+ try:
25
+ logger.info("Loading GLiNER model...")
26
+ gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
27
+ except Exception:
28
+ logger.exception("Failed to load GLiNER model")
29
+ raise
30
+
31
+ # Regex patterns
32
+ EMAIL_REGEX = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b")
33
+ WEBSITE_REGEX = re.compile(r"^(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})(?:/\S*)?$")
34
+ # Phone number constants and regex
35
+ SAUDI_CODE = '+966'
36
+ UAE_CODE = '+971'
37
+ PHONE_REGEX = re.compile(r'^(?:\+9665\d{8}|\+9715\d{8}|05\d{8}|5\d{8})$')
38
+
39
+ # Utility functions
40
+ def extract_emails(text: str) -> list[str]:
41
+ return [e.lower() for e in EMAIL_REGEX.findall(text)]
42
+
43
+ def extract_websites(text: str) -> list[str]:
44
+ return [m.lower() for m in WEBSITE_REGEX.findall(text)]
45
+
46
+ def clean_phone_number(phone: str) -> str | None:
47
+ cleaned = re.sub(r"[^\d+]", "", phone)
48
+ # International formats
49
+ if cleaned.startswith(SAUDI_CODE + '5') and len(cleaned) == 12:
50
+ return cleaned
51
+ if cleaned.startswith(UAE_CODE + '5') and len(cleaned) == 12:
52
+ return cleaned
53
+ # Local to international
54
+ if cleaned.startswith('05') and len(cleaned) == 10:
55
+ return f"{UAE_CODE}{cleaned[1:]}"
56
+ if cleaned.startswith('5') and len(cleaned) == 9:
57
+ return f"{UAE_CODE}{cleaned}"
58
+ if cleaned.startswith('9665') and len(cleaned) == 12:
59
+ return f"+{cleaned}"
60
  return None
61
 
62
+ def process_phone_numbers(text: str) -> list[str]:
63
+ found = []
64
+ for match in re.finditer(r'(?:\+?\d{8,13}|05\d{8})', text):
65
+ raw = match.group().strip()
66
+ if (c := clean_phone_number(raw)):
67
+ found.append(c)
68
+ return list(set(found))
69
+
70
+ def normalize_website(url: str) -> str | None:
71
+ u = url.lower().replace('www.', '').split('/')[0]
72
+ if re.match(r"^[a-z0-9-]+\.[a-z]{2,}$", u):
73
+ return f"www.{u}"
74
+ return None
75
+
76
+ def extract_address(ocr_texts: list[str]) -> str | None:
77
+ keywords = ["block","street","ave","area","industrial","road"]
78
+ parts = [t for t in ocr_texts if any(kw in t.lower() for kw in keywords)]
79
+ return " ".join(parts) if parts else None
80
+
81
+ # QR scanning
82
+ def scan_qr_code(image: Image.Image) -> str | None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  try:
84
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
85
+ image.save(tmp, format="PNG")
86
+ path = tmp.name
87
+ img_cv = cv2.imread(path)
88
+ # Direct decode
89
+ try:
90
+ res = zxingcpp.read_barcodes(img_cv)
91
+ if res and res[0].text:
92
+ return res[0].text.strip()
93
+ except:
94
+ logger.warning("Direct ZXing decode failed")
95
+ # Fallback recolor
96
+ default_color = (0, 0, 0)
97
+ tol = 50
98
+ pix = list(image.convert('RGB').getdata())
99
+ new_pix = [default_color if all(abs(p[i]-default_color[i])<=tol for i in range(3)) else (255,255,255) for p in pix]
100
+ img_conv = Image.new('RGB', image.size)
101
+ img_conv.putdata(new_pix)
102
+ cv2.imwrite(path + '_conv.png', cv2.cvtColor(np.array(img_conv), cv2.COLOR_RGB2BGR))
103
+ res = zxingcpp.read_barcodes(cv2.imread(path + '_conv.png'))
104
+ if res and res[0].text:
105
+ return res[0].text.strip()
106
+ except Exception:
107
+ logger.exception("QR scan error")
108
+ return None
109
+
110
+ # Deduplication
111
+ def deduplicate_data(results: dict[str, list[str]]) -> None:
112
+ def clean_list(items, normalizer=lambda x: x):
113
+ seen = set(); out = []
114
+ for raw in items:
115
+ for part in re.split(r'[;,]\s*', raw):
116
+ p = part.strip()
117
+ if not p: continue
118
+ norm = normalizer(p)
119
+ if norm and norm not in seen:
120
+ seen.add(norm); out.append(norm)
121
+ return out
122
+ # Normalize lists
123
+ results['Email Address'] = clean_list(results['Email Address'], lambda e: e.lower())
124
+ results['Website'] = clean_list(results['Website'], normalize_website)
125
+ results['Phone Number'] = clean_list(results['Phone Number'], clean_phone_number)
126
+ # Others: simple dedupe
127
+ for key in ['Person Name','Company Name','Job Title','Address','QR Code']:
128
+ seen = set(); out = []
129
+ for v in results.get(key, []):
130
+ vv = v.strip()
131
+ if vv and vv not in seen:
132
+ seen.add(vv); out.append(vv)
133
+ results[key] = out
134
+
135
+ # Inference pipeline
136
+ def inference(img: Image.Image, confidence: float):
137
+ try:
138
+ ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False,
139
+ det_model_dir='./models/det/en',
140
+ cls_model_dir='./models/cls/en',
141
+ rec_model_dir='./models/rec/en')
142
+ arr = np.array(img)
143
+ raw = ocr.ocr(arr, cls=True)[0]
144
+ ocr_texts = [ln[1][0] for ln in raw]
145
+ full_text = ' '.join(ocr_texts)
146
+
147
+ labels = ['person name','company name','job title','phone number','email address','address','website']
148
+ entities = gliner_model.predict_entities(full_text, labels, threshold=confidence, flat_ner=True)
149
+
150
+ results = {k: [] for k in ['Person Name','Company Name','Job Title','Phone Number','Email Address','Address','Website','QR Code']}
151
+ # Entity processing
152
+ for ent in entities:
153
+ txt, lbl = ent['text'].strip(), ent['label'].lower()
154
+ if lbl == 'person name': results['Person Name'].append(txt)
155
+ elif lbl == 'company name': results['Company Name'].append(txt)
156
+ elif lbl == 'job title': results['Job Title'].append(txt.title())
157
+ elif lbl == 'phone number':
158
+ if (c:=clean_phone_number(txt)): results['Phone Number'].append(c)
159
+ elif lbl == 'email address' and EMAIL_REGEX.fullmatch(txt):
160
+ results['Email Address'].append(txt.lower())
161
+ elif lbl == 'website' and WEBSITE_REGEX.fullmatch(txt):
162
+ if (n:=normalize_website(txt)): results['Website'].append(n)
163
+ elif lbl == 'address': results['Address'].append(txt)
164
+ # Regex fallbacks
165
+ results['Email Address'] += extract_emails(full_text)
166
+ results['Website'] += extract_websites(full_text)
167
+ # Phone regex fallback
168
+ results['Phone Number'] += process_phone_numbers(full_text)
169
+ # QR
170
+ if qr := scan_qr_code(img): results['QR Code'].append(qr)
171
+ # Address fallback
172
+ if not results['Address']:
173
+ if addr := extract_address(ocr_texts): results['Address'].append(addr)
174
+ # Dedupe
175
+ deduplicate_data(results)
176
+ # Company fallback
177
+ if not results['Company Name']:
178
+ if results['Email Address']:
179
+ dom = results['Email Address'][0].split('@')[-1].split('.')[0]
180
+ results['Company Name'].append(dom.title())
181
+ elif results['Website']:
182
+ dom = results['Website'][0].split('.')[1]
183
+ results['Company Name'].append(dom.title())
184
+ # Name fallback
185
+ if not results['Person Name']:
186
+ for t in ocr_texts:
187
+ if re.match(r'^(?:[A-Z][a-z]+\s?){2,}$', t):
188
+ results['Person Name'].append(t); break
189
+ # CSV
190
+ csv_map = {k: '; '.join(v) for k,v in results.items() if v}
191
  with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f:
192
+ pd.DataFrame([csv_map]).to_csv(f, index=False)
193
  csv_path = f.name
194
+ return full_text, csv_map, csv_path, ''
195
+ except Exception:
196
+ err = traceback.format_exc()
197
+ logger.error(f"Processing failed: {err}")
198
+ return '', {}, None, f"Error:\n{err}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ # Gradio Interface
201
  if __name__ == '__main__':
202
+ demo = gr.Interface(
203
+ inference,
204
+ [gr.Image(type='pil', label='Upload Business Card'),
205
+ gr.Slider(0.1, 1, 0.4, step=0.1, label='Confidence Threshold')],
206
+ [gr.Textbox(label="OCR Result"),
207
+ gr.JSON(label="Structured Data"),
208
+ gr.File(label="Download CSV"),
209
+ gr.Textbox(label="Error Log")],
210
+ title='Enhanced Business Card Parser',
211
+ description='Accurate entity extraction with combined AI and regex validation (with Saudi/UAE support)',
212
+ css=".gr-interface {max-width: 800px !important;}"
213
+ )
214
+ demo.launch()
215
+ ```