iqramukhtiar commited on
Commit
06ee41a
·
verified ·
1 Parent(s): 17e606f

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +3 -2
train_model.py CHANGED
@@ -56,7 +56,8 @@ def train_model():
56
  print(f"Validation set size: {dataset_sizes['val']}")
57
 
58
  # Load a pre-trained model
59
- model = models.resnet50(pretrained=True)
 
60
 
61
  # Modify the final layer for our number of classes
62
  num_ftrs = model.fc.in_features
@@ -75,7 +76,7 @@ def train_model():
75
 
76
  for epoch in range(num_epochs):
77
  print(f'Epoch {epoch+1}/{num_epochs}')
78
- print('-' * 15)
79
 
80
  # Each epoch has a training and validation phase
81
  for phase in ['train', 'val']:
 
56
  print(f"Validation set size: {dataset_sizes['val']}")
57
 
58
  # Load a pre-trained model
59
+ from torchvision.models import ResNet50_Weights
60
+ model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
61
 
62
  # Modify the final layer for our number of classes
63
  num_ftrs = model.fc.in_features
 
76
 
77
  for epoch in range(num_epochs):
78
  print(f'Epoch {epoch+1}/{num_epochs}')
79
+ print('-' * 10)
80
 
81
  # Each epoch has a training and validation phase
82
  for phase in ['train', 'val']: