Stockai / multimodal_model.py
rmanzo28's picture
Upload 4 files
2134603 verified
import torch
import torch.nn as nn
try:
from .image_encoder import ImageEncoder
except ImportError:
ImageEncoder = None
try:
from .text_encoder import TextEncoder
except ImportError:
TextEncoder = None
try:
from .audio_encoder import AudioEncoder
except ImportError:
AudioEncoder = None
try:
from .time_series_encoder import TimeSeriesEncoder
except ImportError:
TimeSeriesEncoder = None
from .fusion import GatedFusion, CrossModalAttention
class MultiModalModel(nn.Module):
"""
Flexible multimodal model supporting image, text, audio, and time series.
Supports custom encoders and fusion strategies.
"""
def __init__(self, config):
super().__init__()
# config: dict with keys for each modality and fusion
self.encoders = nn.ModuleDict()
if 'image' in config:
self.encoders['image'] = config.get('image_encoder') or ImageEncoder(**config['image'])
if 'text' in config:
self.encoders['text'] = config.get('text_encoder') or TextEncoder(**config['text'])
if 'audio' in config:
self.encoders['audio'] = config.get('audio_encoder') or AudioEncoder(**config['audio'])
if 'time_series' in config:
self.encoders['time_series'] = config.get('time_series_encoder') or TimeSeriesEncoder(**config['time_series'])
self.fusion_type = config.get('fusion', 'concat')
if self.fusion_type == 'gated':
input_dims = [enc.out_features for enc in self.encoders.values()]
self.fusion = GatedFusion(input_dims, config['fusion_output_dim'])
elif self.fusion_type == 'cross_attention':
# Assume two modalities for simplicity
dims = list(self.encoders.values())
self.fusion = CrossModalAttention(dims[0].out_features, dims[1].out_features, config['fusion_output_dim'])
else:
self.fusion = None # fallback to concat
self.classifier = nn.Linear(sum(enc.out_features for enc in self.encoders.values()), config['num_classes'])
def forward(self, inputs):
# inputs: dict of modality_name -> tensor
feats = []
for name, encoder in self.encoders.items():
feats.append(encoder(inputs[name]))
if self.fusion_type == 'gated':
fused = self.fusion(feats)
elif self.fusion_type == 'cross_attention':
fused = self.fusion(feats[0], feats[1])
else:
fused = torch.cat(feats, dim=-1)
return self.classifier(fused)