yangdx commited on
Commit
d630b99
·
1 Parent(s): 86d15dc

Refactor embedding function initialization and remove start-server.sh

Browse files

- Simplified RAG initialization logic by deduplicating embedding function
- Removed start-server.sh script which is not needed
- No functional changes to the application

lightrag/api/lightrag_ollama.py DELETED
@@ -1,924 +0,0 @@
1
- from fastapi import FastAPI, HTTPException, File, UploadFile, Form, Request
2
- from pydantic import BaseModel
3
- import logging
4
- import argparse
5
- import json
6
- import time
7
- import re
8
- from typing import List, Dict, Any, Optional
9
- from lightrag import LightRAG, QueryParam
10
- from lightrag.llm import openai_complete_if_cache, ollama_embedding
11
-
12
- from lightrag.utils import EmbeddingFunc
13
- from enum import Enum
14
- from pathlib import Path
15
- import shutil
16
- import aiofiles
17
- from ascii_colors import trace_exception
18
- import os
19
-
20
- from fastapi import Depends, Security
21
- from fastapi.security import APIKeyHeader
22
- from fastapi.middleware.cors import CORSMiddleware
23
-
24
- from starlette.status import HTTP_403_FORBIDDEN
25
-
26
- from dotenv import load_dotenv
27
-
28
- load_dotenv()
29
-
30
-
31
- def estimate_tokens(text: str) -> int:
32
- """Estimate the number of tokens in text
33
- Chinese characters: approximately 1.5 tokens per character
34
- English characters: approximately 0.25 tokens per character
35
- """
36
- # Use regex to match Chinese and non-Chinese characters separately
37
- chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
38
- non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
39
-
40
- # Calculate estimated token count
41
- tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
42
-
43
- return int(tokens)
44
-
45
-
46
- # Constants for model information
47
- LIGHTRAG_NAME = "lightrag"
48
- LIGHTRAG_TAG = "latest"
49
- LIGHTRAG_MODEL = "lightrag:latest"
50
- LIGHTRAG_SIZE = 7365960935
51
- LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
52
- LIGHTRAG_DIGEST = "sha256:lightrag"
53
-
54
-
55
- async def llm_model_func(
56
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
57
- ) -> str:
58
- return await openai_complete_if_cache(
59
- "deepseek-chat",
60
- prompt,
61
- system_prompt=system_prompt,
62
- history_messages=history_messages,
63
- api_key=os.getenv("DEEPSEEK_API_KEY"),
64
- base_url=os.getenv("DEEPSEEK_ENDPOINT"),
65
- **kwargs,
66
- )
67
-
68
-
69
- def get_default_host(binding_type: str) -> str:
70
- default_hosts = {
71
- "ollama": "http://m4.lan.znipower.com:11434",
72
- "lollms": "http://localhost:9600",
73
- "azure_openai": "https://api.openai.com/v1",
74
- "openai": os.getenv("DEEPSEEK_ENDPOINT"),
75
- }
76
- return default_hosts.get(
77
- binding_type, "http://localhost:11434"
78
- ) # fallback to ollama if unknown
79
-
80
-
81
- def parse_args():
82
- parser = argparse.ArgumentParser(
83
- description="LightRAG FastAPI Server with separate working and input directories"
84
- )
85
-
86
- # Start by the bindings
87
- parser.add_argument(
88
- "--llm-binding",
89
- default="ollama",
90
- help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)",
91
- )
92
- parser.add_argument(
93
- "--embedding-binding",
94
- default="ollama",
95
- help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)",
96
- )
97
-
98
- # Parse just these arguments first
99
- temp_args, _ = parser.parse_known_args()
100
-
101
- # Add remaining arguments with dynamic defaults for hosts
102
- # Server configuration
103
- parser.add_argument(
104
- "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
105
- )
106
- parser.add_argument(
107
- "--port", type=int, default=9621, help="Server port (default: 9621)"
108
- )
109
-
110
- # Directory configuration
111
- parser.add_argument(
112
- "--working-dir",
113
- default="./rag_storage",
114
- help="Working directory for RAG storage (default: ./rag_storage)",
115
- )
116
- parser.add_argument(
117
- "--input-dir",
118
- default="./inputs",
119
- help="Directory containing input documents (default: ./inputs)",
120
- )
121
-
122
- # LLM Model configuration
123
- default_llm_host = get_default_host(temp_args.llm_binding)
124
- parser.add_argument(
125
- "--llm-binding-host",
126
- default=default_llm_host,
127
- help=f"llm server host URL (default: {default_llm_host})",
128
- )
129
-
130
- parser.add_argument(
131
- "--llm-model",
132
- default="mistral-nemo:latest",
133
- help="LLM model name (default: mistral-nemo:latest)",
134
- )
135
-
136
- # Embedding model configuration
137
- default_embedding_host = get_default_host(temp_args.embedding_binding)
138
- parser.add_argument(
139
- "--embedding-binding-host",
140
- default=default_embedding_host,
141
- help=f"embedding server host URL (default: {default_embedding_host})",
142
- )
143
-
144
- parser.add_argument(
145
- "--embedding-model",
146
- default="bge-m3:latest",
147
- help="Embedding model name (default: bge-m3:latest)",
148
- )
149
-
150
- def timeout_type(value):
151
- if value is None or value == "None":
152
- return None
153
- return int(value)
154
-
155
- parser.add_argument(
156
- "--timeout",
157
- default=None,
158
- type=timeout_type,
159
- help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
160
- )
161
- # RAG configuration
162
- parser.add_argument(
163
- "--max-async", type=int, default=4, help="Maximum async operations (default: 4)"
164
- )
165
- parser.add_argument(
166
- "--max-tokens",
167
- type=int,
168
- default=32768,
169
- help="Maximum token size (default: 32768)",
170
- )
171
- parser.add_argument(
172
- "--embedding-dim",
173
- type=int,
174
- default=1024,
175
- help="Embedding dimensions (default: 1024)",
176
- )
177
- parser.add_argument(
178
- "--max-embed-tokens",
179
- type=int,
180
- default=8192,
181
- help="Maximum embedding token size (default: 8192)",
182
- )
183
-
184
- # Logging configuration
185
- parser.add_argument(
186
- "--log-level",
187
- default="INFO",
188
- choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
189
- help="Logging level (default: INFO)",
190
- )
191
-
192
- parser.add_argument(
193
- "--key",
194
- type=str,
195
- help="API key for authentication. This protects lightrag server against unauthorized access",
196
- default=None,
197
- )
198
-
199
- # Optional https parameters
200
- parser.add_argument(
201
- "--ssl", action="store_true", help="Enable HTTPS (default: False)"
202
- )
203
- parser.add_argument(
204
- "--ssl-certfile",
205
- default=None,
206
- help="Path to SSL certificate file (required if --ssl is enabled)",
207
- )
208
- parser.add_argument(
209
- "--ssl-keyfile",
210
- default=None,
211
- help="Path to SSL private key file (required if --ssl is enabled)",
212
- )
213
- return parser.parse_args()
214
-
215
-
216
- class DocumentManager:
217
- """Handles document operations and tracking"""
218
-
219
- def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")):
220
- self.input_dir = Path(input_dir)
221
- self.supported_extensions = supported_extensions
222
- self.indexed_files = set()
223
-
224
- # Create input directory if it doesn't exist
225
- self.input_dir.mkdir(parents=True, exist_ok=True)
226
-
227
- def scan_directory(self) -> List[Path]:
228
- """Scan input directory for new files"""
229
- new_files = []
230
- for ext in self.supported_extensions:
231
- for file_path in self.input_dir.rglob(f"*{ext}"):
232
- if file_path not in self.indexed_files:
233
- new_files.append(file_path)
234
- return new_files
235
-
236
- def mark_as_indexed(self, file_path: Path):
237
- """Mark a file as indexed"""
238
- self.indexed_files.add(file_path)
239
-
240
- def is_supported_file(self, filename: str) -> bool:
241
- """Check if file type is supported"""
242
- return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
243
-
244
-
245
- # Pydantic models
246
- class SearchMode(str, Enum):
247
- naive = "naive"
248
- local = "local"
249
- global_ = "global" # Using global_ because global is a Python reserved keyword, but enum value will be converted to string "global"
250
- hybrid = "hybrid"
251
- mix = "mix"
252
-
253
-
254
- # Ollama API compatible models
255
- class OllamaMessage(BaseModel):
256
- role: str
257
- content: str
258
- images: Optional[List[str]] = None
259
-
260
-
261
- class OllamaChatRequest(BaseModel):
262
- model: str = LIGHTRAG_MODEL
263
- messages: List[OllamaMessage]
264
- stream: bool = True # Default to streaming mode
265
- options: Optional[Dict[str, Any]] = None
266
-
267
-
268
- class OllamaChatResponse(BaseModel):
269
- model: str
270
- created_at: str
271
- message: OllamaMessage
272
- done: bool
273
-
274
-
275
- class OllamaVersionResponse(BaseModel):
276
- version: str
277
-
278
-
279
- class OllamaModelDetails(BaseModel):
280
- parent_model: str
281
- format: str
282
- family: str
283
- families: List[str]
284
- parameter_size: str
285
- quantization_level: str
286
-
287
-
288
- class OllamaModel(BaseModel):
289
- name: str
290
- model: str
291
- size: int
292
- digest: str
293
- modified_at: str
294
- details: OllamaModelDetails
295
-
296
-
297
- class OllamaTagResponse(BaseModel):
298
- models: List[OllamaModel]
299
-
300
-
301
- # Original LightRAG models
302
- class QueryRequest(BaseModel):
303
- query: str
304
- mode: SearchMode = SearchMode.hybrid
305
- stream: bool = False
306
- only_need_context: bool = False
307
-
308
-
309
- class QueryResponse(BaseModel):
310
- response: str
311
-
312
-
313
- class InsertTextRequest(BaseModel):
314
- text: str
315
- description: Optional[str] = None
316
-
317
-
318
- class InsertResponse(BaseModel):
319
- status: str
320
- message: str
321
- document_count: int
322
-
323
-
324
- def get_api_key_dependency(api_key: Optional[str]):
325
- if not api_key:
326
- # If no API key is configured, return a dummy dependency that always succeeds
327
- async def no_auth():
328
- return None
329
-
330
- return no_auth
331
-
332
- # If API key is configured, use proper authentication
333
- api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
334
-
335
- async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
336
- if not api_key_header_value:
337
- raise HTTPException(
338
- status_code=HTTP_403_FORBIDDEN, detail="API Key required"
339
- )
340
- if api_key_header_value != api_key:
341
- raise HTTPException(
342
- status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
343
- )
344
- return api_key_header_value
345
-
346
- return api_key_auth
347
-
348
-
349
- def create_app(args):
350
- # Verify that bindings arer correctly setup
351
- if args.llm_binding not in ["lollms", "ollama", "openai"]:
352
- raise Exception("llm binding not supported")
353
-
354
- if args.embedding_binding not in ["lollms", "ollama", "openai"]:
355
- raise Exception("embedding binding not supported")
356
-
357
- # Add SSL validation
358
- if args.ssl:
359
- if not args.ssl_certfile or not args.ssl_keyfile:
360
- raise Exception(
361
- "SSL certificate and key files must be provided when SSL is enabled"
362
- )
363
- if not os.path.exists(args.ssl_certfile):
364
- raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
365
- if not os.path.exists(args.ssl_keyfile):
366
- raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
367
-
368
- # Setup logging
369
- logging.basicConfig(
370
- format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
371
- )
372
-
373
- # Check if API key is provided either through env var or args
374
- api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
375
-
376
- # Initialize FastAPI
377
- app = FastAPI(
378
- title="LightRAG API",
379
- description="API for querying text using LightRAG with separate storage and input directories"
380
- + "(With authentication)"
381
- if api_key
382
- else "",
383
- version="1.0.1",
384
- openapi_tags=[{"name": "api"}],
385
- )
386
-
387
- # Add CORS middleware
388
- app.add_middleware(
389
- CORSMiddleware,
390
- allow_origins=["*"],
391
- allow_credentials=True,
392
- allow_methods=["*"],
393
- allow_headers=["*"],
394
- )
395
-
396
- # Create the optional API key dependency
397
- optional_api_key = get_api_key_dependency(api_key)
398
-
399
- # Create working directory if it doesn't exist
400
- Path(args.working_dir).mkdir(parents=True, exist_ok=True)
401
-
402
- # Initialize document manager
403
- doc_manager = DocumentManager(args.input_dir)
404
-
405
- # Initialize RAG
406
- rag = LightRAG(
407
- working_dir=args.working_dir,
408
- llm_model_func=llm_model_func,
409
- embedding_func=EmbeddingFunc(
410
- embedding_dim=1024,
411
- max_token_size=8192,
412
- func=lambda texts: ollama_embedding(
413
- texts,
414
- embed_model="bge-m3:latest",
415
- host="http://m4.lan.znipower.com:11434",
416
- ),
417
- ),
418
- )
419
-
420
- @app.on_event("startup")
421
- async def startup_event():
422
- """Index all files in input directory during startup"""
423
- try:
424
- new_files = doc_manager.scan_directory()
425
- for file_path in new_files:
426
- try:
427
- # Use async file reading
428
- async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
429
- content = await f.read()
430
- # Use the async version of insert directly
431
- await rag.ainsert(content)
432
- doc_manager.mark_as_indexed(file_path)
433
- logging.info(f"Indexed file: {file_path}")
434
- except Exception as e:
435
- trace_exception(e)
436
- logging.error(f"Error indexing file {file_path}: {str(e)}")
437
-
438
- logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
439
-
440
- except Exception as e:
441
- logging.error(f"Error during startup indexing: {str(e)}")
442
-
443
- @app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
444
- async def scan_for_new_documents():
445
- """Manually trigger scanning for new documents"""
446
- try:
447
- new_files = doc_manager.scan_directory()
448
- indexed_count = 0
449
-
450
- for file_path in new_files:
451
- try:
452
- with open(file_path, "r", encoding="utf-8") as f:
453
- content = f.read()
454
- await rag.ainsert(content)
455
- doc_manager.mark_as_indexed(file_path)
456
- indexed_count += 1
457
- except Exception as e:
458
- logging.error(f"Error indexing file {file_path}: {str(e)}")
459
-
460
- return {
461
- "status": "success",
462
- "indexed_count": indexed_count,
463
- "total_documents": len(doc_manager.indexed_files),
464
- }
465
- except Exception as e:
466
- raise HTTPException(status_code=500, detail=str(e))
467
-
468
- @app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
469
- async def upload_to_input_dir(file: UploadFile = File(...)):
470
- """Upload a file to the input directory"""
471
- try:
472
- if not doc_manager.is_supported_file(file.filename):
473
- raise HTTPException(
474
- status_code=400,
475
- detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
476
- )
477
-
478
- file_path = doc_manager.input_dir / file.filename
479
- with open(file_path, "wb") as buffer:
480
- shutil.copyfileobj(file.file, buffer)
481
-
482
- # Immediately index the uploaded file
483
- with open(file_path, "r", encoding="utf-8") as f:
484
- content = f.read()
485
- await rag.ainsert(content)
486
- doc_manager.mark_as_indexed(file_path)
487
-
488
- return {
489
- "status": "success",
490
- "message": f"File uploaded and indexed: {file.filename}",
491
- "total_documents": len(doc_manager.indexed_files),
492
- }
493
- except Exception as e:
494
- raise HTTPException(status_code=500, detail=str(e))
495
-
496
- @app.post(
497
- "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
498
- )
499
- async def query_text(request: QueryRequest):
500
- try:
501
- response = await rag.aquery(
502
- request.query,
503
- param=QueryParam(
504
- mode=request.mode,
505
- stream=request.stream,
506
- only_need_context=request.only_need_context,
507
- ),
508
- )
509
-
510
- # If response is a string (e.g. cache hit), return directly
511
- if isinstance(response, str):
512
- return QueryResponse(response=response)
513
-
514
- # If it's an async generator, decide whether to stream based on stream parameter
515
- if request.stream:
516
- result = ""
517
- async for chunk in response:
518
- result += chunk
519
- return QueryResponse(response=result)
520
- else:
521
- result = ""
522
- async for chunk in response:
523
- result += chunk
524
- return QueryResponse(response=result)
525
- except Exception as e:
526
- raise HTTPException(status_code=500, detail=str(e))
527
-
528
- @app.post("/query/stream", dependencies=[Depends(optional_api_key)])
529
- async def query_text_stream(request: QueryRequest):
530
- try:
531
- response = await rag.aquery( # Use aquery instead of query, and add await
532
- request.query,
533
- param=QueryParam(
534
- mode=request.mode,
535
- stream=True,
536
- only_need_context=request.only_need_context,
537
- ),
538
- )
539
-
540
- from fastapi.responses import StreamingResponse
541
-
542
- async def stream_generator():
543
- if isinstance(response, str):
544
- # If it's a string, send it all at once
545
- yield f"{json.dumps({'response': response})}\n"
546
- else:
547
- # If it's an async generator, send chunks one by one
548
- try:
549
- async for chunk in response:
550
- if chunk: # Only send non-empty content
551
- yield f"{json.dumps({'response': chunk})}\n"
552
- except Exception as e:
553
- logging.error(f"Streaming error: {str(e)}")
554
- yield f"{json.dumps({'error': str(e)})}\n"
555
-
556
- return StreamingResponse(
557
- stream_generator(),
558
- media_type="application/x-ndjson",
559
- headers={
560
- "Cache-Control": "no-cache",
561
- "Connection": "keep-alive",
562
- "Content-Type": "application/x-ndjson",
563
- "Access-Control-Allow-Origin": "*",
564
- "Access-Control-Allow-Methods": "POST, OPTIONS",
565
- "Access-Control-Allow-Headers": "Content-Type",
566
- "X-Accel-Buffering": "no", # Disable Nginx buffering
567
- },
568
- )
569
- except Exception as e:
570
- raise HTTPException(status_code=500, detail=str(e))
571
-
572
- @app.post(
573
- "/documents/text",
574
- response_model=InsertResponse,
575
- dependencies=[Depends(optional_api_key)],
576
- )
577
- async def insert_text(request: InsertTextRequest):
578
- try:
579
- await rag.ainsert(request.text)
580
- return InsertResponse(
581
- status="success",
582
- message="Text successfully inserted",
583
- document_count=1,
584
- )
585
- except Exception as e:
586
- raise HTTPException(status_code=500, detail=str(e))
587
-
588
- @app.post(
589
- "/documents/file",
590
- response_model=InsertResponse,
591
- dependencies=[Depends(optional_api_key)],
592
- )
593
- async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
594
- try:
595
- content = await file.read()
596
-
597
- if file.filename.endswith((".txt", ".md")):
598
- text = content.decode("utf-8")
599
- await rag.ainsert(text)
600
- else:
601
- raise HTTPException(
602
- status_code=400,
603
- detail="Unsupported file type. Only .txt and .md files are supported",
604
- )
605
-
606
- return InsertResponse(
607
- status="success",
608
- message=f"File '{file.filename}' successfully inserted",
609
- document_count=1,
610
- )
611
- except UnicodeDecodeError:
612
- raise HTTPException(status_code=400, detail="File encoding not supported")
613
- except Exception as e:
614
- raise HTTPException(status_code=500, detail=str(e))
615
-
616
- @app.post(
617
- "/documents/batch",
618
- response_model=InsertResponse,
619
- dependencies=[Depends(optional_api_key)],
620
- )
621
- async def insert_batch(files: List[UploadFile] = File(...)):
622
- try:
623
- inserted_count = 0
624
- failed_files = []
625
-
626
- for file in files:
627
- try:
628
- content = await file.read()
629
- if file.filename.endswith((".txt", ".md")):
630
- text = content.decode("utf-8")
631
- await rag.ainsert(text)
632
- inserted_count += 1
633
- else:
634
- failed_files.append(f"{file.filename} (unsupported type)")
635
- except Exception as e:
636
- failed_files.append(f"{file.filename} ({str(e)})")
637
-
638
- status_message = f"Successfully inserted {inserted_count} documents"
639
- if failed_files:
640
- status_message += f". Failed files: {', '.join(failed_files)}"
641
-
642
- return InsertResponse(
643
- status="success" if inserted_count > 0 else "partial_success",
644
- message=status_message,
645
- document_count=len(files),
646
- )
647
- except Exception as e:
648
- raise HTTPException(status_code=500, detail=str(e))
649
-
650
- @app.delete(
651
- "/documents",
652
- response_model=InsertResponse,
653
- dependencies=[Depends(optional_api_key)],
654
- )
655
- async def clear_documents():
656
- try:
657
- rag.text_chunks = []
658
- rag.entities_vdb = None
659
- rag.relationships_vdb = None
660
- return InsertResponse(
661
- status="success",
662
- message="All documents cleared successfully",
663
- document_count=0,
664
- )
665
- except Exception as e:
666
- raise HTTPException(status_code=500, detail=str(e))
667
-
668
- # Ollama compatible API endpoints
669
- @app.get("/api/version")
670
- async def get_version():
671
- """Get Ollama version information"""
672
- return OllamaVersionResponse(version="0.5.4")
673
-
674
- @app.get("/api/tags")
675
- async def get_tags():
676
- """Get available models"""
677
- return OllamaTagResponse(
678
- models=[
679
- {
680
- "name": LIGHTRAG_MODEL,
681
- "model": LIGHTRAG_MODEL,
682
- "size": LIGHTRAG_SIZE,
683
- "digest": LIGHTRAG_DIGEST,
684
- "modified_at": LIGHTRAG_CREATED_AT,
685
- "details": {
686
- "parent_model": "",
687
- "format": "gguf",
688
- "family": LIGHTRAG_NAME,
689
- "families": [LIGHTRAG_NAME],
690
- "parameter_size": "13B",
691
- "quantization_level": "Q4_0",
692
- },
693
- }
694
- ]
695
- )
696
-
697
- def parse_query_mode(query: str) -> tuple[str, SearchMode]:
698
- """Parse query prefix to determine search mode
699
- Returns tuple of (cleaned_query, search_mode)
700
- """
701
- mode_map = {
702
- "/local ": SearchMode.local,
703
- "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
704
- "/naive ": SearchMode.naive,
705
- "/hybrid ": SearchMode.hybrid,
706
- "/mix ": SearchMode.mix,
707
- }
708
-
709
- for prefix, mode in mode_map.items():
710
- if query.startswith(prefix):
711
- # After removing prefix an leading spaces
712
- cleaned_query = query[len(prefix) :].lstrip()
713
- return cleaned_query, mode
714
-
715
- return query, SearchMode.hybrid
716
-
717
- @app.post("/api/chat")
718
- async def chat(raw_request: Request, request: OllamaChatRequest):
719
- """Handle chat completion requests"""
720
- try:
721
- # Get all messages
722
- messages = request.messages
723
- if not messages:
724
- raise HTTPException(status_code=400, detail="No messages provided")
725
-
726
- # Get the last message as query
727
- query = messages[-1].content
728
-
729
- # 解析查询模式
730
- cleaned_query, mode = parse_query_mode(query)
731
-
732
- # 开始计时
733
- start_time = time.time_ns()
734
-
735
- # 计算输入token数量
736
- prompt_tokens = estimate_tokens(cleaned_query)
737
-
738
- # 调用RAG进行查询
739
- query_param = QueryParam(
740
- mode=mode, stream=request.stream, only_need_context=False
741
- )
742
-
743
- if request.stream:
744
- from fastapi.responses import StreamingResponse
745
-
746
- response = await rag.aquery( # Need await to get async generator
747
- cleaned_query, param=query_param
748
- )
749
-
750
- async def stream_generator():
751
- try:
752
- first_chunk_time = None
753
- last_chunk_time = None
754
- total_response = ""
755
-
756
- # Ensure response is an async generator
757
- if isinstance(response, str):
758
- # If it's a string, send in two parts
759
- first_chunk_time = time.time_ns()
760
- last_chunk_time = first_chunk_time
761
- total_response = response
762
-
763
- data = {
764
- "model": LIGHTRAG_MODEL,
765
- "created_at": LIGHTRAG_CREATED_AT,
766
- "message": {
767
- "role": "assistant",
768
- "content": response,
769
- "images": None,
770
- },
771
- "done": False,
772
- }
773
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
774
-
775
- completion_tokens = estimate_tokens(total_response)
776
- total_time = last_chunk_time - start_time
777
- prompt_eval_time = first_chunk_time - start_time
778
- eval_time = last_chunk_time - first_chunk_time
779
-
780
- data = {
781
- "model": LIGHTRAG_MODEL,
782
- "created_at": LIGHTRAG_CREATED_AT,
783
- "done": True,
784
- "total_duration": total_time,
785
- "load_duration": 0,
786
- "prompt_eval_count": prompt_tokens,
787
- "prompt_eval_duration": prompt_eval_time,
788
- "eval_count": completion_tokens,
789
- "eval_duration": eval_time,
790
- }
791
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
792
- else:
793
- async for chunk in response:
794
- if chunk:
795
- if first_chunk_time is None:
796
- first_chunk_time = time.time_ns()
797
-
798
- last_chunk_time = time.time_ns()
799
-
800
- total_response += chunk
801
- data = {
802
- "model": LIGHTRAG_MODEL,
803
- "created_at": LIGHTRAG_CREATED_AT,
804
- "message": {
805
- "role": "assistant",
806
- "content": chunk,
807
- "images": None,
808
- },
809
- "done": False,
810
- }
811
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
812
-
813
- completion_tokens = estimate_tokens(total_response)
814
- total_time = last_chunk_time - start_time
815
- prompt_eval_time = first_chunk_time - start_time
816
- eval_time = last_chunk_time - first_chunk_time
817
-
818
- data = {
819
- "model": LIGHTRAG_MODEL,
820
- "created_at": LIGHTRAG_CREATED_AT,
821
- "done": True,
822
- "total_duration": total_time,
823
- "load_duration": 0,
824
- "prompt_eval_count": prompt_tokens,
825
- "prompt_eval_duration": prompt_eval_time,
826
- "eval_count": completion_tokens,
827
- "eval_duration": eval_time,
828
- }
829
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
830
- return # Ensure the generator ends immediately after sending the completion marker
831
- except Exception as e:
832
- logging.error(f"Error in stream_generator: {str(e)}")
833
- raise
834
-
835
- return StreamingResponse(
836
- stream_generator(),
837
- media_type="application/x-ndjson",
838
- headers={
839
- "Cache-Control": "no-cache",
840
- "Connection": "keep-alive",
841
- "Content-Type": "application/x-ndjson",
842
- "Access-Control-Allow-Origin": "*",
843
- "Access-Control-Allow-Methods": "POST, OPTIONS",
844
- "Access-Control-Allow-Headers": "Content-Type",
845
- },
846
- )
847
- else:
848
- first_chunk_time = time.time_ns()
849
- response_text = await rag.aquery(cleaned_query, param=query_param)
850
- last_chunk_time = time.time_ns()
851
-
852
- if not response_text:
853
- response_text = "No response generated"
854
-
855
- completion_tokens = estimate_tokens(str(response_text))
856
- total_time = last_chunk_time - start_time
857
- prompt_eval_time = first_chunk_time - start_time
858
- eval_time = last_chunk_time - first_chunk_time
859
-
860
- return {
861
- "model": LIGHTRAG_MODEL,
862
- "created_at": LIGHTRAG_CREATED_AT,
863
- "message": {
864
- "role": "assistant",
865
- "content": str(response_text),
866
- "images": None,
867
- },
868
- "done": True,
869
- "total_duration": total_time,
870
- "load_duration": 0,
871
- "prompt_eval_count": prompt_tokens,
872
- "prompt_eval_duration": prompt_eval_time,
873
- "eval_count": completion_tokens,
874
- "eval_duration": eval_time,
875
- }
876
- except Exception as e:
877
- raise HTTPException(status_code=500, detail=str(e))
878
-
879
- @app.get("/health", dependencies=[Depends(optional_api_key)])
880
- async def get_status():
881
- """Get current system status"""
882
- return {
883
- "status": "healthy",
884
- "working_directory": str(args.working_dir),
885
- "input_directory": str(args.input_dir),
886
- "indexed_files": len(doc_manager.indexed_files),
887
- "configuration": {
888
- # LLM configuration binding/host address (if applicable)/model (if applicable)
889
- "llm_binding": args.llm_binding,
890
- "llm_binding_host": args.llm_binding_host,
891
- "llm_model": args.llm_model,
892
- # embedding model configuration binding/host address (if applicable)/model (if applicable)
893
- "embedding_binding": args.embedding_binding,
894
- "embedding_binding_host": args.embedding_binding_host,
895
- "embedding_model": args.embedding_model,
896
- "max_tokens": args.max_tokens,
897
- },
898
- }
899
-
900
- return app
901
-
902
-
903
- def main():
904
- args = parse_args()
905
- import uvicorn
906
-
907
- app = create_app(args)
908
- uvicorn_config = {
909
- "app": app,
910
- "host": args.host,
911
- "port": args.port,
912
- }
913
- if args.ssl:
914
- uvicorn_config.update(
915
- {
916
- "ssl_certfile": args.ssl_certfile,
917
- "ssl_keyfile": args.ssl_keyfile,
918
- }
919
- )
920
- uvicorn.run(**uvicorn_config)
921
-
922
-
923
- if __name__ == "__main__":
924
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightrag/api/lightrag_server.py CHANGED
@@ -615,6 +615,32 @@ def create_app(args):
615
  **kwargs,
616
  )
617
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
  # Initialize RAG
619
  if args.llm_binding in ["lollms", "ollama"] :
620
  rag = LightRAG(
@@ -630,31 +656,7 @@ def create_app(args):
630
  "timeout": args.timeout,
631
  "options": {"num_ctx": args.max_tokens},
632
  },
633
- embedding_func=EmbeddingFunc(
634
- embedding_dim=args.embedding_dim,
635
- max_token_size=args.max_embed_tokens,
636
- func=lambda texts: lollms_embed(
637
- texts,
638
- embed_model=args.embedding_model,
639
- host=args.embedding_binding_host,
640
- )
641
- if args.embedding_binding == "lollms"
642
- else ollama_embed(
643
- texts,
644
- embed_model=args.embedding_model,
645
- host=args.embedding_binding_host,
646
- )
647
- if args.embedding_binding == "ollama"
648
- else azure_openai_embedding(
649
- texts,
650
- model=args.embedding_model, # no host is used for openai
651
- )
652
- if args.embedding_binding == "azure_openai"
653
- else openai_embedding(
654
- texts,
655
- model=args.embedding_model, # no host is used for openai
656
- ),
657
- ),
658
  )
659
  else :
660
  rag = LightRAG(
@@ -662,31 +664,7 @@ def create_app(args):
662
  llm_model_func=azure_openai_model_complete
663
  if args.llm_binding == "azure_openai"
664
  else openai_alike_model_complete,
665
- embedding_func=EmbeddingFunc(
666
- embedding_dim=args.embedding_dim,
667
- max_token_size=args.max_embed_tokens,
668
- func=lambda texts: lollms_embed(
669
- texts,
670
- embed_model=args.embedding_model,
671
- host=args.embedding_binding_host,
672
- )
673
- if args.embedding_binding == "lollms"
674
- else ollama_embed(
675
- texts,
676
- embed_model=args.embedding_model,
677
- host=args.embedding_binding_host,
678
- )
679
- if args.embedding_binding == "ollama"
680
- else azure_openai_embedding(
681
- texts,
682
- model=args.embedding_model, # no host is used for openai
683
- )
684
- if args.embedding_binding == "azure_openai"
685
- else openai_embedding(
686
- texts,
687
- model=args.embedding_model, # no host is used for openai
688
- ),
689
- ),
690
  )
691
 
692
  async def index_file(file_path: Union[str, Path]) -> None:
 
615
  **kwargs,
616
  )
617
 
618
+ embedding_func = EmbeddingFunc(
619
+ embedding_dim=args.embedding_dim,
620
+ max_token_size=args.max_embed_tokens,
621
+ func=lambda texts: lollms_embed(
622
+ texts,
623
+ embed_model=args.embedding_model,
624
+ host=args.embedding_binding_host,
625
+ )
626
+ if args.embedding_binding == "lollms"
627
+ else ollama_embed(
628
+ texts,
629
+ embed_model=args.embedding_model,
630
+ host=args.embedding_binding_host,
631
+ )
632
+ if args.embedding_binding == "ollama"
633
+ else azure_openai_embedding(
634
+ texts,
635
+ model=args.embedding_model, # no host is used for openai
636
+ )
637
+ if args.embedding_binding == "azure_openai"
638
+ else openai_embedding(
639
+ texts,
640
+ model=args.embedding_model, # no host is used for openai
641
+ ),
642
+ )
643
+
644
  # Initialize RAG
645
  if args.llm_binding in ["lollms", "ollama"] :
646
  rag = LightRAG(
 
656
  "timeout": args.timeout,
657
  "options": {"num_ctx": args.max_tokens},
658
  },
659
+ embedding_func=embedding_func,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
  )
661
  else :
662
  rag = LightRAG(
 
664
  llm_model_func=azure_openai_model_complete
665
  if args.llm_binding == "azure_openai"
666
  else openai_alike_model_complete,
667
+ embedding_func=embedding_func,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
  )
669
 
670
  async def index_file(file_path: Union[str, Path]) -> None:
setup.py CHANGED
@@ -101,7 +101,6 @@ setuptools.setup(
101
  entry_points={
102
  "console_scripts": [
103
  "lightrag-server=lightrag.api.lightrag_server:main [api]",
104
- "lightrag-ollama=lightrag.api.lightrag_ollama:main [api]",
105
  ],
106
  },
107
  )
 
101
  entry_points={
102
  "console_scripts": [
103
  "lightrag-server=lightrag.api.lightrag_server:main [api]",
 
104
  ],
105
  },
106
  )
start-server.sh DELETED
@@ -1,3 +0,0 @@
1
- . venv/bin/activate
2
-
3
- lightrag-ollama --llm-binding openai --llm-model deepseek-chat --embedding-model "bge-m3:latest" --embedding-dim 1024