ultrageopro commited on
Commit
e1e3f6b
·
unverified ·
1 Parent(s): cf0bd5c

feat: trimming the model’s reasoning

Browse files
Files changed (3) hide show
  1. README.md +6 -0
  2. lightrag/llm/ollama.py +16 -2
  3. lightrag/utils.py +33 -0
README.md CHANGED
@@ -338,6 +338,12 @@ rag = LightRAG(
338
 
339
  There fully functional example `examples/lightrag_ollama_demo.py` that utilizes `gemma2:2b` model, runs only 4 requests in parallel and set context size to 32k.
340
 
 
 
 
 
 
 
341
  #### Low RAM GPUs
342
 
343
  In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`.
 
338
 
339
  There fully functional example `examples/lightrag_ollama_demo.py` that utilizes `gemma2:2b` model, runs only 4 requests in parallel and set context size to 32k.
340
 
341
+ #### Using "Thinking" Models (e.g., DeepSeek)
342
+
343
+ To return only the model's response, you can pass `reasoning_tag` in `llm_model_kwargs`.
344
+
345
+ For example, for DeepSeek models, `reasoning_tag` should be set to `think`.
346
+
347
  #### Low RAM GPUs
348
 
349
  In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`.
lightrag/llm/ollama.py CHANGED
@@ -66,6 +66,7 @@ from lightrag.exceptions import (
66
  RateLimitError,
67
  APITimeoutError,
68
  )
 
69
  import numpy as np
70
  from typing import Union
71
 
@@ -85,6 +86,7 @@ async def ollama_model_if_cache(
85
  **kwargs,
86
  ) -> Union[str, AsyncIterator[str]]:
87
  stream = True if kwargs.get("stream") else False
 
88
  kwargs.pop("max_tokens", None)
89
  # kwargs.pop("response_format", None) # allow json
90
  host = kwargs.pop("host", None)
@@ -105,7 +107,7 @@ async def ollama_model_if_cache(
105
 
106
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
107
  if stream:
108
- """cannot cache stream response"""
109
 
110
  async def inner():
111
  async for chunk in response:
@@ -113,7 +115,19 @@ async def ollama_model_if_cache(
113
 
114
  return inner()
115
  else:
116
- return response["message"]["content"]
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  async def ollama_model_complete(
 
66
  RateLimitError,
67
  APITimeoutError,
68
  )
69
+ from lightrag.utils import extract_reasoning
70
  import numpy as np
71
  from typing import Union
72
 
 
86
  **kwargs,
87
  ) -> Union[str, AsyncIterator[str]]:
88
  stream = True if kwargs.get("stream") else False
89
+ reasoning_tag = kwargs.pop("reasoning_tag", None)
90
  kwargs.pop("max_tokens", None)
91
  # kwargs.pop("response_format", None) # allow json
92
  host = kwargs.pop("host", None)
 
107
 
108
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
109
  if stream:
110
+ """cannot cache stream response and process reasoning"""
111
 
112
  async def inner():
113
  async for chunk in response:
 
115
 
116
  return inner()
117
  else:
118
+ model_response = response["message"]["content"]
119
+
120
+ """
121
+ If the model also wraps its thoughts in a specific tag,
122
+ this information is not needed for the final
123
+ response and can simply be trimmed.
124
+ """
125
+
126
+ return (
127
+ model_response
128
+ if reasoning_tag is None
129
+ else extract_reasoning(model_response, reasoning_tag).response_content
130
+ )
131
 
132
 
133
  async def ollama_model_complete(
lightrag/utils.py CHANGED
@@ -11,6 +11,7 @@ from functools import wraps
11
  from hashlib import md5
12
  from typing import Any, Union, List, Optional
13
  import xml.etree.ElementTree as ET
 
14
 
15
  import numpy as np
16
  import tiktoken
@@ -64,6 +65,13 @@ class EmbeddingFunc:
64
  return await self.func(*args, **kwargs)
65
 
66
 
 
 
 
 
 
 
 
67
  def locate_json_string_body_from_string(content: str) -> Union[str, None]:
68
  """Locate the JSON string body from a string"""
69
  try:
@@ -666,3 +674,28 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
666
  )
667
 
668
  return "\n".join(formatted_turns)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from hashlib import md5
12
  from typing import Any, Union, List, Optional
13
  import xml.etree.ElementTree as ET
14
+ import bs4
15
 
16
  import numpy as np
17
  import tiktoken
 
65
  return await self.func(*args, **kwargs)
66
 
67
 
68
+ @dataclass
69
+ class ReasoningResponse:
70
+ reasoning_content: str
71
+ response_content: str
72
+ tag: str
73
+
74
+
75
  def locate_json_string_body_from_string(content: str) -> Union[str, None]:
76
  """Locate the JSON string body from a string"""
77
  try:
 
674
  )
675
 
676
  return "\n".join(formatted_turns)
677
+
678
+
679
+ def extract_reasoning(response: str, tag: str) -> ReasoningResponse:
680
+ """Extract the reasoning section and the following section from the LLM response.
681
+
682
+ Args:
683
+ response: LLM response
684
+ tag: Tag to extract
685
+ Returns:
686
+ ReasoningResponse: Reasoning section and following section
687
+
688
+ """
689
+ soup = bs4.BeautifulSoup(response, "html.parser")
690
+
691
+ reasoning_section = soup.find(tag)
692
+ if reasoning_section is None:
693
+ return ReasoningResponse(None, response, tag)
694
+ reasoning_content = reasoning_section.get_text().strip()
695
+
696
+ after_reasoning_section = reasoning_section.next_sibling
697
+ if after_reasoning_section is None:
698
+ return ReasoningResponse(reasoning_content, "", tag)
699
+ after_reasoning_content = after_reasoning_section.get_text().strip()
700
+
701
+ return ReasoningResponse(reasoning_content, after_reasoning_content, tag)