Upload 4 files
Browse files- fusion.py +40 -0
- requirements.txt +6 -0
- time_series_encoder.py +17 -0
- training_utils.py +29 -0
fusion.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class GatedFusion(nn.Module):
|
5 |
+
"""
|
6 |
+
Gated fusion for two or more modalities.
|
7 |
+
"""
|
8 |
+
def __init__(self, input_dims, output_dim):
|
9 |
+
super().__init__()
|
10 |
+
self.gates = nn.ModuleList([nn.Linear(d, output_dim) for d in input_dims])
|
11 |
+
self.fcs = nn.ModuleList([nn.Linear(d, output_dim) for d in input_dims])
|
12 |
+
|
13 |
+
def forward(self, features):
|
14 |
+
# features: list of tensors [batch, dim]
|
15 |
+
gated = []
|
16 |
+
for i, feat in enumerate(features):
|
17 |
+
gate = torch.sigmoid(self.gates[i](feat))
|
18 |
+
proj = self.fcs[i](feat)
|
19 |
+
gated.append(gate * proj)
|
20 |
+
return sum(gated)
|
21 |
+
|
22 |
+
class CrossModalAttention(nn.Module):
|
23 |
+
"""
|
24 |
+
Cross-modal attention for two modalities.
|
25 |
+
"""
|
26 |
+
def __init__(self, dim_q, dim_kv, dim_out):
|
27 |
+
super().__init__()
|
28 |
+
self.query = nn.Linear(dim_q, dim_out)
|
29 |
+
self.key = nn.Linear(dim_kv, dim_out)
|
30 |
+
self.value = nn.Linear(dim_kv, dim_out)
|
31 |
+
self.softmax = nn.Softmax(dim=-1)
|
32 |
+
|
33 |
+
def forward(self, q, kv):
|
34 |
+
# q: (batch, dim_q), kv: (batch, dim_kv)
|
35 |
+
Q = self.query(q).unsqueeze(1) # (batch, 1, dim_out)
|
36 |
+
K = self.key(kv).unsqueeze(1) # (batch, 1, dim_out)
|
37 |
+
V = self.value(kv).unsqueeze(1) # (batch, 1, dim_out)
|
38 |
+
attn = self.softmax(torch.bmm(Q, K.transpose(1,2)))
|
39 |
+
out = torch.bmm(attn, V).squeeze(1)
|
40 |
+
return out
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.40.0
|
3 |
+
scikit-learn
|
4 |
+
pandas
|
5 |
+
numpy
|
6 |
+
tqdm
|
time_series_encoder.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class TimeSeriesEncoder(nn.Module):
|
5 |
+
"""
|
6 |
+
Simple time series encoder using LSTM.
|
7 |
+
"""
|
8 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1):
|
9 |
+
super().__init__()
|
10 |
+
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
|
11 |
+
self.fc = nn.Linear(hidden_dim, output_dim)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
# x: (batch, seq_len, input_dim)
|
15 |
+
_, (h_n, _) = self.lstm(x)
|
16 |
+
out = self.fc(h_n[-1])
|
17 |
+
return out
|
training_utils.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class EarlyStopping:
|
4 |
+
"""
|
5 |
+
Early stops the training if validation loss doesn't improve after a given patience.
|
6 |
+
"""
|
7 |
+
def __init__(self, patience=5, delta=0):
|
8 |
+
self.patience = patience
|
9 |
+
self.delta = delta
|
10 |
+
self.counter = 0
|
11 |
+
self.best_loss = None
|
12 |
+
self.early_stop = False
|
13 |
+
|
14 |
+
def __call__(self, val_loss):
|
15 |
+
if self.best_loss is None or val_loss < self.best_loss - self.delta:
|
16 |
+
self.best_loss = val_loss
|
17 |
+
self.counter = 0
|
18 |
+
else:
|
19 |
+
self.counter += 1
|
20 |
+
if self.counter >= self.patience:
|
21 |
+
self.early_stop = True
|
22 |
+
|
23 |
+
def get_scheduler(optimizer, scheduler_type='plateau', **kwargs):
|
24 |
+
if scheduler_type == 'plateau':
|
25 |
+
return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **kwargs)
|
26 |
+
elif scheduler_type == 'step':
|
27 |
+
return torch.optim.lr_scheduler.StepLR(optimizer, **kwargs)
|
28 |
+
else:
|
29 |
+
raise ValueError(f"Unknown scheduler type: {scheduler_type}")
|