|
import torch |
|
import torchvision |
|
from torch import softmax |
|
from pathlib import Path |
|
from vision_transformer import ViT |
|
|
|
def load_model(model: torch.nn.Module, |
|
model_weights_dir: str, |
|
model_weights_name: str): |
|
|
|
"""Loads a PyTorch model from a target directory. |
|
|
|
Args: |
|
model: A target PyTorch model to load. |
|
model_weights_dir: A directory where the model is located. |
|
model_weights_name: The name of the model to load. |
|
Should include either ".pth" or ".pt" as the file extension. |
|
|
|
Example usage: |
|
model = load_model(model=model, |
|
model_weights_dir="models", |
|
model_weights_name="05_going_modular_tingvgg_model.pth") |
|
|
|
Returns: |
|
The loaded PyTorch model. |
|
""" |
|
|
|
|
|
model_dir_path = Path(model_weights_dir) |
|
|
|
|
|
assert model_weights_name.endswith(".pth") or model_weights_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'" |
|
model_path = model_dir_path / model_weights_name |
|
|
|
|
|
print(f"[INFO] Loading model from: {model_path}") |
|
|
|
model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu'))) |
|
|
|
return model |
|
|
|
def create_vitbase_model( |
|
model_weights_dir:Path, |
|
model_weights_name:str, |
|
image_size:int=224, |
|
num_classes:int=101, |
|
compile:bool=False |
|
): |
|
|
|
""" |
|
Creates a ViT-B/16 model with the specified number of classes. |
|
|
|
Args: |
|
model_weights_dir: A directory where the model is located. |
|
model_weights_name: The name of the model to load. |
|
image_size: The size of the input image. |
|
num_classes: The number of classes for the classification task. |
|
|
|
Returns: |
|
The created ViT-B/16 model. |
|
""" |
|
|
|
|
|
vitbase16_model = torchvision.models.vit_b_16(image_size=image_size).to("cpu") |
|
vitbase16_model.heads = torch.nn.Linear(in_features=768, out_features=num_classes).to("cpu") |
|
|
|
|
|
if compile: |
|
vitbase16_model = torch.compile(vitbase16_model, backend="aot_eager") |
|
|
|
|
|
vitbase16_model = load_model( |
|
model=vitbase16_model, |
|
model_weights_dir=model_weights_dir, |
|
model_weights_name=model_weights_name |
|
) |
|
|
|
return vitbase16_model |
|
|
|
def create_swin_tiny_model( |
|
model_weights_dir:Path, |
|
model_weights_name:str, |
|
image_size:int=224, |
|
num_classes:int=101, |
|
compile:bool=False |
|
): |
|
|
|
""" |
|
Creates a Swin-V2-Tiny model with the specified number of classes. |
|
|
|
Args: |
|
model_weights_dir: A directory where the model is located. |
|
model_weights_name: The name of the model to load. |
|
image_size: The size of the input image. |
|
num_classes: The number of classes for the classification task. |
|
|
|
Returns: |
|
The created ViT-B/16 model. |
|
""" |
|
|
|
|
|
swint_model = torchvision.models.swin_v2_t().to("cpu") |
|
swint_model.head = torch.nn.Linear(in_features=768, out_features=num_classes).to("cpu") |
|
|
|
|
|
if compile: |
|
swint_model = torch.compile(swint_model, backend="aot_eager") |
|
|
|
|
|
swint_model = load_model( |
|
model=swint_model, |
|
model_weights_dir=model_weights_dir, |
|
model_weights_name=model_weights_name |
|
) |
|
|
|
return swint_model |
|
|
|
|
|
def create_effnetb0( |
|
model_weights_dir: Path, |
|
model_weights_name: str, |
|
num_classes: int=2, |
|
dropout: float=0.2, |
|
compile:bool=False |
|
): |
|
|
|
"""Creates an EfficientNetB0 feature extractor model and transforms. |
|
|
|
Args: |
|
model_weights_dir: A directory where the model is located. |
|
model_weights_name: The name of the model to load. |
|
num_classes (int, optional): number of classes in the classifier head. |
|
dropout (float, optional): Dropout rate. Defaults to 0.2. |
|
|
|
Returns: |
|
effnetb0_model (torch.nn.Module): EffNetB0 feature extractor model. |
|
transforms (torchvision.transforms): Image transforms. |
|
""" |
|
|
|
|
|
weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT |
|
effnetb0_model = torchvision.models.efficientnet_b0(weights=weights).to('cpu') |
|
|
|
|
|
if dropout != 0.0: |
|
effnetb0_model.classifier = torch.nn.Sequential( |
|
torch.nn.Dropout(p=dropout, inplace=True), |
|
torch.nn.Linear(in_features=1280, |
|
out_features=num_classes, |
|
bias=True)) |
|
else: |
|
effnetb0_model.classifier = torch.nn.Sequential( |
|
torch.nn.Linear(in_features=1280, |
|
out_features=num_classes, |
|
bias=True)) |
|
|
|
|
|
if compile: |
|
effnetb0_model = torch.compile(effnetb0_model, backend="aot_eager") |
|
|
|
|
|
model_dir_path = Path(model_weights_dir) |
|
|
|
|
|
assert model_weights_name.endswith(".pth") or model_weights_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'" |
|
model_path = model_dir_path / model_weights_name |
|
|
|
|
|
effnetb0_model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu'))) |
|
|
|
return effnetb0_model |
|
|