feat: Integrate Jina embeddings API support
Browse files- Implemented Jina embedding function
- Add new EMBEDDING_BINDING type of jina for LightRAG Server
- Add env var sample
- env.example +8 -1
- lightrag/api/lightrag_server.py +10 -1
- lightrag/llm/jina.py +78 -12
env.example
CHANGED
@@ -123,7 +123,7 @@ LLM_BINDING_API_KEY=your_api_key
|
|
123 |
####################################################################################
|
124 |
### Embedding Configuration (Should not be changed after the first file processed)
|
125 |
####################################################################################
|
126 |
-
### Embedding Binding type: openai, ollama, lollms, azure_openai
|
127 |
EMBEDDING_BINDING=ollama
|
128 |
EMBEDDING_MODEL=bge-m3:latest
|
129 |
EMBEDDING_DIM=1024
|
@@ -139,6 +139,13 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
|
|
139 |
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
|
140 |
# AZURE_EMBEDDING_API_KEY=your_api_key
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
############################
|
143 |
### Data storage selection
|
144 |
############################
|
|
|
123 |
####################################################################################
|
124 |
### Embedding Configuration (Should not be changed after the first file processed)
|
125 |
####################################################################################
|
126 |
+
### Embedding Binding type: openai, ollama, lollms, azure_openai, jina
|
127 |
EMBEDDING_BINDING=ollama
|
128 |
EMBEDDING_MODEL=bge-m3:latest
|
129 |
EMBEDDING_DIM=1024
|
|
|
139 |
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
|
140 |
# AZURE_EMBEDDING_API_KEY=your_api_key
|
141 |
|
142 |
+
### Jina AI Embedding
|
143 |
+
EMBEDDING_BINDING=jina
|
144 |
+
EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings
|
145 |
+
EMBEDDING_MODEL=jina-embeddings-v4
|
146 |
+
EMBEDDING_DIM=2048
|
147 |
+
EMBEDDING_BINDING_API_KEY=your_api_key
|
148 |
+
|
149 |
############################
|
150 |
### Data storage selection
|
151 |
############################
|
lightrag/api/lightrag_server.py
CHANGED
@@ -89,7 +89,7 @@ def create_app(args):
|
|
89 |
]:
|
90 |
raise Exception("llm binding not supported")
|
91 |
|
92 |
-
if args.embedding_binding not in ["lollms", "ollama", "openai", "azure_openai"]:
|
93 |
raise Exception("embedding binding not supported")
|
94 |
|
95 |
# Set default hosts if not provided
|
@@ -213,6 +213,8 @@ def create_app(args):
|
|
213 |
if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama":
|
214 |
from lightrag.llm.openai import openai_complete_if_cache
|
215 |
from lightrag.llm.ollama import ollama_embed
|
|
|
|
|
216 |
|
217 |
async def openai_alike_model_complete(
|
218 |
prompt,
|
@@ -284,6 +286,13 @@ def create_app(args):
|
|
284 |
api_key=args.embedding_binding_api_key,
|
285 |
)
|
286 |
if args.embedding_binding == "azure_openai"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
else openai_embed(
|
288 |
texts,
|
289 |
model=args.embedding_model,
|
|
|
89 |
]:
|
90 |
raise Exception("llm binding not supported")
|
91 |
|
92 |
+
if args.embedding_binding not in ["lollms", "ollama", "openai", "azure_openai", "jina"]:
|
93 |
raise Exception("embedding binding not supported")
|
94 |
|
95 |
# Set default hosts if not provided
|
|
|
213 |
if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama":
|
214 |
from lightrag.llm.openai import openai_complete_if_cache
|
215 |
from lightrag.llm.ollama import ollama_embed
|
216 |
+
if args.embedding_binding == "jina":
|
217 |
+
from lightrag.llm.jina import jina_embed
|
218 |
|
219 |
async def openai_alike_model_complete(
|
220 |
prompt,
|
|
|
286 |
api_key=args.embedding_binding_api_key,
|
287 |
)
|
288 |
if args.embedding_binding == "azure_openai"
|
289 |
+
else jina_embed(
|
290 |
+
texts,
|
291 |
+
dimensions=args.embedding_dim,
|
292 |
+
base_url=args.embedding_binding_host,
|
293 |
+
api_key=args.embedding_binding_api_key,
|
294 |
+
)
|
295 |
+
if args.embedding_binding == "jina"
|
296 |
else openai_embed(
|
297 |
texts,
|
298 |
model=args.embedding_model,
|
lightrag/llm/jina.py
CHANGED
@@ -2,45 +2,111 @@ import os
|
|
2 |
import pipmaster as pm # Pipmaster for dynamic library install
|
3 |
|
4 |
# install specific modules
|
5 |
-
if not pm.is_installed("
|
6 |
-
pm.install("
|
7 |
if not pm.is_installed("tenacity"):
|
8 |
pm.install("tenacity")
|
9 |
|
10 |
-
|
11 |
import numpy as np
|
12 |
import aiohttp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
async def fetch_data(url, headers, data):
|
16 |
async with aiohttp.ClientSession() as session:
|
17 |
async with session.post(url, headers=headers, json=data) as response:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
response_json = await response.json()
|
19 |
data_list = response_json.get("data", [])
|
20 |
return data_list
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
async def jina_embed(
|
24 |
texts: list[str],
|
25 |
-
dimensions: int =
|
26 |
late_chunking: bool = False,
|
27 |
base_url: str = None,
|
28 |
api_key: str = None,
|
29 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
if api_key:
|
31 |
os.environ["JINA_API_KEY"] = api_key
|
32 |
-
|
|
|
|
|
|
|
|
|
33 |
headers = {
|
34 |
"Content-Type": "application/json",
|
35 |
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
|
36 |
}
|
37 |
data = {
|
38 |
-
"model": "jina-embeddings-
|
39 |
-
"
|
40 |
-
"
|
41 |
-
"dimensions": f"{dimensions}",
|
42 |
-
"late_chunking": late_chunking,
|
43 |
"input": texts,
|
44 |
}
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import pipmaster as pm # Pipmaster for dynamic library install
|
3 |
|
4 |
# install specific modules
|
5 |
+
if not pm.is_installed("aiohttp"):
|
6 |
+
pm.install("aiohttp")
|
7 |
if not pm.is_installed("tenacity"):
|
8 |
pm.install("tenacity")
|
9 |
|
|
|
10 |
import numpy as np
|
11 |
import aiohttp
|
12 |
+
from tenacity import (
|
13 |
+
retry,
|
14 |
+
stop_after_attempt,
|
15 |
+
wait_exponential,
|
16 |
+
retry_if_exception_type,
|
17 |
+
)
|
18 |
+
from lightrag.utils import wrap_embedding_func_with_attrs, logger
|
19 |
|
20 |
|
21 |
async def fetch_data(url, headers, data):
|
22 |
async with aiohttp.ClientSession() as session:
|
23 |
async with session.post(url, headers=headers, json=data) as response:
|
24 |
+
if response.status != 200:
|
25 |
+
error_text = await response.text()
|
26 |
+
logger.error(f"Jina API error {response.status}: {error_text}")
|
27 |
+
raise aiohttp.ClientResponseError(
|
28 |
+
request_info=response.request_info,
|
29 |
+
history=response.history,
|
30 |
+
status=response.status,
|
31 |
+
message=f"Jina API error: {error_text}"
|
32 |
+
)
|
33 |
response_json = await response.json()
|
34 |
data_list = response_json.get("data", [])
|
35 |
return data_list
|
36 |
|
37 |
|
38 |
+
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
|
39 |
+
@retry(
|
40 |
+
stop=stop_after_attempt(3),
|
41 |
+
wait=wait_exponential(multiplier=1, min=4, max=60),
|
42 |
+
retry=(
|
43 |
+
retry_if_exception_type(aiohttp.ClientError)
|
44 |
+
| retry_if_exception_type(aiohttp.ClientResponseError)
|
45 |
+
),
|
46 |
+
)
|
47 |
async def jina_embed(
|
48 |
texts: list[str],
|
49 |
+
dimensions: int = 2048,
|
50 |
late_chunking: bool = False,
|
51 |
base_url: str = None,
|
52 |
api_key: str = None,
|
53 |
) -> np.ndarray:
|
54 |
+
"""Generate embeddings for a list of texts using Jina AI's API.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
texts: List of texts to embed.
|
58 |
+
dimensions: The embedding dimensions (default: 2048 for jina-embeddings-v4).
|
59 |
+
late_chunking: Whether to use late chunking.
|
60 |
+
base_url: Optional base URL for the Jina API.
|
61 |
+
api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
A numpy array of embeddings, one per input text.
|
65 |
+
|
66 |
+
Raises:
|
67 |
+
aiohttp.ClientError: If there is a connection error with the Jina API.
|
68 |
+
aiohttp.ClientResponseError: If the Jina API returns an error response.
|
69 |
+
"""
|
70 |
if api_key:
|
71 |
os.environ["JINA_API_KEY"] = api_key
|
72 |
+
|
73 |
+
if "JINA_API_KEY" not in os.environ:
|
74 |
+
raise ValueError("JINA_API_KEY environment variable is required")
|
75 |
+
|
76 |
+
url = base_url or "https://api.jina.ai/v1/embeddings"
|
77 |
headers = {
|
78 |
"Content-Type": "application/json",
|
79 |
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
|
80 |
}
|
81 |
data = {
|
82 |
+
"model": "jina-embeddings-v4",
|
83 |
+
"task": "text-matching",
|
84 |
+
"dimensions": dimensions,
|
|
|
|
|
85 |
"input": texts,
|
86 |
}
|
87 |
+
|
88 |
+
# Only add optional parameters if they have non-default values
|
89 |
+
if late_chunking:
|
90 |
+
data["late_chunking"] = late_chunking
|
91 |
+
|
92 |
+
logger.debug(f"Jina embedding request: {len(texts)} texts, dimensions: {dimensions}")
|
93 |
+
|
94 |
+
try:
|
95 |
+
data_list = await fetch_data(url, headers, data)
|
96 |
+
|
97 |
+
if not data_list:
|
98 |
+
logger.error("Jina API returned empty data list")
|
99 |
+
raise ValueError("Jina API returned empty data list")
|
100 |
+
|
101 |
+
if len(data_list) != len(texts):
|
102 |
+
logger.error(f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts")
|
103 |
+
raise ValueError(f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts")
|
104 |
+
|
105 |
+
embeddings = np.array([dp["embedding"] for dp in data_list])
|
106 |
+
logger.debug(f"Jina embeddings generated: shape {embeddings.shape}")
|
107 |
+
|
108 |
+
return embeddings
|
109 |
+
|
110 |
+
except Exception as e:
|
111 |
+
logger.error(f"Jina embedding error: {e}")
|
112 |
+
raise
|