Spaces:
Running
Running
Gil Stetler
commited on
Commit
·
218f038
1
Parent(s):
0262268
fix
Browse files- train_autogluon.py +13 -13
train_autogluon.py
CHANGED
|
@@ -7,35 +7,35 @@ def train_bolt_small(
|
|
| 7 |
start="2015-01-01",
|
| 8 |
interval="1d",
|
| 9 |
prediction_length=30,
|
| 10 |
-
time_limit=900, #
|
| 11 |
):
|
| 12 |
"""
|
| 13 |
-
|
| 14 |
-
|
| 15 |
"""
|
| 16 |
print(f"[AutoFT] Lade {ticker} ...")
|
| 17 |
close = fetch_close_series(ticker, start=start, interval=interval)
|
| 18 |
rv = realized_vol(close)
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
df = rv_to_autogluon_df(rv)
|
| 22 |
|
| 23 |
-
#
|
| 24 |
tsdf = TimeSeriesDataFrame.from_data_frame(
|
| 25 |
df,
|
| 26 |
id_column="item_id",
|
| 27 |
timestamp_column="timestamp",
|
| 28 |
-
target_column
|
| 29 |
-
freq="B",
|
| 30 |
)
|
| 31 |
-
#
|
| 32 |
tsdf = tsdf.convert_frequency("B")
|
| 33 |
|
| 34 |
predictor = TimeSeriesPredictor(
|
| 35 |
path="/mnt/data/AutogluonChronosBoltSmall",
|
| 36 |
prediction_length=prediction_length,
|
| 37 |
eval_metric="WQL",
|
| 38 |
-
freq="B",
|
| 39 |
verbosity=2,
|
| 40 |
)
|
| 41 |
|
|
@@ -47,13 +47,13 @@ def train_bolt_small(
|
|
| 47 |
"Chronos": {
|
| 48 |
"model_path": "autogluon/chronos-bolt-small",
|
| 49 |
"fine_tune": True,
|
| 50 |
-
"fine_tune_steps": 200, #
|
| 51 |
"fine_tune_lr": 1e-4,
|
| 52 |
-
"context_length": 128, #
|
| 53 |
"quantile_levels": [0.1, 0.5, 0.9],
|
| 54 |
}
|
| 55 |
},
|
| 56 |
-
time_limit=time_limit, #
|
| 57 |
)
|
| 58 |
|
| 59 |
print("✅ Training abgeschlossen. Modellpfad:", predictor.path)
|
|
|
|
| 7 |
start="2015-01-01",
|
| 8 |
interval="1d",
|
| 9 |
prediction_length=30,
|
| 10 |
+
time_limit=900, # Sekunden (15 Min). Bei Bedarf anpassen.
|
| 11 |
):
|
| 12 |
"""
|
| 13 |
+
Trainiert Chronos-Bolt-Small auf CPU via AutoGluon mit CPU-freundlichen Limits.
|
| 14 |
+
Explizite Business-Day-Frequenz ('B') verhindert Frequency-Fehler.
|
| 15 |
"""
|
| 16 |
print(f"[AutoFT] Lade {ticker} ...")
|
| 17 |
close = fetch_close_series(ticker, start=start, interval=interval)
|
| 18 |
rv = realized_vol(close)
|
| 19 |
|
| 20 |
+
# tidy DataFrame: columns = item_id, timestamp, target
|
| 21 |
+
df = rv_to_autogluon_df(rv)
|
| 22 |
|
| 23 |
+
# TimeSeriesDataFrame mit expliziter Frequenz erzeugen
|
| 24 |
tsdf = TimeSeriesDataFrame.from_data_frame(
|
| 25 |
df,
|
| 26 |
id_column="item_id",
|
| 27 |
timestamp_column="timestamp",
|
| 28 |
+
# KEIN target_column-Argument in AG 1.4.0 – 'target' wird implizit erkannt
|
| 29 |
+
freq="B",
|
| 30 |
)
|
| 31 |
+
# auf reguläres Business-Day-Gitter bringen (Lücken = NaN)
|
| 32 |
tsdf = tsdf.convert_frequency("B")
|
| 33 |
|
| 34 |
predictor = TimeSeriesPredictor(
|
| 35 |
path="/mnt/data/AutogluonChronosBoltSmall",
|
| 36 |
prediction_length=prediction_length,
|
| 37 |
eval_metric="WQL",
|
| 38 |
+
freq="B",
|
| 39 |
verbosity=2,
|
| 40 |
)
|
| 41 |
|
|
|
|
| 47 |
"Chronos": {
|
| 48 |
"model_path": "autogluon/chronos-bolt-small",
|
| 49 |
"fine_tune": True,
|
| 50 |
+
"fine_tune_steps": 200, # klein halten für CPU
|
| 51 |
"fine_tune_lr": 1e-4,
|
| 52 |
+
"context_length": 128, # klein halten für CPU
|
| 53 |
"quantile_levels": [0.1, 0.5, 0.9],
|
| 54 |
}
|
| 55 |
},
|
| 56 |
+
time_limit=time_limit, # harter Cap, damit HF nicht timeoutet
|
| 57 |
)
|
| 58 |
|
| 59 |
print("✅ Training abgeschlossen. Modellpfad:", predictor.path)
|