File size: 2,552 Bytes
2134603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
try:
    from .image_encoder import ImageEncoder
except ImportError:
    ImageEncoder = None

try:
    from .text_encoder import TextEncoder
except ImportError:
    TextEncoder = None

try:
    from .audio_encoder import AudioEncoder
except ImportError:
    AudioEncoder = None

try:
    from .time_series_encoder import TimeSeriesEncoder
except ImportError:
    TimeSeriesEncoder = None

from .fusion import GatedFusion, CrossModalAttention

class MultiModalModel(nn.Module):
    """
    Flexible multimodal model supporting image, text, audio, and time series.
    Supports custom encoders and fusion strategies.
    """
    def __init__(self, config):
        super().__init__()
        # config: dict with keys for each modality and fusion
        self.encoders = nn.ModuleDict()
        if 'image' in config:
            self.encoders['image'] = config.get('image_encoder') or ImageEncoder(**config['image'])
        if 'text' in config:
            self.encoders['text'] = config.get('text_encoder') or TextEncoder(**config['text'])
        if 'audio' in config:
            self.encoders['audio'] = config.get('audio_encoder') or AudioEncoder(**config['audio'])
        if 'time_series' in config:
            self.encoders['time_series'] = config.get('time_series_encoder') or TimeSeriesEncoder(**config['time_series'])
        self.fusion_type = config.get('fusion', 'concat')
        if self.fusion_type == 'gated':
            input_dims = [enc.out_features for enc in self.encoders.values()]
            self.fusion = GatedFusion(input_dims, config['fusion_output_dim'])
        elif self.fusion_type == 'cross_attention':
            # Assume two modalities for simplicity
            dims = list(self.encoders.values())
            self.fusion = CrossModalAttention(dims[0].out_features, dims[1].out_features, config['fusion_output_dim'])
        else:
            self.fusion = None  # fallback to concat

        self.classifier = nn.Linear(sum(enc.out_features for enc in self.encoders.values()), config['num_classes'])

    def forward(self, inputs):
        # inputs: dict of modality_name -> tensor
        feats = []
        for name, encoder in self.encoders.items():
            feats.append(encoder(inputs[name]))
        if self.fusion_type == 'gated':
            fused = self.fusion(feats)
        elif self.fusion_type == 'cross_attention':
            fused = self.fusion(feats[0], feats[1])
        else:
            fused = torch.cat(feats, dim=-1)
        return self.classifier(fused)