leynessa commited on
Commit
2edff24
·
verified ·
1 Parent(s): 66294ea

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +45 -122
streamlit_app.py CHANGED
@@ -52,12 +52,12 @@ inference_transform = A.Compose([
52
  # Enhanced model loading function
53
  @st.cache_resource
54
  def load_model():
55
- """Enhanced model loading with architecture detection and fallback options"""
56
 
57
  # Try different model file names
58
  model_files = [
59
  "butterfly_classifier.pth",
60
- "best_butterfly_model_v3.pth",
61
  "best_butterfly_model.pth"
62
  ]
63
 
@@ -68,7 +68,7 @@ def load_model():
68
  break
69
 
70
  if MODEL_PATH is None:
71
- st.error("No model file found! Please ensure one of these files exists: " + ", ".join(model_files))
72
  return None
73
 
74
  st.info(f"Loading model from: {MODEL_PATH}")
@@ -80,87 +80,42 @@ def load_model():
80
  # Extract model state dict
81
  if 'model_state_dict' in checkpoint:
82
  model_state_dict = checkpoint['model_state_dict']
83
- if 'class_names' in checkpoint:
84
- st.info(f"Model trained on {len(checkpoint['class_names'])} classes")
85
  else:
86
  model_state_dict = checkpoint
87
-
88
  num_classes = len(class_names)
89
 
90
- # Architecture detection based on model state dict
91
- def detect_model_architecture(state_dict):
92
- """Detect model architecture from state dict"""
93
-
94
- # Check for EfficientNet variants by looking at key layer dimensions
95
- architecture_indicators = {
96
- 'conv_head.weight': 'efficientnet',
97
- 'head.weight': 'efficientnet_v2',
98
- 'classifier.weight': 'other'
99
- }
100
-
101
- # Look for specific layer patterns
102
- for key in state_dict.keys():
103
- if 'conv_head.weight' in key:
104
- shape = state_dict[key].shape
105
- if len(shape) >= 2:
106
- feature_dim = shape[1]
107
- # EfficientNet feature dimensions
108
- efficientnet_map = {
109
- 1280: 'efficientnet_b0',
110
- 1408: 'efficientnet_b1',
111
- 1536: 'efficientnet_b2',
112
- 1792: 'efficientnet_b3',
113
- 1920: 'efficientnet_b4',
114
- 2048: 'efficientnet_b5',
115
- 2304: 'efficientnet_b6',
116
- 2560: 'efficientnet_b7'
117
- }
118
- return efficientnet_map.get(feature_dim, 'efficientnet_b3')
119
-
120
- if 'head.weight' in key:
121
- shape = state_dict[key].shape
122
- if len(shape) >= 2:
123
- feature_dim = shape[1]
124
- # EfficientNetV2 feature dimensions
125
- efficientnetv2_map = {
126
- 1280: 'tf_efficientnetv2_s',
127
- 1408: 'tf_efficientnetv2_m',
128
- 1792: 'tf_efficientnetv2_l'
129
- }
130
- return efficientnetv2_map.get(feature_dim, 'tf_efficientnetv2_s')
131
-
132
- # Fallback: check bn2 layer for EfficientNet variants
133
- for key in state_dict.keys():
134
- if key.endswith("bn2.weight"):
135
- bn2_shape = state_dict[key].shape[0]
136
- feature_map = {
137
- 1280: 'efficientnet_b0',
138
- 1408: 'efficientnet_b1',
139
- 1536: 'efficientnet_b2',
140
- 1792: 'efficientnet_b3',
141
- 1920: 'efficientnet_b4',
142
- 2048: 'efficientnet_b5',
143
- 2304: 'efficientnet_b6',
144
- 2560: 'efficientnet_b7'
145
  }
146
- return feature_map.get(bn2_shape, 'efficientnet_b3')
147
-
148
- return 'efficientnet_b3' # Default fallback
149
 
150
- # Detect architecture
151
- detected_arch = detect_model_architecture(model_state_dict)
152
- st.info(f"Detected model architecture: {detected_arch}")
153
 
154
- # List of architectures to try in order
155
- architectures_to_try = [
156
- detected_arch,
157
- 'efficientnet_b3',
158
- 'efficientnet_b2',
159
- 'efficientnet_b0',
160
- 'efficientnet_b1',
161
- 'efficientnet_b4',
162
  'tf_efficientnetv2_s',
163
- 'tf_efficientnetv2_m'
 
 
 
 
 
164
  ]
165
 
166
  # Remove duplicates while preserving order
@@ -175,61 +130,30 @@ def load_model():
175
  try:
176
  st.info(f"Trying architecture: {arch}")
177
 
178
- # Create model with the detected/guessed architecture
179
  model = timm.create_model(
180
  arch,
181
  pretrained=False,
182
  num_classes=num_classes,
183
- drop_rate=0.4,
184
- drop_path_rate=0.3
185
  )
186
 
187
- # Check if the model has a custom head/classifier in the checkpoint
188
- if any('head.' in key for key in model_state_dict.keys()):
189
- # Model has custom head - try to load it
 
 
 
 
 
190
  try:
191
  model.load_state_dict(model_state_dict, strict=False)
192
- st.success(f" Successfully loaded model with architecture: {arch}")
193
  successful_arch = arch
194
  break
195
- except Exception as e:
196
- st.warning(f"Failed to load custom head for {arch}: {str(e)}")
197
- continue
198
-
199
- elif any('classifier.' in key for key in model_state_dict.keys()):
200
- # Model has custom classifier - try to load it
201
- try:
202
- model.load_state_dict(model_state_dict, strict=False)
203
- st.success(f"✅ Successfully loaded model with architecture: {arch}")
204
- successful_arch = arch
205
- break
206
- except Exception as e:
207
- st.warning(f"Failed to load custom classifier for {arch}: {str(e)}")
208
- continue
209
-
210
- else:
211
- # Try to create custom head/classifier and load backbone
212
- try:
213
- # Load backbone weights (ignore head/classifier mismatches)
214
- backbone_dict = {k: v for k, v in model_state_dict.items()
215
- if not (k.startswith('head.') or k.startswith('classifier.'))}
216
-
217
- model.load_state_dict(backbone_dict, strict=False)
218
-
219
- # Create new head/classifier
220
- if hasattr(model, 'classifier'):
221
- in_features = model.classifier.in_features
222
- model.classifier = torch.nn.Linear(in_features, num_classes)
223
- elif hasattr(model, 'head'):
224
- in_features = model.head.in_features
225
- model.head = torch.nn.Linear(in_features, num_classes)
226
-
227
- st.warning(f"⚠️ Loaded {arch} with new head/classifier (backbone weights only)")
228
- successful_arch = arch
229
- break
230
-
231
- except Exception as e:
232
- st.warning(f"Failed to load backbone for {arch}: {str(e)}")
233
  continue
234
 
235
  except Exception as e:
@@ -254,7 +178,6 @@ def load_model():
254
 
255
  except Exception as e:
256
  st.error(f"❌ Error loading model: {str(e)}")
257
- st.error("Please check your model file and ensure it's compatible")
258
  return None
259
 
260
  # Load model
 
52
  # Enhanced model loading function
53
  @st.cache_resource
54
  def load_model():
55
+ """Enhanced model loading with better architecture detection"""
56
 
57
  # Try different model file names
58
  model_files = [
59
  "butterfly_classifier.pth",
60
+ "best_butterfly_model_v3.pth",
61
  "best_butterfly_model.pth"
62
  ]
63
 
 
68
  break
69
 
70
  if MODEL_PATH is None:
71
+ st.error("No model file found!")
72
  return None
73
 
74
  st.info(f"Loading model from: {MODEL_PATH}")
 
80
  # Extract model state dict
81
  if 'model_state_dict' in checkpoint:
82
  model_state_dict = checkpoint['model_state_dict']
 
 
83
  else:
84
  model_state_dict = checkpoint
85
+
86
  num_classes = len(class_names)
87
 
88
+ # Better architecture detection based on conv_stem channels
89
+ def detect_architecture_by_channels(state_dict):
90
+ """Detect architecture by examining conv_stem channels"""
91
+ for key, tensor in state_dict.items():
92
+ if key.endswith('conv_stem.weight'):
93
+ channels = tensor.shape[0] # Output channels
94
+ # Map channels to likely architectures
95
+ channel_map = {
96
+ 24: ['tf_efficientnetv2_s', 'efficientnet_b0'],
97
+ 32: ['tf_efficientnetv2_s', 'efficientnet_b1'],
98
+ 40: ['efficientnet_b3', 'efficientnet_b2'],
99
+ 48: ['efficientnet_b4', 'tf_efficientnetv2_m'],
100
+ 56: ['efficientnet_b5'],
101
+ 64: ['efficientnet_b6', 'tf_efficientnetv2_l'],
102
+ 72: ['efficientnet_b7']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  }
104
+ return channel_map.get(channels, ['tf_efficientnetv2_s'])
105
+ return ['tf_efficientnetv2_s']
 
106
 
107
+ # Get likely architectures based on channels
108
+ likely_architectures = detect_architecture_by_channels(model_state_dict)
 
109
 
110
+ # Expanded list of architectures to try
111
+ architectures_to_try = likely_architectures + [
 
 
 
 
 
 
112
  'tf_efficientnetv2_s',
113
+ 'efficientnet_b0',
114
+ 'efficientnet_b1',
115
+ 'efficientnet_b2',
116
+ 'efficientnet_b3',
117
+ 'tf_efficientnetv2_m',
118
+ 'efficientnet_b4'
119
  ]
120
 
121
  # Remove duplicates while preserving order
 
130
  try:
131
  st.info(f"Trying architecture: {arch}")
132
 
133
+ # Create model
134
  model = timm.create_model(
135
  arch,
136
  pretrained=False,
137
  num_classes=num_classes,
138
+ drop_rate=0.0, # Set to 0 for inference
139
+ drop_path_rate=0.0 # Set to 0 for inference
140
  )
141
 
142
+ # Try to load the state dict
143
+ try:
144
+ model.load_state_dict(model_state_dict, strict=True)
145
+ st.success(f"✅ Successfully loaded model with architecture: {arch}")
146
+ successful_arch = arch
147
+ break
148
+ except Exception as e:
149
+ # Try with strict=False
150
  try:
151
  model.load_state_dict(model_state_dict, strict=False)
152
+ st.warning(f"⚠️ Loaded {arch} with some mismatched weights")
153
  successful_arch = arch
154
  break
155
+ except Exception as e2:
156
+ st.warning(f"Failed to load {arch}: {str(e2)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  continue
158
 
159
  except Exception as e:
 
178
 
179
  except Exception as e:
180
  st.error(f"❌ Error loading model: {str(e)}")
 
181
  return None
182
 
183
  # Load model