File size: 4,339 Bytes
f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 2d79e15 f6a036d 842c1e0 f6a036d 2d79e15 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 2d79e15 f6a036d 842c1e0 f6a036d 842c1e0 f6a036d 842c1e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import sys
import warnings
def _check_torch_version():
"""Check if PyTorch version is >= 2.7.1"""
try:
import torch
# Simple version comparison
version_str = torch.__version__.split("+")[0] # Remove any suffixes like +cu118
version_parts = version_str.split(".")
# Compare major version
if int(version_parts[0]) > 2:
return True
# Compare minor version
elif int(version_parts[0]) == 2 and int(version_parts[1]) > 7:
return True
# Compare patch version
elif (
int(version_parts[0]) == 2
and int(version_parts[1]) == 7
and int(version_parts[2]) >= 1
):
return True
return False
except (ImportError, AttributeError, IndexError, ValueError):
return False
def _check_transformers_version():
"""Check if Transformers version is >= 4.51.1"""
try:
import transformers
# Simple version comparison
version_str = transformers.__version__.split("+")[0] # Remove any suffixes
version_parts = version_str.split(".")
# Compare major version
if int(version_parts[0]) > 4:
return True
# Compare minor version
elif int(version_parts[0]) == 4 and int(version_parts[1]) > 51:
return True
# Compare patch version
elif (
int(version_parts[0]) == 4
and int(version_parts[1]) == 51
and int(version_parts[2]) >= 1
):
return True
return False
except (ImportError, AttributeError, IndexError, ValueError):
return False
class DramaModelWrapper:
"""
Factory class for DramaModel that returns the appropriate implementation
based on the Python version.
If Python version >= 3.12, returns an instance of the nested tensor implementation.
Otherwise, returns an instance of the non-nested implementation.
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Instantiate a pretrained model from a pre-trained model configuration.
This method is required by the transformers library's auto model loading mechanism.
Args:
pretrained_model_name_or_path: Path to the pretrained model or its name
*model_args: Additional positional arguments to pass to the implementation
**kwargs: Additional keyword arguments to pass to the implementation
Returns:
An instance of the appropriate DramaModel implementation.
"""
# Check Python version
use_nested = sys.version_info >= (3, 15)
if not use_nested:
warnings.warn(
"Python version < 3.12 detected. Using non-nested implementation."
)
# For Python versions below 3.12, use the non-nested implementation
from .modeling_drama_non_nested import DramaModel as NonNestedDramaModel
return NonNestedDramaModel.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
# Check PyTorch version
if not _check_torch_version():
warnings.warn(
"PyTorch version < 2.7.1 detected. Falling back to non-nested implementation."
)
from .modeling_drama_non_nested import DramaModel as NonNestedDramaModel
return NonNestedDramaModel.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
# Check Transformers version
if not _check_transformers_version():
warnings.warn(
"Transformers version < 4.51.1 detected. Falling back to non-nested implementation."
)
from .modeling_drama_non_nested import DramaModel as NonNestedDramaModel
return NonNestedDramaModel.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
# Use the nested tensor implementation if all requirements are met
from .modeling_drama_nested import DramaModel as NestedDramaModel
return NestedDramaModel.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
|