Update stock_ai.py
Browse files- stock_ai.py +5 -3
stock_ai.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
-
from transformers import AutoModel, AutoConfig
|
| 4 |
from .fusion import GatedFusion, CrossModalAttention
|
| 5 |
from .audio_encoder import AudioEncoder
|
| 6 |
from .time_series_encoder import TimeSeriesEncoder
|
|
@@ -8,7 +8,7 @@ from .interpretability import compute_feature_importance
|
|
| 8 |
|
| 9 |
class MultimodalStockPredictor(nn.Module):
|
| 10 |
def __init__(self,
|
| 11 |
-
text_model_name="
|
| 12 |
vision_model_name=None,
|
| 13 |
tabular_dim=64, # <-- Change this to match your data, e.g., tabular_dim=5
|
| 14 |
audio_dim=None,
|
|
@@ -50,6 +50,8 @@ class MultimodalStockPredictor(nn.Module):
|
|
| 50 |
use_mixed_precision (bool): Enable mixed precision training.
|
| 51 |
...existing code...
|
| 52 |
"""
|
|
|
|
|
|
|
| 53 |
super().__init__()
|
| 54 |
self.tabular_dim = tabular_dim # Save for runtime check
|
| 55 |
# Text encoder (large transformer)
|
|
@@ -563,4 +565,4 @@ if __name__ == "__main__":
|
|
| 563 |
test_fusion_cross_attention()
|
| 564 |
test_input_validation()
|
| 565 |
test_interpretability()
|
| 566 |
-
test_mixed_precision()
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
+
from transformers import AutoModel, AutoConfig,AutoTokenizer
|
| 4 |
from .fusion import GatedFusion, CrossModalAttention
|
| 5 |
from .audio_encoder import AudioEncoder
|
| 6 |
from .time_series_encoder import TimeSeriesEncoder
|
|
|
|
| 8 |
|
| 9 |
class MultimodalStockPredictor(nn.Module):
|
| 10 |
def __init__(self,
|
| 11 |
+
text_model_name="albert-large-v2",
|
| 12 |
vision_model_name=None,
|
| 13 |
tabular_dim=64, # <-- Change this to match your data, e.g., tabular_dim=5
|
| 14 |
audio_dim=None,
|
|
|
|
| 50 |
use_mixed_precision (bool): Enable mixed precision training.
|
| 51 |
...existing code...
|
| 52 |
"""
|
| 53 |
+
# Load the tokenizer for ALBERT
|
| 54 |
+
self.tokenizer = AutoTokenizer.from_pretrained(text_model_name)
|
| 55 |
super().__init__()
|
| 56 |
self.tabular_dim = tabular_dim # Save for runtime check
|
| 57 |
# Text encoder (large transformer)
|
|
|
|
| 565 |
test_fusion_cross_attention()
|
| 566 |
test_input_validation()
|
| 567 |
test_interpretability()
|
| 568 |
+
test_mixed_precision()
|