AttributeError: 'InternImage' object has no attribute '_initialize_weights'

#1
by 98Rakib - opened

Code:

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score, roc_auc_score, roc_curve
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import copy
import time
import math
from collections import Counter
from tqdm import tqdm
from torch.amp import autocast, GradScaler
from transformers import AutoModel

seed=42

Ensure reproducibility

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Define constants

IMAGE_DIR = "/kaggle/input/safenet-ai-final-dataset/Notebooks + CSV + Images/Images/Images"
BATCH_SIZE = 8
NUM_EPOCHS = 200
LEARNING_RATE = 0.0001
IMAGE_SIZE = 224
NUM_CLASSES = 3
PATIENCE = 3
#D_O=.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Filter dataframe to only include the classes we're interested in

target_classes = ['Harmless Trolling','Targeted Trolling','Provocative_Trolls']#,] 'Satirical Trolling'
#target_classes = ['Satirical Trolling','Targeted Trolling','Provocative Trolls','Harmless Trolling'] #,Provocative Trolls','Explicit Harassment Trolls'

Custom dataset class

class MemeDataset(Dataset):
def init(self, dataframe, image_dir, transform=None):
self.dataframe = dataframe
self.image_dir = image_dir
self.transform = transform

def __len__(self):
    return len(self.dataframe)

def __getitem__(self, idx):
    img_name = self.dataframe.iloc[idx]['image']
    img_path = os.path.join(self.image_dir, img_name)
    
    try:
        image = Image.open(img_path).convert('RGB')
    except Exception as e:
        image = Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), color='white')
        
    label = self.dataframe.iloc[idx]['class_idx']
    
    if self.transform:
        image = self.transform(image)
    
    return image, label

from transformers import AutoImageProcessor, AutoModelForImageClassification

Load processor and model

model_name = "OpenGVLab/internimage_b_1k_224"
#processor = AutoImageProcessor.from_pretrained(model_name)
model = model = AutoModelForImageClassification.from_pretrained(
model_name,
trust_remote_code=True,
num_labels=NUM_CLASSES,
ignore_mismatched_sizes=True
)
model.to(DEVICE)

train_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Create dataset objects

train_dataset = MemeDataset(train_df, IMAGE_DIR, transform=train_transforms)
val_dataset = MemeDataset(val_df, IMAGE_DIR, transform=val_transforms)
test_dataset = MemeDataset(test_df, IMAGE_DIR, transform=val_transforms)

Create data loaders

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

'''

Compute class weights

class_counts = Counter(train_df['class_idx'])
total_samples = sum(class_counts.values())
num_classes = len(class_counts)

class_weights = torch.tensor([
math.log(total_samples / class_counts[i]) for i in range(num_classes)
], dtype=torch.float32).to(DEVICE)

criterion = nn.CrossEntropyLoss(weight=class_weights)
#'''

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = GradScaler()

from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, classification_report, roc_curve

best_f1 = 0.0
no_improve_epochs = 0
best_model_wts = copy.deepcopy(model.state_dict())

print("Starting training...")
for epoch in range(NUM_EPOCHS):
print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
print('-' * 10)

# Training phase
model.train()
running_loss = 0.0
running_corrects = 0
all_preds_train = []
all_labels_train = []

for inputs, labels in tqdm(train_loader, desc="Training"):
    inputs = inputs.to(DEVICE)
    labels = labels.to(DEVICE)
    
    optimizer.zero_grad()
    
    with autocast(device_type='cuda'):
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    running_loss += loss.item() * inputs.size(0)
    running_corrects += torch.sum(preds == labels.data)
    all_preds_train.extend(preds.cpu().numpy())
    all_labels_train.extend(labels.cpu().numpy())

train_loss = running_loss / len(train_loader.dataset)
train_acc = running_corrects.double() / len(train_loader.dataset)
train_f1 = f1_score(all_labels_train, all_preds_train, average='macro')
train_auc = roc_auc_score(all_labels_train, np.eye(NUM_CLASSES)[all_preds_train], multi_class='ovr')

# Validation phase
model.eval()
val_preds, val_labels = [], []
with torch.no_grad():
    for inputs, labels in tqdm(val_loader, desc="Validation"):
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        val_preds.extend(preds.cpu().numpy())
        val_labels.extend(labels.cpu().numpy())

val_acc = accuracy_score(val_labels, val_preds)
val_f1 = f1_score(val_labels, val_preds, average='macro')
val_auc = roc_auc_score(val_labels, np.eye(NUM_CLASSES)[val_preds], multi_class='ovr')

print(f'Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f} | Train ROC-AUC: {train_auc:.4f}')
print(f'Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f} | Val ROC-AUC: {val_auc:.4f}')

# Early stopping based on validation macro F1 score
if val_f1 > best_f1:
    best_f1 = val_f1
    best_model_wts = copy.deepcopy(model.state_dict())
    no_improve_epochs = 0
else:
    no_improve_epochs += 1
    print(f'Early stopping Count {no_improve_epochs}')

if no_improve_epochs >= PATIENCE:
    print(f'Early stopping triggered after {epoch+1} epochs')
    break

print(f'Best Validation f1:{best_f1:.4f}')

Load best model weights

#model.load_state_dict(best_model_wts)

Evaluate model on test set

model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
for inputs, labels in tqdm(test_loader,desc='Testing'):
inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(all_labels, all_preds)
test_f1 = f1_score(all_labels, all_preds, average='macro')
test_auc = roc_auc_score(all_labels, np.eye(NUM_CLASSES)[all_preds], multi_class='ovr')
print('Swin Transformer')
print("Test Classification Report:")
print(classification_report(all_labels, all_preds, target_names=target_classes, digits=4))
print(f'Test Acc: {test_acc:.4f} | Test F1: {test_f1:.4f} | Test ROC-AUC: {test_auc:.4f}')

ROC Curve Visualization

plt.figure(figsize=(8, 6))
for i in range(NUM_CLASSES):
fpr, tpr, _ = roc_curve(np.array(all_labels) == i, np.array(all_preds) == i)
plt.plot(fpr, tpr, label=f'Class {target_classes[i]}')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.show()

Confusion Matrix

plt.figure(figsize=(10, 8))
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=target_classes, yticklabels=target_classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.savefig('confusion_matrix.png')
plt.show()

Save the model

torch.save(model.state_dict(), 'meme_classifier.pth')

Error:

AttributeError Traceback (most recent call last)
/tmp/ipykernel_36/1370559554.py in <cell line: 0>()
93 model_name = "OpenGVLab/internimage_b_1k_224"
94 #processor = AutoImageProcessor.from_pretrained(model_name)
---> 95 model = model = AutoModelForImageClassification.from_pretrained(
96 model_name,
97 trust_remote_code=True,

/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
562 elif type(config) in cls._model_mapping.keys():
563 model_class = _get_model_class(config, cls._model_mapping)
--> 564 return model_class.from_pretrained(
565 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
566 )

/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in _wrapper(*args, **kwargs)
307
308 def get_state_dict_dtype(state_dict):
--> 309 """
310 Returns the first found floating dtype in state_dict if there is one, otherwise returns the first dtype.
311 """

/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)
4572 f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
4573 " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
-> 4574 )
4575 elif len(mismatched_keys) == 0:
4576 logger.info(

/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in _load_pretrained_model(cls, model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, device_map, disk_offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_regex, device_mesh, key_mapping, weights_only)
4882 start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
4883
-> 4884 x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
4885 x = self.activation(x)
4886 x = self.LayerNorm(x)

/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in _initialize_missing_keys(self, loaded_keys, ignore_mismatched_sizes, is_quantized)

/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
117
118 return decorate_context

/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in initialize_weights(self)
2554 if token is not None:
2555 kwargs["token"] = token
-> 2556
2557 _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False)
2558

/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in smart_apply(self, fn)
2545 "The use_auth_token argument is deprecated and will be removed in v5 of Transformers. Please use token instead.",
2546 FutureWarning,
-> 2547 )
2548 if token is not None:
2549 raise ValueError(

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in getattr(self, name)
1926 if name in modules:
1927 return modules[name]
-> 1928 raise AttributeError(
1929 f"'{type(self).name}' object has no attribute '{name}'"
1930 )

AttributeError: 'InternImage' object has no attribute '_initialize_weights'

Sign up or log in to comment