yangdx commited on
Commit
073182d
·
1 Parent(s): d268e77

Make batch methods in BaseGraphStorage optional with default implementations

Browse files

- Removing the @abstractmethod decorator
- Adding default implementations that call the corresponding non-batch methods
- Preserving full backward compatibility with existing implementations

Files changed (1) hide show
  1. lightrag/base.py +59 -10
lightrag/base.py CHANGED
@@ -361,25 +361,74 @@ class BaseGraphStorage(StorageNameSpace, ABC):
361
  or None if the node doesn't exist
362
  """
363
 
364
- @abstractmethod
365
  async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
366
- """Get nodes as a batch using UNWIND"""
 
 
 
 
 
 
 
 
 
 
 
367
 
368
- @abstractmethod
369
  async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
370
- """Node degrees as a batch using UNWIND"""
 
 
 
 
 
 
 
 
 
 
371
 
372
- @abstractmethod
373
  async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
374
- """Edge degrees as a batch using UNWIND also uses node_degrees_batch"""
 
 
 
 
 
 
 
 
 
 
375
 
376
- @abstractmethod
377
  async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
378
- """Get edges as a batch using UNWIND"""
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
- @abstractmethod
381
  async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
382
- """"Get nodes edges as a batch using UNWIND"""
 
 
 
 
 
 
 
 
 
 
383
 
384
  @abstractmethod
385
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
 
361
  or None if the node doesn't exist
362
  """
363
 
 
364
  async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
365
+ """Get nodes as a batch using UNWIND
366
+
367
+ Default implementation fetches nodes one by one.
368
+ Override this method for better performance in storage backends
369
+ that support batch operations.
370
+ """
371
+ result = {}
372
+ for node_id in node_ids:
373
+ node = await self.get_node(node_id)
374
+ if node is not None:
375
+ result[node_id] = node
376
+ return result
377
 
 
378
  async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
379
+ """Node degrees as a batch using UNWIND
380
+
381
+ Default implementation fetches node degrees one by one.
382
+ Override this method for better performance in storage backends
383
+ that support batch operations.
384
+ """
385
+ result = {}
386
+ for node_id in node_ids:
387
+ degree = await self.node_degree(node_id)
388
+ result[node_id] = degree
389
+ return result
390
 
 
391
  async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
392
+ """Edge degrees as a batch using UNWIND also uses node_degrees_batch
393
+
394
+ Default implementation calculates edge degrees one by one.
395
+ Override this method for better performance in storage backends
396
+ that support batch operations.
397
+ """
398
+ result = {}
399
+ for src_id, tgt_id in edge_pairs:
400
+ degree = await self.edge_degree(src_id, tgt_id)
401
+ result[(src_id, tgt_id)] = degree
402
+ return result
403
 
 
404
  async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
405
+ """Get edges as a batch using UNWIND
406
+
407
+ Default implementation fetches edges one by one.
408
+ Override this method for better performance in storage backends
409
+ that support batch operations.
410
+ """
411
+ result = {}
412
+ for pair in pairs:
413
+ src_id = pair["src"]
414
+ tgt_id = pair["tgt"]
415
+ edge = await self.get_edge(src_id, tgt_id)
416
+ if edge is not None:
417
+ result[(src_id, tgt_id)] = edge
418
+ return result
419
 
 
420
  async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
421
+ """Get nodes edges as a batch using UNWIND
422
+
423
+ Default implementation fetches node edges one by one.
424
+ Override this method for better performance in storage backends
425
+ that support batch operations.
426
+ """
427
+ result = {}
428
+ for node_id in node_ids:
429
+ edges = await self.get_node_edges(node_id)
430
+ result[node_id] = edges if edges is not None else []
431
+ return result
432
 
433
  @abstractmethod
434
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: