improved typing
Browse files- lightrag/base.py +11 -9
- lightrag/exceptions.py +2 -0
- lightrag/lightrag.py +3 -3
- lightrag/llm.py +5 -3
- lightrag/namespace.py +2 -0
- lightrag/operate.py +5 -3
- lightrag/prompt.py +2 -0
- lightrag/types.py +11 -8
- lightrag/utils.py +11 -9
lightrag/base.py
CHANGED
@@ -1,13 +1,13 @@
|
|
|
|
|
|
1 |
import os
|
2 |
from dataclasses import dataclass, field
|
3 |
from enum import Enum
|
4 |
from typing import (
|
5 |
Any,
|
6 |
Literal,
|
7 |
-
Optional,
|
8 |
TypedDict,
|
9 |
TypeVar,
|
10 |
-
Union,
|
11 |
)
|
12 |
|
13 |
import numpy as np
|
@@ -115,7 +115,7 @@ class BaseVectorStorage(StorageNameSpace):
|
|
115 |
class BaseKVStorage(StorageNameSpace):
|
116 |
embedding_func: EmbeddingFunc | None = None
|
117 |
|
118 |
-
async def get_by_id(self, id: str) ->
|
119 |
raise NotImplementedError
|
120 |
|
121 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
@@ -157,21 +157,23 @@ class BaseGraphStorage(StorageNameSpace):
|
|
157 |
|
158 |
"""Get a node by its id."""
|
159 |
|
160 |
-
async def get_node(self, node_id: str) ->
|
161 |
raise NotImplementedError
|
162 |
|
163 |
"""Get an edge by its source and target node ids."""
|
164 |
|
165 |
async def get_edge(
|
166 |
-
self,
|
167 |
-
|
|
|
|
|
168 |
raise NotImplementedError
|
169 |
|
170 |
"""Get all edges connected to a node."""
|
171 |
|
172 |
async def get_node_edges(
|
173 |
self, source_node_id: str
|
174 |
-
) ->
|
175 |
raise NotImplementedError
|
176 |
|
177 |
"""Upsert a node into the graph."""
|
@@ -236,9 +238,9 @@ class DocProcessingStatus:
|
|
236 |
"""ISO format timestamp when document was created"""
|
237 |
updated_at: str
|
238 |
"""ISO format timestamp when document was last updated"""
|
239 |
-
chunks_count:
|
240 |
"""Number of chunks after splitting, used for processing"""
|
241 |
-
error:
|
242 |
"""Error message if failed"""
|
243 |
metadata: dict[str, Any] = field(default_factory=dict)
|
244 |
"""Additional metadata"""
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
import os
|
4 |
from dataclasses import dataclass, field
|
5 |
from enum import Enum
|
6 |
from typing import (
|
7 |
Any,
|
8 |
Literal,
|
|
|
9 |
TypedDict,
|
10 |
TypeVar,
|
|
|
11 |
)
|
12 |
|
13 |
import numpy as np
|
|
|
115 |
class BaseKVStorage(StorageNameSpace):
|
116 |
embedding_func: EmbeddingFunc | None = None
|
117 |
|
118 |
+
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
119 |
raise NotImplementedError
|
120 |
|
121 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
|
|
157 |
|
158 |
"""Get a node by its id."""
|
159 |
|
160 |
+
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
161 |
raise NotImplementedError
|
162 |
|
163 |
"""Get an edge by its source and target node ids."""
|
164 |
|
165 |
async def get_edge(
|
166 |
+
self,
|
167 |
+
source_node_id: str,
|
168 |
+
target_node_id: str
|
169 |
+
) -> dict[str, str] | None :
|
170 |
raise NotImplementedError
|
171 |
|
172 |
"""Get all edges connected to a node."""
|
173 |
|
174 |
async def get_node_edges(
|
175 |
self, source_node_id: str
|
176 |
+
) -> list[tuple[str, str]] | None:
|
177 |
raise NotImplementedError
|
178 |
|
179 |
"""Upsert a node into the graph."""
|
|
|
238 |
"""ISO format timestamp when document was created"""
|
239 |
updated_at: str
|
240 |
"""ISO format timestamp when document was last updated"""
|
241 |
+
chunks_count: int | None = None
|
242 |
"""Number of chunks after splitting, used for processing"""
|
243 |
+
error: str | None = None
|
244 |
"""Error message if failed"""
|
245 |
metadata: dict[str, Any] = field(default_factory=dict)
|
246 |
"""Additional metadata"""
|
lightrag/exceptions.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import httpx
|
2 |
from typing import Literal
|
3 |
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
import httpx
|
4 |
from typing import Literal
|
5 |
|
lightrag/lightrag.py
CHANGED
@@ -6,7 +6,7 @@ import configparser
|
|
6 |
from dataclasses import asdict, dataclass, field
|
7 |
from datetime import datetime
|
8 |
from functools import partial
|
9 |
-
from typing import Any, AsyncIterator, Callable, Iterator,
|
10 |
|
11 |
from .base import (
|
12 |
BaseGraphStorage,
|
@@ -314,7 +314,7 @@ class LightRAG:
|
|
314 |
"""Maximum number of concurrent embedding function calls."""
|
315 |
|
316 |
# LLM Configuration
|
317 |
-
llm_model_func:
|
318 |
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
319 |
|
320 |
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
@@ -354,7 +354,7 @@ class LightRAG:
|
|
354 |
chunking_func: Callable[
|
355 |
[
|
356 |
str,
|
357 |
-
|
358 |
bool,
|
359 |
int,
|
360 |
int,
|
|
|
6 |
from dataclasses import asdict, dataclass, field
|
7 |
from datetime import datetime
|
8 |
from functools import partial
|
9 |
+
from typing import Any, AsyncIterator, Callable, Iterator, cast
|
10 |
|
11 |
from .base import (
|
12 |
BaseGraphStorage,
|
|
|
314 |
"""Maximum number of concurrent embedding function calls."""
|
315 |
|
316 |
# LLM Configuration
|
317 |
+
llm_model_func: Callable[..., object] | None = None
|
318 |
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
319 |
|
320 |
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
|
|
354 |
chunking_func: Callable[
|
355 |
[
|
356 |
str,
|
357 |
+
str | None,
|
358 |
bool,
|
359 |
int,
|
360 |
int,
|
lightrag/llm.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
-
from
|
|
|
|
|
2 |
from pydantic import BaseModel, Field
|
3 |
|
4 |
|
@@ -23,7 +25,7 @@ class Model(BaseModel):
|
|
23 |
...,
|
24 |
description="A function that generates the response from the llm. The response must be a string",
|
25 |
)
|
26 |
-
kwargs:
|
27 |
...,
|
28 |
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
29 |
)
|
@@ -57,7 +59,7 @@ class MultiModel:
|
|
57 |
```
|
58 |
"""
|
59 |
|
60 |
-
def __init__(self, models:
|
61 |
self._models = models
|
62 |
self._current_model = 0
|
63 |
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Callable, Any
|
4 |
from pydantic import BaseModel, Field
|
5 |
|
6 |
|
|
|
25 |
...,
|
26 |
description="A function that generates the response from the llm. The response must be a string",
|
27 |
)
|
28 |
+
kwargs: dict[str, Any] = Field(
|
29 |
...,
|
30 |
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
31 |
)
|
|
|
59 |
```
|
60 |
"""
|
61 |
|
62 |
+
def __init__(self, models: list[Model]):
|
63 |
self._models = models
|
64 |
self._current_model = 0
|
65 |
|
lightrag/namespace.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from typing import Iterable
|
2 |
|
3 |
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
from typing import Iterable
|
4 |
|
5 |
|
lightrag/operate.py
CHANGED
@@ -1,8 +1,10 @@
|
|
|
|
|
|
1 |
import asyncio
|
2 |
import json
|
3 |
import re
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
-
from typing import Any, AsyncIterator
|
6 |
from collections import Counter, defaultdict
|
7 |
from .utils import (
|
8 |
logger,
|
@@ -36,7 +38,7 @@ import time
|
|
36 |
|
37 |
def chunking_by_token_size(
|
38 |
content: str,
|
39 |
-
split_by_character:
|
40 |
split_by_character_only: bool = False,
|
41 |
overlap_token_size: int = 128,
|
42 |
max_token_size: int = 1024,
|
@@ -297,7 +299,7 @@ async def extract_entities(
|
|
297 |
relationships_vdb: BaseVectorStorage,
|
298 |
global_config: dict[str, str],
|
299 |
llm_response_cache: BaseKVStorage | None = None,
|
300 |
-
) ->
|
301 |
use_llm_func: callable = global_config["llm_model_func"]
|
302 |
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
303 |
enable_llm_cache_for_entity_extract: bool = global_config[
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
import asyncio
|
4 |
import json
|
5 |
import re
|
6 |
from tqdm.asyncio import tqdm as tqdm_async
|
7 |
+
from typing import Any, AsyncIterator
|
8 |
from collections import Counter, defaultdict
|
9 |
from .utils import (
|
10 |
logger,
|
|
|
38 |
|
39 |
def chunking_by_token_size(
|
40 |
content: str,
|
41 |
+
split_by_character: str | None = None,
|
42 |
split_by_character_only: bool = False,
|
43 |
overlap_token_size: int = 128,
|
44 |
max_token_size: int = 1024,
|
|
|
299 |
relationships_vdb: BaseVectorStorage,
|
300 |
global_config: dict[str, str],
|
301 |
llm_response_cache: BaseKVStorage | None = None,
|
302 |
+
) -> BaseGraphStorage | None:
|
303 |
use_llm_func: callable = global_config["llm_model_func"]
|
304 |
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
305 |
enable_llm_cache_for_entity_extract: bool = global_config[
|
lightrag/prompt.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
GRAPH_FIELD_SEP = "<SEP>"
|
2 |
|
3 |
PROMPTS = {}
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
GRAPH_FIELD_SEP = "<SEP>"
|
4 |
|
5 |
PROMPTS = {}
|
lightrag/types.py
CHANGED
@@ -1,16 +1,19 @@
|
|
|
|
|
|
|
|
1 |
from pydantic import BaseModel
|
2 |
-
from typing import
|
3 |
|
4 |
|
5 |
class GPTKeywordExtractionFormat(BaseModel):
|
6 |
-
high_level_keywords:
|
7 |
-
low_level_keywords:
|
8 |
|
9 |
|
10 |
class KnowledgeGraphNode(BaseModel):
|
11 |
id: str
|
12 |
-
labels:
|
13 |
-
properties:
|
14 |
|
15 |
|
16 |
class KnowledgeGraphEdge(BaseModel):
|
@@ -18,9 +21,9 @@ class KnowledgeGraphEdge(BaseModel):
|
|
18 |
type: str
|
19 |
source: str # id of source node
|
20 |
target: str # id of target node
|
21 |
-
properties:
|
22 |
|
23 |
|
24 |
class KnowledgeGraph(BaseModel):
|
25 |
-
nodes:
|
26 |
-
edges:
|
|
|
1 |
+
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
from pydantic import BaseModel
|
5 |
+
from typing import Any
|
6 |
|
7 |
|
8 |
class GPTKeywordExtractionFormat(BaseModel):
|
9 |
+
high_level_keywords: list[str]
|
10 |
+
low_level_keywords: list[str]
|
11 |
|
12 |
|
13 |
class KnowledgeGraphNode(BaseModel):
|
14 |
id: str
|
15 |
+
labels: list[str]
|
16 |
+
properties: dict[str, Any] # anything else goes here
|
17 |
|
18 |
|
19 |
class KnowledgeGraphEdge(BaseModel):
|
|
|
21 |
type: str
|
22 |
source: str # id of source node
|
23 |
target: str # id of target node
|
24 |
+
properties: dict[str, Any] # anything else goes here
|
25 |
|
26 |
|
27 |
class KnowledgeGraph(BaseModel):
|
28 |
+
nodes: list[KnowledgeGraphNode] = []
|
29 |
+
edges: list[KnowledgeGraphEdge] = []
|
lightrag/utils.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import asyncio
|
2 |
import html
|
3 |
import io
|
@@ -9,7 +11,7 @@ import re
|
|
9 |
from dataclasses import dataclass
|
10 |
from functools import wraps
|
11 |
from hashlib import md5
|
12 |
-
from typing import Any, Callable
|
13 |
import xml.etree.ElementTree as ET
|
14 |
import bs4
|
15 |
|
@@ -72,7 +74,7 @@ class ReasoningResponse:
|
|
72 |
tag: str
|
73 |
|
74 |
|
75 |
-
def locate_json_string_body_from_string(content: str) ->
|
76 |
"""Locate the JSON string body from a string"""
|
77 |
try:
|
78 |
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
@@ -238,7 +240,7 @@ def truncate_list_by_token_size(
|
|
238 |
return list_data
|
239 |
|
240 |
|
241 |
-
def list_of_list_to_csv(data:
|
242 |
output = io.StringIO()
|
243 |
writer = csv.writer(
|
244 |
output,
|
@@ -251,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
|
|
251 |
return output.getvalue()
|
252 |
|
253 |
|
254 |
-
def csv_string_to_list(csv_string: str) ->
|
255 |
# Clean the string by removing NUL characters
|
256 |
cleaned_string = csv_string.replace("\0", "")
|
257 |
|
@@ -382,7 +384,7 @@ async def get_best_cached_response(
|
|
382 |
llm_func=None,
|
383 |
original_prompt=None,
|
384 |
cache_type=None,
|
385 |
-
) ->
|
386 |
logger.debug(
|
387 |
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
388 |
)
|
@@ -486,7 +488,7 @@ def cosine_similarity(v1, v2):
|
|
486 |
return dot_product / (norm1 * norm2)
|
487 |
|
488 |
|
489 |
-
def quantize_embedding(embedding:
|
490 |
"""Quantize embedding to specified bits"""
|
491 |
# Convert list to numpy array if needed
|
492 |
if isinstance(embedding, list):
|
@@ -577,9 +579,9 @@ class CacheData:
|
|
577 |
args_hash: str
|
578 |
content: str
|
579 |
prompt: str
|
580 |
-
quantized:
|
581 |
-
min_val:
|
582 |
-
max_val:
|
583 |
mode: str = "default"
|
584 |
cache_type: str = "query"
|
585 |
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
import asyncio
|
4 |
import html
|
5 |
import io
|
|
|
11 |
from dataclasses import dataclass
|
12 |
from functools import wraps
|
13 |
from hashlib import md5
|
14 |
+
from typing import Any, Callable
|
15 |
import xml.etree.ElementTree as ET
|
16 |
import bs4
|
17 |
|
|
|
74 |
tag: str
|
75 |
|
76 |
|
77 |
+
def locate_json_string_body_from_string(content: str) -> str | None:
|
78 |
"""Locate the JSON string body from a string"""
|
79 |
try:
|
80 |
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
|
|
240 |
return list_data
|
241 |
|
242 |
|
243 |
+
def list_of_list_to_csv(data: list[list[str]]) -> str:
|
244 |
output = io.StringIO()
|
245 |
writer = csv.writer(
|
246 |
output,
|
|
|
253 |
return output.getvalue()
|
254 |
|
255 |
|
256 |
+
def csv_string_to_list(csv_string: str) -> list[list[str]]:
|
257 |
# Clean the string by removing NUL characters
|
258 |
cleaned_string = csv_string.replace("\0", "")
|
259 |
|
|
|
384 |
llm_func=None,
|
385 |
original_prompt=None,
|
386 |
cache_type=None,
|
387 |
+
) -> str | None:
|
388 |
logger.debug(
|
389 |
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
390 |
)
|
|
|
488 |
return dot_product / (norm1 * norm2)
|
489 |
|
490 |
|
491 |
+
def quantize_embedding(embedding: np.ndarray | list[float], bits: int=8) -> tuple:
|
492 |
"""Quantize embedding to specified bits"""
|
493 |
# Convert list to numpy array if needed
|
494 |
if isinstance(embedding, list):
|
|
|
579 |
args_hash: str
|
580 |
content: str
|
581 |
prompt: str
|
582 |
+
quantized: np.ndarray | None = None
|
583 |
+
min_val: float | None = None
|
584 |
+
max_val: float | None = None
|
585 |
mode: str = "default"
|
586 |
cache_type: str = "query"
|
587 |
|