FoodVisionMini / model.py
fbrynpk's picture
Initial Commit
b418fb1
raw
history blame contribute delete
819 Bytes
import torch
import torchvision
from torch import nn
def create_vit(pretrained_weights: torchvision.models.Weights,
model: torchvision.models,
in_features: int,
out_features: int,
device: torch.device):
"""Creates a Vision Transformer (ViT) instance from torchvision
and returns it.
"""
# Create a pretrained ViT model
model = torchvision.models.vit_b_16(weights=pretrained_weights).to(device)
transforms = pretrained_weights.transforms()
# Freeze the feature extractor
for param in model.parameters():
param.requires_grad = False
# Change the head of the ViT
model.heads = nn.Sequential(
nn.Linear(in_features=in_features, out_features=out_features)
).to(device)
return model, transforms