Commit
·
5695d43
1
Parent(s):
3103208
feat: add option assume_text_inputs in sentence transformers
Browse files- custom_st.py +37 -19
custom_st.py
CHANGED
|
@@ -22,6 +22,7 @@ class Transformer(nn.Module):
|
|
| 22 |
model_args: Optional[Dict[str, Any]] = None,
|
| 23 |
tokenizer_args: Optional[Dict[str, Any]] = None,
|
| 24 |
image_processor_args: Optional[Dict[str, Any]] = None,
|
|
|
|
| 25 |
cache_dir: Optional[str] = None,
|
| 26 |
backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
|
| 27 |
**_,
|
|
@@ -56,6 +57,8 @@ class Transformer(nn.Module):
|
|
| 56 |
image_processor_args (Dict[str, Any], optional): Additional image processor
|
| 57 |
configuration parameters to be passed to the Hugging Face Transformers
|
| 58 |
image processor
|
|
|
|
|
|
|
| 59 |
cache_dir (str, optional): The Hugging Face Hub cache directory
|
| 60 |
backend (str, optional): Computational backend, only 'torch' is supported
|
| 61 |
|
|
@@ -119,6 +122,7 @@ class Transformer(nn.Module):
|
|
| 119 |
cache_dir=cache_dir,
|
| 120 |
**image_processor_kwargs,
|
| 121 |
)
|
|
|
|
| 122 |
|
| 123 |
# No max_seq_length set. Try to infer from model
|
| 124 |
if max_seq_length is None:
|
|
@@ -151,26 +155,40 @@ class Transformer(nn.Module):
|
|
| 151 |
_images = []
|
| 152 |
_texts = []
|
| 153 |
_image_or_text_descriptors = []
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
_image_or_text_descriptors.append(
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
_image_or_text_descriptors.append(0)
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
encoding = {}
|
| 176 |
if len(_texts):
|
|
|
|
| 22 |
model_args: Optional[Dict[str, Any]] = None,
|
| 23 |
tokenizer_args: Optional[Dict[str, Any]] = None,
|
| 24 |
image_processor_args: Optional[Dict[str, Any]] = None,
|
| 25 |
+
assume_text_inputs: bool = False,
|
| 26 |
cache_dir: Optional[str] = None,
|
| 27 |
backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
|
| 28 |
**_,
|
|
|
|
| 57 |
image_processor_args (Dict[str, Any], optional): Additional image processor
|
| 58 |
configuration parameters to be passed to the Hugging Face Transformers
|
| 59 |
image processor
|
| 60 |
+
assume_text_inputs (bool, optional): If set to `True`, all inputs are
|
| 61 |
+
treated as texts. Defaults to `False`
|
| 62 |
cache_dir (str, optional): The Hugging Face Hub cache directory
|
| 63 |
backend (str, optional): Computational backend, only 'torch' is supported
|
| 64 |
|
|
|
|
| 122 |
cache_dir=cache_dir,
|
| 123 |
**image_processor_kwargs,
|
| 124 |
)
|
| 125 |
+
self.assume_text_inputs = assume_text_inputs
|
| 126 |
|
| 127 |
# No max_seq_length set. Try to infer from model
|
| 128 |
if max_seq_length is None:
|
|
|
|
| 155 |
_images = []
|
| 156 |
_texts = []
|
| 157 |
_image_or_text_descriptors = []
|
| 158 |
+
|
| 159 |
+
if self.assume_text_inputs:
|
| 160 |
+
for sample in texts:
|
| 161 |
+
if isinstance(sample, str):
|
| 162 |
+
_texts.append(sample)
|
| 163 |
+
_image_or_text_descriptors.append(1)
|
| 164 |
+
else:
|
| 165 |
+
for sample in texts:
|
| 166 |
+
if isinstance(sample, str):
|
| 167 |
+
if sample.startswith('http'):
|
| 168 |
+
try:
|
| 169 |
+
response = requests.get(sample)
|
| 170 |
+
_images.append(
|
| 171 |
+
Image.open(BytesIO(response.content)).convert('RGB')
|
| 172 |
+
)
|
| 173 |
+
_image_or_text_descriptors.append(0)
|
| 174 |
+
except Exception as e:
|
| 175 |
+
_ = str(e)
|
| 176 |
+
_texts.append(sample)
|
| 177 |
+
_image_or_text_descriptors.append(1)
|
| 178 |
+
elif sample.startswith('data:image/'):
|
| 179 |
+
_images.append(self._decode_data_image(sample).convert('RGB'))
|
| 180 |
_image_or_text_descriptors.append(0)
|
| 181 |
+
else:
|
| 182 |
+
try:
|
| 183 |
+
_images.append(Image.open(sample).convert('RGB'))
|
| 184 |
+
_image_or_text_descriptors.append(0)
|
| 185 |
+
except Exception as e:
|
| 186 |
+
_ = str(e)
|
| 187 |
+
_texts.append(sample)
|
| 188 |
+
_image_or_text_descriptors.append(1)
|
| 189 |
+
elif isinstance(sample, Image.Image):
|
| 190 |
+
_images.append(sample.convert('RGB'))
|
| 191 |
+
_image_or_text_descriptors.append(0)
|
| 192 |
|
| 193 |
encoding = {}
|
| 194 |
if len(_texts):
|