moshel commited on
Commit
6839f51
·
1 Parent(s): a3093be
Files changed (1) hide show
  1. app.py +18 -0
app.py CHANGED
@@ -3,12 +3,30 @@ import gradio as gr
3
 
4
  import torch
5
  import torchvision
 
6
 
7
  checkpoint = torch.load('v4-epoch=19-val_loss=0.6964-val_accuracy=0.8964.ckpt', map_location=torch.device('cpu'))
8
  state_dict = checkpoint["state_dict"]
9
  model_weights = state_dict
10
  for key in list(model_weights):
11
  model_weights[key.replace("backbone.", "")] = model_weights.pop(key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  model.load_state_dict(model_weights).eval()
13
 
14
  import requests
 
3
 
4
  import torch
5
  import torchvision
6
+ import timm
7
 
8
  checkpoint = torch.load('v4-epoch=19-val_loss=0.6964-val_accuracy=0.8964.ckpt', map_location=torch.device('cpu'))
9
  state_dict = checkpoint["state_dict"]
10
  model_weights = state_dict
11
  for key in list(model_weights):
12
  model_weights[key.replace("backbone.", "")] = model_weights.pop(key)
13
+
14
+
15
+ def get_model():
16
+ model = timm.create_model('tf_efficientnet_b1', pretrained=True, num_classes=2, global_pool='catavgmax')
17
+ num_in_features = model.get_classifier().in_features
18
+ from torch import nn
19
+
20
+ model.fc = nn.Sequential(
21
+ nn.Linear(in_features=num_in_features, out_features=1024, bias=False),
22
+ nn.ReLU(),
23
+ nn.Linear(in_features=1024, out_features=2, bias=False),
24
+ )
25
+
26
+ return model
27
+
28
+ model = get_model()
29
+
30
  model.load_state_dict(model_weights).eval()
31
 
32
  import requests