torchvision
Browse files- app.py +8 -6
- requirements.txt +1 -0
app.py
CHANGED
@@ -3,21 +3,23 @@ import gradio as gr
|
|
3 |
|
4 |
import torch
|
5 |
|
6 |
-
model = torch.
|
7 |
|
8 |
import requests
|
9 |
from PIL import Image
|
10 |
from torchvision import transforms
|
11 |
|
12 |
# Download human-readable labels for ImageNet.
|
13 |
-
|
14 |
-
labels = response.text.split("\n")
|
15 |
|
16 |
def predict(inp):
|
17 |
-
|
|
|
|
|
|
|
18 |
with torch.no_grad():
|
19 |
-
prediction = torch.nn.functional.softmax(model(
|
20 |
-
confidences = {labels[i]: float(prediction[i]) for i in range(
|
21 |
return confidences
|
22 |
|
23 |
import gradio as gr
|
|
|
3 |
|
4 |
import torch
|
5 |
|
6 |
+
model = torch.load('v4-epoch=19-val_loss=0.6964-val_accuracy=0.8964.ckpt')
|
7 |
|
8 |
import requests
|
9 |
from PIL import Image
|
10 |
from torchvision import transforms
|
11 |
|
12 |
# Download human-readable labels for ImageNet.
|
13 |
+
labels = ['good', 'ill']
|
|
|
14 |
|
15 |
def predict(inp):
|
16 |
+
img = transforms.ToTensor()(inp)
|
17 |
+
img = torchvision.transforms.Resize((800, 800))(img)
|
18 |
+
img = torchvision.transforms.CenterCrop(CROP)(img)
|
19 |
+
img = img..unsqueeze(0)
|
20 |
with torch.no_grad():
|
21 |
+
prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
|
22 |
+
confidences = {labels[i]: float(prediction[i]) for i in range(2)}
|
23 |
return confidences
|
24 |
|
25 |
import gradio as gr
|
requirements.txt
CHANGED
@@ -1 +1,2 @@
|
|
1 |
torch
|
|
|
|
1 |
torch
|
2 |
+
torchvision
|