File size: 2,268 Bytes
fdd0e8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from typing import Optional
class SegFormer(nn.Module):
"""
SegFormer model for multi-class semantic segmentation.
Default setup targets RGB (3 bands), but you can set `in_channels` to support
multispectral inputs (e.g., 13 for Sentinel-2 L1C). Outputs raw logits with
shape (B, num_classes, H, W).
"""
def __init__(
self,
encoder_name: str = "mit_b4",
encoder_weights: Optional[str] = "imagenet", # set to None if incompatible with in_channels
in_channels: int = 3,
num_classes: int = 4,
freeze_encoder: bool = False,
) -> None:
"""
Args:
encoder_name: TIMM encoder name (e.g., 'mit_b0'...'mit_b5', default 'mit_b4').
encoder_weights: Pretrained weights name (typically 'imagenet' or None).
in_channels: Number of input channels (3 for RGB, 13 for Sentinel-2, etc.).
num_classes: Number of output classes for segmentation.
freeze_encoder: If True, freezes encoder parameters during training.
"""
super().__init__()
self.segformer = smp.Segformer(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=num_classes,
)
if freeze_encoder:
for p in self.segformer.encoder.parameters():
p.requires_grad = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x: Tensor of shape (B, in_channels, H, W).
Returns:
torch.Tensor: Logits of shape (B, num_classes, H, W).
"""
return self.segformer(x)
@torch.no_grad()
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""
Inference helper: applies softmax + argmax to produce label maps.
Args:
x: Tensor of shape (B, in_channels, H, W).
Returns:
torch.Tensor: Integer labels of shape (B, H, W).
"""
self.eval()
logits = self.forward(x) # (B, num_classes, H, W)
return torch.softmax(logits, dim=1).argmax(dim=1) |