Jason Guo
commited on
Commit
·
577f5ec
1
Parent(s):
1e5c642
Modify the chat_complete method to support keywords extraction.
Browse files- examples/lightrag_zhipu_demo.py +61 -0
- lightrag/llm.py +174 -1
examples/lightrag_zhipu_demo.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
import inspect
|
4 |
+
import logging
|
5 |
+
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
from lightrag import LightRAG, QueryParam
|
9 |
+
from lightrag.llm import zhipu_complete, zhipu_embedding
|
10 |
+
from lightrag.utils import EmbeddingFunc
|
11 |
+
|
12 |
+
WORKING_DIR = "./dickens"
|
13 |
+
|
14 |
+
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
|
15 |
+
|
16 |
+
if not os.path.exists(WORKING_DIR):
|
17 |
+
os.mkdir(WORKING_DIR)
|
18 |
+
|
19 |
+
api_key = os.environ.get("ZHIPUAI_API_KEY")
|
20 |
+
if api_key is None:
|
21 |
+
raise Exception("Please set ZHIPU_API_KEY in your environment")
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
rag = LightRAG(
|
26 |
+
working_dir=WORKING_DIR,
|
27 |
+
llm_model_func=zhipu_complete,
|
28 |
+
llm_model_name="glm-4-flashx", # Using the most cost/performance balance model, but you can change it here.
|
29 |
+
llm_model_max_async=4,
|
30 |
+
llm_model_max_token_size=32768,
|
31 |
+
embedding_func=EmbeddingFunc(
|
32 |
+
embedding_dim=2048, # Zhipu embedding-3 dimension
|
33 |
+
max_token_size=8192,
|
34 |
+
func=lambda texts: zhipu_embedding(
|
35 |
+
texts
|
36 |
+
),
|
37 |
+
),
|
38 |
+
)
|
39 |
+
|
40 |
+
with open("./book.txt", "r", encoding="utf-8") as f:
|
41 |
+
rag.insert(f.read())
|
42 |
+
|
43 |
+
# Perform naive search
|
44 |
+
print(
|
45 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
46 |
+
)
|
47 |
+
|
48 |
+
# Perform local search
|
49 |
+
print(
|
50 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
51 |
+
)
|
52 |
+
|
53 |
+
# Perform global search
|
54 |
+
print(
|
55 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
56 |
+
)
|
57 |
+
|
58 |
+
# Perform hybrid search
|
59 |
+
print(
|
60 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
61 |
+
)
|
lightrag/llm.py
CHANGED
@@ -4,7 +4,7 @@ import json
|
|
4 |
import os
|
5 |
import struct
|
6 |
from functools import lru_cache
|
7 |
-
from typing import List, Dict, Callable, Any, Union
|
8 |
import aioboto3
|
9 |
import aiohttp
|
10 |
import numpy as np
|
@@ -596,6 +596,179 @@ async def ollama_model_complete(
|
|
596 |
)
|
597 |
|
598 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
599 |
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
600 |
@retry(
|
601 |
stop=stop_after_attempt(3),
|
|
|
4 |
import os
|
5 |
import struct
|
6 |
from functools import lru_cache
|
7 |
+
from typing import List, Dict, Callable, Any, Union, Optional
|
8 |
import aioboto3
|
9 |
import aiohttp
|
10 |
import numpy as np
|
|
|
596 |
)
|
597 |
|
598 |
|
599 |
+
@retry(
|
600 |
+
stop=stop_after_attempt(3),
|
601 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
602 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
603 |
+
)
|
604 |
+
async def zhipu_complete_if_cache(
|
605 |
+
prompt: Union[str, List[Dict[str, str]]],
|
606 |
+
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
|
607 |
+
api_key: Optional[str] = None,
|
608 |
+
system_prompt: Optional[str] = None,
|
609 |
+
history_messages: List[Dict[str, str]] = [],
|
610 |
+
**kwargs
|
611 |
+
) -> str:
|
612 |
+
# dynamically load ZhipuAI
|
613 |
+
try:
|
614 |
+
from zhipuai import ZhipuAI
|
615 |
+
except ImportError:
|
616 |
+
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
|
617 |
+
|
618 |
+
if api_key:
|
619 |
+
client = ZhipuAI(api_key=api_key)
|
620 |
+
else:
|
621 |
+
# please set ZHIPUAI_API_KEY in your environment
|
622 |
+
# os.environ["ZHIPUAI_API_KEY"]
|
623 |
+
client = ZhipuAI()
|
624 |
+
|
625 |
+
messages = []
|
626 |
+
|
627 |
+
if not system_prompt:
|
628 |
+
system_prompt = "You are a helpful assistant. 注意内容里的敏感词用***替换。"
|
629 |
+
|
630 |
+
# Add system prompt if provided
|
631 |
+
if system_prompt:
|
632 |
+
messages.append({"role": "system", "content": system_prompt})
|
633 |
+
messages.extend(history_messages)
|
634 |
+
messages.append({"role": "user", "content": prompt})
|
635 |
+
|
636 |
+
# Add debug logging
|
637 |
+
logger.debug("===== Query Input to LLM =====")
|
638 |
+
logger.debug(f"Query: {prompt}")
|
639 |
+
logger.debug(f"System prompt: {system_prompt}")
|
640 |
+
|
641 |
+
# Remove unsupported kwargs
|
642 |
+
kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'keyword_extraction']}
|
643 |
+
|
644 |
+
response = client.chat.completions.create(
|
645 |
+
model=model,
|
646 |
+
messages=messages,
|
647 |
+
**kwargs
|
648 |
+
)
|
649 |
+
|
650 |
+
return response.choices[0].message.content
|
651 |
+
|
652 |
+
|
653 |
+
async def zhipu_complete(
|
654 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
655 |
+
):
|
656 |
+
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
|
657 |
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
658 |
+
|
659 |
+
if keyword_extraction:
|
660 |
+
# Add a system prompt to guide the model to return JSON format
|
661 |
+
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
|
662 |
+
Please analyze the content and extract two types of keywords:
|
663 |
+
1. High-level keywords: Important concepts and main themes
|
664 |
+
2. Low-level keywords: Specific details and supporting elements
|
665 |
+
|
666 |
+
Return your response in this exact JSON format:
|
667 |
+
{
|
668 |
+
"high_level_keywords": ["keyword1", "keyword2"],
|
669 |
+
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
|
670 |
+
}
|
671 |
+
|
672 |
+
Only return the JSON, no other text."""
|
673 |
+
|
674 |
+
# Combine with existing system prompt if any
|
675 |
+
if system_prompt:
|
676 |
+
system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
|
677 |
+
else:
|
678 |
+
system_prompt = extraction_prompt
|
679 |
+
|
680 |
+
try:
|
681 |
+
response = await zhipu_complete_if_cache(
|
682 |
+
prompt=prompt,
|
683 |
+
system_prompt=system_prompt,
|
684 |
+
history_messages=history_messages,
|
685 |
+
**kwargs
|
686 |
+
)
|
687 |
+
|
688 |
+
# Try to parse as JSON
|
689 |
+
try:
|
690 |
+
data = json.loads(response)
|
691 |
+
return GPTKeywordExtractionFormat(
|
692 |
+
high_level_keywords=data.get("high_level_keywords", []),
|
693 |
+
low_level_keywords=data.get("low_level_keywords", [])
|
694 |
+
)
|
695 |
+
except json.JSONDecodeError:
|
696 |
+
# If direct JSON parsing fails, try to extract JSON from text
|
697 |
+
match = re.search(r"\{[\s\S]*\}", response)
|
698 |
+
if match:
|
699 |
+
try:
|
700 |
+
data = json.loads(match.group())
|
701 |
+
return GPTKeywordExtractionFormat(
|
702 |
+
high_level_keywords=data.get("high_level_keywords", []),
|
703 |
+
low_level_keywords=data.get("low_level_keywords", [])
|
704 |
+
)
|
705 |
+
except json.JSONDecodeError:
|
706 |
+
pass
|
707 |
+
|
708 |
+
# If all parsing fails, log warning and return empty format
|
709 |
+
logger.warning(f"Failed to parse keyword extraction response: {response}")
|
710 |
+
return GPTKeywordExtractionFormat(
|
711 |
+
high_level_keywords=[], low_level_keywords=[]
|
712 |
+
)
|
713 |
+
except Exception as e:
|
714 |
+
logger.error(f"Error during keyword extraction: {str(e)}")
|
715 |
+
return GPTKeywordExtractionFormat(
|
716 |
+
high_level_keywords=[], low_level_keywords=[]
|
717 |
+
)
|
718 |
+
else:
|
719 |
+
# For non-keyword-extraction, just return the raw response string
|
720 |
+
return await zhipu_complete_if_cache(
|
721 |
+
prompt=prompt,
|
722 |
+
system_prompt=system_prompt,
|
723 |
+
history_messages=history_messages,
|
724 |
+
**kwargs
|
725 |
+
)
|
726 |
+
|
727 |
+
|
728 |
+
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
729 |
+
@retry(
|
730 |
+
stop=stop_after_attempt(3),
|
731 |
+
wait=wait_exponential(multiplier=1, min=4, max=60),
|
732 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
733 |
+
)
|
734 |
+
async def zhipu_embedding(
|
735 |
+
texts: list[str],
|
736 |
+
model: str = "embedding-3",
|
737 |
+
api_key: str = None,
|
738 |
+
**kwargs
|
739 |
+
) -> np.ndarray:
|
740 |
+
|
741 |
+
# dynamically load ZhipuAI
|
742 |
+
try:
|
743 |
+
from zhipuai import ZhipuAI
|
744 |
+
except ImportError:
|
745 |
+
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
|
746 |
+
if api_key:
|
747 |
+
client = ZhipuAI(api_key=api_key)
|
748 |
+
else:
|
749 |
+
# please set ZHIPUAI_API_KEY in your environment
|
750 |
+
# os.environ["ZHIPUAI_API_KEY"]
|
751 |
+
client = ZhipuAI()
|
752 |
+
|
753 |
+
# Convert single text to list if needed
|
754 |
+
if isinstance(texts, str):
|
755 |
+
texts = [texts]
|
756 |
+
|
757 |
+
embeddings = []
|
758 |
+
for text in texts:
|
759 |
+
try:
|
760 |
+
response = client.embeddings.create(
|
761 |
+
model=model,
|
762 |
+
input=[text],
|
763 |
+
**kwargs
|
764 |
+
)
|
765 |
+
embeddings.append(response.data[0].embedding)
|
766 |
+
except Exception as e:
|
767 |
+
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
|
768 |
+
|
769 |
+
return np.array(embeddings)
|
770 |
+
|
771 |
+
|
772 |
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
773 |
@retry(
|
774 |
stop=stop_after_attempt(3),
|