zrguo commited on
Commit
2d7cde9
·
unverified ·
2 Parent(s): d15359e 45562ea

Merge pull request #423 from davidleon/feature/jina_embedding

Browse files
Files changed (2) hide show
  1. lightrag/llm.py +34 -0
  2. lightrag_jinaai_demo.py +114 -0
lightrag/llm.py CHANGED
@@ -583,6 +583,40 @@ async def openai_embedding(
583
  return np.array([dp.embedding for dp in response.data])
584
 
585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
587
  @retry(
588
  stop=stop_after_attempt(3),
 
583
  return np.array([dp.embedding for dp in response.data])
584
 
585
 
586
+ async def fetch_data(url, headers, data):
587
+ async with aiohttp.ClientSession() as session:
588
+ async with session.post(url, headers=headers, json=data) as response:
589
+ response_json = await response.json()
590
+ data_list = response_json.get("data", [])
591
+ return data_list
592
+
593
+
594
+ async def jina_embedding(
595
+ texts: list[str],
596
+ dimensions: int = 1024,
597
+ late_chunking: bool = False,
598
+ base_url: str = None,
599
+ api_key: str = None,
600
+ ) -> np.ndarray:
601
+ if api_key:
602
+ os.environ["JINA_API_KEY"] = api_key
603
+ url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
604
+ headers = {
605
+ "Content-Type": "application/json",
606
+ "Authorization": f"Bearer {os.environ["JINA_API_KEY"]}",
607
+ }
608
+ data = {
609
+ "model": "jina-embeddings-v3",
610
+ "normalized": True,
611
+ "embedding_type": "float",
612
+ "dimensions": f"{dimensions}",
613
+ "late_chunking": late_chunking,
614
+ "input": texts,
615
+ }
616
+ data_list = await fetch_data(url, headers, data)
617
+ return np.array([dp["embedding"] for dp in data_list])
618
+
619
+
620
  @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
621
  @retry(
622
  stop=stop_after_attempt(3),
lightrag_jinaai_demo.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from lightrag import LightRAG, QueryParam
3
+ from lightrag.utils import EmbeddingFunc
4
+ from lightrag.llm import jina_embedding, openai_complete_if_cache
5
+ import os
6
+ import asyncio
7
+
8
+
9
+ async def embedding_func(texts: list[str]) -> np.ndarray:
10
+ return await jina_embedding(texts, api_key="YourJinaAPIKey")
11
+
12
+
13
+ WORKING_DIR = "./dickens"
14
+
15
+ if not os.path.exists(WORKING_DIR):
16
+ os.mkdir(WORKING_DIR)
17
+
18
+
19
+ async def llm_model_func(
20
+ prompt, system_prompt=None, history_messages=[], **kwargs
21
+ ) -> str:
22
+ return await openai_complete_if_cache(
23
+ "solar-mini",
24
+ prompt,
25
+ system_prompt=system_prompt,
26
+ history_messages=history_messages,
27
+ api_key=os.getenv("UPSTAGE_API_KEY"),
28
+ base_url="https://api.upstage.ai/v1/solar",
29
+ **kwargs,
30
+ )
31
+
32
+
33
+ rag = LightRAG(
34
+ working_dir=WORKING_DIR,
35
+ llm_model_func=llm_model_func,
36
+ embedding_func=EmbeddingFunc(
37
+ embedding_dim=1024, max_token_size=8192, func=embedding_func
38
+ ),
39
+ )
40
+
41
+
42
+ async def lightraginsert(file_path, semaphore):
43
+ async with semaphore:
44
+ try:
45
+ with open(file_path, "r", encoding="utf-8") as f:
46
+ content = f.read()
47
+ except UnicodeDecodeError:
48
+ # If UTF-8 decoding fails, try other encodings
49
+ with open(file_path, "r", encoding="gbk") as f:
50
+ content = f.read()
51
+ await rag.ainsert(content)
52
+
53
+
54
+ async def process_files(directory, concurrency_limit):
55
+ semaphore = asyncio.Semaphore(concurrency_limit)
56
+ tasks = []
57
+ for root, dirs, files in os.walk(directory):
58
+ for f in files:
59
+ file_path = os.path.join(root, f)
60
+ if f.startswith("."):
61
+ continue
62
+ tasks.append(lightraginsert(file_path, semaphore))
63
+ await asyncio.gather(*tasks)
64
+
65
+
66
+ async def main():
67
+ try:
68
+ rag = LightRAG(
69
+ working_dir=WORKING_DIR,
70
+ llm_model_func=llm_model_func,
71
+ embedding_func=EmbeddingFunc(
72
+ embedding_dim=1024,
73
+ max_token_size=8192,
74
+ func=embedding_func,
75
+ ),
76
+ )
77
+
78
+ asyncio.run(process_files(WORKING_DIR, concurrency_limit=4))
79
+
80
+ # Perform naive search
81
+ print(
82
+ await rag.aquery(
83
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
84
+ )
85
+ )
86
+
87
+ # Perform local search
88
+ print(
89
+ await rag.aquery(
90
+ "What are the top themes in this story?", param=QueryParam(mode="local")
91
+ )
92
+ )
93
+
94
+ # Perform global search
95
+ print(
96
+ await rag.aquery(
97
+ "What are the top themes in this story?",
98
+ param=QueryParam(mode="global"),
99
+ )
100
+ )
101
+
102
+ # Perform hybrid search
103
+ print(
104
+ await rag.aquery(
105
+ "What are the top themes in this story?",
106
+ param=QueryParam(mode="hybrid"),
107
+ )
108
+ )
109
+ except Exception as e:
110
+ print(f"An error occurred: {e}")
111
+
112
+
113
+ if __name__ == "__main__":
114
+ asyncio.run(main())