File size: 5,585 Bytes
b64811b
18add7e
8e8c05b
4df6cc7
c781e06
 
59126da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e8c05b
59126da
 
 
 
 
 
 
 
 
 
6cfaed9
59126da
 
 
c781e06
 
 
3ff958a
7c933ef
 
c781e06
8e8c05b
59126da
 
 
 
8df1ddb
 
3ff958a
8df1ddb
59126da
 
 
8e8c05b
 
c781e06
3ff958a
 
c781e06
c2a4d06
7c933ef
 
c2a4d06
c781e06
 
 
 
 
 
 
 
8df1ddb
e82f64b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8df1ddb
 
 
 
bbc7a9d
f8f3da7
 
8df1ddb
8e8c05b
8df1ddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63cd96e
8df1ddb
 
8e8c05b
 
 
 
 
 
 
 
 
 
 
8df1ddb
f8f3da7
 
 
 
8df1ddb
 
 
 
 
 
 
 
cd79e1e
8df1ddb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
import torchvision
from torch import softmax
from pathlib import Path
from vision_transformer import ViT

def load_model(model: torch.nn.Module,
               model_weights_dir: str,
               model_weights_name: str):

    """Loads a PyTorch model from a target directory.

    Args:
    model: A target PyTorch model to load.
    model_weights_dir: A directory where the model is located.
    model_weights_name: The name of the model to load.
      Should include either ".pth" or ".pt" as the file extension.

    Example usage:
    model = load_model(model=model,
                       model_weights_dir="models",
                       model_weights_name="05_going_modular_tingvgg_model.pth")

    Returns:
    The loaded PyTorch model.
    """
    
    # Create the model directory path
    model_dir_path = Path(model_weights_dir)

    # Create the model path
    assert model_weights_name.endswith(".pth") or model_weights_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
    model_path = model_dir_path / model_weights_name

    # Load the model
    print(f"[INFO] Loading model from: {model_path}")
    
    model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu')))
    
    return model

def create_vitbase_model(
    model_weights_dir:Path,
    model_weights_name:str,
    image_size:int=224,
    num_classes:int=101,
    compile:bool=False
    ):

    """
    Creates a ViT-B/16 model with the specified number of classes.

    Args:
        model_weights_dir: A directory where the model is located.
        model_weights_name: The name of the model to load.
        image_size: The size of the input image.
        num_classes: The number of classes for the classification task.

    Returns:
    The created ViT-B/16 model.
    """ 

    # Instantiate the model
    vitbase16_model = torchvision.models.vit_b_16(image_size=image_size).to("cpu")
    vitbase16_model.heads = torch.nn.Linear(in_features=768, out_features=num_classes).to("cpu")
    
    # Compile the model
    if compile:
        vitbase16_model = torch.compile(vitbase16_model, backend="aot_eager")

    # Load the trained weights
    vitbase16_model = load_model(
        model=vitbase16_model,
        model_weights_dir=model_weights_dir,
        model_weights_name=model_weights_name
        )
    
    return vitbase16_model

def create_swin_tiny_model(
    model_weights_dir:Path,
    model_weights_name:str,
    image_size:int=224,
    num_classes:int=101,
    compile:bool=False
    ):

    """
    Creates a Swin-V2-Tiny model with the specified number of classes.

    Args:
        model_weights_dir: A directory where the model is located.
        model_weights_name: The name of the model to load.
        image_size: The size of the input image.
        num_classes: The number of classes for the classification task.

    Returns:
    The created ViT-B/16 model.
    """ 

    # Instantiate the model
    swint_model = torchvision.models.swin_v2_t().to("cpu")
    swint_model.head = torch.nn.Linear(in_features=768, out_features=num_classes).to("cpu")
    
    # Compile the model
    if compile:
        swint_model = torch.compile(swint_model, backend="aot_eager")

    # Load the trained weights
    swint_model = load_model(
        model=swint_model,
        model_weights_dir=model_weights_dir,
        model_weights_name=model_weights_name
        )
    
    return swint_model

# Create an EfficientNet-B0 Model
def create_effnetb0(
        model_weights_dir: Path,
        model_weights_name: str,
        num_classes: int=2,
        dropout: float=0.2,
        compile:bool=False
        ):

    """Creates an EfficientNetB0 feature extractor model and transforms.

    Args:
        model_weights_dir: A directory where the model is located.
        model_weights_name: The name of the model to load.
        num_classes (int, optional): number of classes in the classifier head.
        dropout (float, optional): Dropout rate. Defaults to 0.2.

    Returns:
        effnetb0_model (torch.nn.Module): EffNetB0 feature extractor model.
        transforms (torchvision.transforms): Image transforms.
    """
    
    # Load pretrained weights
    weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT # .DEFAULT = best available weights 
    effnetb0_model = torchvision.models.efficientnet_b0(weights=weights).to('cpu')

    # Recreate the classifier layer and seed it to the target device
    if dropout != 0.0:
        effnetb0_model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=dropout, inplace=True), 
            torch.nn.Linear(in_features=1280, 
                            out_features=num_classes,
                            bias=True))
    else:
        effnetb0_model.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features=1280,
                            out_features=num_classes,
                            bias=True))
    
    # Compile the model
    if compile:
        effnetb0_model = torch.compile(effnetb0_model, backend="aot_eager")
    
    # Create the model directory path
    model_dir_path = Path(model_weights_dir)

    # Create the model path
    assert model_weights_name.endswith(".pth") or model_weights_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
    model_path = model_dir_path / model_weights_name

    # Load the state dictionary into the model
    effnetb0_model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu')))
        
    return effnetb0_model