rmanzo28 commited on
Commit
ecc5214
·
verified ·
1 Parent(s): 2b586d1

Update stock_ai.py

Browse files
Files changed (1) hide show
  1. stock_ai.py +5 -3
stock_ai.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import torch.nn as nn
3
- from transformers import AutoModel, AutoConfig
4
  from .fusion import GatedFusion, CrossModalAttention
5
  from .audio_encoder import AudioEncoder
6
  from .time_series_encoder import TimeSeriesEncoder
@@ -8,7 +8,7 @@ from .interpretability import compute_feature_importance
8
 
9
  class MultimodalStockPredictor(nn.Module):
10
  def __init__(self,
11
- text_model_name="bert-large-uncased",
12
  vision_model_name=None,
13
  tabular_dim=64, # <-- Change this to match your data, e.g., tabular_dim=5
14
  audio_dim=None,
@@ -50,6 +50,8 @@ class MultimodalStockPredictor(nn.Module):
50
  use_mixed_precision (bool): Enable mixed precision training.
51
  ...existing code...
52
  """
 
 
53
  super().__init__()
54
  self.tabular_dim = tabular_dim # Save for runtime check
55
  # Text encoder (large transformer)
@@ -563,4 +565,4 @@ if __name__ == "__main__":
563
  test_fusion_cross_attention()
564
  test_input_validation()
565
  test_interpretability()
566
- test_mixed_precision()
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import AutoModel, AutoConfig,AutoTokenizer
4
  from .fusion import GatedFusion, CrossModalAttention
5
  from .audio_encoder import AudioEncoder
6
  from .time_series_encoder import TimeSeriesEncoder
 
8
 
9
  class MultimodalStockPredictor(nn.Module):
10
  def __init__(self,
11
+ text_model_name="albert-large-v2",
12
  vision_model_name=None,
13
  tabular_dim=64, # <-- Change this to match your data, e.g., tabular_dim=5
14
  audio_dim=None,
 
50
  use_mixed_precision (bool): Enable mixed precision training.
51
  ...existing code...
52
  """
53
+ # Load the tokenizer for ALBERT
54
+ self.tokenizer = AutoTokenizer.from_pretrained(text_model_name)
55
  super().__init__()
56
  self.tabular_dim = tabular_dim # Save for runtime check
57
  # Text encoder (large transformer)
 
565
  test_fusion_cross_attention()
566
  test_input_validation()
567
  test_interpretability()
568
+ test_mixed_precision()