Updated stuff
Browse files- sagvit.bin → SAG-ViT.pth +2 -2
- model.safetensors +2 -2
- model_components.py +25 -4
- push_model_to_hfhub.py +12 -0
sagvit.bin → SAG-ViT.pth
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8bdfa65b805d74c284af254153960436014a6e26b740f25cf2c6c9289234d9d3
|
| 3 |
+
size 27137010
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:234c74e97091f54931e7050f775e56465b0abff0596906ad015c45df31e7b12a
|
| 3 |
+
size 27009200
|
model_components.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
-
from torch_geometric.nn import GATConv, global_mean_pool
|
| 5 |
|
| 6 |
from torchvision import models
|
| 7 |
|
|
@@ -52,15 +52,36 @@ class GATGNN(nn.Module):
|
|
| 52 |
This module corresponds to the Graph Attention stage (Section 3.3),
|
| 53 |
refining local relationships between patches in a learned manner.
|
| 54 |
"""
|
| 55 |
-
def __init__(self, in_channels, hidden_channels, out_channels, heads=
|
| 56 |
super(GATGNN, self).__init__()
|
| 57 |
# GAT layers:
|
| 58 |
# First layer maps raw patch embeddings to a higher-level representation.
|
| 59 |
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
|
| 60 |
-
#
|
| 61 |
-
self.conv2 =
|
| 62 |
self.pool = global_mean_pool
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def forward(self, data):
|
| 65 |
"""
|
| 66 |
Input:
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
from torch_geometric.nn import GATConv, global_mean_pool, GCNConv
|
| 5 |
|
| 6 |
from torchvision import models
|
| 7 |
|
|
|
|
| 52 |
This module corresponds to the Graph Attention stage (Section 3.3),
|
| 53 |
refining local relationships between patches in a learned manner.
|
| 54 |
"""
|
| 55 |
+
def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
|
| 56 |
super(GATGNN, self).__init__()
|
| 57 |
# GAT layers:
|
| 58 |
# First layer maps raw patch embeddings to a higher-level representation.
|
| 59 |
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
|
| 60 |
+
# Final GCN layer for refined representation
|
| 61 |
+
self.conv2 = GCNConv(hidden_channels * heads, out_channels)
|
| 62 |
self.pool = global_mean_pool
|
| 63 |
|
| 64 |
+
def forward(self, data):
|
| 65 |
+
"""
|
| 66 |
+
Input:
|
| 67 |
+
- data (PyG Data): Contains x (node features), edge_index (graph edges), and batch indexing.
|
| 68 |
+
|
| 69 |
+
Output:
|
| 70 |
+
- x (Tensor): Aggregated graph-level embedding after mean pooling.
|
| 71 |
+
"""
|
| 72 |
+
x, edge_index, batch = data.x, data.edge_index, data.batch
|
| 73 |
+
|
| 74 |
+
# GAT layer with ReLU activation
|
| 75 |
+
x = F.relu(self.conv1(x, edge_index))
|
| 76 |
+
|
| 77 |
+
# GCN layer for further aggregation
|
| 78 |
+
x = self.conv2(x, edge_index)
|
| 79 |
+
|
| 80 |
+
# Global mean pooling to obtain graph-level representation
|
| 81 |
+
out = self.pool(x, batch)
|
| 82 |
+
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
def forward(self, data):
|
| 86 |
"""
|
| 87 |
Input:
|
push_model_to_hfhub.py
CHANGED
|
@@ -1,13 +1,25 @@
|
|
|
|
|
| 1 |
from transformers import AutoConfig, AutoModel
|
| 2 |
from modeling_sagvit import SAGViTClassifier, SAGViTConfig
|
| 3 |
|
| 4 |
|
|
|
|
| 5 |
AutoConfig.register("sagvit", SAGViTConfig)
|
| 6 |
AutoModel.register(SAGViTConfig, SAGViTClassifier)
|
|
|
|
| 7 |
|
| 8 |
# Load config and model
|
| 9 |
config = SAGViTConfig()
|
| 10 |
model = SAGViTClassifier(config)
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# Push model and code
|
|
|
|
| 13 |
model.push_to_hub("shravvvv/SAG-ViT")
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
from transformers import AutoConfig, AutoModel
|
| 3 |
from modeling_sagvit import SAGViTClassifier, SAGViTConfig
|
| 4 |
|
| 5 |
|
| 6 |
+
print("Registering model...")
|
| 7 |
AutoConfig.register("sagvit", SAGViTConfig)
|
| 8 |
AutoModel.register(SAGViTConfig, SAGViTClassifier)
|
| 9 |
+
print("Registered model...")
|
| 10 |
|
| 11 |
# Load config and model
|
| 12 |
config = SAGViTConfig()
|
| 13 |
model = SAGViTClassifier(config)
|
| 14 |
|
| 15 |
+
# Load the state dict into the model
|
| 16 |
+
print("Loading model weights...")
|
| 17 |
+
state_dict = torch.load('SAG-ViT.pth')
|
| 18 |
+
model.load_state_dict(state_dict)
|
| 19 |
+
print("Loaded model weights...")
|
| 20 |
+
|
| 21 |
# Push model and code
|
| 22 |
+
model.save_pretrained('.')
|
| 23 |
model.push_to_hub("shravvvv/SAG-ViT")
|
| 24 |
+
|
| 25 |
+
print("Pushed model to Hugging Face hub...")
|