refine-the-codebase (#5)
Browse files- feat: encode, prefixes, matryoshka, etc (df46e74a27635c8ed0e13a53e06588c2d1f933ff)
- README.md +68 -8
- config.json +2 -1
- modeling_jina_embeddings_v4.py +116 -54
README.md
CHANGED
|
@@ -1,24 +1,84 @@
|
|
| 1 |
# Jina Embeddings V4
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
```python
|
|
|
|
| 6 |
from transformers import AutoModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
|
| 8 |
-
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
```python
|
| 13 |
-
text_embedding = model.encode_texts(['test'])
|
| 14 |
```
|
| 15 |
|
| 16 |
-
|
|
|
|
| 17 |
```python
|
|
|
|
|
|
|
| 18 |
from PIL import Image
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
```
|
| 23 |
|
| 24 |
|
|
|
|
| 1 |
# Jina Embeddings V4
|
| 2 |
|
| 3 |
+
|
| 4 |
+
## Examples
|
| 5 |
+
|
| 6 |
+
Encode functions:
|
| 7 |
|
| 8 |
```python
|
| 9 |
+
import torch
|
| 10 |
from transformers import AutoModel
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 14 |
+
|
| 15 |
+
# Load model
|
| 16 |
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
|
| 17 |
+
model = model.to(device)
|
| 18 |
|
| 19 |
+
# Sample data
|
| 20 |
+
texts = ["Here is some sample code", "This is a matching text"]
|
| 21 |
+
image_paths = ['/<path_to_image>']
|
| 22 |
+
images = [Image.open(path) for path in image_paths]
|
| 23 |
+
|
| 24 |
+
# Example 1: Text matching task with single vector embeddings
|
| 25 |
+
model.set_task(task='text-matching')
|
| 26 |
+
|
| 27 |
+
# Generate embeddings with dimension truncation (256)
|
| 28 |
+
img_embeddings = model.encode_images(images=images, truncate_dim=256)
|
| 29 |
+
text_embeddings = model.encode_texts(texts=texts, truncate_dim=256, max_length=512)
|
| 30 |
+
|
| 31 |
+
# Example 2: Retrieval task with multi-vector embeddings
|
| 32 |
+
model.set_task(task='retrieval')
|
| 33 |
+
|
| 34 |
+
# Generate multi-vector embeddings
|
| 35 |
+
img_embeddings = model.encode_images(images=images, vector_type='multi_vector')
|
| 36 |
+
text_embeddings = model.encode_texts(texts=texts, vector_type='multi_vector', text_type='passage')
|
| 37 |
+
|
| 38 |
+
# Example 3: Code task with single vector embeddings
|
| 39 |
+
model.set_task(task='code')
|
| 40 |
+
|
| 41 |
+
code = ["def hello_world():\n print('Hello, World!')"]
|
| 42 |
+
code_embeddings = model.encode_texts(texts=code)
|
| 43 |
|
|
|
|
|
|
|
| 44 |
```
|
| 45 |
|
| 46 |
+
Using the model forward:
|
| 47 |
+
|
| 48 |
```python
|
| 49 |
+
import torch
|
| 50 |
+
from transformers import AutoModel, AutoProcessor
|
| 51 |
from PIL import Image
|
| 52 |
|
| 53 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 54 |
+
|
| 55 |
+
# Load model and processor
|
| 56 |
+
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
|
| 57 |
+
model = model.to(device)
|
| 58 |
+
processor = AutoProcessor.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Sample data
|
| 62 |
+
texts = ["Here is some sample code", "This is a matching text"]
|
| 63 |
+
image_paths = ['/<path_to_image>']
|
| 64 |
+
|
| 65 |
+
# Process text and images
|
| 66 |
+
text_batch = processor.process_texts(texts=texts, prefix="Query", max_length=512)
|
| 67 |
+
images = [Image.open(path) for path in image_paths]
|
| 68 |
+
image_batch = processor.process_images(images=images)
|
| 69 |
+
|
| 70 |
+
# Forward pass
|
| 71 |
+
model.eval()
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
text_batch = {k: v.to(device) for k, v in text_batch.items()}
|
| 74 |
+
image_batch = {k: v.to(device) for k, v in image_batch.items()}
|
| 75 |
+
|
| 76 |
+
with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
|
| 77 |
+
# Get embeddings
|
| 78 |
+
text_embeddings = model.model(**text_batch).single_vec_emb
|
| 79 |
+
img_embeddings = model.model(**image_batch).single_vec_emb
|
| 80 |
+
|
| 81 |
+
|
| 82 |
```
|
| 83 |
|
| 84 |
|
config.json
CHANGED
|
@@ -53,5 +53,6 @@
|
|
| 53 |
"vision_end_token_id": 151653,
|
| 54 |
"vision_start_token_id": 151652,
|
| 55 |
"vision_token_id": 151654,
|
| 56 |
-
"vocab_size": 151936
|
|
|
|
| 57 |
}
|
|
|
|
| 53 |
"vision_end_token_id": 151653,
|
| 54 |
"vision_start_token_id": 151652,
|
| 55 |
"vision_token_id": 151654,
|
| 56 |
+
"vocab_size": 151936,
|
| 57 |
+
"truncate_dim": null
|
| 58 |
}
|
modeling_jina_embeddings_v4.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from enum import Enum
|
|
@@ -15,7 +17,6 @@ from torch import nn
|
|
| 15 |
from torch.utils.data import DataLoader
|
| 16 |
from tqdm import tqdm
|
| 17 |
from transformers import BatchFeature
|
| 18 |
-
from transformers.modeling_utils import PreTrainedModel
|
| 19 |
from transformers.models.qwen2_5_vl import (Qwen2_5_VLForConditionalGeneration,
|
| 20 |
Qwen2_5_VLProcessor)
|
| 21 |
|
|
@@ -33,27 +34,17 @@ class TaskType(str, Enum):
|
|
| 33 |
text_matching = "text-matching"
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
| 37 |
def __init__(self, *args, **kwargs) -> None:
|
| 38 |
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
|
| 39 |
self.assistant_prefix_len = 58
|
| 40 |
self.text_max_length = 8192
|
| 41 |
|
| 42 |
-
@staticmethod
|
| 43 |
-
def round_by_factor(number: float, factor: int) -> int:
|
| 44 |
-
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
| 45 |
-
return round(number / factor) * factor
|
| 46 |
-
|
| 47 |
-
@staticmethod
|
| 48 |
-
def ceil_by_factor(number: float, factor: int) -> int:
|
| 49 |
-
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
| 50 |
-
return math.ceil(number / factor) * factor
|
| 51 |
-
|
| 52 |
-
@staticmethod
|
| 53 |
-
def floor_by_factor(number: float, factor: int) -> int:
|
| 54 |
-
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
| 55 |
-
return math.floor(number / factor) * factor
|
| 56 |
-
|
| 57 |
def process_images(
|
| 58 |
self,
|
| 59 |
images: Union[List[Image.Image], List[List[Image.Image]]],
|
|
@@ -175,7 +166,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 175 |
[pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0
|
| 176 |
)
|
| 177 |
|
| 178 |
-
position_ids, rope_deltas = super().get_rope_index(
|
| 179 |
input_ids=input_ids,
|
| 180 |
image_grid_thw=kwargs.get("image_grid_thw", None),
|
| 181 |
attention_mask=attention_mask,
|
|
@@ -267,10 +258,10 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 267 |
**kwargs,
|
| 268 |
) -> JinaEmbeddingsV4ModelOutput:
|
| 269 |
"""
|
| 270 |
-
Forward pass through
|
| 271 |
Args:
|
| 272 |
-
input_ids (torch.
|
| 273 |
-
attention_mask (torch.
|
| 274 |
Returns:
|
| 275 |
JinaEmbeddingsV4ModelOutput:
|
| 276 |
single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
|
|
@@ -302,17 +293,17 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 302 |
data: List[Union[str, Image.Image]],
|
| 303 |
processor_fn: Callable,
|
| 304 |
desc: str,
|
| 305 |
-
vector_type:
|
| 306 |
return_numpy: bool = False,
|
| 307 |
-
|
|
|
|
| 308 |
) -> Union[np.ndarray, List[torch.Tensor]]:
|
| 309 |
dataloader = DataLoader(
|
| 310 |
dataset=data,
|
| 311 |
-
batch_size=
|
| 312 |
shuffle=False,
|
| 313 |
collate_fn=processor_fn,
|
| 314 |
)
|
| 315 |
-
vector_type = vector_type or "single_vector"
|
| 316 |
results = []
|
| 317 |
self.eval()
|
| 318 |
for batch in tqdm(dataloader, desc=desc):
|
|
@@ -322,8 +313,11 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 322 |
embeddings = self(**batch)
|
| 323 |
if vector_type == "single_vector":
|
| 324 |
embeddings = embeddings.single_vec_emb
|
|
|
|
|
|
|
| 325 |
else:
|
| 326 |
embeddings = embeddings.multi_vec_emb
|
|
|
|
| 327 |
results.append(
|
| 328 |
embeddings.cpu()
|
| 329 |
if return_numpy
|
|
@@ -333,44 +327,98 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 333 |
return np.concatenate([result.numpy() for result in results], axis=0)
|
| 334 |
return [item for sublist in results for item in sublist]
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
def encode_texts(
|
| 337 |
self,
|
| 338 |
-
|
| 339 |
max_length: int = 8192,
|
| 340 |
batch_size: int = 8,
|
| 341 |
vector_type: Optional[str] = None,
|
| 342 |
-
|
| 343 |
-
|
|
|
|
| 344 |
) -> List[torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
processor_fn = partial(
|
| 346 |
-
self.processor.process_texts,
|
|
|
|
|
|
|
| 347 |
)
|
| 348 |
-
|
| 349 |
-
|
|
|
|
|
|
|
| 350 |
processor_fn=processor_fn,
|
| 351 |
-
desc=
|
| 352 |
-
|
| 353 |
batch_size=batch_size,
|
| 354 |
-
**
|
| 355 |
)
|
| 356 |
|
|
|
|
|
|
|
| 357 |
def encode_images(
|
| 358 |
self,
|
| 359 |
-
|
| 360 |
batch_size: int = 8,
|
| 361 |
vector_type: Optional[str] = None,
|
| 362 |
-
|
| 363 |
-
|
| 364 |
) -> List[torch.Tensor]:
|
| 365 |
-
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
| 367 |
processor_fn=self.processor.process_images,
|
| 368 |
-
desc=
|
| 369 |
-
vector_type=vector_type,
|
| 370 |
batch_size=batch_size,
|
| 371 |
-
|
|
|
|
| 372 |
)
|
| 373 |
|
|
|
|
|
|
|
| 374 |
@classmethod
|
| 375 |
def from_pretrained(
|
| 376 |
cls,
|
|
@@ -381,9 +429,15 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 381 |
if "torch_dtype" not in kwargs:
|
| 382 |
kwargs["torch_dtype"] = "auto"
|
| 383 |
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
-
# Get the base model first
|
| 387 |
base_model = super().from_pretrained(
|
| 388 |
pretrained_model_name_or_path, *args, **kwargs
|
| 389 |
)
|
|
@@ -397,36 +451,44 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 397 |
)
|
| 398 |
adapter_dir = os.path.join(adapter_cache_path, "adapters")
|
| 399 |
|
| 400 |
-
# Store adapter directory for later use with set_task
|
| 401 |
base_model.adapter_dir = adapter_dir
|
|
|
|
| 402 |
|
| 403 |
# Create the PEFT model with the requested task adapter
|
| 404 |
peft_model = PeftModel.from_pretrained(
|
| 405 |
-
base_model, os.path.join(adapter_dir, task)
|
| 406 |
)
|
| 407 |
|
| 408 |
# Add set_task method to the PEFT model instance
|
| 409 |
-
def set_task_method(self,
|
| 410 |
"""
|
| 411 |
Set the task adapter for the model.
|
| 412 |
|
| 413 |
Args:
|
| 414 |
-
|
| 415 |
one of ['retrieval', 'text-matching', 'code']
|
| 416 |
"""
|
| 417 |
-
if isinstance(
|
| 418 |
try:
|
| 419 |
-
|
| 420 |
except ValueError:
|
| 421 |
valid_tasks = [t.value for t in TaskType]
|
| 422 |
raise ValueError(
|
| 423 |
-
f"Invalid task: {
|
| 424 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
|
| 426 |
-
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
-
# Bind the
|
| 430 |
peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model))
|
|
|
|
| 431 |
|
| 432 |
return peft_model
|
|
|
|
| 1 |
+
# Jina Embeddings V4 Model implementation was inspired by the ColPali codebase:
|
| 2 |
+
# https://github.com/illuin-tech/colpali
|
| 3 |
+
|
| 4 |
import os
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from enum import Enum
|
|
|
|
| 17 |
from torch.utils.data import DataLoader
|
| 18 |
from tqdm import tqdm
|
| 19 |
from transformers import BatchFeature
|
|
|
|
| 20 |
from transformers.models.qwen2_5_vl import (Qwen2_5_VLForConditionalGeneration,
|
| 21 |
Qwen2_5_VLProcessor)
|
| 22 |
|
|
|
|
| 34 |
text_matching = "text-matching"
|
| 35 |
|
| 36 |
|
| 37 |
+
PREFIX_DICT = {"query": "Query", "passage": "Passage"}
|
| 38 |
+
TRUNCATE_DIMS = [128, 256, 512, 1024]
|
| 39 |
+
VECTOR_TYPES = ["single_vector", "multi_vector"]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
| 43 |
def __init__(self, *args, **kwargs) -> None:
|
| 44 |
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
|
| 45 |
self.assistant_prefix_len = 58
|
| 46 |
self.text_max_length = 8192
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
def process_images(
|
| 49 |
self,
|
| 50 |
images: Union[List[Image.Image], List[List[Image.Image]]],
|
|
|
|
| 166 |
[pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0
|
| 167 |
)
|
| 168 |
|
| 169 |
+
position_ids, rope_deltas = super().get_rope_index(
|
| 170 |
input_ids=input_ids,
|
| 171 |
image_grid_thw=kwargs.get("image_grid_thw", None),
|
| 172 |
attention_mask=attention_mask,
|
|
|
|
| 258 |
**kwargs,
|
| 259 |
) -> JinaEmbeddingsV4ModelOutput:
|
| 260 |
"""
|
| 261 |
+
Forward pass through the model. Returns both single-vector and multi-vector embeddings.
|
| 262 |
Args:
|
| 263 |
+
input_ids (torch.Tensor): The input tokens tensor.
|
| 264 |
+
attention_mask (torch.Tensor): The attention mask tensor.
|
| 265 |
Returns:
|
| 266 |
JinaEmbeddingsV4ModelOutput:
|
| 267 |
single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
|
|
|
|
| 293 |
data: List[Union[str, Image.Image]],
|
| 294 |
processor_fn: Callable,
|
| 295 |
desc: str,
|
| 296 |
+
vector_type: str = "single_vector",
|
| 297 |
return_numpy: bool = False,
|
| 298 |
+
batch_size: int = 32,
|
| 299 |
+
truncate_dim: Optional[int] = None,
|
| 300 |
) -> Union[np.ndarray, List[torch.Tensor]]:
|
| 301 |
dataloader = DataLoader(
|
| 302 |
dataset=data,
|
| 303 |
+
batch_size=batch_size,
|
| 304 |
shuffle=False,
|
| 305 |
collate_fn=processor_fn,
|
| 306 |
)
|
|
|
|
| 307 |
results = []
|
| 308 |
self.eval()
|
| 309 |
for batch in tqdm(dataloader, desc=desc):
|
|
|
|
| 313 |
embeddings = self(**batch)
|
| 314 |
if vector_type == "single_vector":
|
| 315 |
embeddings = embeddings.single_vec_emb
|
| 316 |
+
if truncate_dim is not None:
|
| 317 |
+
embeddings = embeddings[:, :truncate_dim]
|
| 318 |
else:
|
| 319 |
embeddings = embeddings.multi_vec_emb
|
| 320 |
+
|
| 321 |
results.append(
|
| 322 |
embeddings.cpu()
|
| 323 |
if return_numpy
|
|
|
|
| 327 |
return np.concatenate([result.numpy() for result in results], axis=0)
|
| 328 |
return [item for sublist in results for item in sublist]
|
| 329 |
|
| 330 |
+
def _validate_encoding_params(
|
| 331 |
+
self,
|
| 332 |
+
vector_type: Optional[str] = None,
|
| 333 |
+
truncate_dim: Optional[int] = None,
|
| 334 |
+
text_type: Optional[str] = None,
|
| 335 |
+
) -> Dict[str, Any]:
|
| 336 |
+
encode_kwargs = {}
|
| 337 |
+
if text_type is not None:
|
| 338 |
+
if text_type not in PREFIX_DICT:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
f"Invalid text_type: {text_type}. Must be one of {list(PREFIX_DICT.keys())}."
|
| 341 |
+
)
|
| 342 |
+
else:
|
| 343 |
+
encode_kwargs["prefix"] = (
|
| 344 |
+
PREFIX_DICT[text_type]
|
| 345 |
+
if self.task != TaskType.text_matching
|
| 346 |
+
else PREFIX_DICT["query"]
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
vector_type = vector_type or "single_vector"
|
| 350 |
+
if vector_type not in VECTOR_TYPES:
|
| 351 |
+
raise ValueError(
|
| 352 |
+
f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}."
|
| 353 |
+
)
|
| 354 |
+
else:
|
| 355 |
+
encode_kwargs["vector_type"] = vector_type
|
| 356 |
+
|
| 357 |
+
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 358 |
+
if truncate_dim is not None and truncate_dim not in TRUNCATE_DIMS:
|
| 359 |
+
raise ValueError(
|
| 360 |
+
f"Invalid truncate_dim: {truncate_dim}. Must be one of {TRUNCATE_DIMS}."
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
encode_kwargs["truncate_dim"] = truncate_dim
|
| 364 |
+
|
| 365 |
+
return encode_kwargs
|
| 366 |
+
|
| 367 |
def encode_texts(
|
| 368 |
self,
|
| 369 |
+
texts: List[str],
|
| 370 |
max_length: int = 8192,
|
| 371 |
batch_size: int = 8,
|
| 372 |
vector_type: Optional[str] = None,
|
| 373 |
+
return_numpy: bool = False,
|
| 374 |
+
truncate_dim: Optional[int] = None,
|
| 375 |
+
text_type: Optional[str] = None,
|
| 376 |
) -> List[torch.Tensor]:
|
| 377 |
+
text_type = text_type or "query"
|
| 378 |
+
encode_kwargs = self._validate_encoding_params(
|
| 379 |
+
vector_type, truncate_dim, text_type
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
processor_fn = partial(
|
| 383 |
+
self.processor.process_texts,
|
| 384 |
+
max_length=max_length,
|
| 385 |
+
prefix=encode_kwargs.pop("prefix"),
|
| 386 |
)
|
| 387 |
+
|
| 388 |
+
is_single = len(texts) == 1
|
| 389 |
+
embeddings = self._process_batches(
|
| 390 |
+
data=texts,
|
| 391 |
processor_fn=processor_fn,
|
| 392 |
+
desc="Encoding texts...",
|
| 393 |
+
return_numpy=return_numpy,
|
| 394 |
batch_size=batch_size,
|
| 395 |
+
**encode_kwargs,
|
| 396 |
)
|
| 397 |
|
| 398 |
+
return embeddings[0] if is_single else embeddings
|
| 399 |
+
|
| 400 |
def encode_images(
|
| 401 |
self,
|
| 402 |
+
images: List[Image.Image],
|
| 403 |
batch_size: int = 8,
|
| 404 |
vector_type: Optional[str] = None,
|
| 405 |
+
return_numpy: bool = False,
|
| 406 |
+
truncate_dim: Optional[int] = None,
|
| 407 |
) -> List[torch.Tensor]:
|
| 408 |
+
encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
|
| 409 |
+
|
| 410 |
+
is_single = len(images) == 1
|
| 411 |
+
embeddings = self._process_batches(
|
| 412 |
+
data=images,
|
| 413 |
processor_fn=self.processor.process_images,
|
| 414 |
+
desc="Encoding images...",
|
|
|
|
| 415 |
batch_size=batch_size,
|
| 416 |
+
return_numpy=return_numpy,
|
| 417 |
+
**encode_kwargs,
|
| 418 |
)
|
| 419 |
|
| 420 |
+
return embeddings[0] if is_single else embeddings
|
| 421 |
+
|
| 422 |
@classmethod
|
| 423 |
def from_pretrained(
|
| 424 |
cls,
|
|
|
|
| 429 |
if "torch_dtype" not in kwargs:
|
| 430 |
kwargs["torch_dtype"] = "auto"
|
| 431 |
|
| 432 |
+
task_value = kwargs.pop("task", "retrieval")
|
| 433 |
+
try:
|
| 434 |
+
task = TaskType(task_value)
|
| 435 |
+
except ValueError:
|
| 436 |
+
valid_tasks = [t.value for t in TaskType]
|
| 437 |
+
raise ValueError(
|
| 438 |
+
f"Invalid task: {task_value}. Must be one of {valid_tasks}."
|
| 439 |
+
)
|
| 440 |
|
|
|
|
| 441 |
base_model = super().from_pretrained(
|
| 442 |
pretrained_model_name_or_path, *args, **kwargs
|
| 443 |
)
|
|
|
|
| 451 |
)
|
| 452 |
adapter_dir = os.path.join(adapter_cache_path, "adapters")
|
| 453 |
|
|
|
|
| 454 |
base_model.adapter_dir = adapter_dir
|
| 455 |
+
base_model.task = task
|
| 456 |
|
| 457 |
# Create the PEFT model with the requested task adapter
|
| 458 |
peft_model = PeftModel.from_pretrained(
|
| 459 |
+
base_model, os.path.join(adapter_dir, task.value)
|
| 460 |
)
|
| 461 |
|
| 462 |
# Add set_task method to the PEFT model instance
|
| 463 |
+
def set_task_method(self, task: Union[str, TaskType]):
|
| 464 |
"""
|
| 465 |
Set the task adapter for the model.
|
| 466 |
|
| 467 |
Args:
|
| 468 |
+
task (Union[str, TaskType]): The task name. Must be one of TaskType values or
|
| 469 |
one of ['retrieval', 'text-matching', 'code']
|
| 470 |
"""
|
| 471 |
+
if isinstance(task, str):
|
| 472 |
try:
|
| 473 |
+
task = TaskType(task)
|
| 474 |
except ValueError:
|
| 475 |
valid_tasks = [t.value for t in TaskType]
|
| 476 |
raise ValueError(
|
| 477 |
+
f"Invalid task: {task}. Must be one of {valid_tasks}"
|
| 478 |
)
|
| 479 |
+
if self.model.task != task:
|
| 480 |
+
adapter_path = os.path.join(self.adapter_dir, task.value)
|
| 481 |
+
hotswap_adapter(self, adapter_path, adapter_name="default")
|
| 482 |
+
self.model.task = task
|
| 483 |
|
| 484 |
+
def get_task_method(self):
|
| 485 |
+
"""
|
| 486 |
+
Get the task adapter for the model.
|
| 487 |
+
"""
|
| 488 |
+
return self.model.task.value
|
| 489 |
|
| 490 |
+
# Bind the methods to the instance
|
| 491 |
peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model))
|
| 492 |
+
peft_model.get_task = get_task_method.__get__(peft_model, type(peft_model))
|
| 493 |
|
| 494 |
return peft_model
|