Update train_model.py
Browse files- train_model.py +3 -2
train_model.py
CHANGED
|
@@ -56,7 +56,8 @@ def train_model():
|
|
| 56 |
print(f"Validation set size: {dataset_sizes['val']}")
|
| 57 |
|
| 58 |
# Load a pre-trained model
|
| 59 |
-
|
|
|
|
| 60 |
|
| 61 |
# Modify the final layer for our number of classes
|
| 62 |
num_ftrs = model.fc.in_features
|
|
@@ -75,7 +76,7 @@ def train_model():
|
|
| 75 |
|
| 76 |
for epoch in range(num_epochs):
|
| 77 |
print(f'Epoch {epoch+1}/{num_epochs}')
|
| 78 |
-
print('-' *
|
| 79 |
|
| 80 |
# Each epoch has a training and validation phase
|
| 81 |
for phase in ['train', 'val']:
|
|
|
|
| 56 |
print(f"Validation set size: {dataset_sizes['val']}")
|
| 57 |
|
| 58 |
# Load a pre-trained model
|
| 59 |
+
from torchvision.models import ResNet50_Weights
|
| 60 |
+
model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
|
| 61 |
|
| 62 |
# Modify the final layer for our number of classes
|
| 63 |
num_ftrs = model.fc.in_features
|
|
|
|
| 76 |
|
| 77 |
for epoch in range(num_epochs):
|
| 78 |
print(f'Epoch {epoch+1}/{num_epochs}')
|
| 79 |
+
print('-' * 10)
|
| 80 |
|
| 81 |
# Each epoch has a training and validation phase
|
| 82 |
for phase in ['train', 'val']:
|