diff --git a/.gitattributes b/.gitattributes index b967487de257bc1d982c9adf5c6bbfc8c3363483..d2aa1293a042924406c75b4f56749bb8e3ec6d27 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,5 @@ -lightrag/api/webui/** -diff +lightrag/api/webui/** binary +lightrag/api/webui/** linguist-generated *.png filter=lfs diff=lfs merge=lfs -text *.ttf filter=lfs diff=lfs merge=lfs -text *.ico filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index c9a352602c6daadca2f825d4311146dd0578e08a..eb2575e7317cfe9da7df074f8d388a40858495f5 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ This repository hosts the code of LightRAG. The structure of this code is based 🎉 News +- [X] [2025.03.18]🎯📢LightRAG now supports citation functionality. - [X] [2025.02.05]🎯📢Our team has released [VideoRAG](https://github.com/HKUDS/VideoRAG) understanding extremely long-context videos. - [X] [2025.01.13]🎯📢Our team has released [MiniRAG](https://github.com/HKUDS/MiniRAG) making RAG simpler with small models. - [X] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](#using-postgresql-for-storage). @@ -673,6 +674,22 @@ rag.insert(text_content.decode('utf-8')) +
+ Citation Functionality + +By providing file paths, the system ensures that sources can be traced back to their original documents. + +```python +# Define documents and their file paths +documents = ["Document content 1", "Document content 2"] +file_paths = ["path/to/doc1.txt", "path/to/doc2.txt"] + +# Insert documents with file paths +rag.insert(documents, file_paths=file_paths) +``` + +
+ ## Storage
diff --git a/env.example b/env.example index 66d209ade2eb83f2be2fca7095fdf5b313fdccf6..fffa89a492f182299658470f676ff7afdd0247bc 100644 --- a/env.example +++ b/env.example @@ -73,6 +73,8 @@ LLM_BINDING_HOST=http://localhost:11434 ### Embedding Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal) EMBEDDING_MODEL=bge-m3:latest EMBEDDING_DIM=1024 +EMBEDDING_BATCH_NUM=32 +EMBEDDING_FUNC_MAX_ASYNC=16 # EMBEDDING_BINDING_API_KEY=your_api_key ### ollama example EMBEDDING_BINDING=ollama @@ -151,9 +153,9 @@ QDRANT_URL=http://localhost:16333 ### Redis REDIS_URI=redis://localhost:6379 -# For jwt auth -AUTH_USERNAME=admin # login name -AUTH_PASSWORD=admin123 # password -TOKEN_SECRET=your-key # JWT key -TOKEN_EXPIRE_HOURS=4 # expire duration +### For JWTt Auth +AUTH_USERNAME=admin # login name +AUTH_PASSWORD=admin123 # password +TOKEN_SECRET=your-key-for-LightRAG-API-Server # JWT key +TOKEN_EXPIRE_HOURS=4 # expire duration WHITELIST_PATHS=/login,/health # white list diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 89475dca3e361886c5f5d2f0ea1b520512768d7e..f7dee88860788e593dd771bcbee63869e97771e5 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.2.6" +__version__ = "1.2.7" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" diff --git a/lightrag/api/auth.py b/lightrag/api/auth.py index 4d905de8795fb2a25a356b323f83408ccd6f6037..5d9b00acfa3e70cc3c43bfb20881cc1eb649045a 100644 --- a/lightrag/api/auth.py +++ b/lightrag/api/auth.py @@ -3,11 +3,16 @@ from datetime import datetime, timedelta import jwt from fastapi import HTTPException, status from pydantic import BaseModel +from dotenv import load_dotenv + +load_dotenv() class TokenPayload(BaseModel): - sub: str - exp: datetime + sub: str # Username + exp: datetime # Expiration time + role: str = "user" # User role, default is regular user + metadata: dict = {} # Additional metadata class AuthHandler: @@ -15,13 +20,60 @@ class AuthHandler: self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46") self.algorithm = "HS256" self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4)) + self.guest_expire_hours = int( + os.getenv("GUEST_TOKEN_EXPIRE_HOURS", 2) + ) # Guest token default expiration time + + def create_token( + self, + username: str, + role: str = "user", + custom_expire_hours: int = None, + metadata: dict = None, + ) -> str: + """ + Create JWT token + + Args: + username: Username + role: User role, default is "user", guest is "guest" + custom_expire_hours: Custom expiration time (hours), if None use default value + metadata: Additional metadata + + Returns: + str: Encoded JWT token + """ + # Choose default expiration time based on role + if custom_expire_hours is None: + if role == "guest": + expire_hours = self.guest_expire_hours + else: + expire_hours = self.expire_hours + else: + expire_hours = custom_expire_hours + + expire = datetime.utcnow() + timedelta(hours=expire_hours) + + # Create payload + payload = TokenPayload( + sub=username, exp=expire, role=role, metadata=metadata or {} + ) - def create_token(self, username: str) -> str: - expire = datetime.utcnow() + timedelta(hours=self.expire_hours) - payload = TokenPayload(sub=username, exp=expire) return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm) - def validate_token(self, token: str) -> str: + def validate_token(self, token: str) -> dict: + """ + Validate JWT token + + Args: + token: JWT token + + Returns: + dict: Dictionary containing user information + + Raises: + HTTPException: If token is invalid or expired + """ try: payload = jwt.decode(token, self.secret, algorithms=[self.algorithm]) expire_timestamp = payload["exp"] @@ -31,7 +83,14 @@ class AuthHandler: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired" ) - return payload["sub"] + + # Return complete payload instead of just username + return { + "username": payload["sub"], + "role": payload.get("role", "user"), + "metadata": payload.get("metadata", {}), + "exp": expire_time, + } except jwt.PyJWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token" diff --git a/lightrag/api/gunicorn_config.py b/lightrag/api/gunicorn_config.py index 23e468078478be73f9690d45eda5fd4dcff4f844..0aef108ed7fcbd69701a8099b015121f37c41681 100644 --- a/lightrag/api/gunicorn_config.py +++ b/lightrag/api/gunicorn_config.py @@ -29,7 +29,9 @@ preload_app = True worker_class = "uvicorn.workers.UvicornWorker" # Other Gunicorn configurations -timeout = int(os.getenv("TIMEOUT", 150)) # Default 150s to match run_with_gunicorn.py +timeout = int( + os.getenv("TIMEOUT", 150 * 2) +) # Default 150s *2 to match run_with_gunicorn.py keepalive = int(os.getenv("KEEPALIVE", 5)) # Default 5s # Logging configuration diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index ca4425e542ade42c42f068de6dc404097366211e..6c8d11f1d327bcc2925f120627acaf6e89f94ef0 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -10,6 +10,7 @@ import logging.config import uvicorn import pipmaster as pm from fastapi.staticfiles import StaticFiles +from fastapi.responses import RedirectResponse from pathlib import Path import configparser from ascii_colors import ASCIIColors @@ -48,7 +49,7 @@ from .auth import auth_handler # Load environment variables # Updated to use the .env that is inside the current folder # This update allows the user to put a different.env file for each lightrag folder -load_dotenv(".env", override=True) +load_dotenv() # Initialize config parser config = configparser.ConfigParser() @@ -341,25 +342,62 @@ def create_app(args): ollama_api = OllamaAPI(rag, top_k=args.top_k) app.include_router(ollama_api.router, prefix="/api") - @app.post("/login") + @app.get("/") + async def redirect_to_webui(): + """Redirect root path to /webui""" + return RedirectResponse(url="/webui") + + @app.get("/auth-status", dependencies=[Depends(optional_api_key)]) + async def get_auth_status(): + """Get authentication status and guest token if auth is not configured""" + username = os.getenv("AUTH_USERNAME") + password = os.getenv("AUTH_PASSWORD") + + if not (username and password): + # Authentication not configured, return guest token + guest_token = auth_handler.create_token( + username="guest", role="guest", metadata={"auth_mode": "disabled"} + ) + return { + "auth_configured": False, + "access_token": guest_token, + "token_type": "bearer", + "auth_mode": "disabled", + "message": "Authentication is disabled. Using guest access.", + } + + return {"auth_configured": True, "auth_mode": "enabled"} + + @app.post("/login", dependencies=[Depends(optional_api_key)]) async def login(form_data: OAuth2PasswordRequestForm = Depends()): username = os.getenv("AUTH_USERNAME") password = os.getenv("AUTH_PASSWORD") if not (username and password): - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="Authentication not configured", + # Authentication not configured, return guest token + guest_token = auth_handler.create_token( + username="guest", role="guest", metadata={"auth_mode": "disabled"} ) + return { + "access_token": guest_token, + "token_type": "bearer", + "auth_mode": "disabled", + "message": "Authentication is disabled. Using guest access.", + } if form_data.username != username or form_data.password != password: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials" ) + # Regular user login + user_token = auth_handler.create_token( + username=username, role="user", metadata={"auth_mode": "enabled"} + ) return { - "access_token": auth_handler.create_token(username), + "access_token": user_token, "token_type": "bearer", + "auth_mode": "enabled", } @app.get("/health", dependencies=[Depends(optional_api_key)]) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 7b6f11c1e73d6a4870e9604d6159a717f757fbf3..e0c8f545c2d075d04a942a3df6f23b27e9b7bcdf 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -405,7 +405,7 @@ async def pipeline_index_file(rag: LightRAG, file_path: Path): async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]): - """Index multiple files concurrently + """Index multiple files sequentially to avoid high CPU load Args: rag: LightRAG instance @@ -416,12 +416,12 @@ async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]): try: enqueued = False - if len(file_paths) == 1: - enqueued = await pipeline_enqueue_file(rag, file_paths[0]) - else: - tasks = [pipeline_enqueue_file(rag, path) for path in file_paths] - enqueued = any(await asyncio.gather(*tasks)) + # Process files sequentially + for file_path in file_paths: + if await pipeline_enqueue_file(rag, file_path): + enqueued = True + # Process the queue only if at least one file was successfully enqueued if enqueued: await rag.apipeline_process_enqueue_documents() except Exception as e: @@ -472,14 +472,34 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): total_files = len(new_files) logger.info(f"Found {total_files} new files to index.") - for idx, file_path in enumerate(new_files): - try: - await pipeline_index_file(rag, file_path) - except Exception as e: - logger.error(f"Error indexing file {file_path}: {str(e)}") + if not new_files: + return + + # Get MAX_PARALLEL_INSERT from global_args + max_parallel = global_args["max_parallel_insert"] + # Calculate batch size as 2 * MAX_PARALLEL_INSERT + batch_size = 2 * max_parallel + + # Process files in batches + for i in range(0, total_files, batch_size): + batch_files = new_files[i : i + batch_size] + batch_num = i // batch_size + 1 + total_batches = (total_files + batch_size - 1) // batch_size + + logger.info( + f"Processing batch {batch_num}/{total_batches} with {len(batch_files)} files" + ) + await pipeline_index_files(rag, batch_files) + + # Log progress + processed = min(i + batch_size, total_files) + logger.info( + f"Processed {processed}/{total_files} files ({processed/total_files*100:.1f}%)" + ) except Exception as e: logger.error(f"Error during scanning process: {str(e)}") + logger.error(traceback.format_exc()) def create_document_routes( diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index cf9b3b915483534c9dd01938403bc28b23f4db33..126d772d015ace5568798dc629e03bb653f0640e 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -13,7 +13,7 @@ from dotenv import load_dotenv # Updated to use the .env that is inside the current folder # This update allows the user to put a different.env file for each lightrag folder -load_dotenv(".env") +load_dotenv() def check_and_install_dependencies(): @@ -140,7 +140,7 @@ def main(): # Timeout configuration prioritizes command line arguments gunicorn_config.timeout = ( - args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) + args.timeout if args.timeout * 2 else int(os.getenv("TIMEOUT", 150 * 2)) ) # Keepalive configuration diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 88a0132c28a003485afd387313890dc50f142f23..25136bd28128abbea920d62cbea29e5dee0e8670 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -9,14 +9,14 @@ import sys import logging from ascii_colors import ASCIIColors from lightrag.api import __api_version__ -from fastapi import HTTPException, Security, Depends, Request +from fastapi import HTTPException, Security, Depends, Request, status from dotenv import load_dotenv from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from starlette.status import HTTP_403_FORBIDDEN from .auth import auth_handler # Load environment variables -load_dotenv(override=True) +load_dotenv() global_args = {"main_args": None} @@ -35,19 +35,46 @@ ollama_server_infos = OllamaServerInfos() def get_auth_dependency(): - whitelist = os.getenv("WHITELIST_PATHS", "").split(",") + # Set default whitelist paths + whitelist = os.getenv("WHITELIST_PATHS", "/login,/health").split(",") async def dependency( request: Request, token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)), ): - if request.url.path in whitelist: + # Check if authentication is configured + auth_configured = bool( + os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD") + ) + + # If authentication is not configured, skip all validation + if not auth_configured: return - if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")): + # For configured auth, allow whitelist paths without token + if request.url.path in whitelist: return - auth_handler.validate_token(token) + # Require token for all other paths when auth is configured + if not token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Token required" + ) + + try: + token_info = auth_handler.validate_token(token) + # Reject guest tokens when authentication is configured + if token_info.get("role") == "guest": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required. Guest access not allowed when authentication is configured.", + ) + except Exception: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token" + ) + + return return dependency @@ -338,6 +365,9 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE ) + # Get MAX_PARALLEL_INSERT from environment + global_args["max_parallel_insert"] = get_env_value("MAX_PARALLEL_INSERT", 2, int) + # Handle openai-ollama special case if args.llm_binding == "openai-ollama": args.llm_binding = "openai" @@ -414,8 +444,8 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.log_level}") ASCIIColors.white(" ├─ Verbose Debug: ", end="") ASCIIColors.yellow(f"{args.verbose}") - ASCIIColors.white(" ├─ Timeout: ", end="") - ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") + ASCIIColors.white(" ├─ History Turns: ", end="") + ASCIIColors.yellow(f"{args.history_turns}") ASCIIColors.white(" └─ API Key: ", end="") ASCIIColors.yellow("Set" if args.key else "Not Set") @@ -432,8 +462,10 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.llm_binding}") ASCIIColors.white(" ├─ Host: ", end="") ASCIIColors.yellow(f"{args.llm_binding_host}") - ASCIIColors.white(" └─ Model: ", end="") + ASCIIColors.white(" ├─ Model: ", end="") ASCIIColors.yellow(f"{args.llm_model}") + ASCIIColors.white(" └─ Timeout: ", end="") + ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") # Embedding Configuration ASCIIColors.magenta("\n📊 Embedding Configuration:") @@ -448,8 +480,10 @@ def display_splash_screen(args: argparse.Namespace) -> None: # RAG Configuration ASCIIColors.magenta("\n⚙️ RAG Configuration:") - ASCIIColors.white(" ├─ Max Async Operations: ", end="") + ASCIIColors.white(" ├─ Max Async for LLM: ", end="") ASCIIColors.yellow(f"{args.max_async}") + ASCIIColors.white(" ├─ Max Parallel Insert: ", end="") + ASCIIColors.yellow(f"{global_args['max_parallel_insert']}") ASCIIColors.white(" ├─ Max Tokens: ", end="") ASCIIColors.yellow(f"{args.max_tokens}") ASCIIColors.white(" ├─ Max Embed Tokens: ", end="") @@ -458,8 +492,6 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.chunk_size}") ASCIIColors.white(" ├─ Chunk Overlap Size: ", end="") ASCIIColors.yellow(f"{args.chunk_overlap_size}") - ASCIIColors.white(" ├─ History Turns: ", end="") - ASCIIColors.yellow(f"{args.history_turns}") ASCIIColors.white(" ├─ Cosine Threshold: ", end="") ASCIIColors.yellow(f"{args.cosine_threshold}") ASCIIColors.white(" ├─ Top-K: ", end="") diff --git a/lightrag/api/webui/assets/index-DwcJE583.js b/lightrag/api/webui/assets/index-4I5HV9Fr.js similarity index 51% rename from lightrag/api/webui/assets/index-DwcJE583.js rename to lightrag/api/webui/assets/index-4I5HV9Fr.js index a00231338ea2c1bcb87cff4021cd00b449d23de8..297c5542c1078f2b2c88099c6bd47502727b1ec2 100644 Binary files a/lightrag/api/webui/assets/index-DwcJE583.js and b/lightrag/api/webui/assets/index-4I5HV9Fr.js differ diff --git a/lightrag/api/webui/assets/index-BSOt8Nur.css b/lightrag/api/webui/assets/index-BSOt8Nur.css new file mode 100644 index 0000000000000000000000000000000000000000..74781c969e8a6b8fe82135ffefb03e11176e202e Binary files /dev/null and b/lightrag/api/webui/assets/index-BSOt8Nur.css differ diff --git a/lightrag/api/webui/assets/index-BV5s8k-a.css b/lightrag/api/webui/assets/index-BV5s8k-a.css deleted file mode 100644 index 8dca5fe7205b3003d6824ddcc5f7fcb11080d411..0000000000000000000000000000000000000000 Binary files a/lightrag/api/webui/assets/index-BV5s8k-a.css and /dev/null differ diff --git a/lightrag/api/webui/index.html b/lightrag/api/webui/index.html index 49fc0ea6ebcac282b3edc530da229efe63be7ad1..2135cfc3e755afeb4d20601e3b02e54ca6cdb467 100644 Binary files a/lightrag/api/webui/index.html and b/lightrag/api/webui/index.html differ diff --git a/lightrag/base.py b/lightrag/base.py index 865667879ab74af1514a507a99aeec73f5096ffa..f0376c01075da3ee7671415a6522663630640f2a 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -257,6 +257,8 @@ class DocProcessingStatus: """First 100 chars of document content, used for preview""" content_length: int """Total length of document""" + file_path: str + """File path of the document""" status: DocStatus """Current processing status""" created_at: str diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 57a34ae51cbbd49095c958786eabfe81efe1f53f..22da07b5f5b9ab46675bdd0ba1a3fac6e13e14c8 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -87,6 +87,9 @@ class JsonDocStatusStorage(DocStatusStorage): # If content is missing, use content_summary as content if "content" not in data and "content_summary" in data: data["content"] = data["content_summary"] + # If file_path is not in data, use document id as file path + if "file_path" not in data: + data["file_path"] = "no-file-path" result[k] = DocProcessingStatus(**data) except KeyError as e: logger.error(f"Missing required field for document {k}: {e}") diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 004cfd4f6a6e2df2a82a1e0e9968fc0a0fbb5656..7026cf6d76215cca5b3fec68f7742c126d1afa43 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -373,6 +373,9 @@ class NetworkXStorage(BaseGraphStorage): # Add edges to result for edge in subgraph.edges(): source, target = edge + # Esure unique edge_id for undirect graph + if source > target: + source, target = target, source edge_id = f"{source}-{target}" if edge_id in seen_edges: continue diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index d2630659599c30620aff9b82ae3f34200f89a7a0..4ff34e1309f9edc36c2f4f28a5a598113aacf675 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -423,6 +423,7 @@ class PGVectorStorage(BaseVectorStorage): "full_doc_id": item["full_doc_id"], "content": item["content"], "content_vector": json.dumps(item["__vector__"].tolist()), + "file_path": item["file_path"], } except Exception as e: logger.error(f"Error to prepare upsert,\nsql: {e}\nitem: {item}") @@ -445,6 +446,7 @@ class PGVectorStorage(BaseVectorStorage): "content": item["content"], "content_vector": json.dumps(item["__vector__"].tolist()), "chunk_ids": chunk_ids, + "file_path": item["file_path"], # TODO: add document_id } return upsert_sql, data @@ -465,6 +467,7 @@ class PGVectorStorage(BaseVectorStorage): "content": item["content"], "content_vector": json.dumps(item["__vector__"].tolist()), "chunk_ids": chunk_ids, + "file_path": item["file_path"], # TODO: add document_id } return upsert_sql, data @@ -732,7 +735,7 @@ class PGDocStatusStorage(DocStatusStorage): if result is None or result == []: return None else: - return DocProcessingStatus( + return dict( content=result[0]["content"], content_length=result[0]["content_length"], content_summary=result[0]["content_summary"], @@ -740,11 +743,34 @@ class PGDocStatusStorage(DocStatusStorage): chunks_count=result[0]["chunks_count"], created_at=result[0]["created_at"], updated_at=result[0]["updated_at"], + file_path=result[0]["file_path"], ) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """Get doc_chunks data by id""" - raise NotImplementedError + """Get doc_chunks data by multiple IDs.""" + if not ids: + return [] + + sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)" + params = {"workspace": self.db.workspace, "ids": ids} + + results = await self.db.query(sql, params, True) + + if not results: + return [] + return [ + { + "content": row["content"], + "content_length": row["content_length"], + "content_summary": row["content_summary"], + "status": row["status"], + "chunks_count": row["chunks_count"], + "created_at": row["created_at"], + "updated_at": row["updated_at"], + "file_path": row["file_path"], + } + for row in results + ] async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" @@ -774,6 +800,7 @@ class PGDocStatusStorage(DocStatusStorage): created_at=element["created_at"], updated_at=element["updated_at"], chunks_count=element["chunks_count"], + file_path=element["file_path"], ) for element in result } @@ -793,14 +820,15 @@ class PGDocStatusStorage(DocStatusStorage): if not data: return - sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status) - values($1,$2,$3,$4,$5,$6,$7) + sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path) + values($1,$2,$3,$4,$5,$6,$7,$8) on conflict(id,workspace) do update set content = EXCLUDED.content, content_summary = EXCLUDED.content_summary, content_length = EXCLUDED.content_length, chunks_count = EXCLUDED.chunks_count, status = EXCLUDED.status, + file_path = EXCLUDED.file_path, updated_at = CURRENT_TIMESTAMP""" for k, v in data.items(): # chunks_count is optional @@ -814,6 +842,7 @@ class PGDocStatusStorage(DocStatusStorage): "content_length": v["content_length"], "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, "status": v["status"], + "file_path": v["file_path"], }, ) @@ -1058,7 +1087,6 @@ class PGGraphStorage(BaseGraphStorage): Args: query (str): a cypher query to be executed - params (dict): parameters for the query Returns: list[dict[str, Any]]: a list of dictionaries containing the result set @@ -1549,6 +1577,7 @@ TABLES = { tokens INTEGER, content TEXT, content_vector VECTOR, + file_path VARCHAR(256), create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP, CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) @@ -1563,7 +1592,8 @@ TABLES = { content_vector VECTOR, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP, - chunk_id TEXT NULL, + chunk_ids VARCHAR(255)[] NULL, + file_path TEXT NULL, CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id) )""" }, @@ -1577,7 +1607,8 @@ TABLES = { content_vector VECTOR, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP, - chunk_id TEXT NULL, + chunk_ids VARCHAR(255)[] NULL, + file_path TEXT NULL, CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id) )""" }, @@ -1602,6 +1633,7 @@ TABLES = { content_length int4 NULL, chunks_count int4 NULL, status varchar(64) NULL, + file_path TEXT NULL, created_at timestamp DEFAULT CURRENT_TIMESTAMP NULL, updated_at timestamp DEFAULT CURRENT_TIMESTAMP NULL, CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id) @@ -1650,35 +1682,38 @@ SQL_TEMPLATES = { update_time = CURRENT_TIMESTAMP """, "upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, - chunk_order_index, full_doc_id, content, content_vector) - VALUES ($1, $2, $3, $4, $5, $6, $7) + chunk_order_index, full_doc_id, content, content_vector, file_path) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (workspace,id) DO UPDATE SET tokens=EXCLUDED.tokens, chunk_order_index=EXCLUDED.chunk_order_index, full_doc_id=EXCLUDED.full_doc_id, content = EXCLUDED.content, content_vector=EXCLUDED.content_vector, + file_path=EXCLUDED.file_path, update_time = CURRENT_TIMESTAMP """, "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, - content_vector, chunk_ids) - VALUES ($1, $2, $3, $4, $5, $6::varchar[]) + content_vector, chunk_ids, file_path) + VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7) ON CONFLICT (workspace,id) DO UPDATE SET entity_name=EXCLUDED.entity_name, content=EXCLUDED.content, content_vector=EXCLUDED.content_vector, chunk_ids=EXCLUDED.chunk_ids, + file_path=EXCLUDED.file_path, update_time=CURRENT_TIMESTAMP """, "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, - target_id, content, content_vector, chunk_ids) - VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[]) + target_id, content, content_vector, chunk_ids, file_path) + VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8) ON CONFLICT (workspace,id) DO UPDATE SET source_id=EXCLUDED.source_id, target_id=EXCLUDED.target_id, content=EXCLUDED.content, content_vector=EXCLUDED.content_vector, chunk_ids=EXCLUDED.chunk_ids, + file_path=EXCLUDED.file_path, update_time = CURRENT_TIMESTAMP """, # SQL for VectorStorage diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 736887a6aaebf16dc1c69e58a4befaa37fdc93cc..e26645c8a4e351041cd8083055aaafa24586a41a 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -41,6 +41,9 @@ _pipeline_status_lock: Optional[LockType] = None _graph_db_lock: Optional[LockType] = None _data_init_lock: Optional[LockType] = None +# async locks for coroutine synchronization in multiprocess mode +_async_locks: Optional[Dict[str, asyncio.Lock]] = None + class UnifiedLock(Generic[T]): """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" @@ -51,12 +54,14 @@ class UnifiedLock(Generic[T]): is_async: bool, name: str = "unnamed", enable_logging: bool = True, + async_lock: Optional[asyncio.Lock] = None, ): self._lock = lock self._is_async = is_async self._pid = os.getpid() # for debug only self._name = name # for debug only self._enable_logging = enable_logging # for debug only + self._async_lock = async_lock # auxiliary lock for coroutine synchronization async def __aenter__(self) -> "UnifiedLock[T]": try: @@ -64,16 +69,39 @@ class UnifiedLock(Generic[T]): f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging, ) + + # If in multiprocess mode and async lock exists, acquire it first + if not self._is_async and self._async_lock is not None: + direct_log( + f"== Lock == Process {self._pid}: Acquiring async lock for '{self._name}'", + enable_output=self._enable_logging, + ) + await self._async_lock.acquire() + direct_log( + f"== Lock == Process {self._pid}: Async lock for '{self._name}' acquired", + enable_output=self._enable_logging, + ) + + # Then acquire the main lock if self._is_async: await self._lock.acquire() else: self._lock.acquire() + direct_log( f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})", enable_output=self._enable_logging, ) return self except Exception as e: + # If main lock acquisition fails, release the async lock if it was acquired + if ( + not self._is_async + and self._async_lock is not None + and self._async_lock.locked() + ): + self._async_lock.release() + direct_log( f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", level="ERROR", @@ -82,15 +110,29 @@ class UnifiedLock(Generic[T]): raise async def __aexit__(self, exc_type, exc_val, exc_tb): + main_lock_released = False try: direct_log( f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging, ) + + # Release main lock first if self._is_async: self._lock.release() else: self._lock.release() + + main_lock_released = True + + # Then release async lock if in multiprocess mode + if not self._is_async and self._async_lock is not None: + direct_log( + f"== Lock == Process {self._pid}: Releasing async lock for '{self._name}'", + enable_output=self._enable_logging, + ) + self._async_lock.release() + direct_log( f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})", enable_output=self._enable_logging, @@ -101,6 +143,31 @@ class UnifiedLock(Generic[T]): level="ERROR", enable_output=self._enable_logging, ) + + # If main lock release failed but async lock hasn't been released, try to release it + if ( + not main_lock_released + and not self._is_async + and self._async_lock is not None + ): + try: + direct_log( + f"== Lock == Process {self._pid}: Attempting to release async lock after main lock failure", + level="WARNING", + enable_output=self._enable_logging, + ) + self._async_lock.release() + direct_log( + f"== Lock == Process {self._pid}: Successfully released async lock after main lock failure", + enable_output=self._enable_logging, + ) + except Exception as inner_e: + direct_log( + f"== Lock == Process {self._pid}: Failed to release async lock after main lock failure: {inner_e}", + level="ERROR", + enable_output=self._enable_logging, + ) + raise def __enter__(self) -> "UnifiedLock[T]": @@ -151,51 +218,61 @@ class UnifiedLock(Generic[T]): def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" + async_lock = _async_locks.get("internal_lock") if is_multiprocess else None return UnifiedLock( lock=_internal_lock, is_async=not is_multiprocess, name="internal_lock", enable_logging=enable_logging, + async_lock=async_lock, ) def get_storage_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" + async_lock = _async_locks.get("storage_lock") if is_multiprocess else None return UnifiedLock( lock=_storage_lock, is_async=not is_multiprocess, name="storage_lock", enable_logging=enable_logging, + async_lock=async_lock, ) def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" + async_lock = _async_locks.get("pipeline_status_lock") if is_multiprocess else None return UnifiedLock( lock=_pipeline_status_lock, is_async=not is_multiprocess, name="pipeline_status_lock", enable_logging=enable_logging, + async_lock=async_lock, ) def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: """return unified graph database lock for ensuring atomic operations""" + async_lock = _async_locks.get("graph_db_lock") if is_multiprocess else None return UnifiedLock( lock=_graph_db_lock, is_async=not is_multiprocess, name="graph_db_lock", enable_logging=enable_logging, + async_lock=async_lock, ) def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock: """return unified data initialization lock for ensuring atomic data initialization""" + async_lock = _async_locks.get("data_init_lock") if is_multiprocess else None return UnifiedLock( lock=_data_init_lock, is_async=not is_multiprocess, name="data_init_lock", enable_logging=enable_logging, + async_lock=async_lock, ) @@ -229,7 +306,8 @@ def initialize_share_data(workers: int = 1): _shared_dicts, \ _init_flags, \ _initialized, \ - _update_flags + _update_flags, \ + _async_locks # Check if already initialized if _initialized: @@ -251,6 +329,16 @@ def initialize_share_data(workers: int = 1): _shared_dicts = _manager.dict() _init_flags = _manager.dict() _update_flags = _manager.dict() + + # Initialize async locks for multiprocess mode + _async_locks = { + "internal_lock": asyncio.Lock(), + "storage_lock": asyncio.Lock(), + "pipeline_status_lock": asyncio.Lock(), + "graph_db_lock": asyncio.Lock(), + "data_init_lock": asyncio.Lock(), + } + direct_log( f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})" ) @@ -264,6 +352,7 @@ def initialize_share_data(workers: int = 1): _shared_dicts = {} _init_flags = {} _update_flags = {} + _async_locks = None # No need for async locks in single process mode direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") # Mark as initialized @@ -458,7 +547,8 @@ def finalize_share_data(): _shared_dicts, \ _init_flags, \ _initialized, \ - _update_flags + _update_flags, \ + _async_locks # Check if already initialized if not _initialized: @@ -523,5 +613,6 @@ def finalize_share_data(): _graph_db_lock = None _data_init_lock = None _update_flags = None + _async_locks = None direct_log(f"Process {os.getpid()} storage data finalization complete") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 27a03e12caa1663e8e1aa04874188d1691e83154..f053b25e483bdb0a212ca72e350bcdc4dbeb4593 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -183,10 +183,10 @@ class LightRAG: embedding_func: EmbeddingFunc | None = field(default=None) """Function for computing text embeddings. Must be set before use.""" - embedding_batch_num: int = field(default=32) + embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 32))) """Batch size for embedding computations.""" - embedding_func_max_async: int = field(default=16) + embedding_func_max_async: int = field(default=int(os.getenv("EMBEDDING_FUNC_MAX_ASYNC", 16))) """Maximum number of concurrent embedding function calls.""" embedding_cache_config: dict[str, Any] = field( @@ -389,20 +389,21 @@ class LightRAG: self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES ), embedding_func=self.embedding_func, - meta_fields={"entity_name", "source_id", "content"}, + meta_fields={"entity_name", "source_id", "content", "file_path"}, ) self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS ), embedding_func=self.embedding_func, - meta_fields={"src_id", "tgt_id", "source_id", "content"}, + meta_fields={"src_id", "tgt_id", "source_id", "content", "file_path"}, ) self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS ), embedding_func=self.embedding_func, + meta_fields={"full_doc_id", "content", "file_path"}, ) # Initialize document status storage @@ -547,6 +548,7 @@ class LightRAG: split_by_character: str | None = None, split_by_character_only: bool = False, ids: str | list[str] | None = None, + file_paths: str | list[str] | None = None, ) -> None: """Sync Insert documents with checkpoint support @@ -557,10 +559,13 @@ class LightRAG: split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character is None, this parameter is ignored. ids: single string of the document ID or list of unique document IDs, if not provided, MD5 hash IDs will be generated + file_paths: single string of the file path or list of file paths, used for citation """ loop = always_get_an_event_loop() loop.run_until_complete( - self.ainsert(input, split_by_character, split_by_character_only, ids) + self.ainsert( + input, split_by_character, split_by_character_only, ids, file_paths + ) ) async def ainsert( @@ -569,6 +574,7 @@ class LightRAG: split_by_character: str | None = None, split_by_character_only: bool = False, ids: str | list[str] | None = None, + file_paths: str | list[str] | None = None, ) -> None: """Async Insert documents with checkpoint support @@ -579,8 +585,9 @@ class LightRAG: split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character is None, this parameter is ignored. ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated + file_paths: list of file paths corresponding to each document, used for citation """ - await self.apipeline_enqueue_documents(input, ids) + await self.apipeline_enqueue_documents(input, ids, file_paths) await self.apipeline_process_enqueue_documents( split_by_character, split_by_character_only ) @@ -654,7 +661,10 @@ class LightRAG: await self._insert_done() async def apipeline_enqueue_documents( - self, input: str | list[str], ids: list[str] | None = None + self, + input: str | list[str], + ids: list[str] | None = None, + file_paths: str | list[str] | None = None, ) -> None: """ Pipeline for Processing Documents @@ -664,11 +674,30 @@ class LightRAG: 3. Generate document initial status 4. Filter out already processed documents 5. Enqueue document in status + + Args: + input: Single document string or list of document strings + ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated + file_paths: list of file paths corresponding to each document, used for citation """ if isinstance(input, str): input = [input] if isinstance(ids, str): ids = [ids] + if isinstance(file_paths, str): + file_paths = [file_paths] + + # If file_paths is provided, ensure it matches the number of documents + if file_paths is not None: + if isinstance(file_paths, str): + file_paths = [file_paths] + if len(file_paths) != len(input): + raise ValueError( + "Number of file paths must match the number of documents" + ) + else: + # If no file paths provided, use placeholder + file_paths = ["unknown_source"] * len(input) # 1. Validate ids if provided or generate MD5 hash IDs if ids is not None: @@ -681,32 +710,59 @@ class LightRAG: raise ValueError("IDs must be unique") # Generate contents dict of IDs provided by user and documents - contents = {id_: doc for id_, doc in zip(ids, input)} + contents = { + id_: {"content": doc, "file_path": path} + for id_, doc, path in zip(ids, input, file_paths) + } else: # Clean input text and remove duplicates - input = list(set(clean_text(doc) for doc in input)) - # Generate contents dict of MD5 hash IDs and documents - contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input} + cleaned_input = [ + (clean_text(doc), path) for doc, path in zip(input, file_paths) + ] + unique_content_with_paths = {} + + # Keep track of unique content and their paths + for content, path in cleaned_input: + if content not in unique_content_with_paths: + unique_content_with_paths[content] = path + + # Generate contents dict of MD5 hash IDs and documents with paths + contents = { + compute_mdhash_id(content, prefix="doc-"): { + "content": content, + "file_path": path, + } + for content, path in unique_content_with_paths.items() + } # 2. Remove duplicate contents - unique_contents = { - id_: content - for content, id_ in { - content: id_ for id_, content in contents.items() - }.items() + unique_contents = {} + for id_, content_data in contents.items(): + content = content_data["content"] + file_path = content_data["file_path"] + if content not in unique_contents: + unique_contents[content] = (id_, file_path) + + # Reconstruct contents with unique content + contents = { + id_: {"content": content, "file_path": file_path} + for content, (id_, file_path) in unique_contents.items() } # 3. Generate document initial status new_docs: dict[str, Any] = { id_: { - "content": content, - "content_summary": get_content_summary(content), - "content_length": len(content), "status": DocStatus.PENDING, + "content": content_data["content"], + "content_summary": get_content_summary(content_data["content"]), + "content_length": len(content_data["content"]), "created_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(), + "file_path": content_data[ + "file_path" + ], # Store file path in document status } - for id_, content in unique_contents.items() + for id_, content_data in contents.items() } # 4. Filter out already processed documents @@ -841,11 +897,15 @@ class LightRAG: ) -> None: """Process single document""" try: + # Get file path from status document + file_path = getattr(status_doc, "file_path", "unknown_source") + # Generate chunks from document chunks: dict[str, Any] = { compute_mdhash_id(dp["content"], prefix="chunk-"): { **dp, "full_doc_id": doc_id, + "file_path": file_path, # Add file path to each chunk } for dp in self.chunking_func( status_doc.content, @@ -856,6 +916,7 @@ class LightRAG: self.tiktoken_model_name, ) } + # Process document (text chunks and full docs) in parallel # Create tasks with references for potential cancellation doc_status_task = asyncio.create_task( @@ -863,11 +924,13 @@ class LightRAG: { doc_id: { "status": DocStatus.PROCESSING, - "updated_at": datetime.now().isoformat(), + "chunks_count": len(chunks), "content": status_doc.content, "content_summary": status_doc.content_summary, "content_length": status_doc.content_length, "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + "file_path": file_path, } } ) @@ -906,6 +969,7 @@ class LightRAG: "content_length": status_doc.content_length, "created_at": status_doc.created_at, "updated_at": datetime.now().isoformat(), + "file_path": file_path, } } ) @@ -937,6 +1001,7 @@ class LightRAG: "content_length": status_doc.content_length, "created_at": status_doc.created_at, "updated_at": datetime.now().isoformat(), + "file_path": file_path, } } ) @@ -1063,7 +1128,10 @@ class LightRAG: loop.run_until_complete(self.ainsert_custom_kg(custom_kg, full_doc_id)) async def ainsert_custom_kg( - self, custom_kg: dict[str, Any], full_doc_id: str = None + self, + custom_kg: dict[str, Any], + full_doc_id: str = None, + file_path: str = "custom_kg", ) -> None: update_storage = False try: @@ -1093,6 +1161,7 @@ class LightRAG: "full_doc_id": full_doc_id if full_doc_id is not None else source_id, + "file_path": file_path, # Add file path "status": DocStatus.PROCESSED, } all_chunks_data[chunk_id] = chunk_entry @@ -1197,6 +1266,7 @@ class LightRAG: "source_id": dp["source_id"], "description": dp["description"], "entity_type": dp["entity_type"], + "file_path": file_path, # Add file path } for dp in all_entities_data } @@ -1212,6 +1282,7 @@ class LightRAG: "keywords": dp["keywords"], "description": dp["description"], "weight": dp["weight"], + "file_path": file_path, # Add file path } for dp in all_relationships_data } @@ -1473,8 +1544,7 @@ class LightRAG: """ try: # 1. Get the document status and related data - doc_status = await self.doc_status.get_by_id(doc_id) - if not doc_status: + if not await self.doc_status.get_by_id(doc_id): logger.warning(f"Document {doc_id} not found") return @@ -1877,6 +1947,8 @@ class LightRAG: # 2. Update entity information in the graph new_node_data = {**node_data, **updated_data} + new_node_data["entity_id"] = new_entity_name + if "entity_name" in new_node_data: del new_node_data[ "entity_name" @@ -1893,7 +1965,7 @@ class LightRAG: # Store relationships that need to be updated relations_to_update = [] - + relations_to_delete = [] # Get all edges related to the original entity edges = await self.chunk_entity_relation_graph.get_node_edges( entity_name @@ -1905,6 +1977,12 @@ class LightRAG: source, target ) if edge_data: + relations_to_delete.append( + compute_mdhash_id(source + target, prefix="rel-") + ) + relations_to_delete.append( + compute_mdhash_id(target + source, prefix="rel-") + ) if source == entity_name: await self.chunk_entity_relation_graph.upsert_edge( new_entity_name, target, edge_data @@ -1930,6 +2008,12 @@ class LightRAG: f"Deleted old entity '{entity_name}' and its vector embedding from database" ) + # Delete old relation records from vector database + await self.relationships_vdb.delete(relations_to_delete) + logger.info( + f"Deleted {len(relations_to_delete)} relation records for entity '{entity_name}' from vector database" + ) + # Update relationship vector representations for src, tgt, edge_data in relations_to_update: description = edge_data.get("description", "") @@ -2220,7 +2304,6 @@ class LightRAG: """Synchronously create a new entity. Creates a new entity in the knowledge graph and adds it to the vector database. - Args: entity_name: Name of the new entity entity_data: Dictionary containing entity attributes, e.g. {"description": "description", "entity_type": "type"} @@ -2429,39 +2512,21 @@ class LightRAG: # 4. Get all relationships of the source entities all_relations = [] for entity_name in source_entities: - # Get all relationships where this entity is the source - outgoing_edges = await self.chunk_entity_relation_graph.get_node_edges( + # Get all relationships of the source entities + edges = await self.chunk_entity_relation_graph.get_node_edges( entity_name ) - if outgoing_edges: - for src, tgt in outgoing_edges: + if edges: + for src, tgt in edges: # Ensure src is the current entity if src == entity_name: edge_data = await self.chunk_entity_relation_graph.get_edge( src, tgt ) - all_relations.append(("outgoing", src, tgt, edge_data)) - - # Get all relationships where this entity is the target - incoming_edges = [] - all_labels = await self.chunk_entity_relation_graph.get_all_labels() - for label in all_labels: - if label == entity_name: - continue - node_edges = await self.chunk_entity_relation_graph.get_node_edges( - label - ) - for src, tgt in node_edges or []: - if tgt == entity_name: - incoming_edges.append((src, tgt)) - - for src, tgt in incoming_edges: - edge_data = await self.chunk_entity_relation_graph.get_edge( - src, tgt - ) - all_relations.append(("incoming", src, tgt, edge_data)) + all_relations.append((src, tgt, edge_data)) # 5. Create or update the target entity + merged_entity_data["entity_id"] = target_entity if not target_exists: await self.chunk_entity_relation_graph.upsert_node( target_entity, merged_entity_data @@ -2475,8 +2540,11 @@ class LightRAG: # 6. Recreate all relationships, pointing to the target entity relation_updates = {} # Track relationships that need to be merged + relations_to_delete = [] - for rel_type, src, tgt, edge_data in all_relations: + for src, tgt, edge_data in all_relations: + relations_to_delete.append(compute_mdhash_id(src + tgt, prefix="rel-")) + relations_to_delete.append(compute_mdhash_id(tgt + src, prefix="rel-")) new_src = target_entity if src in source_entities else src new_tgt = target_entity if tgt in source_entities else tgt @@ -2521,6 +2589,12 @@ class LightRAG: f"Created or updated relationship: {rel_data['src']} -> {rel_data['tgt']}" ) + # Delete relationships records from vector database + await self.relationships_vdb.delete(relations_to_delete) + logger.info( + f"Deleted {len(relations_to_delete)} relation records for entity '{entity_name}' from vector database" + ) + # 7. Update entity vector representation description = merged_entity_data.get("description", "") source_id = merged_entity_data.get("source_id", "") @@ -2583,19 +2657,6 @@ class LightRAG: entity_id = compute_mdhash_id(entity_name, prefix="ent-") await self.entities_vdb.delete([entity_id]) - # Also ensure any relationships specific to this entity are deleted from vector DB - # This is a safety check, as these should have been transformed to the target entity already - entity_relation_prefix = compute_mdhash_id(entity_name, prefix="rel-") - relations_with_entity = await self.relationships_vdb.search_by_prefix( - entity_relation_prefix - ) - if relations_with_entity: - relation_ids = [r["id"] for r in relations_with_entity] - await self.relationships_vdb.delete(relation_ids) - logger.info( - f"Deleted {len(relation_ids)} relation records for entity '{entity_name}' from vector database" - ) - logger.info( f"Deleted source entity '{entity_name}' and its vector embedding from database" ) diff --git a/lightrag/llm/hf.py b/lightrag/llm/hf.py index fb5208b016a083b069fe44772870a95e69b4494b..954a99b76e1e244d2f4abb34493f2d956fabe442 100644 --- a/lightrag/llm/hf.py +++ b/lightrag/llm/hf.py @@ -138,16 +138,31 @@ async def hf_model_complete( async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: - device = next(embed_model.parameters()).device + # Detect the appropriate device + if torch.cuda.is_available(): + device = next(embed_model.parameters()).device # Use CUDA if available + elif torch.backends.mps.is_available(): + device = torch.device("mps") # Use MPS for Apple Silicon + else: + device = torch.device("cpu") # Fallback to CPU + + # Move the model to the detected device + embed_model = embed_model.to(device) + + # Tokenize the input texts and move them to the same device encoded_texts = tokenizer( texts, return_tensors="pt", padding=True, truncation=True ).to(device) + + # Perform inference with torch.no_grad(): outputs = embed_model( input_ids=encoded_texts["input_ids"], attention_mask=encoded_texts["attention_mask"], ) embeddings = outputs.last_hidden_state.mean(dim=1) + + # Convert embeddings to NumPy if embeddings.dtype == torch.bfloat16: return embeddings.detach().to(torch.float32).cpu().numpy() else: diff --git a/lightrag/operate.py b/lightrag/operate.py index d062ae73a03831718340f948de80706aca98a910..3291c49f30a4403591d01315b245d26846a0ffad 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -138,6 +138,7 @@ async def _handle_entity_relation_summary( async def _handle_single_entity_extraction( record_attributes: list[str], chunk_key: str, + file_path: str = "unknown_source", ): if len(record_attributes) < 4 or record_attributes[0] != '"entity"': return None @@ -171,13 +172,14 @@ async def _handle_single_entity_extraction( entity_type=entity_type, description=entity_description, source_id=chunk_key, - metadata={"created_at": time.time()}, + file_path=file_path, ) async def _handle_single_relationship_extraction( record_attributes: list[str], chunk_key: str, + file_path: str = "unknown_source", ): if len(record_attributes) < 5 or record_attributes[0] != '"relationship"': return None @@ -199,7 +201,7 @@ async def _handle_single_relationship_extraction( description=edge_description, keywords=edge_keywords, source_id=edge_source_id, - metadata={"created_at": time.time()}, + file_path=file_path, ) @@ -213,6 +215,7 @@ async def _merge_nodes_then_upsert( already_entity_types = [] already_source_ids = [] already_description = [] + already_file_paths = [] already_node = await knowledge_graph_inst.get_node(entity_name) if already_node is not None: @@ -220,6 +223,9 @@ async def _merge_nodes_then_upsert( already_source_ids.extend( split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP]) ) + already_file_paths.extend( + split_string_by_multi_markers(already_node["file_path"], [GRAPH_FIELD_SEP]) + ) already_description.append(already_node["description"]) entity_type = sorted( @@ -235,6 +241,11 @@ async def _merge_nodes_then_upsert( source_id = GRAPH_FIELD_SEP.join( set([dp["source_id"] for dp in nodes_data] + already_source_ids) ) + file_path = GRAPH_FIELD_SEP.join( + set([dp["file_path"] for dp in nodes_data] + already_file_paths) + ) + + logger.debug(f"file_path: {file_path}") description = await _handle_entity_relation_summary( entity_name, description, global_config ) @@ -243,6 +254,7 @@ async def _merge_nodes_then_upsert( entity_type=entity_type, description=description, source_id=source_id, + file_path=file_path, ) await knowledge_graph_inst.upsert_node( entity_name, @@ -263,6 +275,7 @@ async def _merge_edges_then_upsert( already_source_ids = [] already_description = [] already_keywords = [] + already_file_paths = [] if await knowledge_graph_inst.has_edge(src_id, tgt_id): already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) @@ -279,6 +292,14 @@ async def _merge_edges_then_upsert( ) ) + # Get file_path with empty string default if missing or None + if already_edge.get("file_path") is not None: + already_file_paths.extend( + split_string_by_multi_markers( + already_edge["file_path"], [GRAPH_FIELD_SEP] + ) + ) + # Get description with empty string default if missing or None if already_edge.get("description") is not None: already_description.append(already_edge["description"]) @@ -315,6 +336,12 @@ async def _merge_edges_then_upsert( + already_source_ids ) ) + file_path = GRAPH_FIELD_SEP.join( + set( + [dp["file_path"] for dp in edges_data if dp.get("file_path")] + + already_file_paths + ) + ) for need_insert_id in [src_id, tgt_id]: if not (await knowledge_graph_inst.has_node(need_insert_id)): @@ -325,6 +352,7 @@ async def _merge_edges_then_upsert( "source_id": source_id, "description": description, "entity_type": "UNKNOWN", + "file_path": file_path, }, ) description = await _handle_entity_relation_summary( @@ -338,6 +366,7 @@ async def _merge_edges_then_upsert( description=description, keywords=keywords, source_id=source_id, + file_path=file_path, ), ) @@ -347,6 +376,7 @@ async def _merge_edges_then_upsert( description=description, keywords=keywords, source_id=source_id, + file_path=file_path, ) return edge_data @@ -456,11 +486,14 @@ async def extract_entities( else: return await use_llm_func(input_text) - async def _process_extraction_result(result: str, chunk_key: str): + async def _process_extraction_result( + result: str, chunk_key: str, file_path: str = "unknown_source" + ): """Process a single extraction result (either initial or gleaning) Args: result (str): The extraction result to process chunk_key (str): The chunk key for source tracking + file_path (str): The file path for citation Returns: tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships """ @@ -482,14 +515,14 @@ async def extract_entities( ) if_entities = await _handle_single_entity_extraction( - record_attributes, chunk_key + record_attributes, chunk_key, file_path ) if if_entities is not None: maybe_nodes[if_entities["entity_name"]].append(if_entities) continue if_relation = await _handle_single_relationship_extraction( - record_attributes, chunk_key + record_attributes, chunk_key, file_path ) if if_relation is not None: maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append( @@ -508,6 +541,8 @@ async def extract_entities( chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] content = chunk_dp["content"] + # Get file path from chunk data or use default + file_path = chunk_dp.get("file_path", "unknown_source") # Get initial extraction hint_prompt = entity_extract_prompt.format( @@ -517,9 +552,9 @@ async def extract_entities( final_result = await _user_llm_func_with_cache(hint_prompt) history = pack_user_ass_to_openai_messages(hint_prompt, final_result) - # Process initial extraction + # Process initial extraction with file path maybe_nodes, maybe_edges = await _process_extraction_result( - final_result, chunk_key + final_result, chunk_key, file_path ) # Process additional gleaning results @@ -530,9 +565,9 @@ async def extract_entities( history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) - # Process gleaning result separately + # Process gleaning result separately with file path glean_nodes, glean_edges = await _process_extraction_result( - glean_result, chunk_key + glean_result, chunk_key, file_path ) # Merge results @@ -637,9 +672,7 @@ async def extract_entities( "entity_type": dp["entity_type"], "content": f"{dp['entity_name']}\n{dp['description']}", "source_id": dp["source_id"], - "metadata": { - "created_at": dp.get("metadata", {}).get("created_at", time.time()) - }, + "file_path": dp.get("file_path", "unknown_source"), } for dp in all_entities_data } @@ -653,9 +686,7 @@ async def extract_entities( "keywords": dp["keywords"], "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}", "source_id": dp["source_id"], - "metadata": { - "created_at": dp.get("metadata", {}).get("created_at", time.time()) - }, + "file_path": dp.get("file_path", "unknown_source"), } for dp in all_relationships_data } @@ -1232,12 +1263,17 @@ async def _get_node_data( "description", "rank", "created_at", + "file_path", ] ] for i, n in enumerate(node_datas): created_at = n.get("created_at", "UNKNOWN") if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) + + # Get file path from node data + file_path = n.get("file_path", "unknown_source") + entites_section_list.append( [ i, @@ -1246,6 +1282,7 @@ async def _get_node_data( n.get("description", "UNKNOWN"), n["rank"], created_at, + file_path, ] ) entities_context = list_of_list_to_csv(entites_section_list) @@ -1260,6 +1297,7 @@ async def _get_node_data( "weight", "rank", "created_at", + "file_path", ] ] for i, e in enumerate(use_relations): @@ -1267,6 +1305,10 @@ async def _get_node_data( # Convert timestamp to readable format if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) + + # Get file path from edge data + file_path = e.get("file_path", "unknown_source") + relations_section_list.append( [ i, @@ -1277,6 +1319,7 @@ async def _get_node_data( e["weight"], e["rank"], created_at, + file_path, ] ) relations_context = list_of_list_to_csv(relations_section_list) @@ -1492,6 +1535,7 @@ async def _get_edge_data( "weight", "rank", "created_at", + "file_path", ] ] for i, e in enumerate(edge_datas): @@ -1499,6 +1543,10 @@ async def _get_edge_data( # Convert timestamp to readable format if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) + + # Get file path from edge data + file_path = e.get("file_path", "unknown_source") + relations_section_list.append( [ i, @@ -1509,16 +1557,23 @@ async def _get_edge_data( e["weight"], e["rank"], created_at, + file_path, ] ) relations_context = list_of_list_to_csv(relations_section_list) - entites_section_list = [["id", "entity", "type", "description", "rank"]] + entites_section_list = [ + ["id", "entity", "type", "description", "rank", "created_at", "file_path"] + ] for i, n in enumerate(use_entities): - created_at = e.get("created_at", "Unknown") + created_at = n.get("created_at", "Unknown") # Convert timestamp to readable format if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) + + # Get file path from node data + file_path = n.get("file_path", "unknown_source") + entites_section_list.append( [ i, @@ -1527,6 +1582,7 @@ async def _get_edge_data( n.get("description", "UNKNOWN"), n["rank"], created_at, + file_path, ] ) entities_context = list_of_list_to_csv(entites_section_list) @@ -1882,13 +1938,14 @@ async def kg_query_with_keywords( len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt)) logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}") + # 6. Generate response response = await use_model_func( query, system_prompt=sys_prompt, stream=query_param.stream, ) - # 清理响应内容 + # Clean up response content if isinstance(response, str) and len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") diff --git a/lightrag/prompt.py b/lightrag/prompt.py index f81cd44168e8f3ec52d5c76599ca29ba1dd522e9..88ebd7fceca20aaf048a40d347e630d9c282e86f 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -61,7 +61,7 @@ Text: ``` while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order. -Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.” +Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. "If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us." The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce. @@ -92,7 +92,7 @@ Among the hardest hit, Nexon Technologies saw its stock plummet by 7.8% after re Meanwhile, commodity markets reflected a mixed sentiment. Gold futures rose by 1.5%, reaching $2,080 per ounce, as investors sought safe-haven assets. Crude oil prices continued their rally, climbing to $87.60 per barrel, supported by supply constraints and strong demand. -Financial experts are closely watching the Federal Reserve’s next move, as speculation grows over potential rate hikes. The upcoming policy announcement is expected to influence investor confidence and overall market stability. +Financial experts are closely watching the Federal Reserve's next move, as speculation grows over potential rate hikes. The upcoming policy announcement is expected to influence investor confidence and overall market stability. ``` Output: @@ -222,6 +222,7 @@ When handling relationships with timestamps: - Use markdown formatting with appropriate section headings - Please respond in the same language as the user's question. - Ensure the response maintains continuity with the conversation history. +- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path) - If you don't know the answer, just say so. - Do not make anything up. Do not include information not provided by the Knowledge Base.""" @@ -319,6 +320,7 @@ When handling content with timestamps: - Use markdown formatting with appropriate section headings - Please respond in the same language as the user's question. - Ensure the response maintains continuity with the conversation history. +- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path) - If you don't know the answer, just say so. - Do not include information not provided by the Document Chunks.""" @@ -378,8 +380,8 @@ When handling information with timestamps: - Use markdown formatting with appropriate section headings - Please respond in the same language as the user's question. - Ensure the response maintains continuity with the conversation history. -- Organize answer in sesctions focusing on one main point or aspect of the answer +- Organize answer in sections focusing on one main point or aspect of the answer - Use clear and descriptive section titles that reflect the content -- List up to 5 most important reference sources at the end under "References" sesction. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), in the following format: [KG/DC] Source content +- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path) - If you don't know the answer, just say so. Do not make anything up. - Do not include information not provided by the Data Sources.""" diff --git a/lightrag/utils.py b/lightrag/utils.py index 362e553116f6f1005125856c68fa873918e39df3..282042f6a2eb3fb943a48df705859f002dfe4705 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -109,15 +109,17 @@ def setup_logger( logger_name: str, level: str = "INFO", add_filter: bool = False, - log_file_path: str = None, + log_file_path: str | None = None, + enable_file_logging: bool = True, ): - """Set up a logger with console and file handlers + """Set up a logger with console and optionally file handlers Args: logger_name: Name of the logger to set up level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) add_filter: Whether to add LightragPathFilter to the logger - log_file_path: Path to the log file. If None, will use current directory/lightrag.log + log_file_path: Path to the log file. If None and file logging is enabled, defaults to lightrag.log in LOG_DIR or cwd + enable_file_logging: Whether to enable logging to a file (defaults to True) """ # Configure formatters detailed_formatter = logging.Formatter( @@ -125,18 +127,6 @@ def setup_logger( ) simple_formatter = logging.Formatter("%(levelname)s: %(message)s") - # Get log file path - if log_file_path is None: - log_dir = os.getenv("LOG_DIR", os.getcwd()) - log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log")) - - # Ensure log directory exists - os.makedirs(os.path.dirname(log_file_path), exist_ok=True) - - # Get log file max size and backup count from environment variables - log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB - log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups - logger_instance = logging.getLogger(logger_name) logger_instance.setLevel(level) logger_instance.handlers = [] # Clear existing handlers @@ -148,16 +138,34 @@ def setup_logger( console_handler.setLevel(level) logger_instance.addHandler(console_handler) - # Add file handler - file_handler = logging.handlers.RotatingFileHandler( - filename=log_file_path, - maxBytes=log_max_bytes, - backupCount=log_backup_count, - encoding="utf-8", - ) - file_handler.setFormatter(detailed_formatter) - file_handler.setLevel(level) - logger_instance.addHandler(file_handler) + # Add file handler by default unless explicitly disabled + if enable_file_logging: + # Get log file path + if log_file_path is None: + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log")) + + # Ensure log directory exists + os.makedirs(os.path.dirname(log_file_path), exist_ok=True) + + # Get log file max size and backup count from environment variables + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups + + try: + # Add file handler + file_handler = logging.handlers.RotatingFileHandler( + filename=log_file_path, + maxBytes=log_max_bytes, + backupCount=log_backup_count, + encoding="utf-8", + ) + file_handler.setFormatter(detailed_formatter) + file_handler.setLevel(level) + logger_instance.addHandler(file_handler) + except PermissionError as e: + logger.warning(f"Could not create log file at {log_file_path}: {str(e)}") + logger.warning("Continuing with console logging only") # Add path filter if requested if add_filter: diff --git a/lightrag_webui/bun.lock b/lightrag_webui/bun.lock index d3f75f5c5560ba7a6c28fbb3c4e9c33efe6cc615..35ffdf1354e84774381b21f2dacf1e4db758b73c 100644 --- a/lightrag_webui/bun.lock +++ b/lightrag_webui/bun.lock @@ -40,9 +40,11 @@ "react": "^19.0.0", "react-dom": "^19.0.0", "react-dropzone": "^14.3.6", + "react-error-boundary": "^5.0.0", "react-i18next": "^15.4.1", "react-markdown": "^9.1.0", "react-number-format": "^5.4.3", + "react-router-dom": "^7.3.0", "react-syntax-highlighter": "^15.6.1", "rehype-react": "^8.0.0", "remark-gfm": "^4.0.1", @@ -418,6 +420,8 @@ "@types/bun": ["@types/bun@1.2.3", "", { "dependencies": { "bun-types": "1.2.3" } }, "sha512-054h79ipETRfjtsCW9qJK8Ipof67Pw9bodFWmkfkaUaRiIQ1dIV2VTlheshlBx3mpKr0KeK8VqnMMCtgN9rQtw=="], + "@types/cookie": ["@types/cookie@0.6.0", "https://registry.npmmirror.com/@types/cookie/-/cookie-0.6.0.tgz", {}, "sha512-4Kh9a6B2bQciAhf7FSuMRRkUWecJgJu9nPnx3yzpsfXX/c50REIqpHY4C82bXP90qrLtXtkDxTZosYO3UpOwlA=="], + "@types/debug": ["@types/debug@4.1.12", "", { "dependencies": { "@types/ms": "*" } }, "sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ=="], "@types/estree": ["@types/estree@1.0.6", "", {}, "sha512-AYnb1nQyY49te+VRAVgmzfcgjYS91mY5P0TKUDCLEM+gNnA+3T6rWITXRLYCpahpqSQbN5cE+gHpnPyXjHWxcw=="], @@ -566,6 +570,8 @@ "convert-source-map": ["convert-source-map@1.9.0", "", {}, "sha512-ASFBup0Mz1uyiIjANan1jzLQami9z1PoYSZCiiYW2FczPbenXc45FZdBZLzOT+r6+iciuEModtmCti+hjaAk0A=="], + "cookie": ["cookie@1.0.2", "https://registry.npmmirror.com/cookie/-/cookie-1.0.2.tgz", {}, "sha512-9Kr/j4O16ISv8zBBhJoi4bXOYNTkFLOqSL3UDB0njXxCXNezjeyVrJyGOWtgfs/q2km1gwBcfH8q1yEGoMYunA=="], + "cosmiconfig": ["cosmiconfig@7.1.0", "", { "dependencies": { "@types/parse-json": "^4.0.0", "import-fresh": "^3.2.1", "parse-json": "^5.0.0", "path-type": "^4.0.0", "yaml": "^1.10.0" } }, "sha512-AdmX6xUzdNASswsFtmwSt7Vj8po9IuqXm0UXz7QKPuEUmPB4XyjGfaAr2PSuELMwkRMVH1EpIkX5bTZGRB3eCA=="], "cross-spawn": ["cross-spawn@7.0.6", "", { "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", "which": "^2.0.1" } }, "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA=="], @@ -1102,6 +1108,8 @@ "react-dropzone": ["react-dropzone@14.3.6", "", { "dependencies": { "attr-accept": "^2.2.4", "file-selector": "^2.1.0", "prop-types": "^15.8.1" }, "peerDependencies": { "react": ">= 16.8 || 18.0.0" } }, "sha512-U792j+x0rcwH/U/Slv/OBNU/LGFYbDLHKKiJoPhNaOianayZevCt4Y5S0CraPssH/6/wT6xhKDfzdXUgCBS0HQ=="], + "react-error-boundary": ["react-error-boundary@5.0.0", "", { "dependencies": { "@babel/runtime": "^7.12.5" }, "peerDependencies": { "react": ">=16.13.1" } }, "sha512-tnjAxG+IkpLephNcePNA7v6F/QpWLH8He65+DmedchDwg162JZqx4NmbXj0mlAYVVEd81OW7aFhmbsScYfiAFQ=="], + "react-i18next": ["react-i18next@15.4.1", "", { "dependencies": { "@babel/runtime": "^7.25.0", "html-parse-stringify": "^3.0.1" }, "peerDependencies": { "i18next": ">= 23.2.3", "react": ">= 16.8.0" } }, "sha512-ahGab+IaSgZmNPYXdV1n+OYky95TGpFwnKRflX/16dY04DsYYKHtVLjeny7sBSCREEcoMbAgSkFiGLF5g5Oofw=="], "react-is": ["react-is@16.13.1", "", {}, "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ=="], @@ -1114,6 +1122,10 @@ "react-remove-scroll-bar": ["react-remove-scroll-bar@2.3.8", "", { "dependencies": { "react-style-singleton": "^2.2.2", "tslib": "^2.0.0" }, "peerDependencies": { "@types/react": "*", "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" }, "optionalPeers": ["@types/react"] }, "sha512-9r+yi9+mgU33AKcj6IbT9oRCO78WriSj6t/cF8DWBZJ9aOGPOTEDvdUDz1FwKim7QXWwmHqtdHnRJfhAxEG46Q=="], + "react-router": ["react-router@7.3.0", "https://registry.npmmirror.com/react-router/-/react-router-7.3.0.tgz", { "dependencies": { "@types/cookie": "^0.6.0", "cookie": "^1.0.1", "set-cookie-parser": "^2.6.0", "turbo-stream": "2.4.0" }, "peerDependencies": { "react": ">=18", "react-dom": ">=18" }, "optionalPeers": ["react-dom"] }, "sha512-466f2W7HIWaNXTKM5nHTqNxLrHTyXybm7R0eBlVSt0k/u55tTCDO194OIx/NrYD4TS5SXKTNekXfT37kMKUjgw=="], + + "react-router-dom": ["react-router-dom@7.3.0", "https://registry.npmmirror.com/react-router-dom/-/react-router-dom-7.3.0.tgz", { "dependencies": { "react-router": "7.3.0" }, "peerDependencies": { "react": ">=18", "react-dom": ">=18" } }, "sha512-z7Q5FTiHGgQfEurX/FBinkOXhWREJIAB2RiU24lvcBa82PxUpwqvs/PAXb9lJyPjTs2jrl6UkLvCZVGJPeNuuQ=="], + "react-select": ["react-select@5.10.0", "", { "dependencies": { "@babel/runtime": "^7.12.0", "@emotion/cache": "^11.4.0", "@emotion/react": "^11.8.1", "@floating-ui/dom": "^1.0.1", "@types/react-transition-group": "^4.4.0", "memoize-one": "^6.0.0", "prop-types": "^15.6.0", "react-transition-group": "^4.3.0", "use-isomorphic-layout-effect": "^1.2.0" }, "peerDependencies": { "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-k96gw+i6N3ExgDwPIg0lUPmexl1ygPe6u5BdQFNBhkpbwroIgCNXdubtIzHfThYXYYTubwOBafoMnn7ruEP1xA=="], "react-style-singleton": ["react-style-singleton@2.2.3", "", { "dependencies": { "get-nonce": "^1.0.0", "tslib": "^2.0.0" }, "peerDependencies": { "@types/react": "*", "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react"] }, "sha512-b6jSvxvVnyptAiLjbkWLE/lOnR4lfTtDAl+eUC7RZy+QQWc6wRzIV2CE6xBuMmDxc2qIihtDCZD5NPOFl7fRBQ=="], @@ -1164,6 +1176,8 @@ "semver": ["semver@6.3.1", "", { "bin": { "semver": "bin/semver.js" } }, "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA=="], + "set-cookie-parser": ["set-cookie-parser@2.7.1", "https://registry.npmmirror.com/set-cookie-parser/-/set-cookie-parser-2.7.1.tgz", {}, "sha512-IOc8uWeOZgnb3ptbCURJWNjWUPcO3ZnTTdzsurqERrP6nPyv+paC55vJM0LpOlT2ne+Ix+9+CRG1MNLlyZ4GjQ=="], + "set-function-length": ["set-function-length@1.2.2", "", { "dependencies": { "define-data-property": "^1.1.4", "es-errors": "^1.3.0", "function-bind": "^1.1.2", "get-intrinsic": "^1.2.4", "gopd": "^1.0.1", "has-property-descriptors": "^1.0.2" } }, "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg=="], "set-function-name": ["set-function-name@2.0.2", "", { "dependencies": { "define-data-property": "^1.1.4", "es-errors": "^1.3.0", "functions-have-names": "^1.2.3", "has-property-descriptors": "^1.0.2" } }, "sha512-7PGFlmtwsEADb0WYyvCMa1t+yke6daIG4Wirafur5kcf+MhUnPms1UeR0CKQdTZD81yESwMHbtn+TR+dMviakQ=="], @@ -1234,6 +1248,8 @@ "tslib": ["tslib@2.8.1", "", {}, "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w=="], + "turbo-stream": ["turbo-stream@2.4.0", "https://registry.npmmirror.com/turbo-stream/-/turbo-stream-2.4.0.tgz", {}, "sha512-FHncC10WpBd2eOmGwpmQsWLDoK4cqsA/UT/GqNoaKOQnT8uzhtCbg3EoUDMvqpOSAI0S26mr0rkjzbOO6S3v1g=="], + "type-check": ["type-check@0.4.0", "", { "dependencies": { "prelude-ls": "^1.2.1" } }, "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew=="], "typed-array-buffer": ["typed-array-buffer@1.0.3", "", { "dependencies": { "call-bound": "^1.0.3", "es-errors": "^1.3.0", "is-typed-array": "^1.1.14" } }, "sha512-nAYYwfY3qnzX30IkA6AQZjVbtK6duGontcQm1WSG1MD94YLqK0515GNApXkoxKOWMusVssAHWLh9SeaoefYFGw=="], diff --git a/lightrag_webui/env.development.smaple b/lightrag_webui/env.development.smaple new file mode 100644 index 0000000000000000000000000000000000000000..080cf95f08ed80fb67e4e03c4ca97b7bcf4ef03e --- /dev/null +++ b/lightrag_webui/env.development.smaple @@ -0,0 +1,2 @@ +# Development environment configuration +VITE_BACKEND_URL=/api diff --git a/lightrag_webui/env.local.sample b/lightrag_webui/env.local.sample new file mode 100644 index 0000000000000000000000000000000000000000..0cd53ad5719728d49cdc69e99c2a7aa7322d502f --- /dev/null +++ b/lightrag_webui/env.local.sample @@ -0,0 +1,3 @@ +VITE_BACKEND_URL=http://localhost:9621 +VITE_API_PROXY=true +VITE_API_ENDPOINTS=/,/api,/documents,/graphs,/graph,/health,/query,/docs,/openapi.json,/login,/auth-status diff --git a/lightrag_webui/index.html b/lightrag_webui/index.html index 5550430225d62c54c486e15a1b60f61e7954c158..3dd1ebbc2cc79e324b645bcbee8787d22c0383ec 100644 --- a/lightrag_webui/index.html +++ b/lightrag_webui/index.html @@ -5,7 +5,7 @@ - + Lightrag diff --git a/lightrag_webui/package.json b/lightrag_webui/package.json index 8a10d1e639ca6482be011f03478145dc52d4f79b..8476f8f9a9a9fd6c8e27480f3f31621e2804918e 100644 --- a/lightrag_webui/package.json +++ b/lightrag_webui/package.json @@ -49,9 +49,11 @@ "react": "^19.0.0", "react-dom": "^19.0.0", "react-dropzone": "^14.3.6", + "react-error-boundary": "^5.0.0", "react-i18next": "^15.4.1", "react-markdown": "^9.1.0", "react-number-format": "^5.4.3", + "react-router-dom": "^7.3.0", "react-syntax-highlighter": "^15.6.1", "rehype-react": "^8.0.0", "remark-gfm": "^4.0.1", diff --git a/lightrag_webui/src/App.tsx b/lightrag_webui/src/App.tsx index b7d66b7ebcc3223d50b39a5a5771058880a66372..4596f684beca4c841fa6a7296ee4b5a626387a1f 100644 --- a/lightrag_webui/src/App.tsx +++ b/lightrag_webui/src/App.tsx @@ -8,7 +8,6 @@ import { healthCheckInterval } from '@/lib/constants' import { useBackendState } from '@/stores/state' import { useSettingsStore } from '@/stores/settings' import { useEffect } from 'react' -import { Toaster } from 'sonner' import SiteHeader from '@/features/SiteHeader' import { InvalidApiKeyError, RequireApiKeError } from '@/api/lightrag' @@ -27,8 +26,6 @@ function App() { // Health check useEffect(() => { - if (!enableHealthCheck) return - // Check immediately useBackendState.getState().check() @@ -56,24 +53,24 @@ function App() { return ( -
+
- + - + - + - +
@@ -81,7 +78,6 @@ function App() { {enableHealthCheck && } {message !== null && !apiKeyInvalid && } {apiKeyInvalid && } -
diff --git a/lightrag_webui/src/AppRouter.tsx b/lightrag_webui/src/AppRouter.tsx new file mode 100644 index 0000000000000000000000000000000000000000..e85a97bebfb2340c7eeb63cc8d73905f4c2031b3 --- /dev/null +++ b/lightrag_webui/src/AppRouter.tsx @@ -0,0 +1,190 @@ +import { HashRouter as Router, Routes, Route, useNavigate } from 'react-router-dom' +import { useEffect, useState } from 'react' +import { useAuthStore } from '@/stores/state' +import { navigationService } from '@/services/navigation' +import { getAuthStatus } from '@/api/lightrag' +import { toast } from 'sonner' +import { Toaster } from 'sonner' +import App from './App' +import LoginPage from '@/features/LoginPage' +import ThemeProvider from '@/components/ThemeProvider' + +interface ProtectedRouteProps { + children: React.ReactNode +} + +const ProtectedRoute = ({ children }: ProtectedRouteProps) => { + const { isAuthenticated } = useAuthStore() + const [isChecking, setIsChecking] = useState(true) + const navigate = useNavigate() + + // Set navigate function for navigation service + useEffect(() => { + navigationService.setNavigate(navigate) + }, [navigate]) + + useEffect(() => { + let isMounted = true; // Flag to prevent state updates after unmount + + // This effect will run when the component mounts + // and will check if authentication is required + const checkAuthStatus = async () => { + try { + // Skip check if already authenticated + if (isAuthenticated) { + if (isMounted) setIsChecking(false); + return; + } + + const status = await getAuthStatus() + + // Only proceed if component is still mounted + if (!isMounted) return; + + if (!status.auth_configured && status.access_token) { + // If auth is not configured, use the guest token + useAuthStore.getState().login(status.access_token, true) + if (status.message) { + toast.info(status.message) + } + } + } catch (error) { + console.error('Failed to check auth status:', error) + } finally { + // Only update state if component is still mounted + if (isMounted) { + setIsChecking(false) + } + } + } + + // Execute immediately + checkAuthStatus() + + // Cleanup function to prevent state updates after unmount + return () => { + isMounted = false; + } + }, [isAuthenticated]) + + // Handle navigation when authentication status changes + useEffect(() => { + if (!isChecking && !isAuthenticated) { + const currentPath = window.location.hash.slice(1); // Remove the '#' from hash + const isLoginPage = currentPath === '/login'; + + if (!isLoginPage) { + // Use navigation service for redirection + console.log('Not authenticated, redirecting to login'); + navigationService.navigateToLogin(); + } + } + }, [isChecking, isAuthenticated]); + + // Show nothing while checking auth status or when not authenticated on login page + if (isChecking || (!isAuthenticated && window.location.hash.slice(1) === '/login')) { + return null; + } + + // Show children only when authenticated + if (!isAuthenticated) { + return null; + } + + return <>{children}; +} + +const AppContent = () => { + const [initializing, setInitializing] = useState(true) + const { isAuthenticated } = useAuthStore() + const navigate = useNavigate() + + // Set navigate function for navigation service + useEffect(() => { + navigationService.setNavigate(navigate) + }, [navigate]) + + // Check token validity and auth configuration on app initialization + useEffect(() => { + let isMounted = true; // Flag to prevent state updates after unmount + + const checkAuth = async () => { + try { + const token = localStorage.getItem('LIGHTRAG-API-TOKEN') + + // If we have a token, we're already authenticated + if (token && isAuthenticated) { + if (isMounted) setInitializing(false); + return; + } + + // If no token or not authenticated, check if auth is configured + const status = await getAuthStatus() + + // Only proceed if component is still mounted + if (!isMounted) return; + + if (!status.auth_configured && status.access_token) { + // If auth is not configured, use the guest token + useAuthStore.getState().login(status.access_token, true) + if (status.message) { + toast.info(status.message) + } + } else if (!token) { + // Only logout if we don't have a token + useAuthStore.getState().logout() + } + } catch (error) { + console.error('Auth initialization error:', error) + if (isMounted && !isAuthenticated) { + useAuthStore.getState().logout() + } + } finally { + // Only update state if component is still mounted + if (isMounted) { + setInitializing(false) + } + } + } + + // Execute immediately + checkAuth() + + // Cleanup function to prevent state updates after unmount + return () => { + isMounted = false; + } + }, [isAuthenticated]) + + // Show nothing while initializing + if (initializing) { + return null + } + + return ( + + } /> + + + + } + /> + + ) +} + +const AppRouter = () => { + return ( + + + + + + + ) +} + +export default AppRouter diff --git a/lightrag_webui/src/api/lightrag.ts b/lightrag_webui/src/api/lightrag.ts index cba9c96425790bc595c899ffa9a92b4f50d4ba1b..0406075a1becf02d3c21e8fa38561dda3fe2fd57 100644 --- a/lightrag_webui/src/api/lightrag.ts +++ b/lightrag_webui/src/api/lightrag.ts @@ -2,6 +2,7 @@ import axios, { AxiosError } from 'axios' import { backendBaseUrl } from '@/lib/constants' import { errorMessage } from '@/lib/utils' import { useSettingsStore } from '@/stores/settings' +import { navigationService } from '@/services/navigation' // Types export type LightragNodeType = { @@ -125,6 +126,21 @@ export type DocsStatusesResponse = { statuses: Record } +export type AuthStatusResponse = { + auth_configured: boolean + access_token?: string + token_type?: string + auth_mode?: 'enabled' | 'disabled' + message?: string +} + +export type LoginResponse = { + access_token: string + token_type: string + auth_mode?: 'enabled' | 'disabled' // Authentication mode identifier + message?: string // Optional message +} + export const InvalidApiKeyError = 'Invalid API Key' export const RequireApiKeError = 'API Key required' @@ -136,9 +152,15 @@ const axiosInstance = axios.create({ } }) -// Interceptor:add api key +// Interceptor: add api key and check authentication axiosInstance.interceptors.request.use((config) => { const apiKey = useSettingsStore.getState().apiKey + const token = localStorage.getItem('LIGHTRAG-API-TOKEN'); + + // Always include token if it exists, regardless of path + if (token) { + config.headers['Authorization'] = `Bearer ${token}` + } if (apiKey) { config.headers['X-API-Key'] = apiKey } @@ -150,6 +172,16 @@ axiosInstance.interceptors.response.use( (response) => response, (error: AxiosError) => { if (error.response) { + if (error.response?.status === 401) { + // For login API, throw error directly + if (error.config?.url?.includes('/login')) { + throw error; + } + // For other APIs, navigate to login page + navigationService.navigateToLogin(); + // Return a never-resolving promise to prevent further execution + return new Promise(() => {}); + } throw new Error( `${error.response.status} ${error.response.statusText}\n${JSON.stringify( error.response.data @@ -324,3 +356,74 @@ export const clearDocuments = async (): Promise => { const response = await axiosInstance.delete('/documents') return response.data } + +export const getAuthStatus = async (): Promise => { + try { + // Add a timeout to the request to prevent hanging + const response = await axiosInstance.get('/auth-status', { + timeout: 5000, // 5 second timeout + headers: { + 'Accept': 'application/json' // Explicitly request JSON + } + }); + + // Check if response is HTML (which indicates a redirect or wrong endpoint) + const contentType = response.headers['content-type'] || ''; + if (contentType.includes('text/html')) { + console.warn('Received HTML response instead of JSON for auth-status endpoint'); + return { + auth_configured: true, + auth_mode: 'enabled' + }; + } + + // Strict validation of the response data + if (response.data && + typeof response.data === 'object' && + 'auth_configured' in response.data && + typeof response.data.auth_configured === 'boolean') { + + // For unconfigured auth, ensure we have an access token + if (!response.data.auth_configured) { + if (response.data.access_token && typeof response.data.access_token === 'string') { + return response.data; + } else { + console.warn('Auth not configured but no valid access token provided'); + } + } else { + // For configured auth, just return the data + return response.data; + } + } + + // If response data is invalid but we got a response, log it + console.warn('Received invalid auth status response:', response.data); + + // Default to auth configured if response is invalid + return { + auth_configured: true, + auth_mode: 'enabled' + }; + } catch (error) { + // If the request fails, assume authentication is configured + console.error('Failed to get auth status:', errorMessage(error)); + return { + auth_configured: true, + auth_mode: 'enabled' + }; + } +} + +export const loginToServer = async (username: string, password: string): Promise => { + const formData = new FormData(); + formData.append('username', username); + formData.append('password', password); + + const response = await axiosInstance.post('/login', formData, { + headers: { + 'Content-Type': 'multipart/form-data' + } + }); + + return response.data; +} diff --git a/lightrag_webui/src/components/AppSettings.tsx b/lightrag_webui/src/components/AppSettings.tsx index 284ad67f33b1cb8eee308533653e966c484194cd..a1ac14039346c52d47affbfd02c70605c91fe1dd 100644 --- a/lightrag_webui/src/components/AppSettings.tsx +++ b/lightrag_webui/src/components/AppSettings.tsx @@ -5,8 +5,13 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@ import { useSettingsStore } from '@/stores/settings' import { PaletteIcon } from 'lucide-react' import { useTranslation } from 'react-i18next' +import { cn } from '@/lib/utils' -export default function AppSettings() { +interface AppSettingsProps { + className?: string +} + +export default function AppSettings({ className }: AppSettingsProps) { const [opened, setOpened] = useState(false) const { t } = useTranslation() @@ -27,7 +32,7 @@ export default function AppSettings() { return ( - diff --git a/lightrag_webui/src/components/LanguageToggle.tsx b/lightrag_webui/src/components/LanguageToggle.tsx new file mode 100644 index 0000000000000000000000000000000000000000..0eab780eae449608fb3b78038766a50e110b2f2d --- /dev/null +++ b/lightrag_webui/src/components/LanguageToggle.tsx @@ -0,0 +1,49 @@ +import Button from '@/components/ui/Button' +import { useCallback } from 'react' +import { controlButtonVariant } from '@/lib/constants' +import { useTranslation } from 'react-i18next' +import { useSettingsStore } from '@/stores/settings' + +/** + * Component that toggles the language between English and Chinese. + */ +export default function LanguageToggle() { + const { i18n } = useTranslation() + const currentLanguage = i18n.language + const setLanguage = useSettingsStore.use.setLanguage() + + const setEnglish = useCallback(() => { + i18n.changeLanguage('en') + setLanguage('en') + }, [i18n, setLanguage]) + + const setChinese = useCallback(() => { + i18n.changeLanguage('zh') + setLanguage('zh') + }, [i18n, setLanguage]) + + if (currentLanguage === 'zh') { + return ( + + ) + } + return ( + + ) +} diff --git a/lightrag_webui/src/components/graph/FocusOnNode.tsx b/lightrag_webui/src/components/graph/FocusOnNode.tsx index 70af75251b3cdf8b2698140110d0b9d083199c2d..3f3cf027af26cf25fa755905e40ac6d6f4df3645 100644 --- a/lightrag_webui/src/components/graph/FocusOnNode.tsx +++ b/lightrag_webui/src/components/graph/FocusOnNode.tsx @@ -13,23 +13,37 @@ const FocusOnNode = ({ node, move }: { node: string | null; move?: boolean }) => * When the selected item changes, highlighted the node and center the camera on it. */ useEffect(() => { + const graph = sigma.getGraph(); + if (move) { - if (node) { - sigma.getGraph().setNodeAttribute(node, 'highlighted', true) - gotoNode(node) + if (node && graph.hasNode(node)) { + try { + graph.setNodeAttribute(node, 'highlighted', true); + gotoNode(node); + } catch (error) { + console.error('Error focusing on node:', error); + } } else { // If no node is selected but move is true, reset to default view - sigma.setCustomBBox(null) - sigma.getCamera().animate({ x: 0.5, y: 0.5, ratio: 1 }, { duration: 0 }) + sigma.setCustomBBox(null); + sigma.getCamera().animate({ x: 0.5, y: 0.5, ratio: 1 }, { duration: 0 }); + } + useGraphStore.getState().setMoveToSelectedNode(false); + } else if (node && graph.hasNode(node)) { + try { + graph.setNodeAttribute(node, 'highlighted', true); + } catch (error) { + console.error('Error highlighting node:', error); } - useGraphStore.getState().setMoveToSelectedNode(false) - } else if (node) { - sigma.getGraph().setNodeAttribute(node, 'highlighted', true) } return () => { - if (node) { - sigma.getGraph().setNodeAttribute(node, 'highlighted', false) + if (node && graph.hasNode(node)) { + try { + graph.setNodeAttribute(node, 'highlighted', false); + } catch (error) { + console.error('Error cleaning up node highlight:', error); + } } } }, [node, move, sigma, gotoNode]) diff --git a/lightrag_webui/src/components/graph/GraphControl.tsx b/lightrag_webui/src/components/graph/GraphControl.tsx index 7d0143162b0de9ee5dd27084aa11bb7d04caf4d2..baa98bfe2f6cbcf3355f39ab5283feec4b7937ed 100644 --- a/lightrag_webui/src/components/graph/GraphControl.tsx +++ b/lightrag_webui/src/components/graph/GraphControl.tsx @@ -1,5 +1,5 @@ -import { useLoadGraph, useRegisterEvents, useSetSettings, useSigma } from '@react-sigma/core' -import Graph from 'graphology' +import { useRegisterEvents, useSetSettings, useSigma } from '@react-sigma/core' +import { AbstractGraph } from 'graphology-types' // import { useLayoutCircular } from '@react-sigma/layout-circular' import { useLayoutForceAtlas2 } from '@react-sigma/layout-forceatlas2' import { useEffect } from 'react' @@ -25,7 +25,6 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean }) const sigma = useSigma() const registerEvents = useRegisterEvents() const setSettings = useSetSettings() - const loadGraph = useLoadGraph() const maxIterations = useSettingsStore.use.graphLayoutMaxIterations() const { assign: assignLayout } = useLayoutForceAtlas2({ @@ -45,14 +44,42 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean }) /** * When component mount or maxIterations changes - * => load the graph and apply layout + * => ensure graph reference and apply layout */ useEffect(() => { - if (sigmaGraph) { - loadGraph(sigmaGraph as unknown as Graph) - assignLayout() + if (sigmaGraph && sigma) { + // Ensure sigma binding to sigmaGraph + try { + if (typeof sigma.setGraph === 'function') { + sigma.setGraph(sigmaGraph as unknown as AbstractGraph); + console.log('Binding graph to sigma instance'); + } else { + (sigma as any).graph = sigmaGraph; + console.warn('Simgma missing setGraph function, set graph property directly'); + } + } catch (error) { + console.error('Error setting graph on sigma instance:', error); + } + + assignLayout(); + console.log('Initial layout applied to graph'); + } + }, [sigma, sigmaGraph, assignLayout, maxIterations]) + + /** + * Ensure the sigma instance is set in the store + * This provides a backup in case the instance wasn't set in GraphViewer + */ + useEffect(() => { + if (sigma) { + // Double-check that the store has the sigma instance + const currentInstance = useGraphStore.getState().sigmaInstance; + if (!currentInstance) { + console.log('Setting sigma instance from GraphControl'); + useGraphStore.getState().setSigmaInstance(sigma); + } } - }, [assignLayout, loadGraph, sigmaGraph, maxIterations]) + }, [sigma]); /** * When component mount @@ -138,14 +165,18 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean }) const _focusedNode = focusedNode || selectedNode const _focusedEdge = focusedEdge || selectedEdge - if (_focusedNode) { - if (node === _focusedNode || graph.neighbors(_focusedNode).includes(node)) { - newData.highlighted = true - if (node === selectedNode) { - newData.borderColor = Constants.nodeBorderColorSelected + if (_focusedNode && graph.hasNode(_focusedNode)) { + try { + if (node === _focusedNode || graph.neighbors(_focusedNode).includes(node)) { + newData.highlighted = true + if (node === selectedNode) { + newData.borderColor = Constants.nodeBorderColorSelected + } } + } catch (error) { + console.error('Error in nodeReducer:', error); } - } else if (_focusedEdge) { + } else if (_focusedEdge && graph.hasEdge(_focusedEdge)) { if (graph.extremities(_focusedEdge).includes(node)) { newData.highlighted = true newData.size = 3 @@ -173,21 +204,28 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean }) if (!disableHoverEffect) { const _focusedNode = focusedNode || selectedNode - if (_focusedNode) { - if (hideUnselectedEdges) { - if (!graph.extremities(edge).includes(_focusedNode)) { - newData.hidden = true - } - } else { - if (graph.extremities(edge).includes(_focusedNode)) { - newData.color = Constants.edgeColorHighlighted + if (_focusedNode && graph.hasNode(_focusedNode)) { + try { + if (hideUnselectedEdges) { + if (!graph.extremities(edge).includes(_focusedNode)) { + newData.hidden = true + } + } else { + if (graph.extremities(edge).includes(_focusedNode)) { + newData.color = Constants.edgeColorHighlighted + } } + } catch (error) { + console.error('Error in edgeReducer:', error); } } else { - if (focusedEdge || selectedEdge) { - if (edge === selectedEdge) { + const _selectedEdge = selectedEdge && graph.hasEdge(selectedEdge) ? selectedEdge : null; + const _focusedEdge = focusedEdge && graph.hasEdge(focusedEdge) ? focusedEdge : null; + + if (_selectedEdge || _focusedEdge) { + if (edge === _selectedEdge) { newData.color = Constants.edgeColorSelected - } else if (edge === focusedEdge) { + } else if (edge === _focusedEdge) { newData.color = Constants.edgeColorHighlighted } else if (hideUnselectedEdges) { newData.hidden = true diff --git a/lightrag_webui/src/components/graph/GraphLabels.tsx b/lightrag_webui/src/components/graph/GraphLabels.tsx index bd2c8ea0762353105499aead49762a0a36566ef5..305f63bd6abf769f0bf17fd2316bc99a18ca57a4 100644 --- a/lightrag_webui/src/components/graph/GraphLabels.tsx +++ b/lightrag_webui/src/components/graph/GraphLabels.tsx @@ -2,20 +2,23 @@ import { useCallback, useEffect, useRef } from 'react' import { AsyncSelect } from '@/components/ui/AsyncSelect' import { useSettingsStore } from '@/stores/settings' import { useGraphStore } from '@/stores/graph' -import { labelListLimit } from '@/lib/constants' +import { labelListLimit, controlButtonVariant } from '@/lib/constants' import MiniSearch from 'minisearch' import { useTranslation } from 'react-i18next' +import { RefreshCw } from 'lucide-react' +import Button from '@/components/ui/Button' const GraphLabels = () => { const { t } = useTranslation() const label = useSettingsStore.use.queryLabel() const allDatabaseLabels = useGraphStore.use.allDatabaseLabels() + const rawGraph = useGraphStore.use.rawGraph() const labelsLoadedRef = useRef(false) // Track if a fetch is in progress to prevent multiple simultaneous fetches const fetchInProgressRef = useRef(false) - // Fetch labels once on component mount, using global flag to prevent duplicates + // Fetch labels and trigger initial data load useEffect(() => { // Check if we've already attempted to fetch labels in this session const labelsFetchAttempted = useGraphStore.getState().labelsFetchAttempted @@ -26,8 +29,6 @@ const GraphLabels = () => { // Set global flag to indicate we've attempted to fetch in this session useGraphStore.getState().setLabelsFetchAttempted(true) - console.log('Fetching graph labels (once per session)...') - useGraphStore.getState().fetchAllDatabaseLabels() .then(() => { labelsLoadedRef.current = true @@ -42,6 +43,14 @@ const GraphLabels = () => { } }, []) // Empty dependency array ensures this only runs once on mount + // Trigger data load when labels are loaded + useEffect(() => { + if (labelsLoadedRef.current) { + // Reset the fetch attempted flag to force a new data fetch + useGraphStore.getState().setGraphDataFetchAttempted(false) + } + }, [label]) + const getSearchEngine = useCallback(() => { // Create search engine const searchEngine = new MiniSearch({ @@ -83,52 +92,73 @@ const GraphLabels = () => { [getSearchEngine] ) - return ( - - className="ml-2" - triggerClassName="max-h-8" - searchInputClassName="max-h-8" - triggerTooltip={t('graphPanel.graphLabels.selectTooltip')} - fetcher={fetchData} - renderOption={(item) =>
{item}
} - getOptionValue={(item) => item} - getDisplayValue={(item) =>
{item}
} - notFound={
No labels found
} - label={t('graphPanel.graphLabels.label')} - placeholder={t('graphPanel.graphLabels.placeholder')} - value={label !== null ? label : '*'} - onChange={(newLabel) => { - const currentLabel = useSettingsStore.getState().queryLabel + const handleRefresh = useCallback(() => { + // Reset labels fetch status to allow fetching labels again + useGraphStore.getState().setLabelsFetchAttempted(false) - // select the last item means query all - if (newLabel === '...') { - newLabel = '*' - } + // Reset graph data fetch status directly, not depending on allDatabaseLabels changes + useGraphStore.getState().setGraphDataFetchAttempted(false) - // Reset the fetch attempted flag to force a new data fetch - useGraphStore.getState().setGraphDataFetchAttempted(false) + // Fetch all labels again + useGraphStore.getState().fetchAllDatabaseLabels() + .then(() => { + // Trigger a graph data reload by changing the query label back and forth + const currentLabel = useSettingsStore.getState().queryLabel + useSettingsStore.getState().setQueryLabel('') + setTimeout(() => { + useSettingsStore.getState().setQueryLabel(currentLabel) + }, 0) + }) + .catch((error) => { + console.error('Failed to refresh labels:', error) + }) + }, []) - // Clear current graph data to ensure complete reload when label changes - if (newLabel !== currentLabel) { - const graphStore = useGraphStore.getState(); - graphStore.clearSelection(); + return ( +
+ {rawGraph && ( + + )} + + className="ml-2" + triggerClassName="max-h-8" + searchInputClassName="max-h-8" + triggerTooltip={t('graphPanel.graphLabels.selectTooltip')} + fetcher={fetchData} + renderOption={(item) =>
{item}
} + getOptionValue={(item) => item} + getDisplayValue={(item) =>
{item}
} + notFound={
No labels found
} + label={t('graphPanel.graphLabels.label')} + placeholder={t('graphPanel.graphLabels.placeholder')} + value={label !== null ? label : '*'} + onChange={(newLabel) => { + const currentLabel = useSettingsStore.getState().queryLabel + + // select the last item means query all + if (newLabel === '...') { + newLabel = '*' + } - // Reset the graph state but preserve the instance - if (graphStore.sigmaGraph) { - const nodes = Array.from(graphStore.sigmaGraph.nodes()); - nodes.forEach(node => graphStore.sigmaGraph?.dropNode(node)); + // Handle reselecting the same label + if (newLabel === currentLabel && newLabel !== '*') { + newLabel = '*' } - } - if (newLabel === currentLabel && newLabel !== '*') { - // reselect the same itme means qery all - useSettingsStore.getState().setQueryLabel('*') - } else { + // Update the label, which will trigger the useEffect to handle data loading useSettingsStore.getState().setQueryLabel(newLabel) - } - }} - clearable={false} // Prevent clearing value on reselect - /> + }} + clearable={false} // Prevent clearing value on reselect + /> +
) } diff --git a/lightrag_webui/src/components/graph/GraphSearch.tsx b/lightrag_webui/src/components/graph/GraphSearch.tsx index 2ba36bdac73f33f5b3d4292e1c1b3334b9078952..51e76a0bf4d38f3dd5ceef6696dff39b40f3c964 100644 --- a/lightrag_webui/src/components/graph/GraphSearch.tsx +++ b/lightrag_webui/src/components/graph/GraphSearch.tsx @@ -1,4 +1,4 @@ -import { FC, useCallback, useEffect, useMemo } from 'react' +import { FC, useCallback, useEffect } from 'react' import { EdgeById, NodeById, @@ -11,28 +11,34 @@ import { useGraphStore } from '@/stores/graph' import MiniSearch from 'minisearch' import { useTranslation } from 'react-i18next' -interface OptionItem { +// Message item identifier for search results +export const messageId = '__message_item' + +// Search result option item interface +export interface OptionItem { id: string type: 'nodes' | 'edges' | 'message' message?: string } +const NodeOption = ({ id }: { id: string }) => { + const graph = useGraphStore.use.sigmaGraph() + if (!graph?.hasNode(id)) { + return null + } + return +} + function OptionComponent(item: OptionItem) { return (
- {item.type === 'nodes' && } + {item.type === 'nodes' && } {item.type === 'edges' && } {item.type === 'message' &&
{item.message}
}
) } -const messageId = '__message_item' -// Reset this cache when graph changes to ensure fresh search results -const lastGraph: any = { - graph: null, - searchEngine: null -} /** * Component thats display the search input. @@ -48,25 +54,24 @@ export const GraphSearchInput = ({ }) => { const { t } = useTranslation() const graph = useGraphStore.use.sigmaGraph() + const searchEngine = useGraphStore.use.searchEngine() - // Force reset the cache when graph changes + // Reset search engine when graph changes useEffect(() => { if (graph) { - // Reset cache to ensure fresh search results with new graph data - lastGraph.graph = null; - lastGraph.searchEngine = null; + useGraphStore.getState().resetSearchEngine() } }, [graph]); - const searchEngine = useMemo(() => { - if (lastGraph.graph == graph) { - return lastGraph.searchEngine + // Create search engine when needed + useEffect(() => { + // Skip if no graph, empty graph, or search engine already exists + if (!graph || graph.nodes().length === 0 || searchEngine) { + return } - if (!graph || graph.nodes().length == 0) return - - lastGraph.graph = graph - const searchEngine = new MiniSearch({ + // Create new search engine + const newSearchEngine = new MiniSearch({ idField: 'id', fields: ['label'], searchOptions: { @@ -78,16 +83,16 @@ export const GraphSearchInput = ({ } }) - // Add documents + // Add nodes to search engine const documents = graph.nodes().map((id: string) => ({ id: id, label: graph.getNodeAttribute(id, 'label') })) - searchEngine.addAll(documents) + newSearchEngine.addAll(documents) - lastGraph.searchEngine = searchEngine - return searchEngine - }, [graph]) + // Update search engine in store + useGraphStore.getState().setSearchEngine(newSearchEngine) + }, [graph, searchEngine]) /** * Loading the options while the user is typing. @@ -95,22 +100,35 @@ export const GraphSearchInput = ({ const loadOptions = useCallback( async (query?: string): Promise => { if (onFocus) onFocus(null) - if (!graph || !searchEngine) return [] - // If no query, return first searchResultLimit nodes + // Safety checks to prevent crashes + if (!graph || !searchEngine) { + return [] + } + + // Verify graph has nodes before proceeding + if (graph.nodes().length === 0) { + return [] + } + + // If no query, return some nodes for user to select if (!query) { - const nodeIds = graph.nodes().slice(0, searchResultLimit) + const nodeIds = graph.nodes() + .filter(id => graph.hasNode(id)) + .slice(0, searchResultLimit) return nodeIds.map(id => ({ id, type: 'nodes' })) } - // If has query, search nodes - const result: OptionItem[] = searchEngine.search(query).map((r: { id: string }) => ({ - id: r.id, - type: 'nodes' - })) + // If has query, search nodes and verify they still exist + const result: OptionItem[] = searchEngine.search(query) + .filter((r: { id: string }) => graph.hasNode(r.id)) + .map((r: { id: string }) => ({ + id: r.id, + type: 'nodes' + })) // prettier-ignore return result.length <= searchResultLimit diff --git a/lightrag_webui/src/components/graph/LayoutsControl.tsx b/lightrag_webui/src/components/graph/LayoutsControl.tsx index 0ed97f2f88c577e8907e3042af0628472b125019..2f0cc50a0124bc074d3bf5ab1b7425f1f72cbd42 100644 --- a/lightrag_webui/src/components/graph/LayoutsControl.tsx +++ b/lightrag_webui/src/components/graph/LayoutsControl.tsx @@ -7,7 +7,7 @@ import { useLayoutForce, useWorkerLayoutForce } from '@react-sigma/layout-force' import { useLayoutForceAtlas2, useWorkerLayoutForceAtlas2 } from '@react-sigma/layout-forceatlas2' import { useLayoutNoverlap, useWorkerLayoutNoverlap } from '@react-sigma/layout-noverlap' import { useLayoutRandom } from '@react-sigma/layout-random' -import { useCallback, useMemo, useState, useEffect } from 'react' +import { useCallback, useMemo, useState, useEffect, useRef } from 'react' import Button from '@/components/ui/Button' import { Popover, PopoverContent, PopoverTrigger } from '@/components/ui/Popover' @@ -26,43 +26,161 @@ type LayoutName = | 'Force Directed' | 'Force Atlas' -const WorkerLayoutControl = ({ layout, autoRunFor }: WorkerLayoutControlProps) => { +// Extend WorkerLayoutControlProps to include mainLayout +interface ExtendedWorkerLayoutControlProps extends WorkerLayoutControlProps { + mainLayout: LayoutHook; +} + +const WorkerLayoutControl = ({ layout, autoRunFor, mainLayout }: ExtendedWorkerLayoutControlProps) => { const sigma = useSigma() - const { stop, start, isRunning } = layout + // Use local state to track animation running status + const [isRunning, setIsRunning] = useState(false) + // Timer reference for animation + const animationTimerRef = useRef(null) const { t } = useTranslation() + // Function to update node positions using the layout algorithm + const updatePositions = useCallback(() => { + if (!sigma) return + + try { + const graph = sigma.getGraph() + if (!graph || graph.order === 0) return + + // Use mainLayout to get positions, similar to refreshLayout function + // console.log('Getting positions from mainLayout') + const positions = mainLayout.positions() + + // Animate nodes to new positions + // console.log('Updating node positions with layout algorithm') + animateNodes(graph, positions, { duration: 300 }) // Reduced duration for more frequent updates + } catch (error) { + console.error('Error updating positions:', error) + // Stop animation if there's an error + if (animationTimerRef.current) { + window.clearInterval(animationTimerRef.current) + animationTimerRef.current = null + setIsRunning(false) + } + } + }, [sigma, mainLayout]) + + // Improved click handler that uses our own animation timer + const handleClick = useCallback(() => { + if (isRunning) { + // Stop the animation + console.log('Stopping layout animation') + if (animationTimerRef.current) { + window.clearInterval(animationTimerRef.current) + animationTimerRef.current = null + } + + // Try to kill the layout algorithm if it's running + try { + if (typeof layout.kill === 'function') { + layout.kill() + console.log('Layout algorithm killed') + } else if (typeof layout.stop === 'function') { + layout.stop() + console.log('Layout algorithm stopped') + } + } catch (error) { + console.error('Error stopping layout algorithm:', error) + } + + setIsRunning(false) + } else { + // Start the animation + console.log('Starting layout animation') + + // Initial position update + updatePositions() + + // Set up interval for continuous updates + animationTimerRef.current = window.setInterval(() => { + updatePositions() + }, 200) // Reduced interval to create overlapping animations for smoother transitions + + setIsRunning(true) + + // Set a timeout to automatically stop the animation after 3 seconds + setTimeout(() => { + if (animationTimerRef.current) { + console.log('Auto-stopping layout animation after 3 seconds') + window.clearInterval(animationTimerRef.current) + animationTimerRef.current = null + setIsRunning(false) + + // Try to stop the layout algorithm + try { + if (typeof layout.kill === 'function') { + layout.kill() + } else if (typeof layout.stop === 'function') { + layout.stop() + } + } catch (error) { + console.error('Error stopping layout algorithm:', error) + } + } + }, 3000) + } + }, [isRunning, layout, updatePositions]) + /** * Init component when Sigma or component settings change. */ useEffect(() => { if (!sigma) { + console.log('No sigma instance available') return } - // we run the algo + // Auto-run if specified let timeout: number | null = null if (autoRunFor !== undefined && autoRunFor > -1 && sigma.getGraph().order > 0) { - start() - // set a timeout to stop it - timeout = - autoRunFor > 0 - ? window.setTimeout(() => { stop() }, autoRunFor) // prettier-ignore - : null + console.log('Auto-starting layout animation') + + // Initial position update + updatePositions() + + // Set up interval for continuous updates + animationTimerRef.current = window.setInterval(() => { + updatePositions() + }, 200) // Reduced interval to create overlapping animations for smoother transitions + + setIsRunning(true) + + // Set a timeout to stop it if autoRunFor > 0 + if (autoRunFor > 0) { + timeout = window.setTimeout(() => { + console.log('Auto-stopping layout animation after timeout') + if (animationTimerRef.current) { + window.clearInterval(animationTimerRef.current) + animationTimerRef.current = null + } + setIsRunning(false) + }, autoRunFor) + } } - //cleaning + // Cleanup function return () => { - stop() + // console.log('Cleaning up WorkerLayoutControl') + if (animationTimerRef.current) { + window.clearInterval(animationTimerRef.current) + animationTimerRef.current = null + } if (timeout) { - clearTimeout(timeout) + window.clearTimeout(timeout) } + setIsRunning(false) } - }, [autoRunFor, start, stop, sigma]) + }, [autoRunFor, sigma, updatePositions]) return ( + + +
{ />
- +
{Object.keys(node.properties) .sort() @@ -181,7 +256,7 @@ const NodePropertiesView = ({ node }: { node: NodeType }) => {
{node.relationships.length > 0 && ( <> -