AttributeError: 'InternImage' object has no attribute '_initialize_weights'
Code:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score, roc_auc_score, roc_curve
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import copy
import time
import math
from collections import Counter
from tqdm import tqdm
from torch.amp import autocast, GradScaler
from transformers import AutoModel
seed=42
Ensure reproducibility
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Define constants
IMAGE_DIR = "/kaggle/input/safenet-ai-final-dataset/Notebooks + CSV + Images/Images/Images"
BATCH_SIZE = 8
NUM_EPOCHS = 200
LEARNING_RATE = 0.0001
IMAGE_SIZE = 224
NUM_CLASSES = 3
PATIENCE = 3
#D_O=.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Filter dataframe to only include the classes we're interested in
target_classes = ['Harmless Trolling','Targeted Trolling','Provocative_Trolls']#,] 'Satirical Trolling'
#target_classes = ['Satirical Trolling','Targeted Trolling','Provocative Trolls','Harmless Trolling'] #,Provocative Trolls','Explicit Harassment Trolls'
Custom dataset class
class MemeDataset(Dataset):
def init(self, dataframe, image_dir, transform=None):
self.dataframe = dataframe
self.image_dir = image_dir
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
img_name = self.dataframe.iloc[idx]['image']
img_path = os.path.join(self.image_dir, img_name)
try:
image = Image.open(img_path).convert('RGB')
except Exception as e:
image = Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), color='white')
label = self.dataframe.iloc[idx]['class_idx']
if self.transform:
image = self.transform(image)
return image, label
from transformers import AutoImageProcessor, AutoModelForImageClassification
Load processor and model
model_name = "OpenGVLab/internimage_b_1k_224"
#processor = AutoImageProcessor.from_pretrained(model_name)
model = model = AutoModelForImageClassification.from_pretrained(
model_name,
trust_remote_code=True,
num_labels=NUM_CLASSES,
ignore_mismatched_sizes=True
)
model.to(DEVICE)
train_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
Create dataset objects
train_dataset = MemeDataset(train_df, IMAGE_DIR, transform=train_transforms)
val_dataset = MemeDataset(val_df, IMAGE_DIR, transform=val_transforms)
test_dataset = MemeDataset(test_df, IMAGE_DIR, transform=val_transforms)
Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
'''
Compute class weights
class_counts = Counter(train_df['class_idx'])
total_samples = sum(class_counts.values())
num_classes = len(class_counts)
class_weights = torch.tensor([
math.log(total_samples / class_counts[i]) for i in range(num_classes)
], dtype=torch.float32).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_weights)
#'''
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = GradScaler()
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, classification_report, roc_curve
best_f1 = 0.0
no_improve_epochs = 0
best_model_wts = copy.deepcopy(model.state_dict())
print("Starting training...")
for epoch in range(NUM_EPOCHS):
print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
print('-' * 10)
# Training phase
model.train()
running_loss = 0.0
running_corrects = 0
all_preds_train = []
all_labels_train = []
for inputs, labels in tqdm(train_loader, desc="Training"):
inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
optimizer.zero_grad()
with autocast(device_type='cuda'):
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
all_preds_train.extend(preds.cpu().numpy())
all_labels_train.extend(labels.cpu().numpy())
train_loss = running_loss / len(train_loader.dataset)
train_acc = running_corrects.double() / len(train_loader.dataset)
train_f1 = f1_score(all_labels_train, all_preds_train, average='macro')
train_auc = roc_auc_score(all_labels_train, np.eye(NUM_CLASSES)[all_preds_train], multi_class='ovr')
# Validation phase
model.eval()
val_preds, val_labels = [], []
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="Validation"):
inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
val_preds.extend(preds.cpu().numpy())
val_labels.extend(labels.cpu().numpy())
val_acc = accuracy_score(val_labels, val_preds)
val_f1 = f1_score(val_labels, val_preds, average='macro')
val_auc = roc_auc_score(val_labels, np.eye(NUM_CLASSES)[val_preds], multi_class='ovr')
print(f'Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f} | Train ROC-AUC: {train_auc:.4f}')
print(f'Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f} | Val ROC-AUC: {val_auc:.4f}')
# Early stopping based on validation macro F1 score
if val_f1 > best_f1:
best_f1 = val_f1
best_model_wts = copy.deepcopy(model.state_dict())
no_improve_epochs = 0
else:
no_improve_epochs += 1
print(f'Early stopping Count {no_improve_epochs}')
if no_improve_epochs >= PATIENCE:
print(f'Early stopping triggered after {epoch+1} epochs')
break
print(f'Best Validation f1:{best_f1:.4f}')
Load best model weights
#model.load_state_dict(best_model_wts)
Evaluate model on test set
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
for inputs, labels in tqdm(test_loader,desc='Testing'):
inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
test_acc = accuracy_score(all_labels, all_preds)
test_f1 = f1_score(all_labels, all_preds, average='macro')
test_auc = roc_auc_score(all_labels, np.eye(NUM_CLASSES)[all_preds], multi_class='ovr')
print('Swin Transformer')
print("Test Classification Report:")
print(classification_report(all_labels, all_preds, target_names=target_classes, digits=4))
print(f'Test Acc: {test_acc:.4f} | Test F1: {test_f1:.4f} | Test ROC-AUC: {test_auc:.4f}')
ROC Curve Visualization
plt.figure(figsize=(8, 6))
for i in range(NUM_CLASSES):
fpr, tpr, _ = roc_curve(np.array(all_labels) == i, np.array(all_preds) == i)
plt.plot(fpr, tpr, label=f'Class {target_classes[i]}')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.show()
Confusion Matrix
plt.figure(figsize=(10, 8))
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=target_classes, yticklabels=target_classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.savefig('confusion_matrix.png')
plt.show()
Save the model
torch.save(model.state_dict(), 'meme_classifier.pth')
Error:
AttributeError Traceback (most recent call last)
/tmp/ipykernel_36/1370559554.py in <cell line: 0>()
93 model_name = "OpenGVLab/internimage_b_1k_224"
94 #processor = AutoImageProcessor.from_pretrained(model_name)
---> 95 model = model = AutoModelForImageClassification.from_pretrained(
96 model_name,
97 trust_remote_code=True,
/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
562 elif type(config) in cls._model_mapping.keys():
563 model_class = _get_model_class(config, cls._model_mapping)
--> 564 return model_class.from_pretrained(
565 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
566 )
/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in _wrapper(*args, **kwargs)
307
308 def get_state_dict_dtype(state_dict):
--> 309 """
310 Returns the first found floating dtype in state_dict if there is one, otherwise returns the first dtype.
311 """
/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)
4572 f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
4573 " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
-> 4574 )
4575 elif len(mismatched_keys) == 0:
4576 logger.info(
/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in _load_pretrained_model(cls, model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, device_map, disk_offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_regex, device_mesh, key_mapping, weights_only)
4882 start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
4883
-> 4884 x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
4885 x = self.activation(x)
4886 x = self.LayerNorm(x)
/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in _initialize_missing_keys(self, loaded_keys, ignore_mismatched_sizes, is_quantized)
/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
117
118 return decorate_context
/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in initialize_weights(self)
2554 if token is not None:
2555 kwargs["token"] = token
-> 2556
2557 _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False)
2558
/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in smart_apply(self, fn)
2545 "The use_auth_token argument is deprecated and will be removed in v5 of Transformers. Please use token instead.",
2546 FutureWarning,
-> 2547 )
2548 if token is not None:
2549 raise ValueError(
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in getattr(self, name)
1926 if name in modules:
1927 return modules[name]
-> 1928 raise AttributeError(
1929 f"'{type(self).name}' object has no attribute '{name}'"
1930 )
AttributeError: 'InternImage' object has no attribute '_initialize_weights'