Chanlefe commited on
Commit
47d0c84
Β·
verified Β·
1 Parent(s): f70931b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -161
app.py CHANGED
@@ -1,159 +1,53 @@
1
  import torch
2
  import os
3
- import glob
4
  from PIL import Image
5
- from transformers import AutoProcessor, AutoModelForImageClassification
6
  import gradio as gr
7
  import pytesseract
8
 
9
- def find_model_files():
10
- """Find model files in the current directory structure"""
11
- print("=== Searching for model files ===")
12
-
13
- # Look for key model files
14
- config_files = glob.glob("**/config.json", recursive=True)
15
- model_files = glob.glob("**/pytorch_model.bin", recursive=True) + glob.glob("**/model.safetensors", recursive=True)
16
- preprocessor_files = glob.glob("**/preprocessor_config.json", recursive=True)
17
-
18
- print(f"Found config.json files: {config_files}")
19
- print(f"Found model weight files: {model_files}")
20
- print(f"Found preprocessor_config.json files: {preprocessor_files}")
21
-
22
- # Find the directory that contains all necessary files
23
- for config_file in config_files:
24
- model_dir = os.path.dirname(config_file)
25
- if not model_dir: # If config.json is in root
26
- model_dir = "."
27
-
28
- # Check if this directory has all required files
29
- has_model = any(os.path.dirname(f) == model_dir or (not os.path.dirname(f) and model_dir == ".") for f in model_files)
30
- has_preprocessor = any(os.path.dirname(f) == model_dir or (not os.path.dirname(f) and model_dir == ".") for f in preprocessor_files)
31
-
32
- if has_model and has_preprocessor:
33
- print(f"Found complete model in directory: {model_dir}")
34
- return model_dir
35
- elif has_model:
36
- print(f"Found model with config but missing preprocessor in: {model_dir}")
37
- return model_dir # Try anyway, might work
38
-
39
- print("No complete model directory found")
40
- return None
41
-
42
- # Search for model files
43
- MODEL_PATH = find_model_files()
44
- if MODEL_PATH is None:
45
- MODEL_PATH = "." # Fallback to current directory
46
- print("Falling back to current directory")
47
 
48
  try:
49
- # Load model and processor from detected path
50
- print(f"=== Attempting to load model from: {MODEL_PATH} ===")
51
- print(f"Current working directory: {os.getcwd()}")
52
 
53
- # List all files in the detected model directory
54
- if MODEL_PATH == ".":
55
- print("Files in root directory:")
56
- for item in os.listdir("."):
57
- if os.path.isfile(item):
58
- print(f" File: {item}")
59
- else:
60
- print(f" Directory: {item}/")
61
- try:
62
- sub_files = os.listdir(item)[:5] # Show first 5 files
63
- print(f" Contains: {sub_files}{'...' if len(os.listdir(item)) > 5 else ''}")
64
- except:
65
- pass
66
- else:
67
- print(f"Files in {MODEL_PATH}:")
68
- print(f" {os.listdir(MODEL_PATH)}")
69
-
70
- # Try to load the model
71
  print("Loading model...")
72
  model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True)
73
- print("Model loaded successfully!")
74
 
75
- print("Loading processor...")
 
76
  try:
77
- processor = AutoProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
78
- print("Processor loaded successfully!")
79
- except Exception as proc_error:
80
- print(f"Error loading processor from local files: {proc_error}")
81
- print("Attempting to load just the image processor...")
82
-
83
- # Try to load just the image processor from your model
84
- try:
85
- from transformers import SiglipImageProcessor
86
- processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
87
- print("Image processor loaded successfully from local files!")
88
- except Exception as img_proc_error:
89
- print(f"Error loading local image processor: {img_proc_error}")
90
- print("Loading image processor from base SigLIP model...")
91
-
92
- # Try to load processor from the base SigLIP model
93
- try:
94
- from transformers import SiglipImageProcessor
95
- processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224")
96
- print("Image processor loaded from base SigLIP model!")
97
- except Exception as base_error:
98
- print(f"Error loading base processor: {base_error}")
99
- print("Using CLIP processor as fallback...")
100
-
101
- # As a last resort, try to create a minimal processor
102
- from transformers import CLIPImageProcessor
103
- processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
104
- print("Using CLIP processor as fallback!")
105
 
106
- # Get labels - handle case where id2label might not exist
107
  if hasattr(model.config, 'id2label') and model.config.id2label:
108
  labels = model.config.id2label
 
109
  else:
110
  # Create generic labels if none exist
111
- num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 1000
112
  labels = {i: f"class_{i}" for i in range(num_labels)}
 
113
 
114
- print(f"Model loaded successfully. Number of classes: {len(labels)}")
115
 
116
  except Exception as e:
117
- print(f"=== ERROR loading model from {MODEL_PATH} ===")
118
- print(f"Error: {e}")
119
- print("\n=== Debugging Information ===")
120
- print("All files in Space:")
121
-
122
- def list_all_files(directory=".", prefix=""):
123
- """Recursively list all files"""
124
- try:
125
- items = sorted(os.listdir(directory))
126
- for item in items:
127
- item_path = os.path.join(directory, item)
128
- if os.path.isfile(item_path):
129
- size = os.path.getsize(item_path)
130
- print(f"{prefix}πŸ“„ {item} ({size} bytes)")
131
- elif os.path.isdir(item_path) and not item.startswith('.'):
132
- print(f"{prefix}πŸ“ {item}/")
133
- if len(prefix) < 6: # Limit recursion depth
134
- list_all_files(item_path, prefix + " ")
135
- except PermissionError:
136
- print(f"{prefix}❌ Permission denied")
137
- except Exception as ex:
138
- print(f"{prefix}❌ Error: {ex}")
139
-
140
- list_all_files()
141
-
142
- print("\n=== Required Files for Model ===")
143
- print("βœ… config.json - Model configuration")
144
- print("βœ… pytorch_model.bin OR model.safetensors - Model weights")
145
- print("βœ… preprocessor_config.json - Image processor config")
146
- print("βœ… tokenizer.json (if applicable) - Tokenizer")
147
-
148
- print("\n=== Solutions ===")
149
- print("1. Make sure all model files are uploaded to your Space")
150
- print("2. Check that files aren't corrupted during upload")
151
- print("3. Try uploading to a 'model' subfolder")
152
- print("4. Verify the model was saved correctly during training")
153
-
154
  raise
155
 
156
- # Classify meme and extract text
157
  def classify_meme(image: Image.Image):
158
  """
159
  Classify meme and extract text using OCR
@@ -162,60 +56,61 @@ def classify_meme(image: Image.Image):
162
  # OCR: extract text from image
163
  extracted_text = pytesseract.image_to_string(image)
164
 
165
- # Process image with the model
166
  inputs = processor(images=image, return_tensors="pt")
167
 
168
- # Move inputs to same device as model if needed
169
- if torch.cuda.is_available() and next(model.parameters()).is_cuda:
170
- inputs = {k: v.to('cuda') for k, v in inputs.items()}
171
-
172
  with torch.no_grad():
173
  outputs = model(**inputs)
174
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
175
 
176
- # Get top predictions
177
- top_k = min(10, len(labels)) # Show top 10 or all if fewer
178
- top_probs, top_indices = torch.topk(probs[0], top_k)
179
-
180
  predictions = {}
181
- for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
182
- label = labels.get(idx.item(), f"class_{idx.item()}")
183
- predictions[label] = float(prob)
 
 
 
184
 
185
- # Debug prints (these will show in the console/logs)
186
- print("Extracted Text:", extracted_text.strip())
187
- print("Top Predictions:", predictions)
 
 
 
188
 
189
- return predictions, extracted_text.strip()
190
 
191
  except Exception as e:
192
- print(f"Error in classification: {e}")
193
- return {"Error": 1.0}, f"Error processing image: {str(e)}"
 
194
 
195
- # Gradio interface
196
  demo = gr.Interface(
197
  fn=classify_meme,
198
  inputs=gr.Image(type="pil", label="Upload Meme Image"),
199
  outputs=[
200
  gr.Label(num_top_classes=5, label="Meme Classification"),
201
- gr.Textbox(label="Extracted Text from OCR", lines=3)
202
  ],
203
- title="Meme Classifier with OCR",
204
  description="""
205
  Upload a meme image to:
206
- 1. Classify its content using your trained SigLIP2_77 model
207
- 2. Extract text using OCR (Optical Character Recognition)
208
 
209
- Note: Make sure all model files are properly uploaded to your Space.
210
  """,
211
  examples=None,
212
  allow_flagging="never"
213
  )
214
 
215
  if __name__ == "__main__":
216
- print("Starting Gradio interface...")
217
  demo.launch(
218
- server_name="0.0.0.0", # Allow external connections in HF Spaces
219
- server_port=7860, # Standard port for HF Spaces
220
- share=False # HF Spaces handles sharing
221
- )
 
1
  import torch
2
  import os
 
3
  from PIL import Image
4
+ from transformers import AutoModelForImageClassification, SiglipImageProcessor
5
  import gradio as gr
6
  import pytesseract
7
 
8
+ # Model path
9
+ MODEL_PATH = "./model"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  try:
12
+ print(f"=== Loading model from: {MODEL_PATH} ===")
13
+ print(f"Available files: {os.listdir(MODEL_PATH)}")
 
14
 
15
+ # Load the model (this should work with your files)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  print("Loading model...")
17
  model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True)
18
+ print("βœ… Model loaded successfully!")
19
 
20
+ # Load just the image processor (not the full AutoProcessor)
21
+ print("Loading image processor...")
22
  try:
23
+ # Try to load the image processor from your local files
24
+ processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
25
+ print("βœ… Image processor loaded from local files!")
26
+ except Exception as e:
27
+ print(f"⚠️ Could not load local processor: {e}")
28
+ print("Loading image processor from base SigLIP model...")
29
+ # Fallback: load processor from base model online
30
+ processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224")
31
+ print("βœ… Image processor loaded from base model!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Get labels from your model config
34
  if hasattr(model.config, 'id2label') and model.config.id2label:
35
  labels = model.config.id2label
36
+ print(f"βœ… Found {len(labels)} labels in model config")
37
  else:
38
  # Create generic labels if none exist
39
+ num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2
40
  labels = {i: f"class_{i}" for i in range(num_labels)}
41
+ print(f"βœ… Created {len(labels)} generic labels")
42
 
43
+ print("πŸŽ‰ Model setup complete!")
44
 
45
  except Exception as e:
46
+ print(f"❌ Error loading model: {e}")
47
+ print("\n=== Debug Information ===")
48
+ print(f"Files in model directory: {os.listdir(MODEL_PATH)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  raise
50
 
 
51
  def classify_meme(image: Image.Image):
52
  """
53
  Classify meme and extract text using OCR
 
56
  # OCR: extract text from image
57
  extracted_text = pytesseract.image_to_string(image)
58
 
59
+ # Process image for the model
60
  inputs = processor(images=image, return_tensors="pt")
61
 
62
+ # Run inference
 
 
 
63
  with torch.no_grad():
64
  outputs = model(**inputs)
65
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
66
 
67
+ # Get predictions
 
 
 
68
  predictions = {}
69
+ for i in range(len(labels)):
70
+ label = labels.get(i, f"class_{i}")
71
+ predictions[label] = float(probs[0][i])
72
+
73
+ # Sort predictions by confidence
74
+ sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True))
75
 
76
+ # Debug prints
77
+ print("=== Classification Results ===")
78
+ print(f"Extracted Text: '{extracted_text.strip()}'")
79
+ print("Top 3 Predictions:")
80
+ for i, (label, prob) in enumerate(list(sorted_predictions.items())[:3]):
81
+ print(f" {i+1}. {label}: {prob:.4f}")
82
 
83
+ return sorted_predictions, extracted_text.strip()
84
 
85
  except Exception as e:
86
+ error_msg = f"Error processing image: {str(e)}"
87
+ print(f"❌ {error_msg}")
88
+ return {"Error": 1.0}, error_msg
89
 
90
+ # Create Gradio interface
91
  demo = gr.Interface(
92
  fn=classify_meme,
93
  inputs=gr.Image(type="pil", label="Upload Meme Image"),
94
  outputs=[
95
  gr.Label(num_top_classes=5, label="Meme Classification"),
96
+ gr.Textbox(label="Extracted Text", lines=3)
97
  ],
98
+ title="🎭 Meme Classifier with OCR",
99
  description="""
100
  Upload a meme image to:
101
+ 1. **Classify** its content using your trained SigLIP2_77 model
102
+ 2. **Extract text** using OCR (Optical Character Recognition)
103
 
104
+ Your model was trained on meme data and will predict the category/sentiment of the uploaded meme.
105
  """,
106
  examples=None,
107
  allow_flagging="never"
108
  )
109
 
110
  if __name__ == "__main__":
111
+ print("πŸš€ Starting Gradio interface...")
112
  demo.launch(
113
+ server_name="0.0.0.0",
114
+ server_port=7860,
115
+ share=False
116
+ )