LarFii
commited on
Commit
·
a928fde
1
Parent(s):
3c73644
Update __version__
Browse files- examples/lightrag_zhipu_demo.py +2 -8
- lightrag/__init__.py +1 -1
- lightrag/kg/milvus_impl.py +3 -1
- lightrag/llm.py +20 -28
- lightrag/storage.py +3 -1
examples/lightrag_zhipu_demo.py
CHANGED
@@ -1,9 +1,6 @@
|
|
1 |
-
import asyncio
|
2 |
import os
|
3 |
-
import inspect
|
4 |
import logging
|
5 |
|
6 |
-
from dotenv import load_dotenv
|
7 |
|
8 |
from lightrag import LightRAG, QueryParam
|
9 |
from lightrag.llm import zhipu_complete, zhipu_embedding
|
@@ -21,7 +18,6 @@ if api_key is None:
|
|
21 |
raise Exception("Please set ZHIPU_API_KEY in your environment")
|
22 |
|
23 |
|
24 |
-
|
25 |
rag = LightRAG(
|
26 |
working_dir=WORKING_DIR,
|
27 |
llm_model_func=zhipu_complete,
|
@@ -31,9 +27,7 @@ rag = LightRAG(
|
|
31 |
embedding_func=EmbeddingFunc(
|
32 |
embedding_dim=2048, # Zhipu embedding-3 dimension
|
33 |
max_token_size=8192,
|
34 |
-
func=lambda texts: zhipu_embedding(
|
35 |
-
texts
|
36 |
-
),
|
37 |
),
|
38 |
)
|
39 |
|
@@ -58,4 +52,4 @@ print(
|
|
58 |
# Perform hybrid search
|
59 |
print(
|
60 |
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
61 |
-
)
|
|
|
|
|
1 |
import os
|
|
|
2 |
import logging
|
3 |
|
|
|
4 |
|
5 |
from lightrag import LightRAG, QueryParam
|
6 |
from lightrag.llm import zhipu_complete, zhipu_embedding
|
|
|
18 |
raise Exception("Please set ZHIPU_API_KEY in your environment")
|
19 |
|
20 |
|
|
|
21 |
rag = LightRAG(
|
22 |
working_dir=WORKING_DIR,
|
23 |
llm_model_func=zhipu_complete,
|
|
|
27 |
embedding_func=EmbeddingFunc(
|
28 |
embedding_dim=2048, # Zhipu embedding-3 dimension
|
29 |
max_token_size=8192,
|
30 |
+
func=lambda texts: zhipu_embedding(texts),
|
|
|
|
|
31 |
),
|
32 |
)
|
33 |
|
|
|
52 |
# Perform hybrid search
|
53 |
print(
|
54 |
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
55 |
+
)
|
lightrag/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
2 |
|
3 |
-
__version__ = "1.0.
|
4 |
__author__ = "Zirui Guo"
|
5 |
__url__ = "https://github.com/HKUDS/LightRAG"
|
|
|
1 |
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
2 |
|
3 |
+
__version__ = "1.0.6"
|
4 |
__author__ = "Zirui Guo"
|
5 |
__url__ = "https://github.com/HKUDS/LightRAG"
|
lightrag/kg/milvus_impl.py
CHANGED
@@ -63,7 +63,9 @@ class MilvusVectorDBStorge(BaseVectorStorage):
|
|
63 |
return result
|
64 |
|
65 |
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
66 |
-
pbar = tqdm_async(
|
|
|
|
|
67 |
embeddings_list = await asyncio.gather(*embedding_tasks)
|
68 |
|
69 |
embeddings = np.concatenate(embeddings_list)
|
|
|
63 |
return result
|
64 |
|
65 |
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
66 |
+
pbar = tqdm_async(
|
67 |
+
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
68 |
+
)
|
69 |
embeddings_list = await asyncio.gather(*embedding_tasks)
|
70 |
|
71 |
embeddings = np.concatenate(embeddings_list)
|
lightrag/llm.py
CHANGED
@@ -604,11 +604,11 @@ async def ollama_model_complete(
|
|
604 |
)
|
605 |
async def zhipu_complete_if_cache(
|
606 |
prompt: Union[str, List[Dict[str, str]]],
|
607 |
-
model: str = "glm-4-flashx",
|
608 |
api_key: Optional[str] = None,
|
609 |
system_prompt: Optional[str] = None,
|
610 |
history_messages: List[Dict[str, str]] = [],
|
611 |
-
**kwargs
|
612 |
) -> str:
|
613 |
# dynamically load ZhipuAI
|
614 |
try:
|
@@ -640,13 +640,11 @@ async def zhipu_complete_if_cache(
|
|
640 |
logger.debug(f"System prompt: {system_prompt}")
|
641 |
|
642 |
# Remove unsupported kwargs
|
643 |
-
kwargs = {
|
|
|
|
|
644 |
|
645 |
-
response = client.chat.completions.create(
|
646 |
-
model=model,
|
647 |
-
messages=messages,
|
648 |
-
**kwargs
|
649 |
-
)
|
650 |
|
651 |
return response.choices[0].message.content
|
652 |
|
@@ -663,13 +661,13 @@ async def zhipu_complete(
|
|
663 |
Please analyze the content and extract two types of keywords:
|
664 |
1. High-level keywords: Important concepts and main themes
|
665 |
2. Low-level keywords: Specific details and supporting elements
|
666 |
-
|
667 |
Return your response in this exact JSON format:
|
668 |
{
|
669 |
"high_level_keywords": ["keyword1", "keyword2"],
|
670 |
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
|
671 |
}
|
672 |
-
|
673 |
Only return the JSON, no other text."""
|
674 |
|
675 |
# Combine with existing system prompt if any
|
@@ -683,15 +681,15 @@ async def zhipu_complete(
|
|
683 |
prompt=prompt,
|
684 |
system_prompt=system_prompt,
|
685 |
history_messages=history_messages,
|
686 |
-
**kwargs
|
687 |
)
|
688 |
-
|
689 |
# Try to parse as JSON
|
690 |
try:
|
691 |
data = json.loads(response)
|
692 |
return GPTKeywordExtractionFormat(
|
693 |
high_level_keywords=data.get("high_level_keywords", []),
|
694 |
-
low_level_keywords=data.get("low_level_keywords", [])
|
695 |
)
|
696 |
except json.JSONDecodeError:
|
697 |
# If direct JSON parsing fails, try to extract JSON from text
|
@@ -701,13 +699,15 @@ async def zhipu_complete(
|
|
701 |
data = json.loads(match.group())
|
702 |
return GPTKeywordExtractionFormat(
|
703 |
high_level_keywords=data.get("high_level_keywords", []),
|
704 |
-
low_level_keywords=data.get("low_level_keywords", [])
|
705 |
)
|
706 |
except json.JSONDecodeError:
|
707 |
pass
|
708 |
-
|
709 |
# If all parsing fails, log warning and return empty format
|
710 |
-
logger.warning(
|
|
|
|
|
711 |
return GPTKeywordExtractionFormat(
|
712 |
high_level_keywords=[], low_level_keywords=[]
|
713 |
)
|
@@ -722,7 +722,7 @@ async def zhipu_complete(
|
|
722 |
prompt=prompt,
|
723 |
system_prompt=system_prompt,
|
724 |
history_messages=history_messages,
|
725 |
-
**kwargs
|
726 |
)
|
727 |
|
728 |
|
@@ -733,13 +733,9 @@ async def zhipu_complete(
|
|
733 |
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
734 |
)
|
735 |
async def zhipu_embedding(
|
736 |
-
texts: list[str],
|
737 |
-
model: str = "embedding-3",
|
738 |
-
api_key: str = None,
|
739 |
-
**kwargs
|
740 |
) -> np.ndarray:
|
741 |
-
|
742 |
-
# dynamically load ZhipuAI
|
743 |
try:
|
744 |
from zhipuai import ZhipuAI
|
745 |
except ImportError:
|
@@ -758,11 +754,7 @@ async def zhipu_embedding(
|
|
758 |
embeddings = []
|
759 |
for text in texts:
|
760 |
try:
|
761 |
-
response = client.embeddings.create(
|
762 |
-
model=model,
|
763 |
-
input=[text],
|
764 |
-
**kwargs
|
765 |
-
)
|
766 |
embeddings.append(response.data[0].embedding)
|
767 |
except Exception as e:
|
768 |
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
|
|
|
604 |
)
|
605 |
async def zhipu_complete_if_cache(
|
606 |
prompt: Union[str, List[Dict[str, str]]],
|
607 |
+
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
|
608 |
api_key: Optional[str] = None,
|
609 |
system_prompt: Optional[str] = None,
|
610 |
history_messages: List[Dict[str, str]] = [],
|
611 |
+
**kwargs,
|
612 |
) -> str:
|
613 |
# dynamically load ZhipuAI
|
614 |
try:
|
|
|
640 |
logger.debug(f"System prompt: {system_prompt}")
|
641 |
|
642 |
# Remove unsupported kwargs
|
643 |
+
kwargs = {
|
644 |
+
k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
|
645 |
+
}
|
646 |
|
647 |
+
response = client.chat.completions.create(model=model, messages=messages, **kwargs)
|
|
|
|
|
|
|
|
|
648 |
|
649 |
return response.choices[0].message.content
|
650 |
|
|
|
661 |
Please analyze the content and extract two types of keywords:
|
662 |
1. High-level keywords: Important concepts and main themes
|
663 |
2. Low-level keywords: Specific details and supporting elements
|
664 |
+
|
665 |
Return your response in this exact JSON format:
|
666 |
{
|
667 |
"high_level_keywords": ["keyword1", "keyword2"],
|
668 |
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
|
669 |
}
|
670 |
+
|
671 |
Only return the JSON, no other text."""
|
672 |
|
673 |
# Combine with existing system prompt if any
|
|
|
681 |
prompt=prompt,
|
682 |
system_prompt=system_prompt,
|
683 |
history_messages=history_messages,
|
684 |
+
**kwargs,
|
685 |
)
|
686 |
+
|
687 |
# Try to parse as JSON
|
688 |
try:
|
689 |
data = json.loads(response)
|
690 |
return GPTKeywordExtractionFormat(
|
691 |
high_level_keywords=data.get("high_level_keywords", []),
|
692 |
+
low_level_keywords=data.get("low_level_keywords", []),
|
693 |
)
|
694 |
except json.JSONDecodeError:
|
695 |
# If direct JSON parsing fails, try to extract JSON from text
|
|
|
699 |
data = json.loads(match.group())
|
700 |
return GPTKeywordExtractionFormat(
|
701 |
high_level_keywords=data.get("high_level_keywords", []),
|
702 |
+
low_level_keywords=data.get("low_level_keywords", []),
|
703 |
)
|
704 |
except json.JSONDecodeError:
|
705 |
pass
|
706 |
+
|
707 |
# If all parsing fails, log warning and return empty format
|
708 |
+
logger.warning(
|
709 |
+
f"Failed to parse keyword extraction response: {response}"
|
710 |
+
)
|
711 |
return GPTKeywordExtractionFormat(
|
712 |
high_level_keywords=[], low_level_keywords=[]
|
713 |
)
|
|
|
722 |
prompt=prompt,
|
723 |
system_prompt=system_prompt,
|
724 |
history_messages=history_messages,
|
725 |
+
**kwargs,
|
726 |
)
|
727 |
|
728 |
|
|
|
733 |
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
734 |
)
|
735 |
async def zhipu_embedding(
|
736 |
+
texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
|
|
|
|
|
|
|
737 |
) -> np.ndarray:
|
738 |
+
# dynamically load ZhipuAI
|
|
|
739 |
try:
|
740 |
from zhipuai import ZhipuAI
|
741 |
except ImportError:
|
|
|
754 |
embeddings = []
|
755 |
for text in texts:
|
756 |
try:
|
757 |
+
response = client.embeddings.create(model=model, input=[text], **kwargs)
|
|
|
|
|
|
|
|
|
758 |
embeddings.append(response.data[0].embedding)
|
759 |
except Exception as e:
|
760 |
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
|
lightrag/storage.py
CHANGED
@@ -103,7 +103,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
103 |
return result
|
104 |
|
105 |
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
106 |
-
pbar = tqdm_async(
|
|
|
|
|
107 |
embeddings_list = await asyncio.gather(*embedding_tasks)
|
108 |
|
109 |
embeddings = np.concatenate(embeddings_list)
|
|
|
103 |
return result
|
104 |
|
105 |
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
106 |
+
pbar = tqdm_async(
|
107 |
+
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
108 |
+
)
|
109 |
embeddings_list = await asyncio.gather(*embedding_tasks)
|
110 |
|
111 |
embeddings = np.concatenate(embeddings_list)
|