Jason Guo commited on
Commit
577f5ec
·
1 Parent(s): 1e5c642

Modify the chat_complete method to support keywords extraction.

Browse files
Files changed (2) hide show
  1. examples/lightrag_zhipu_demo.py +61 -0
  2. lightrag/llm.py +174 -1
examples/lightrag_zhipu_demo.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import inspect
4
+ import logging
5
+
6
+ from dotenv import load_dotenv
7
+
8
+ from lightrag import LightRAG, QueryParam
9
+ from lightrag.llm import zhipu_complete, zhipu_embedding
10
+ from lightrag.utils import EmbeddingFunc
11
+
12
+ WORKING_DIR = "./dickens"
13
+
14
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
15
+
16
+ if not os.path.exists(WORKING_DIR):
17
+ os.mkdir(WORKING_DIR)
18
+
19
+ api_key = os.environ.get("ZHIPUAI_API_KEY")
20
+ if api_key is None:
21
+ raise Exception("Please set ZHIPU_API_KEY in your environment")
22
+
23
+
24
+
25
+ rag = LightRAG(
26
+ working_dir=WORKING_DIR,
27
+ llm_model_func=zhipu_complete,
28
+ llm_model_name="glm-4-flashx", # Using the most cost/performance balance model, but you can change it here.
29
+ llm_model_max_async=4,
30
+ llm_model_max_token_size=32768,
31
+ embedding_func=EmbeddingFunc(
32
+ embedding_dim=2048, # Zhipu embedding-3 dimension
33
+ max_token_size=8192,
34
+ func=lambda texts: zhipu_embedding(
35
+ texts
36
+ ),
37
+ ),
38
+ )
39
+
40
+ with open("./book.txt", "r", encoding="utf-8") as f:
41
+ rag.insert(f.read())
42
+
43
+ # Perform naive search
44
+ print(
45
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
46
+ )
47
+
48
+ # Perform local search
49
+ print(
50
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
51
+ )
52
+
53
+ # Perform global search
54
+ print(
55
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
56
+ )
57
+
58
+ # Perform hybrid search
59
+ print(
60
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
61
+ )
lightrag/llm.py CHANGED
@@ -4,7 +4,7 @@ import json
4
  import os
5
  import struct
6
  from functools import lru_cache
7
- from typing import List, Dict, Callable, Any, Union
8
  import aioboto3
9
  import aiohttp
10
  import numpy as np
@@ -596,6 +596,179 @@ async def ollama_model_complete(
596
  )
597
 
598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
  @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
600
  @retry(
601
  stop=stop_after_attempt(3),
 
4
  import os
5
  import struct
6
  from functools import lru_cache
7
+ from typing import List, Dict, Callable, Any, Union, Optional
8
  import aioboto3
9
  import aiohttp
10
  import numpy as np
 
596
  )
597
 
598
 
599
+ @retry(
600
+ stop=stop_after_attempt(3),
601
+ wait=wait_exponential(multiplier=1, min=4, max=10),
602
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
603
+ )
604
+ async def zhipu_complete_if_cache(
605
+ prompt: Union[str, List[Dict[str, str]]],
606
+ model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
607
+ api_key: Optional[str] = None,
608
+ system_prompt: Optional[str] = None,
609
+ history_messages: List[Dict[str, str]] = [],
610
+ **kwargs
611
+ ) -> str:
612
+ # dynamically load ZhipuAI
613
+ try:
614
+ from zhipuai import ZhipuAI
615
+ except ImportError:
616
+ raise ImportError("Please install zhipuai before initialize zhipuai backend.")
617
+
618
+ if api_key:
619
+ client = ZhipuAI(api_key=api_key)
620
+ else:
621
+ # please set ZHIPUAI_API_KEY in your environment
622
+ # os.environ["ZHIPUAI_API_KEY"]
623
+ client = ZhipuAI()
624
+
625
+ messages = []
626
+
627
+ if not system_prompt:
628
+ system_prompt = "You are a helpful assistant. 注意内容里的敏感词用***替换。"
629
+
630
+ # Add system prompt if provided
631
+ if system_prompt:
632
+ messages.append({"role": "system", "content": system_prompt})
633
+ messages.extend(history_messages)
634
+ messages.append({"role": "user", "content": prompt})
635
+
636
+ # Add debug logging
637
+ logger.debug("===== Query Input to LLM =====")
638
+ logger.debug(f"Query: {prompt}")
639
+ logger.debug(f"System prompt: {system_prompt}")
640
+
641
+ # Remove unsupported kwargs
642
+ kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'keyword_extraction']}
643
+
644
+ response = client.chat.completions.create(
645
+ model=model,
646
+ messages=messages,
647
+ **kwargs
648
+ )
649
+
650
+ return response.choices[0].message.content
651
+
652
+
653
+ async def zhipu_complete(
654
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
655
+ ):
656
+ # Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
657
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
658
+
659
+ if keyword_extraction:
660
+ # Add a system prompt to guide the model to return JSON format
661
+ extraction_prompt = """You are a helpful assistant that extracts keywords from text.
662
+ Please analyze the content and extract two types of keywords:
663
+ 1. High-level keywords: Important concepts and main themes
664
+ 2. Low-level keywords: Specific details and supporting elements
665
+
666
+ Return your response in this exact JSON format:
667
+ {
668
+ "high_level_keywords": ["keyword1", "keyword2"],
669
+ "low_level_keywords": ["keyword1", "keyword2", "keyword3"]
670
+ }
671
+
672
+ Only return the JSON, no other text."""
673
+
674
+ # Combine with existing system prompt if any
675
+ if system_prompt:
676
+ system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
677
+ else:
678
+ system_prompt = extraction_prompt
679
+
680
+ try:
681
+ response = await zhipu_complete_if_cache(
682
+ prompt=prompt,
683
+ system_prompt=system_prompt,
684
+ history_messages=history_messages,
685
+ **kwargs
686
+ )
687
+
688
+ # Try to parse as JSON
689
+ try:
690
+ data = json.loads(response)
691
+ return GPTKeywordExtractionFormat(
692
+ high_level_keywords=data.get("high_level_keywords", []),
693
+ low_level_keywords=data.get("low_level_keywords", [])
694
+ )
695
+ except json.JSONDecodeError:
696
+ # If direct JSON parsing fails, try to extract JSON from text
697
+ match = re.search(r"\{[\s\S]*\}", response)
698
+ if match:
699
+ try:
700
+ data = json.loads(match.group())
701
+ return GPTKeywordExtractionFormat(
702
+ high_level_keywords=data.get("high_level_keywords", []),
703
+ low_level_keywords=data.get("low_level_keywords", [])
704
+ )
705
+ except json.JSONDecodeError:
706
+ pass
707
+
708
+ # If all parsing fails, log warning and return empty format
709
+ logger.warning(f"Failed to parse keyword extraction response: {response}")
710
+ return GPTKeywordExtractionFormat(
711
+ high_level_keywords=[], low_level_keywords=[]
712
+ )
713
+ except Exception as e:
714
+ logger.error(f"Error during keyword extraction: {str(e)}")
715
+ return GPTKeywordExtractionFormat(
716
+ high_level_keywords=[], low_level_keywords=[]
717
+ )
718
+ else:
719
+ # For non-keyword-extraction, just return the raw response string
720
+ return await zhipu_complete_if_cache(
721
+ prompt=prompt,
722
+ system_prompt=system_prompt,
723
+ history_messages=history_messages,
724
+ **kwargs
725
+ )
726
+
727
+
728
+ @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
729
+ @retry(
730
+ stop=stop_after_attempt(3),
731
+ wait=wait_exponential(multiplier=1, min=4, max=60),
732
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
733
+ )
734
+ async def zhipu_embedding(
735
+ texts: list[str],
736
+ model: str = "embedding-3",
737
+ api_key: str = None,
738
+ **kwargs
739
+ ) -> np.ndarray:
740
+
741
+ # dynamically load ZhipuAI
742
+ try:
743
+ from zhipuai import ZhipuAI
744
+ except ImportError:
745
+ raise ImportError("Please install zhipuai before initialize zhipuai backend.")
746
+ if api_key:
747
+ client = ZhipuAI(api_key=api_key)
748
+ else:
749
+ # please set ZHIPUAI_API_KEY in your environment
750
+ # os.environ["ZHIPUAI_API_KEY"]
751
+ client = ZhipuAI()
752
+
753
+ # Convert single text to list if needed
754
+ if isinstance(texts, str):
755
+ texts = [texts]
756
+
757
+ embeddings = []
758
+ for text in texts:
759
+ try:
760
+ response = client.embeddings.create(
761
+ model=model,
762
+ input=[text],
763
+ **kwargs
764
+ )
765
+ embeddings.append(response.data[0].embedding)
766
+ except Exception as e:
767
+ raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
768
+
769
+ return np.array(embeddings)
770
+
771
+
772
  @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
773
  @retry(
774
  stop=stop_after_attempt(3),