tlemagueresse
commited on
Commit
·
45ee714
1
Parent(s):
cac82bc
Replace pkl by joblib
Browse files- model.py +12 -14
- pipeline.pkl → pipeline.joblib +1 -1
model.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import os
|
| 2 |
import struct
|
| 3 |
-
import pickle
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Literal, Union
|
| 6 |
|
|
@@ -9,6 +8,7 @@ import torch
|
|
| 9 |
import lightgbm as lgb
|
| 10 |
import torchaudio
|
| 11 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 12 |
from sklearn.exceptions import NotFittedError
|
| 13 |
from torch import Tensor
|
| 14 |
from torchaudio.transforms import Spectrogram
|
|
@@ -366,7 +366,7 @@ class FastModelHuggingFace:
|
|
| 366 |
Methods
|
| 367 |
-------
|
| 368 |
from_pretrained(repo_id: str, revision: str = "main",
|
| 369 |
-
pipeline_file_name: str = "pipeline.
|
| 370 |
model_file_name: str = "model_lightgbm.txt") -> "FastModelHuggingFace":
|
| 371 |
Loads the FastModel pipeline and model from the Hugging Face Hub.
|
| 372 |
predict(input_data: Union[str, "HuggingFaceDataset"], get_proba: bool = False) -> np.ndarray:
|
|
@@ -392,7 +392,7 @@ class FastModelHuggingFace:
|
|
| 392 |
cls,
|
| 393 |
repo_id: str,
|
| 394 |
revision: str = "main",
|
| 395 |
-
pipeline_file_name: str = "pipeline.
|
| 396 |
model_file_name: str = "model_lightgbm.txt",
|
| 397 |
) -> "FastModelHuggingFace":
|
| 398 |
"""
|
|
@@ -405,7 +405,7 @@ class FastModelHuggingFace:
|
|
| 405 |
revision : str, optional
|
| 406 |
The specific revision of the repository to use (default is "main").
|
| 407 |
pipeline_file_name : str, optional
|
| 408 |
-
The filename of the serialized pipeline (default is "pipeline.
|
| 409 |
model_file_name : str, optional
|
| 410 |
The filename of the LightGBM model (default is "model_lightgbm.txt").
|
| 411 |
|
|
@@ -424,8 +424,7 @@ class FastModelHuggingFace:
|
|
| 424 |
|
| 425 |
if not os.path.exists(pipeline_path):
|
| 426 |
raise FileNotFoundError(f"Pipeline file {pipeline_path} is missing or corrupted.")
|
| 427 |
-
|
| 428 |
-
pipeline = pickle.load(f)
|
| 429 |
|
| 430 |
if not os.path.exists(model_lgbm_path):
|
| 431 |
raise FileNotFoundError(
|
|
@@ -512,10 +511,10 @@ def save_pipeline(
|
|
| 512 |
lgbm_file_name : str, optional
|
| 513 |
The filename for saving the LightGBM model (default is "model_fast_model.txt").
|
| 514 |
pipeline_file_name : str, optional
|
| 515 |
-
The filename for saving the pipeline (default is "pipeline.
|
| 516 |
"""
|
| 517 |
lgbm_file_name = lgbm_file_name or "model_lightgbm.txt"
|
| 518 |
-
pipeline_file_name = pipeline_file_name or "pipeline.
|
| 519 |
|
| 520 |
lightgbm_path = Path(path) / lgbm_file_name
|
| 521 |
if model_class_instance.model:
|
|
@@ -523,8 +522,7 @@ def save_pipeline(
|
|
| 523 |
model_class_instance.model.save_model(model_class_instance.model_file_name)
|
| 524 |
|
| 525 |
pipeline_path = Path(path) / pipeline_file_name
|
| 526 |
-
|
| 527 |
-
pickle.dump(model_class_instance, f)
|
| 528 |
|
| 529 |
|
| 530 |
def load_pipeline(
|
|
@@ -540,7 +538,7 @@ def load_pipeline(
|
|
| 540 |
lgbm_file_name : str, optional
|
| 541 |
The filename for the LightGBM model (default is "model_fast_model.txt").
|
| 542 |
pipeline_file_name : str, optional
|
| 543 |
-
The filename for the pipeline (default is "pipeline.
|
| 544 |
|
| 545 |
Returns
|
| 546 |
-------
|
|
@@ -553,13 +551,13 @@ def load_pipeline(
|
|
| 553 |
If either the LightGBM model or pipeline file is not found.
|
| 554 |
"""
|
| 555 |
lgbm_file_name = lgbm_file_name or "model_fast_model.txt"
|
| 556 |
-
pipeline_file_name = pipeline_file_name or "pipeline.
|
| 557 |
|
| 558 |
pipeline_path = Path(path) / pipeline_file_name
|
| 559 |
if not pipeline_path.exists():
|
| 560 |
raise FileNotFoundError(f"Pipeline file {pipeline_path} not found.")
|
| 561 |
-
|
| 562 |
-
|
| 563 |
|
| 564 |
lightgbm_path = Path(path) / lgbm_file_name
|
| 565 |
if not lightgbm_path.exists():
|
|
|
|
| 1 |
import os
|
| 2 |
import struct
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Literal, Union
|
| 5 |
|
|
|
|
| 8 |
import lightgbm as lgb
|
| 9 |
import torchaudio
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
+
from joblib import dump, load
|
| 12 |
from sklearn.exceptions import NotFittedError
|
| 13 |
from torch import Tensor
|
| 14 |
from torchaudio.transforms import Spectrogram
|
|
|
|
| 366 |
Methods
|
| 367 |
-------
|
| 368 |
from_pretrained(repo_id: str, revision: str = "main",
|
| 369 |
+
pipeline_file_name: str = "pipeline.joblib",
|
| 370 |
model_file_name: str = "model_lightgbm.txt") -> "FastModelHuggingFace":
|
| 371 |
Loads the FastModel pipeline and model from the Hugging Face Hub.
|
| 372 |
predict(input_data: Union[str, "HuggingFaceDataset"], get_proba: bool = False) -> np.ndarray:
|
|
|
|
| 392 |
cls,
|
| 393 |
repo_id: str,
|
| 394 |
revision: str = "main",
|
| 395 |
+
pipeline_file_name: str = "pipeline.joblib",
|
| 396 |
model_file_name: str = "model_lightgbm.txt",
|
| 397 |
) -> "FastModelHuggingFace":
|
| 398 |
"""
|
|
|
|
| 405 |
revision : str, optional
|
| 406 |
The specific revision of the repository to use (default is "main").
|
| 407 |
pipeline_file_name : str, optional
|
| 408 |
+
The filename of the serialized pipeline (default is "pipeline.joblib").
|
| 409 |
model_file_name : str, optional
|
| 410 |
The filename of the LightGBM model (default is "model_lightgbm.txt").
|
| 411 |
|
|
|
|
| 424 |
|
| 425 |
if not os.path.exists(pipeline_path):
|
| 426 |
raise FileNotFoundError(f"Pipeline file {pipeline_path} is missing or corrupted.")
|
| 427 |
+
pipeline = load(pipeline_path)
|
|
|
|
| 428 |
|
| 429 |
if not os.path.exists(model_lgbm_path):
|
| 430 |
raise FileNotFoundError(
|
|
|
|
| 511 |
lgbm_file_name : str, optional
|
| 512 |
The filename for saving the LightGBM model (default is "model_fast_model.txt").
|
| 513 |
pipeline_file_name : str, optional
|
| 514 |
+
The filename for saving the pipeline (default is "pipeline.joblib").
|
| 515 |
"""
|
| 516 |
lgbm_file_name = lgbm_file_name or "model_lightgbm.txt"
|
| 517 |
+
pipeline_file_name = pipeline_file_name or "pipeline.joblib"
|
| 518 |
|
| 519 |
lightgbm_path = Path(path) / lgbm_file_name
|
| 520 |
if model_class_instance.model:
|
|
|
|
| 522 |
model_class_instance.model.save_model(model_class_instance.model_file_name)
|
| 523 |
|
| 524 |
pipeline_path = Path(path) / pipeline_file_name
|
| 525 |
+
dump(model_class_instance, pipeline_path)
|
|
|
|
| 526 |
|
| 527 |
|
| 528 |
def load_pipeline(
|
|
|
|
| 538 |
lgbm_file_name : str, optional
|
| 539 |
The filename for the LightGBM model (default is "model_fast_model.txt").
|
| 540 |
pipeline_file_name : str, optional
|
| 541 |
+
The filename for the pipeline (default is "pipeline.joblib").
|
| 542 |
|
| 543 |
Returns
|
| 544 |
-------
|
|
|
|
| 551 |
If either the LightGBM model or pipeline file is not found.
|
| 552 |
"""
|
| 553 |
lgbm_file_name = lgbm_file_name or "model_fast_model.txt"
|
| 554 |
+
pipeline_file_name = pipeline_file_name or "pipeline.joblib"
|
| 555 |
|
| 556 |
pipeline_path = Path(path) / pipeline_file_name
|
| 557 |
if not pipeline_path.exists():
|
| 558 |
raise FileNotFoundError(f"Pipeline file {pipeline_path} not found.")
|
| 559 |
+
|
| 560 |
+
model_class_instance = load(pipeline_path)
|
| 561 |
|
| 562 |
lightgbm_path = Path(path) / lgbm_file_name
|
| 563 |
if not lightgbm_path.exists():
|
pipeline.pkl → pipeline.joblib
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 834053
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:04a292b51ec618f28089ee0933b30e6623f3abff3e282aafaca15b13c402a847
|
| 3 |
size 834053
|