fish-freshness-classifier / 2_segment_with_sam.py
roqueselopeta's picture
Initial commit with clean project files
c17bef1
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from torchvision import transforms
from transformers import SamModel, SamProcessor
from models.efficientnet_b4 import EfficientNetB4Classifier
from glob import glob
# -------------------------------
# Setup & Paths
# -------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RESULTS_DIR = "results/gradcam_sam"
os.makedirs(RESULTS_DIR, exist_ok=True)
# -------------------------------
# Grad-CAM
# -------------------------------
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
def forward_hook(module, input, output):
self.activations = output
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0]
target_layer.register_forward_hook(forward_hook)
target_layer.register_full_backward_hook(backward_hook)
def __call__(self, x):
self.model.zero_grad()
out = self.model(x)
out.backward(torch.ones_like(out))
weights = self.gradients.mean(dim=(2, 3), keepdim=True)
cam = (weights * self.activations).sum(dim=1, keepdim=True)
cam = torch.nn.functional.relu(cam)
cam = torch.nn.functional.interpolate(cam, size=x.shape[2:], mode='bilinear', align_corners=False)
cam = cam.squeeze().detach().cpu().numpy()
cam = (cam - cam.min()) / (cam.max() - cam.min())
return cam
# -------------------------------
# Load EfficientNetB4
# -------------------------------
model = EfficientNetB4Classifier()
model.load_state_dict(torch.load("efficientnetb4_best8595.pth", map_location=DEVICE))
model.to(DEVICE).eval()
target_layer = model.base_model.features[-1]
cam_extractor = GradCAM(model, target_layer)
# -------------------------------
# Load SAM
# -------------------------------
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(DEVICE)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# -------------------------------
# Apply Grad-CAM + SAM
# -------------------------------
def apply_gradcam_sam(image_path, save_path):
raw = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
transforms.Resize((380, 380)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_tensor = transform(raw).unsqueeze(0).to(DEVICE)
# Grad-CAM
cam = cam_extractor(input_tensor)
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = cv2.resize(heatmap, raw.size)
overlay = cv2.addWeighted(np.array(raw), 0.5, heatmap, 0.5, 0)
# SAM
sam_inputs = sam_processor(raw, return_tensors="pt").to(DEVICE)
with torch.no_grad():
sam_outputs = sam_model(**sam_inputs)
sam_mask = sam_outputs.pred_masks[0][0].cpu().numpy()
# Mask outline and final image
mask_outline = cv2.Canny(np.uint8(sam_mask * 255), 100, 200)
mask_outline = cv2.cvtColor(mask_outline, cv2.COLOR_GRAY2RGB)
final = cv2.addWeighted(overlay, 1, mask_outline, 1, 0)
# Save output
fname = os.path.basename(image_path)
out_path = os.path.join(save_path, f"overlay_{fname}")
Image.fromarray(final).save(out_path)
print(f"βœ… Saved Grad-CAM+SAM: {out_path}")
# -------------------------------
# Run on 5 Random Images
# -------------------------------
if __name__ == "__main__":
image_paths = glob("dataset/**/*.jpg", recursive=True)
selected = np.random.choice(image_paths, size=5, replace=False)
for input_path in selected:
apply_gradcam_sam(input_path, RESULTS_DIR)