gzdaniel commited on
Commit
00de9e2
·
1 Parent(s): ef7a62b

feat: Integrate Jina embeddings API support

Browse files

- Implemented Jina embedding function
- Add new EMBEDDING_BINDING type of jina for LightRAG Server
- Add env var sample

env.example CHANGED
@@ -123,7 +123,7 @@ LLM_BINDING_API_KEY=your_api_key
123
  ####################################################################################
124
  ### Embedding Configuration (Should not be changed after the first file processed)
125
  ####################################################################################
126
- ### Embedding Binding type: openai, ollama, lollms, azure_openai
127
  EMBEDDING_BINDING=ollama
128
  EMBEDDING_MODEL=bge-m3:latest
129
  EMBEDDING_DIM=1024
@@ -139,6 +139,13 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
139
  # AZURE_EMBEDDING_ENDPOINT=your_endpoint
140
  # AZURE_EMBEDDING_API_KEY=your_api_key
141
 
 
 
 
 
 
 
 
142
  ############################
143
  ### Data storage selection
144
  ############################
 
123
  ####################################################################################
124
  ### Embedding Configuration (Should not be changed after the first file processed)
125
  ####################################################################################
126
+ ### Embedding Binding type: openai, ollama, lollms, azure_openai, jina
127
  EMBEDDING_BINDING=ollama
128
  EMBEDDING_MODEL=bge-m3:latest
129
  EMBEDDING_DIM=1024
 
139
  # AZURE_EMBEDDING_ENDPOINT=your_endpoint
140
  # AZURE_EMBEDDING_API_KEY=your_api_key
141
 
142
+ ### Jina AI Embedding
143
+ EMBEDDING_BINDING=jina
144
+ EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings
145
+ EMBEDDING_MODEL=jina-embeddings-v4
146
+ EMBEDDING_DIM=2048
147
+ EMBEDDING_BINDING_API_KEY=your_api_key
148
+
149
  ############################
150
  ### Data storage selection
151
  ############################
lightrag/api/lightrag_server.py CHANGED
@@ -89,7 +89,7 @@ def create_app(args):
89
  ]:
90
  raise Exception("llm binding not supported")
91
 
92
- if args.embedding_binding not in ["lollms", "ollama", "openai", "azure_openai"]:
93
  raise Exception("embedding binding not supported")
94
 
95
  # Set default hosts if not provided
@@ -213,6 +213,8 @@ def create_app(args):
213
  if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama":
214
  from lightrag.llm.openai import openai_complete_if_cache
215
  from lightrag.llm.ollama import ollama_embed
 
 
216
 
217
  async def openai_alike_model_complete(
218
  prompt,
@@ -284,6 +286,13 @@ def create_app(args):
284
  api_key=args.embedding_binding_api_key,
285
  )
286
  if args.embedding_binding == "azure_openai"
 
 
 
 
 
 
 
287
  else openai_embed(
288
  texts,
289
  model=args.embedding_model,
 
89
  ]:
90
  raise Exception("llm binding not supported")
91
 
92
+ if args.embedding_binding not in ["lollms", "ollama", "openai", "azure_openai", "jina"]:
93
  raise Exception("embedding binding not supported")
94
 
95
  # Set default hosts if not provided
 
213
  if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama":
214
  from lightrag.llm.openai import openai_complete_if_cache
215
  from lightrag.llm.ollama import ollama_embed
216
+ if args.embedding_binding == "jina":
217
+ from lightrag.llm.jina import jina_embed
218
 
219
  async def openai_alike_model_complete(
220
  prompt,
 
286
  api_key=args.embedding_binding_api_key,
287
  )
288
  if args.embedding_binding == "azure_openai"
289
+ else jina_embed(
290
+ texts,
291
+ dimensions=args.embedding_dim,
292
+ base_url=args.embedding_binding_host,
293
+ api_key=args.embedding_binding_api_key,
294
+ )
295
+ if args.embedding_binding == "jina"
296
  else openai_embed(
297
  texts,
298
  model=args.embedding_model,
lightrag/llm/jina.py CHANGED
@@ -2,45 +2,111 @@ import os
2
  import pipmaster as pm # Pipmaster for dynamic library install
3
 
4
  # install specific modules
5
- if not pm.is_installed("lmdeploy"):
6
- pm.install("lmdeploy")
7
  if not pm.is_installed("tenacity"):
8
  pm.install("tenacity")
9
 
10
-
11
  import numpy as np
12
  import aiohttp
 
 
 
 
 
 
 
13
 
14
 
15
  async def fetch_data(url, headers, data):
16
  async with aiohttp.ClientSession() as session:
17
  async with session.post(url, headers=headers, json=data) as response:
 
 
 
 
 
 
 
 
 
18
  response_json = await response.json()
19
  data_list = response_json.get("data", [])
20
  return data_list
21
 
22
 
 
 
 
 
 
 
 
 
 
23
  async def jina_embed(
24
  texts: list[str],
25
- dimensions: int = 1024,
26
  late_chunking: bool = False,
27
  base_url: str = None,
28
  api_key: str = None,
29
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  if api_key:
31
  os.environ["JINA_API_KEY"] = api_key
32
- url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
 
 
 
 
33
  headers = {
34
  "Content-Type": "application/json",
35
  "Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
36
  }
37
  data = {
38
- "model": "jina-embeddings-v3",
39
- "normalized": True,
40
- "embedding_type": "float",
41
- "dimensions": f"{dimensions}",
42
- "late_chunking": late_chunking,
43
  "input": texts,
44
  }
45
- data_list = await fetch_data(url, headers, data)
46
- return np.array([dp["embedding"] for dp in data_list])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import pipmaster as pm # Pipmaster for dynamic library install
3
 
4
  # install specific modules
5
+ if not pm.is_installed("aiohttp"):
6
+ pm.install("aiohttp")
7
  if not pm.is_installed("tenacity"):
8
  pm.install("tenacity")
9
 
 
10
  import numpy as np
11
  import aiohttp
12
+ from tenacity import (
13
+ retry,
14
+ stop_after_attempt,
15
+ wait_exponential,
16
+ retry_if_exception_type,
17
+ )
18
+ from lightrag.utils import wrap_embedding_func_with_attrs, logger
19
 
20
 
21
  async def fetch_data(url, headers, data):
22
  async with aiohttp.ClientSession() as session:
23
  async with session.post(url, headers=headers, json=data) as response:
24
+ if response.status != 200:
25
+ error_text = await response.text()
26
+ logger.error(f"Jina API error {response.status}: {error_text}")
27
+ raise aiohttp.ClientResponseError(
28
+ request_info=response.request_info,
29
+ history=response.history,
30
+ status=response.status,
31
+ message=f"Jina API error: {error_text}"
32
+ )
33
  response_json = await response.json()
34
  data_list = response_json.get("data", [])
35
  return data_list
36
 
37
 
38
+ @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
39
+ @retry(
40
+ stop=stop_after_attempt(3),
41
+ wait=wait_exponential(multiplier=1, min=4, max=60),
42
+ retry=(
43
+ retry_if_exception_type(aiohttp.ClientError)
44
+ | retry_if_exception_type(aiohttp.ClientResponseError)
45
+ ),
46
+ )
47
  async def jina_embed(
48
  texts: list[str],
49
+ dimensions: int = 2048,
50
  late_chunking: bool = False,
51
  base_url: str = None,
52
  api_key: str = None,
53
  ) -> np.ndarray:
54
+ """Generate embeddings for a list of texts using Jina AI's API.
55
+
56
+ Args:
57
+ texts: List of texts to embed.
58
+ dimensions: The embedding dimensions (default: 2048 for jina-embeddings-v4).
59
+ late_chunking: Whether to use late chunking.
60
+ base_url: Optional base URL for the Jina API.
61
+ api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
62
+
63
+ Returns:
64
+ A numpy array of embeddings, one per input text.
65
+
66
+ Raises:
67
+ aiohttp.ClientError: If there is a connection error with the Jina API.
68
+ aiohttp.ClientResponseError: If the Jina API returns an error response.
69
+ """
70
  if api_key:
71
  os.environ["JINA_API_KEY"] = api_key
72
+
73
+ if "JINA_API_KEY" not in os.environ:
74
+ raise ValueError("JINA_API_KEY environment variable is required")
75
+
76
+ url = base_url or "https://api.jina.ai/v1/embeddings"
77
  headers = {
78
  "Content-Type": "application/json",
79
  "Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
80
  }
81
  data = {
82
+ "model": "jina-embeddings-v4",
83
+ "task": "text-matching",
84
+ "dimensions": dimensions,
 
 
85
  "input": texts,
86
  }
87
+
88
+ # Only add optional parameters if they have non-default values
89
+ if late_chunking:
90
+ data["late_chunking"] = late_chunking
91
+
92
+ logger.debug(f"Jina embedding request: {len(texts)} texts, dimensions: {dimensions}")
93
+
94
+ try:
95
+ data_list = await fetch_data(url, headers, data)
96
+
97
+ if not data_list:
98
+ logger.error("Jina API returned empty data list")
99
+ raise ValueError("Jina API returned empty data list")
100
+
101
+ if len(data_list) != len(texts):
102
+ logger.error(f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts")
103
+ raise ValueError(f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts")
104
+
105
+ embeddings = np.array([dp["embedding"] for dp in data_list])
106
+ logger.debug(f"Jina embeddings generated: shape {embeddings.shape}")
107
+
108
+ return embeddings
109
+
110
+ except Exception as e:
111
+ logger.error(f"Jina embedding error: {e}")
112
+ raise