YanSte commited on
Commit
e061d88
·
1 Parent(s): b29815b

improved typing

Browse files
lightrag/base.py CHANGED
@@ -1,13 +1,13 @@
 
 
1
  import os
2
  from dataclasses import dataclass, field
3
  from enum import Enum
4
  from typing import (
5
  Any,
6
  Literal,
7
- Optional,
8
  TypedDict,
9
  TypeVar,
10
- Union,
11
  )
12
 
13
  import numpy as np
@@ -115,7 +115,7 @@ class BaseVectorStorage(StorageNameSpace):
115
  class BaseKVStorage(StorageNameSpace):
116
  embedding_func: EmbeddingFunc | None = None
117
 
118
- async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
119
  raise NotImplementedError
120
 
121
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -157,21 +157,23 @@ class BaseGraphStorage(StorageNameSpace):
157
 
158
  """Get a node by its id."""
159
 
160
- async def get_node(self, node_id: str) -> Union[dict[str, str], None]:
161
  raise NotImplementedError
162
 
163
  """Get an edge by its source and target node ids."""
164
 
165
  async def get_edge(
166
- self, source_node_id: str, target_node_id: str
167
- ) -> Union[dict[str, str], None]:
 
 
168
  raise NotImplementedError
169
 
170
  """Get all edges connected to a node."""
171
 
172
  async def get_node_edges(
173
  self, source_node_id: str
174
- ) -> Union[list[tuple[str, str]], None]:
175
  raise NotImplementedError
176
 
177
  """Upsert a node into the graph."""
@@ -236,9 +238,9 @@ class DocProcessingStatus:
236
  """ISO format timestamp when document was created"""
237
  updated_at: str
238
  """ISO format timestamp when document was last updated"""
239
- chunks_count: Optional[int] = None
240
  """Number of chunks after splitting, used for processing"""
241
- error: Optional[str] = None
242
  """Error message if failed"""
243
  metadata: dict[str, Any] = field(default_factory=dict)
244
  """Additional metadata"""
 
1
+ from __future__ import annotations
2
+
3
  import os
4
  from dataclasses import dataclass, field
5
  from enum import Enum
6
  from typing import (
7
  Any,
8
  Literal,
 
9
  TypedDict,
10
  TypeVar,
 
11
  )
12
 
13
  import numpy as np
 
115
  class BaseKVStorage(StorageNameSpace):
116
  embedding_func: EmbeddingFunc | None = None
117
 
118
+ async def get_by_id(self, id: str) -> dict[str, Any] | None:
119
  raise NotImplementedError
120
 
121
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
 
157
 
158
  """Get a node by its id."""
159
 
160
+ async def get_node(self, node_id: str) -> dict[str, str] | None:
161
  raise NotImplementedError
162
 
163
  """Get an edge by its source and target node ids."""
164
 
165
  async def get_edge(
166
+ self,
167
+ source_node_id: str,
168
+ target_node_id: str
169
+ ) -> dict[str, str] | None :
170
  raise NotImplementedError
171
 
172
  """Get all edges connected to a node."""
173
 
174
  async def get_node_edges(
175
  self, source_node_id: str
176
+ ) -> list[tuple[str, str]] | None:
177
  raise NotImplementedError
178
 
179
  """Upsert a node into the graph."""
 
238
  """ISO format timestamp when document was created"""
239
  updated_at: str
240
  """ISO format timestamp when document was last updated"""
241
+ chunks_count: int | None = None
242
  """Number of chunks after splitting, used for processing"""
243
+ error: str | None = None
244
  """Error message if failed"""
245
  metadata: dict[str, Any] = field(default_factory=dict)
246
  """Additional metadata"""
lightrag/exceptions.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import httpx
2
  from typing import Literal
3
 
 
1
+ from __future__ import annotations
2
+
3
  import httpx
4
  from typing import Literal
5
 
lightrag/lightrag.py CHANGED
@@ -6,7 +6,7 @@ import configparser
6
  from dataclasses import asdict, dataclass, field
7
  from datetime import datetime
8
  from functools import partial
9
- from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union, cast
10
 
11
  from .base import (
12
  BaseGraphStorage,
@@ -314,7 +314,7 @@ class LightRAG:
314
  """Maximum number of concurrent embedding function calls."""
315
 
316
  # LLM Configuration
317
- llm_model_func: Union[Callable[..., object], None] = None
318
  """Function for interacting with the large language model (LLM). Must be set before use."""
319
 
320
  llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
@@ -354,7 +354,7 @@ class LightRAG:
354
  chunking_func: Callable[
355
  [
356
  str,
357
- Optional[str],
358
  bool,
359
  int,
360
  int,
 
6
  from dataclasses import asdict, dataclass, field
7
  from datetime import datetime
8
  from functools import partial
9
+ from typing import Any, AsyncIterator, Callable, Iterator, cast
10
 
11
  from .base import (
12
  BaseGraphStorage,
 
314
  """Maximum number of concurrent embedding function calls."""
315
 
316
  # LLM Configuration
317
+ llm_model_func: Callable[..., object] | None = None
318
  """Function for interacting with the large language model (LLM). Must be set before use."""
319
 
320
  llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
 
354
  chunking_func: Callable[
355
  [
356
  str,
357
+ str | None,
358
  bool,
359
  int,
360
  int,
lightrag/llm.py CHANGED
@@ -1,4 +1,6 @@
1
- from typing import List, Dict, Callable, Any
 
 
2
  from pydantic import BaseModel, Field
3
 
4
 
@@ -23,7 +25,7 @@ class Model(BaseModel):
23
  ...,
24
  description="A function that generates the response from the llm. The response must be a string",
25
  )
26
- kwargs: Dict[str, Any] = Field(
27
  ...,
28
  description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
29
  )
@@ -57,7 +59,7 @@ class MultiModel:
57
  ```
58
  """
59
 
60
- def __init__(self, models: List[Model]):
61
  self._models = models
62
  self._current_model = 0
63
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable, Any
4
  from pydantic import BaseModel, Field
5
 
6
 
 
25
  ...,
26
  description="A function that generates the response from the llm. The response must be a string",
27
  )
28
+ kwargs: dict[str, Any] = Field(
29
  ...,
30
  description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
31
  )
 
59
  ```
60
  """
61
 
62
+ def __init__(self, models: list[Model]):
63
  self._models = models
64
  self._current_model = 0
65
 
lightrag/namespace.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from typing import Iterable
2
 
3
 
 
1
+ from __future__ import annotations
2
+
3
  from typing import Iterable
4
 
5
 
lightrag/operate.py CHANGED
@@ -1,8 +1,10 @@
 
 
1
  import asyncio
2
  import json
3
  import re
4
  from tqdm.asyncio import tqdm as tqdm_async
5
- from typing import Any, AsyncIterator, Union
6
  from collections import Counter, defaultdict
7
  from .utils import (
8
  logger,
@@ -36,7 +38,7 @@ import time
36
 
37
  def chunking_by_token_size(
38
  content: str,
39
- split_by_character: Union[str, None] = None,
40
  split_by_character_only: bool = False,
41
  overlap_token_size: int = 128,
42
  max_token_size: int = 1024,
@@ -297,7 +299,7 @@ async def extract_entities(
297
  relationships_vdb: BaseVectorStorage,
298
  global_config: dict[str, str],
299
  llm_response_cache: BaseKVStorage | None = None,
300
- ) -> Union[BaseGraphStorage, None]:
301
  use_llm_func: callable = global_config["llm_model_func"]
302
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
303
  enable_llm_cache_for_entity_extract: bool = global_config[
 
1
+ from __future__ import annotations
2
+
3
  import asyncio
4
  import json
5
  import re
6
  from tqdm.asyncio import tqdm as tqdm_async
7
+ from typing import Any, AsyncIterator
8
  from collections import Counter, defaultdict
9
  from .utils import (
10
  logger,
 
38
 
39
  def chunking_by_token_size(
40
  content: str,
41
+ split_by_character: str | None = None,
42
  split_by_character_only: bool = False,
43
  overlap_token_size: int = 128,
44
  max_token_size: int = 1024,
 
299
  relationships_vdb: BaseVectorStorage,
300
  global_config: dict[str, str],
301
  llm_response_cache: BaseKVStorage | None = None,
302
+ ) -> BaseGraphStorage | None:
303
  use_llm_func: callable = global_config["llm_model_func"]
304
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
305
  enable_llm_cache_for_entity_extract: bool = global_config[
lightrag/prompt.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  GRAPH_FIELD_SEP = "<SEP>"
2
 
3
  PROMPTS = {}
 
1
+ from __future__ import annotations
2
+
3
  GRAPH_FIELD_SEP = "<SEP>"
4
 
5
  PROMPTS = {}
lightrag/types.py CHANGED
@@ -1,16 +1,19 @@
 
 
 
1
  from pydantic import BaseModel
2
- from typing import List, Dict, Any
3
 
4
 
5
  class GPTKeywordExtractionFormat(BaseModel):
6
- high_level_keywords: List[str]
7
- low_level_keywords: List[str]
8
 
9
 
10
  class KnowledgeGraphNode(BaseModel):
11
  id: str
12
- labels: List[str]
13
- properties: Dict[str, Any] # anything else goes here
14
 
15
 
16
  class KnowledgeGraphEdge(BaseModel):
@@ -18,9 +21,9 @@ class KnowledgeGraphEdge(BaseModel):
18
  type: str
19
  source: str # id of source node
20
  target: str # id of target node
21
- properties: Dict[str, Any] # anything else goes here
22
 
23
 
24
  class KnowledgeGraph(BaseModel):
25
- nodes: List[KnowledgeGraphNode] = []
26
- edges: List[KnowledgeGraphEdge] = []
 
1
+
2
+ from __future__ import annotations
3
+
4
  from pydantic import BaseModel
5
+ from typing import Any
6
 
7
 
8
  class GPTKeywordExtractionFormat(BaseModel):
9
+ high_level_keywords: list[str]
10
+ low_level_keywords: list[str]
11
 
12
 
13
  class KnowledgeGraphNode(BaseModel):
14
  id: str
15
+ labels: list[str]
16
+ properties: dict[str, Any] # anything else goes here
17
 
18
 
19
  class KnowledgeGraphEdge(BaseModel):
 
21
  type: str
22
  source: str # id of source node
23
  target: str # id of target node
24
+ properties: dict[str, Any] # anything else goes here
25
 
26
 
27
  class KnowledgeGraph(BaseModel):
28
+ nodes: list[KnowledgeGraphNode] = []
29
+ edges: list[KnowledgeGraphEdge] = []
lightrag/utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import asyncio
2
  import html
3
  import io
@@ -9,7 +11,7 @@ import re
9
  from dataclasses import dataclass
10
  from functools import wraps
11
  from hashlib import md5
12
- from typing import Any, Callable, Union, List, Optional
13
  import xml.etree.ElementTree as ET
14
  import bs4
15
 
@@ -72,7 +74,7 @@ class ReasoningResponse:
72
  tag: str
73
 
74
 
75
- def locate_json_string_body_from_string(content: str) -> Union[str, None]:
76
  """Locate the JSON string body from a string"""
77
  try:
78
  maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -238,7 +240,7 @@ def truncate_list_by_token_size(
238
  return list_data
239
 
240
 
241
- def list_of_list_to_csv(data: List[List[str]]) -> str:
242
  output = io.StringIO()
243
  writer = csv.writer(
244
  output,
@@ -251,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
251
  return output.getvalue()
252
 
253
 
254
- def csv_string_to_list(csv_string: str) -> List[List[str]]:
255
  # Clean the string by removing NUL characters
256
  cleaned_string = csv_string.replace("\0", "")
257
 
@@ -382,7 +384,7 @@ async def get_best_cached_response(
382
  llm_func=None,
383
  original_prompt=None,
384
  cache_type=None,
385
- ) -> Union[str, None]:
386
  logger.debug(
387
  f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
388
  )
@@ -486,7 +488,7 @@ def cosine_similarity(v1, v2):
486
  return dot_product / (norm1 * norm2)
487
 
488
 
489
- def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
490
  """Quantize embedding to specified bits"""
491
  # Convert list to numpy array if needed
492
  if isinstance(embedding, list):
@@ -577,9 +579,9 @@ class CacheData:
577
  args_hash: str
578
  content: str
579
  prompt: str
580
- quantized: Optional[np.ndarray] = None
581
- min_val: Optional[float] = None
582
- max_val: Optional[float] = None
583
  mode: str = "default"
584
  cache_type: str = "query"
585
 
 
1
+ from __future__ import annotations
2
+
3
  import asyncio
4
  import html
5
  import io
 
11
  from dataclasses import dataclass
12
  from functools import wraps
13
  from hashlib import md5
14
+ from typing import Any, Callable
15
  import xml.etree.ElementTree as ET
16
  import bs4
17
 
 
74
  tag: str
75
 
76
 
77
+ def locate_json_string_body_from_string(content: str) -> str | None:
78
  """Locate the JSON string body from a string"""
79
  try:
80
  maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
 
240
  return list_data
241
 
242
 
243
+ def list_of_list_to_csv(data: list[list[str]]) -> str:
244
  output = io.StringIO()
245
  writer = csv.writer(
246
  output,
 
253
  return output.getvalue()
254
 
255
 
256
+ def csv_string_to_list(csv_string: str) -> list[list[str]]:
257
  # Clean the string by removing NUL characters
258
  cleaned_string = csv_string.replace("\0", "")
259
 
 
384
  llm_func=None,
385
  original_prompt=None,
386
  cache_type=None,
387
+ ) -> str | None:
388
  logger.debug(
389
  f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
390
  )
 
488
  return dot_product / (norm1 * norm2)
489
 
490
 
491
+ def quantize_embedding(embedding: np.ndarray | list[float], bits: int=8) -> tuple:
492
  """Quantize embedding to specified bits"""
493
  # Convert list to numpy array if needed
494
  if isinstance(embedding, list):
 
579
  args_hash: str
580
  content: str
581
  prompt: str
582
+ quantized: np.ndarray | None = None
583
+ min_val: float | None = None
584
+ max_val: float | None = None
585
  mode: str = "default"
586
  cache_type: str = "query"
587