jnjj commited on
Commit
dac0de6
·
verified ·
1 Parent(s): a3631f8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1049 -0
app.py ADDED
@@ -0,0 +1,1049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import os
4
+ import time
5
+ from http import HTTPStatus
6
+ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
7
+
8
+ import fastapi
9
+ import uvicorn
10
+ from fastapi import Request, Depends, HTTPException, BackgroundTasks
11
+ from fastapi.exceptions import RequestValidationError
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from fastapi.responses import JSONResponse, StreamingResponse, Response
14
+ from packaging import version
15
+ from pydantic import BaseModel, Field, ValidationError, validator, conint, root_validator
16
+
17
+ from vllm.engine.arg_utils import AsyncEngineArgs
18
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
19
+
20
+ from vllm.entrypoints.openai.protocol import (
21
+ CompletionResponse, CompletionResponseChoice,
22
+ CompletionResponseStreamChoice, CompletionStreamResponse,
23
+ ChatCompletionResponse,
24
+ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
25
+ ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
26
+ ModelCard, ModelList, ModelPermission, UsageInfo)
27
+
28
+
29
+ from vllm.logger import init_logger
30
+ from vllm.outputs import RequestOutput
31
+ from vllm.sampling_params import SamplingParams
32
+ from vllm.transformers_utils.tokenizer import get_tokenizer
33
+ from vllm.utils import random_uuid
34
+ from vllm import LLM
35
+
36
+ from huggingface_hub import snapshot_download
37
+
38
+
39
+ try:
40
+ import fastchat
41
+ from fastchat.conversation import Conversation, SeparatorStyle
42
+ from fastchat.model.model_adapter import get_conversation_template
43
+ _fastchat_available = True
44
+ except ImportError:
45
+ _fastchat_available = False
46
+
47
+ TIMEOUT_KEEP_ALIVE = 5
48
+ DEFAULT_API_KEY = "your_default_api_key"
49
+ API_KEY = os.environ.get("API_KEY", DEFAULT_API_KEY)
50
+ MODEL_NAME = os.environ.get("SERVED_MODEL", "jnjj/gemma-3-4b-it-qat-int4-quantized-inference-unrestricted-pruned-sf")
51
+ HOST = os.environ.get("HOST", "0.0.0.0")
52
+ PORT = int(os.environ.get("PORT", "7860"))
53
+ MAX_MODEL_LEN_CONFIG = int(os.environ.get("MAX_MODEL_LEN", "8000"))
54
+ GPU_MEMORY_UTILIZATION = float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.0"))
55
+ REQUESTS_PER_MINUTE = int(os.environ.get("REQUESTS_PER_MINUTE", "120"))
56
+ LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
57
+ DOWNLOADED_MODEL_PATH = None
58
+ ENABLE_REQUEST_LOGGING = os.environ.get("ENABLE_REQUEST_LOGGING", "false").lower() == "true"
59
+ MAX_CONCURRENT_DOWNLOADS = int(os.environ.get("MAX_CONCURRENT_DOWNLOADS", "2"))
60
+ QUEUE_SIZE = int(os.environ.get("QUEUE_SIZE", "100"))
61
+
62
+ logger = init_logger(__name__)
63
+ served_model = MODEL_NAME
64
+ app = fastapi.FastAPI(title="vLLM OpenAI API", description="Concurrent OpenAI Compatible API - vLLM Powered - Advanced, Robust & Optimized", version="1.2.0")
65
+ engine = None
66
+ tokenizer = None
67
+ max_model_len = MAX_MODEL_LEN_CONFIG
68
+ download_semaphore = asyncio.Semaphore(MAX_CONCURRENT_DOWNLOADS)
69
+ request_queue: asyncio.Queue = asyncio.Queue(maxsize=QUEUE_SIZE)
70
+
71
+ request_timestamps = []
72
+
73
+
74
+ async def rate_limit_dependency(request: Request):
75
+ current_time = time.monotonic()
76
+ request_timestamps.append(current_time)
77
+ request_timestamps[:] = [ts for ts in request_timestamps if current_time - ts <= 60]
78
+ if len(request_timestamps) > REQUESTS_PER_MINUTE:
79
+ raise HTTPException(status_code=429, detail="Too Many Requests. Please try again later.")
80
+ return True
81
+
82
+ async def queue_dependency():
83
+ if request_queue.full():
84
+ raise HTTPException(status_code=429, detail="Queue is full. Please try again later.")
85
+ await request_queue.put(1)
86
+ try:
87
+ yield
88
+ finally:
89
+ await request_queue.get(1)
90
+ request_queue.task_done()
91
+
92
+
93
+ class HTTPException(fastapi.HTTPException):
94
+ pass
95
+
96
+
97
+ class ChatCompletionRequest(BaseModel):
98
+ model: str = Field(default=MODEL_NAME, description="Model name for chat completion")
99
+ api_key: str = Field(..., description="API Key for authentication")
100
+ messages: Union[str, List[Dict[str, str]]] = Field(..., description="Conversation messages")
101
+ temperature: Optional[float] = Field(0.7, description="Sampling temperature")
102
+ top_p: Optional[float] = Field(1.0, description="Top p sampling parameter")
103
+ n: Optional[conint(ge=1, le=10)] = Field(1, description="Number of chat completion choices (max 10)")
104
+ max_tokens: Optional[conint(ge=1, le=max_model_len)] = Field(None, description=f"Max tokens, up to {max_model_len}")
105
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list, description="Stop sequences")
106
+ stream: Optional[bool] = Field(False, description="Enable streaming responses")
107
+ presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
108
+ frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty")
109
+ logit_bias: Optional[Dict[str, float]] = Field(None, description="Logit bias map")
110
+ user: Optional[str] = Field(None, description="User identifier")
111
+ best_of: Optional[conint(ge=1, le=10)] = Field(None, description="Best of sampling (max 10)")
112
+ top_k: Optional[conint(ge=-1)] = Field(-1, description="Top k sampling")
113
+ ignore_eos: Optional[bool] = Field(False, description="Ignore EOS token")
114
+ use_beam_search: Optional[bool] = Field(False, description="Use beam search (not for chat)")
115
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list, description="Stop token IDs")
116
+ skip_special_tokens: Optional[bool] = Field(True, description="Skip special tokens")
117
+ spaces_between_special_tokens: Optional[bool] = Field(True, description="Spaces between special tokens")
118
+
119
+ @validator("messages")
120
+ def messages_must_be_list_or_str(cls, v):
121
+ if not isinstance(v, (str, list)):
122
+ raise ValueError("Messages must be a string or a list of messages")
123
+ return v
124
+
125
+
126
+ class CompletionRequest(BaseModel):
127
+ model: str = Field(default=MODEL_NAME, description="Model name for text completion")
128
+ api_key: str = Field(..., description="API Key for authentication")
129
+ prompt: Union[List[int], List[List[int]], str, List[str]] = Field(..., description="Text prompt for completion")
130
+ suffix: Optional[str] = Field(None, description="Suffix (not supported)")
131
+ max_tokens: Optional[conint(ge=1, le=max_model_len)] = Field(16, description=f"Max completion tokens, up to {max_model_len}")
132
+ temperature: Optional[float] = Field(1.0, description="Sampling temperature")
133
+ top_p: Optional[float] = Field(1.0, description="Top p sampling")
134
+ n: Optional[conint(ge=1, le=10)] = Field(1, description="Number of completions (max 10)")
135
+ stream: Optional[bool] = Field(False, description="Enable streaming responses")
136
+ echo: Optional[bool] = Field(False, description="Echo prompt (not supported)")
137
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list, description="Stop sequences")
138
+ presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
139
+ frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty")
140
+ logit_bias: Optional[Dict[str, float]] = Field(None, description="Logit bias map")
141
+ user: Optional[str] = Field(None, description="User identifier")
142
+ best_of: Optional[conint(ge=1, le=10)] = Field(None, description="Best of sampling (max 10)")
143
+ top_k: Optional[conint(ge=-1)] = Field(-1, description="Top k sampling")
144
+ ignore_eos: Optional[bool] = Field(False, description="Ignore EOS token")
145
+ use_beam_search: Optional[bool] = Field(False, description="Use beam search (not for completion)")
146
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list, description="Stop token IDs")
147
+ skip_special_tokens: Optional[bool] = Field(True, description="Skip special tokens")
148
+ spaces_between_special_tokens: Optional[bool] = Field(True, description="Spaces between special tokens")
149
+
150
+ @validator("prompt")
151
+ def prompt_must_be_list_or_str(cls, v):
152
+ if not isinstance(v, (str, list)):
153
+ raise ValueError("Prompt must be a string or a list of prompts")
154
+ return v
155
+
156
+
157
+ class ModelDownloadResponse(BaseModel):
158
+ model_name: str = Field(..., description="Name of model downloaded")
159
+ download_path: Optional[str] = Field(None, description="Local download path")
160
+ status: str = Field(..., description="Download status")
161
+ message: Optional[str] = Field(None, description="Download message")
162
+
163
+
164
+ def create_error_response(status_code: HTTPStatus, message: str, err_type="invalid_request_error") -> JSONResponse:
165
+ logger.error(f"Error Response: {status_code.value} - {message} ({err_type})")
166
+ return JSONResponse(ErrorResponse(message=message, type=err_type).dict(), status_code=status_code.value)
167
+
168
+
169
+ @app.exception_handler(RequestValidationError)
170
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
171
+ logger.warning(f"Validation Error: {exc}")
172
+ return create_error_response(HTTPStatus.BAD_REQUEST, str(exc), err_type="validation_error")
173
+
174
+
175
+ @app.exception_handler(HTTPException)
176
+ async def http_exception_handler(request: Request, exc: HTTPException):
177
+ logger.warning(f"HTTP Exception: {exc.detail} Status Code: {exc.status_code}")
178
+ return create_error_response(exc.status_code, exc.detail, err_type="rate_limit_error" if exc.status_code == 429 else "http_error")
179
+
180
+
181
+ async def check_api_key(api_key: str = Depends(lambda request: request.headers.get("Authorization") or request.query_params.get("api_key"))):
182
+ if api_key is None or api_key.replace("Bearer ", "") != API_KEY:
183
+ raise HTTPException(status_code=401, detail="Invalid API key.")
184
+ return True
185
+
186
+
187
+ async def check_model(request_model_name: str) -> Optional[JSONResponse]:
188
+ model_to_check = DOWNLOADED_MODEL_PATH if DOWNLOADED_MODEL_PATH else served_model
189
+ if request_model_name == model_to_check:
190
+ return None
191
+ return create_error_response(
192
+ HTTPStatus.NOT_FOUND,
193
+ f"Model '{request_model_name}' not found. Serving: {model_to_check}",
194
+ err_type="model_not_found"
195
+ )
196
+
197
+
198
+ async def get_gen_prompt(request: ChatCompletionRequest) -> str:
199
+ if not _fastchat_available:
200
+ raise ModuleNotFoundError("fastchat not installed. Install to use chat API: pip install fschat")
201
+ if version.parse(fastchat.__version__) < version.parse("0.2.23"):
202
+ raise ImportError(f"fastchat version too low: {fastchat.__version__}. Upgrade: pip install -U fschat")
203
+
204
+ try:
205
+ try:
206
+ conv = get_conversation_template(request.model)
207
+ except Exception:
208
+ logger.warning(f"Conversation template for model '{request.model}' not found. Using default template.")
209
+ if isinstance(request.messages, str):
210
+ return request.messages
211
+ else:
212
+ raise ValueError(f"Conversation template for model '{request.model}' not found and messages is not a string.")
213
+
214
+ conv_dict = request.dict()
215
+ conversation_keys = {f.name for f in Conversation.__fields__.values()}
216
+ filtered_conv_dict = {k: v for k, v in conv_dict.items() if k in conversation_keys}
217
+ conv = Conversation(**filtered_conv_dict)
218
+
219
+
220
+ if isinstance(request.messages, str):
221
+ prompt = request.messages
222
+ else:
223
+ for message in request.messages:
224
+ role = message["role"]
225
+ if role == "system":
226
+ conv.system_message = message["content"]
227
+ elif role == "user":
228
+ conv.append_message(conv.roles[0], message["content"])
229
+ elif role == "assistant":
230
+ conv.append_message(conv.roles[1], message["content"])
231
+ else:
232
+ raise ValueError(f"Unknown role: {role}")
233
+
234
+ conv.append_message(conv.roles[1], None)
235
+ prompt = conv.get_prompt()
236
+ return prompt
237
+ except ValueError as e:
238
+ logger.error(f"Prompt generation error: {e}")
239
+ raise ValueError(f"Failed to generate prompt: {e}")
240
+ except Exception as e:
241
+ logger.error(f"An unexpected error occurred during prompt generation: {e}")
242
+ raise RuntimeError(f"An unexpected error occurred during prompt generation: {e}")
243
+
244
+
245
+ async def check_length(request: Union[ChatCompletionRequest, CompletionRequest], prompt: Optional[str] = None, prompt_ids: Optional[List[int]] = None) -> Tuple[List[int], Optional[JSONResponse]]:
246
+ assert (not (prompt is None and prompt_ids is None) and not (prompt is not None and prompt_ids is not None)), "Provide either prompt or prompt_ids."
247
+
248
+ if tokenizer is None:
249
+ return [], create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, "Tokenizer not initialized.", err_type="internal_error")
250
+
251
+ try:
252
+ input_ids = prompt_ids if prompt_ids else tokenizer(prompt).input_ids
253
+ except Exception as e:
254
+ logger.error(f"Error during tokenization: {e}")
255
+ return [], create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f"Error tokenizing prompt: {e}", err_type="tokenization_error")
256
+
257
+ token_num = len(input_ids)
258
+
259
+ if request.max_tokens is None:
260
+ remaining_tokens = max_model_len - token_num
261
+ if remaining_tokens <= 0:
262
+ return input_ids, create_error_response(
263
+ HTTPStatus.BAD_REQUEST,
264
+ f"Prompt length ({token_num}) exceeds or equals max model length ({max_model_len}). No space for completion.",
265
+ err_type="context_length_exceeded"
266
+ )
267
+ request.max_tokens = remaining_tokens
268
+
269
+ if token_num + request.max_tokens > max_model_len:
270
+ return input_ids, create_error_response(
271
+ HTTPStatus.BAD_REQUEST,
272
+ f"Context length exceeded. Max: {max_model_len}, Prompt Tokens: {token_num}, Requested Completion Tokens: {request.max_tokens}, Total: {request.max_tokens + token_num}",
273
+ err_type="context_length_exceeded"
274
+ )
275
+ return input_ids, None
276
+
277
+
278
+ @app.get("/health", tags=["System"])
279
+ async def health() -> Response:
280
+ if engine is None:
281
+ return Response(status_code=503, content="Engine not initialized")
282
+ try:
283
+ await engine.get_model_config()
284
+ return Response(status_code=200)
285
+ except Exception as e:
286
+ logger.error(f"Health check failed: {e}")
287
+ return Response(status_code=503, content=f"Engine health check failed: {e}")
288
+
289
+
290
+ @app.get("/metrics", tags=["System"])
291
+ async def metrics() -> Response:
292
+ return Response(content="", media_type="text/plain")
293
+
294
+
295
+ @app.get("/models", response_model=ModelList, tags=["System"])
296
+ async def show_available_models():
297
+ model_cards = [
298
+ ModelCard(id= served_model if DOWNLOADED_MODEL_PATH is None else DOWNLOADED_MODEL_PATH,
299
+ root= served_model if DOWNLOADED_MODEL_PATH is None else DOWNLOADED_MODEL_PATH,
300
+ permission=[ModelPermission()])
301
+ ]
302
+ return ModelList(data=model_cards)
303
+
304
+
305
+ @app.get("/model_config", tags=["System"])
306
+ async def get_model_configuration():
307
+ model_config = {
308
+ "model_name": served_model if DOWNLOADED_MODEL_PATH is None else DOWNLOADED_MODEL_PATH,
309
+ "max_model_len_config": MAX_MODEL_LEN_CONFIG,
310
+ "cpu_only": True,
311
+ "gpu_memory_utilization": GPU_MEMORY_UTILIZATION,
312
+ }
313
+ if engine:
314
+ try:
315
+ engine_model_config = await engine.get_model_config()
316
+ model_config["actual_max_model_len"] = engine_model_config.max_model_len
317
+ model_config["dtype"] = engine_model_config.dtype
318
+ model_config["num_layers"] = engine_model_config.num_layers
319
+ model_config["num_attention_heads"] = engine_model_config.num_attention_heads
320
+ model_config["hidden_size"] = engine_model_config.hidden_size
321
+ model_config["vocab_size"] = engine_model_config.vocab_size
322
+
323
+ except Exception as e:
324
+ logger.warning(f"Could not retrieve detailed engine config: {e}")
325
+ model_config["engine_config_status"] = f"Error retrieving engine config: {e}"
326
+
327
+ return model_config
328
+
329
+ @app.post("/models/download", response_model=ModelDownloadResponse, tags=["Model Management"])
330
+ async def download_model(model_name: str = fastapi.Query(..., description="Model name to download"), background_tasks: BackgroundTasks = BackgroundTasks()):
331
+ logger.info(f"Download requested for model: {model_name}")
332
+ if download_semaphore.locked():
333
+ raise HTTPException(status_code=429, detail="Model download already in progress.")
334
+
335
+ global DOWNLOADED_MODEL_PATH
336
+ previous_downloaded_path = DOWNLOADED_MODEL_PATH
337
+ DOWNLOADED_MODEL_PATH = None
338
+
339
+ background_tasks.add_task(run_model_download, model_name, previous_downloaded_path)
340
+
341
+ return ModelDownloadResponse(model_name=model_name, status="pending", message="Model download started. Check logs for progress.")
342
+
343
+
344
+ async def run_model_download(model_name: str, previous_downloaded_path: Optional[str]):
345
+ async with download_semaphore:
346
+ logger.info(f"Starting background download for model: {model_name}")
347
+ loop = asyncio.get_running_loop()
348
+ global DOWNLOADED_MODEL_PATH, engine, tokenizer, max_model_len
349
+ try:
350
+ download_path = await loop.run_in_executor(None, snapshot_download, model_name)
351
+ logger.info(f"Model downloaded to: {download_path}")
352
+
353
+ if engine:
354
+ logger.info("Shutting down existing engine...")
355
+ engine = None
356
+ tokenizer = None
357
+ max_model_len = MAX_MODEL_LEN_CONFIG
358
+ await asyncio.sleep(2)
359
+ logger.info("Existing engine dereferenced.")
360
+
361
+ await initialize_llm_engine(download_path)
362
+ DOWNLOADED_MODEL_PATH = download_path
363
+ logger.info(f"Model '{model_name}' ready from downloaded path: {DOWNLOADED_MODEL_PATH}")
364
+
365
+ except Exception as e:
366
+ logger.error(f"Model download & init error for {model_name}: {e}")
367
+ DOWNLOADED_MODEL_PATH = previous_downloaded_path
368
+
369
+
370
+ async def completion_stream_generator_chat(request: ChatCompletionRequest, result_generator: AsyncGenerator[RequestOutput, None]) -> AsyncGenerator[str, None]:
371
+ model_name = request.model
372
+ request_id = f"cmpl-{random_uuid()}"
373
+ created_time = int(time.time())
374
+ prompt_token_count = 0
375
+
376
+ for i in range(request.n):
377
+ choice_data = ChatCompletionResponseStreamChoice(index=i, delta=DeltaMessage(role="assistant"))
378
+ chunk = ChatCompletionStreamResponse(id=request_id, choices=[choice_data], model=model_name, created=created_time)
379
+ yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
380
+
381
+
382
+ previous_texts = [""] * request.n
383
+ output_completion_tokens = [0] * request.n
384
+
385
+
386
+ async for res in result_generator:
387
+ prompt_token_count = len(res.prompt_token_ids)
388
+ final_res = res
389
+
390
+ for output in res.outputs:
391
+ i = output.index
392
+ delta_text = output.text[len(previous_texts[i]):]
393
+ previous_texts[i] = output.text
394
+
395
+ output_completion_tokens[i] = len(output.token_ids)
396
+
397
+ if delta_text:
398
+ choice_data = ChatCompletionResponseStreamChoice(index=i, delta=DeltaMessage(content=delta_text))
399
+ chunk = ChatCompletionStreamResponse(id=request_id, choices=[choice_data], model=model_name, created=created_time)
400
+ yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
401
+
402
+ if output.finish_reason:
403
+ choice_data_finish = ChatCompletionResponseStreamChoice(index=i, delta=DeltaMessage(), finish_reason=output.finish_reason)
404
+ chunk_finish = ChatCompletionStreamResponse(id=request_id, choices=[choice_data_finish], model=model_name, created=created_time)
405
+ yield f"data: {chunk_finish.json(exclude_unset=True, ensure_ascii=False)}\n\n"
406
+
407
+ total_completion_tokens = sum(output_completion_tokens)
408
+
409
+
410
+ yield "data: [DONE]\n\n"
411
+
412
+
413
+ async def completion_stream_generator_completion(request: CompletionRequest, result_generator: AsyncGenerator[RequestOutput, None]) -> AsyncGenerator[str, None]:
414
+ model_name = request.model
415
+ request_id = f"cmpl-{random_uuid()}"
416
+ created_time = int(time.time())
417
+ prompt_token_count = 0
418
+ completion_token_count = 0
419
+
420
+ previous_texts = [""] * request.n
421
+
422
+ async for res in result_generator:
423
+ prompt_token_count = len(res.prompt_token_ids)
424
+
425
+ for output in res.outputs:
426
+ i = output.index
427
+ delta_text = output.text[len(previous_texts[i]):]
428
+
429
+ current_output_tokens = len(output.token_ids)
430
+ tokens_generated_in_chunk = current_output_tokens - len(tokenizer(previous_texts[i]).input_ids)
431
+
432
+ logprobs_obj = None
433
+
434
+ previous_texts[i] = output.text
435
+
436
+ completion_token_count += tokens_generated_in_chunk
437
+
438
+ choice_data = CompletionResponseStreamChoice(
439
+ index=i,
440
+ text=delta_text,
441
+ logprobs=None
442
+ )
443
+ chunk = CompletionStreamResponse(id=request_id, choices=[choice_data], model=model_name, created=created_time)
444
+ yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
445
+
446
+ if output.finish_reason:
447
+ choice_data_finish = CompletionResponseStreamChoice(
448
+ index=i,
449
+ text="",
450
+ logprobs=None,
451
+ finish_reason=output.finish_reason
452
+ )
453
+ chunk_finish = CompletionStreamResponse(id=request_id, choices=[choice_data_finish], model=model_name, created=created_time)
454
+ yield f"data: {chunk_finish.json(exclude_unset=True, ensure_ascii=False)}\n\n"
455
+
456
+
457
+ yield "data: [DONE]\n\n"
458
+
459
+
460
+ @app.post("/completions", response_model=CompletionResponse, tags=["Completions"], dependencies=[Depends(rate_limit_dependency), Depends(check_api_key), Depends(queue_dependency)])
461
+ async def create_completion(request: CompletionRequest, raw_request: Request):
462
+ start_time = time.monotonic()
463
+ if ENABLE_REQUEST_LOGGING:
464
+ logger.info(f"Completion Request: {request}")
465
+
466
+ model_error_check = await check_model(request.model)
467
+ if model_error_check:
468
+ if request.stream:
469
+ error_json_str = json.dumps(json.loads(model_error_check.body))
470
+ async def error_stream():
471
+ yield f"data: {error_json_str}\n\n"
472
+ yield "data: [DONE]\n\n"
473
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=model_error_check.status_code)
474
+ else:
475
+ raise HTTPException(status_code=model_error_check.status_code, detail=json.loads(model_error_check.body)['message'])
476
+
477
+
478
+ if request.echo:
479
+ error_message = "Echo not supported."
480
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="not_supported")
481
+ if request.stream:
482
+ error_json_str = json.dumps(json.loads(error_res.body))
483
+ async def error_stream():
484
+ yield f"data: {error_json_str}\n\n"
485
+ yield "data: [DONE]\n\n"
486
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
487
+ else:
488
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
489
+
490
+ if request.suffix:
491
+ error_message = "Suffix not supported."
492
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="not_supported")
493
+ if request.stream:
494
+ error_json_str = json.dumps(json.loads(error_res.body))
495
+ async def error_stream():
496
+ yield f"data: {error_json_str}\n\n"
497
+ yield "data: [DONE]\n\n"
498
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
499
+ else:
500
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
501
+
502
+ if request.logit_bias and len(request.logit_bias) > 0:
503
+ error_message = "Logit bias not supported."
504
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="not_supported")
505
+ if request.stream:
506
+ error_json_str = json.dumps(json.loads(error_res.body))
507
+ async def error_stream():
508
+ yield f"data: {error_json_str}\n\n"
509
+ yield "data: [DONE]\n\n"
510
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
511
+ else:
512
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
513
+
514
+
515
+ model_name = request.model
516
+ request_id = f"cmpl-{random_uuid()}"
517
+
518
+ use_token_ids = False
519
+ prompt = request.prompt
520
+ prompt_processed = None
521
+ prompt_token_ids_input = None
522
+
523
+ if isinstance(prompt, list):
524
+ if not prompt:
525
+ error_message = "Provide at least one prompt."
526
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="invalid_prompt")
527
+ if request.stream:
528
+ error_json_str = json.dumps(json.loads(error_res.body))
529
+ async def error_stream():
530
+ yield f"data: {error_json_str}\n\n"
531
+ yield "data: [DONE]\n\n"
532
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
533
+ else:
534
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
535
+
536
+ first_element = prompt[0]
537
+ if isinstance(first_element, int):
538
+ use_token_ids = True
539
+ prompt_token_ids_input = prompt
540
+ elif isinstance(first_element, list):
541
+ if len(prompt) > 1:
542
+ error_message = "Batch requests are not fully supported for 'prompt' field as List[str] or List[List[int]] > 1."
543
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="not_supported")
544
+ if request.stream:
545
+ error_json_str = json.dumps(json.loads(error_res.body))
546
+ async def error_stream():
547
+ yield f"data: {error_json_str}\n\n"
548
+ yield "data: [DONE]\n\n"
549
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
550
+ else:
551
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
552
+ if isinstance(first_element, int):
553
+ use_token_ids = True
554
+ prompt_token_ids_input = prompt
555
+ elif isinstance(first_element, str):
556
+ prompt_processed = prompt[0]
557
+ elif isinstance(first_element, list) and isinstance(first_element[0], int):
558
+ use_token_ids = True
559
+ prompt_token_ids_input = first_element
560
+ else:
561
+ error_message = "Invalid format for 'prompt' list."
562
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="invalid_prompt")
563
+ if request.stream:
564
+ error_json_str = json.dumps(json.loads(error_res.body))
565
+ async def error_stream():
566
+ yield f"data: {error_json_str}\n\n"
567
+ yield "data: [DONE]\n\n"
568
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
569
+ else:
570
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
571
+ elif isinstance(first_element, str):
572
+ prompt_processed = prompt[0]
573
+ else:
574
+ error_message = "Invalid format for 'prompt' list."
575
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="invalid_prompt")
576
+ if request.stream:
577
+ error_json_str = json.dumps(json.loads(error_res.body))
578
+ async def error_stream():
579
+ yield f"data: {error_json_str}\n\n"
580
+ yield "data: [DONE]\n\n"
581
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
582
+ else:
583
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
584
+ elif isinstance(prompt, str):
585
+ prompt_processed = prompt
586
+ else:
587
+ error_message = "Prompt must be a string or a list."
588
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="invalid_prompt")
589
+ if request.stream:
590
+ error_json_str = json.dumps(json.loads(error_res.body))
591
+ async def error_stream():
592
+ yield f"data: {error_json_str}\n\n"
593
+ yield "data: [DONE]\n\n"
594
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
595
+ else:
596
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
597
+
598
+
599
+ try:
600
+ if use_token_ids:
601
+ input_ids, length_error = await check_length(request, prompt_ids=prompt_token_ids_input)
602
+
603
+ else:
604
+ input_ids, length_error = await check_length(request, prompt=prompt_processed)
605
+
606
+ if length_error:
607
+ if request.stream:
608
+ error_json_str = json.dumps(json.loads(length_error.body))
609
+ async def error_stream():
610
+ yield f"data: {error_json_str}\n\n"
611
+ yield "data: [DONE]\n\n"
612
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=length_error.status_code)
613
+ else:
614
+ raise HTTPException(status_code=length_error.status_code, detail=json.loads(length_error.body)['message'])
615
+
616
+ except ValueError as ve:
617
+ error_message = str(ve)
618
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="prompt_error")
619
+ if request.stream:
620
+ error_json_str = json.dumps(json.loads(error_res.body))
621
+ async def error_stream():
622
+ yield f"data: {error_json_str}\n\n"
623
+ yield "data: [DONE]\n\n"
624
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
625
+ else:
626
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
627
+ except Exception as e:
628
+ error_message = f"Error processing prompt length: {e}"
629
+ error_res = create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, error_message, err_type="internal_error")
630
+ if request.stream:
631
+ error_json_str = json.dumps(json.loads(error_res.body))
632
+ async def error_stream():
633
+ yield f"data: {error_json_str}\n\n"
634
+ yield "data: [DONE]\n\n"
635
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
636
+ else:
637
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
638
+
639
+
640
+ created_time = int(time.time())
641
+
642
+ sampling_params = SamplingParams(**request.dict(
643
+ exclude={
644
+ "stream",
645
+ "api_key",
646
+ "model",
647
+ "prompt",
648
+ "user",
649
+ "echo",
650
+ "suffix",
651
+ "logit_bias",
652
+ },
653
+ exclude_none=True
654
+ ))
655
+
656
+ try:
657
+ if use_token_ids:
658
+ result_generator = engine.generate(
659
+ prompt=None,
660
+ sampling_params=sampling_params,
661
+ request_id=request_id,
662
+ prompt_token_ids=input_ids
663
+ )
664
+ else:
665
+ result_generator = engine.generate(
666
+ prompt=prompt_processed,
667
+ sampling_params=sampling_params,
668
+ request_id=request_id,
669
+ prompt_token_ids=input_ids
670
+ )
671
+
672
+ except Exception as e:
673
+ logger.error(f"Error submitting generation request to engine: {e}")
674
+ error_message = f"Error submitting generation request: {e}"
675
+ error_res = create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, error_message, err_type="engine_error")
676
+ if request.stream:
677
+ error_json_str = json.dumps(json.loads(error_res.body))
678
+ async def error_stream():
679
+ yield f"data: {error_json_str}\n\n"
680
+ yield "data: [DONE]\n\n"
681
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
682
+ else:
683
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
684
+
685
+
686
+ try:
687
+ if request.stream:
688
+ response = StreamingResponse(completion_stream_generator_completion(request, result_generator), media_type="text/event-stream")
689
+ return response
690
+ else:
691
+ final_res = None
692
+ async for res in result_generator:
693
+ final_res = res
694
+
695
+ if final_res is None or not final_res.outputs:
696
+ error_message = "Engine returned no output."
697
+ error_res = create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, error_message, err_type="engine_output_error")
698
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
699
+
700
+ choices = [
701
+ CompletionResponseChoice(
702
+ index=output.index,
703
+ text=output.text,
704
+ logprobs=None,
705
+ finish_reason=output.finish_reason
706
+ ) for output in final_res.outputs
707
+ ]
708
+
709
+ prompt_tokens = len(final_res.prompt_token_ids)
710
+ completion_tokens = sum(len(output.token_ids) for output in final_res.outputs)
711
+ total_tokens = prompt_tokens + completion_tokens
712
+ usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
713
+
714
+
715
+ response = CompletionResponse(id=request_id, created=created_time, model=model_name, choices=choices, usage=usage)
716
+
717
+ if ENABLE_REQUEST_LOGGING:
718
+ logger.info(f"Completion Response (non-stream): {response}")
719
+
720
+ return response
721
+
722
+ except Exception as e:
723
+ logger.error(f"Error processing generation result for request {request_id}: {e}")
724
+ error_message = f"Error processing generation result: {e}"
725
+ error_res = create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, error_message, err_type="engine_error")
726
+ if request.stream:
727
+ error_json_str = json.dumps(json.loads(error_res.body))
728
+ async def error_stream():
729
+ yield f"data: {error_json_str}\n\n"
730
+ yield "data: [DONE]\n\n"
731
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
732
+ else:
733
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
734
+
735
+
736
+ @app.post("/chat/completions", response_model=ChatCompletionResponse, tags=["Chat Completions"], dependencies=[Depends(rate_limit_dependency), Depends(check_api_key), Depends(queue_dependency)])
737
+ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
738
+ start_time = time.monotonic()
739
+ if ENABLE_REQUEST_LOGGING:
740
+ log_request_dict = request.dict()
741
+ messages = log_request_dict.pop("messages", "N/A")
742
+ logger.info(f"Chat Completion Request: {log_request_dict}, Messages: {messages}")
743
+
744
+ model_error_check = await check_model(request.model)
745
+ if model_error_check:
746
+ if request.stream:
747
+ error_json_str = json.dumps(json.loads(model_error_check.body))
748
+ async def error_stream():
749
+ yield f"data: {error_json_str}\n\n"
750
+ yield "data: [DONE]\n\n"
751
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=model_error_check.status_code)
752
+ else:
753
+ raise HTTPException(status_code=model_error_check.status_code, detail=json.loads(model_error_check.body)['message'])
754
+
755
+ if request.use_beam_search:
756
+ error_message = "Beam search not supported for chat completions."
757
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="not_supported")
758
+ if request.stream:
759
+ error_json_str = json.dumps(json.loads(error_res.body))
760
+ async def error_stream():
761
+ yield f"data: {error_json_str}\n\n"
762
+ yield "data: [DONE]\n\n"
763
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
764
+ else:
765
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
766
+
767
+ if request.best_of is not None and request.best_of > 1:
768
+ error_message = "Best of > 1 not fully supported for chat completions."
769
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="not_supported")
770
+ if request.stream:
771
+ error_json_str = json.dumps(json.loads(error_res.body))
772
+ async def error_stream():
773
+ yield f"data: {error_json_str}\n\n"
774
+ yield "data: [DONE]\n\n"
775
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
776
+ else:
777
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
778
+
779
+ if request.logit_bias and len(request.logit_bias) > 0:
780
+ error_message = "Logit bias not supported."
781
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="not_supported")
782
+ if request.stream:
783
+ error_json_str = json.dumps(json.loads(error_res.body))
784
+ async def error_stream():
785
+ yield f"data: {error_json_str}\n\n"
786
+ yield "data: [DONE]\n\n"
787
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
788
+ else:
789
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
790
+
791
+
792
+ try:
793
+ prompt = await get_gen_prompt(request)
794
+ except ValueError as ve:
795
+ error_message = str(ve)
796
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="prompt_generation_error")
797
+ if request.stream:
798
+ error_json_str = json.dumps(json.loads(error_res.body))
799
+ async def error_stream():
800
+ yield f"data: {error_json_str}\n\n"
801
+ yield "data: [DONE]\n\n"
802
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
803
+ else:
804
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
805
+ except RuntimeError as re:
806
+ error_message = str(re)
807
+ error_res = create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, error_message, err_type="internal_error")
808
+ if request.stream:
809
+ error_json_str = json.dumps(json.loads(error_res.body))
810
+ async def error_stream():
811
+ yield f"data: {error_json_str}\n\n"
812
+ yield "data: [DONE]\n\n"
813
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
814
+ else:
815
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
816
+ except Exception as e:
817
+ logger.error(f"An unexpected error occurred during chat prompt generation: {e}")
818
+ error_message = f"An unexpected error occurred during prompt generation: {e}"
819
+ error_res = create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, error_message, err_type="internal_error")
820
+ if request.stream:
821
+ error_json_str = json.dumps(json.loads(error_res.body))
822
+ async def error_stream():
823
+ yield f"data: {error_json_str}\n\n"
824
+ yield "data: [DONE]\n\n"
825
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
826
+ else:
827
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
828
+
829
+
830
+ try:
831
+ input_ids, length_error = await check_length(request, prompt=prompt)
832
+ if length_error:
833
+ if request.stream:
834
+ error_json_str = json.dumps(json.loads(length_error.body))
835
+ async def error_stream():
836
+ yield f"data: {error_json_str}\n\n"
837
+ yield "data: [DONE]\n\n"
838
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=length_error.status_code)
839
+ else:
840
+ raise HTTPException(status_code=length_error.status_code, detail=json.loads(length_error.body)['message'])
841
+ except ValueError as ve:
842
+ error_message = str(ve)
843
+ error_res = create_error_response(HTTPStatus.BAD_REQUEST, error_message, err_type="prompt_error")
844
+ if request.stream:
845
+ error_json_str = json.dumps(json.loads(error_res.body))
846
+ async def error_stream():
847
+ yield f"data: {error_json_str}\n\n"
848
+ yield "data: [DONE]\n\n"
849
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
850
+ else:
851
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
852
+ except Exception as e:
853
+ error_message = f"Error processing prompt length: {e}"
854
+ error_res = create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, error_message, err_type="internal_error")
855
+ if request.stream:
856
+ error_json_str = json.dumps(json.loads(error_res.body))
857
+ async def error_stream():
858
+ yield f"data: {error_json_str}\n\n"
859
+ yield "data: [DONE]\n\n"
860
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
861
+ else:
862
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
863
+
864
+
865
+ created_time = int(time.time())
866
+ request_id = f"chatcmpl-{random_uuid()}"
867
+
868
+ sampling_params = SamplingParams(**request.dict(
869
+ exclude={
870
+ "stream",
871
+ "api_key",
872
+ "model",
873
+ "messages",
874
+ "user",
875
+ "use_beam_search",
876
+ "logit_bias",
877
+ },
878
+ exclude_none=True
879
+ ))
880
+
881
+ try:
882
+ result_generator = engine.generate(
883
+ prompt=prompt,
884
+ sampling_params=sampling_params,
885
+ request_id=request_id,
886
+ prompt_token_ids=input_ids
887
+ )
888
+ except Exception as e:
889
+ logger.error(f"Error submitting chat generation request to engine: {e}")
890
+ error_message = f"Error submitting chat generation request: {e}"
891
+ error_res = create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, error_message, err_type="engine_error")
892
+ if request.stream:
893
+ error_json_str = json.dumps(json.loads(error_res.body))
894
+ async def error_stream():
895
+ yield f"data: {error_json_str}\n\n"
896
+ yield "data: [DONE]\n\n"
897
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
898
+ else:
899
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
900
+
901
+
902
+ try:
903
+ if request.stream:
904
+ response = StreamingResponse(completion_stream_generator_chat(request, result_generator), media_type="text/event-stream")
905
+ return response
906
+ else:
907
+ final_res = None
908
+ async for res in result_generator:
909
+ final_res = res
910
+
911
+ if final_res is None or not final_res.outputs:
912
+ error_message = "Engine returned no output."
913
+ error_res = create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, error_message, err_type="engine_output_error")
914
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
915
+
916
+ choices = [
917
+ ChatCompletionResponseChoice(
918
+ index=output.index,
919
+ message=ChatMessage(role="assistant", content=output.text),
920
+ logprobs=None,
921
+ finish_reason=output.finish_reason,
922
+ ) for output in final_res.outputs
923
+ ]
924
+
925
+ prompt_tokens = len(final_res.prompt_token_ids)
926
+ completion_tokens = sum(len(output.token_ids) for output in final_res.outputs)
927
+ total_tokens = prompt_tokens + completion_tokens
928
+ usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
929
+
930
+
931
+ response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, choices=choices, usage=usage)
932
+
933
+ if ENABLE_REQUEST_LOGGING:
934
+ logger.info(f"Chat Completion Response (non-stream): {response}")
935
+
936
+ return response
937
+
938
+ except Exception as e:
939
+ logger.error(f"Error processing chat generation result for request {request_id}: {e}")
940
+ error_message = f"Error processing generation result: {e}"
941
+ error_res = create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, error_message, err_type="engine_error")
942
+ if request.stream:
943
+ error_json_str = json.dumps(json.loads(error_res.body))
944
+ async def error_stream():
945
+ yield f"data: {error_json_str}\n\n"
946
+ yield "data: [DONE]\n\n"
947
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=error_res.status_code)
948
+ else:
949
+ raise HTTPException(status_code=error_res.status_code, detail=error_message)
950
+
951
+
952
+ async def initialize_llm_engine(model_path_to_load: str):
953
+ global engine, tokenizer, max_model_len
954
+ try:
955
+ logger.info(f"Initializing LLM Engine for CPU with model from: {model_path_to_load}")
956
+
957
+ if engine:
958
+ logger.info("Shutting down existing engine...")
959
+ engine = None
960
+ tokenizer = None
961
+ max_model_len = MAX_MODEL_LEN_CONFIG
962
+ await asyncio.sleep(2)
963
+ logger.info("Existing engine dereferenced.")
964
+
965
+
966
+ engine_args = AsyncEngineArgs(
967
+ model=model_path_to_load,
968
+ tensor_parallel_size=1, # For CPU
969
+ dtype="auto", # Let vLLM determine dtype
970
+ max_model_len=MAX_MODEL_LEN_CONFIG,
971
+ gpu_memory_utilization=GPU_MEMORY_UTILIZATION, # This might still be used even on CPU for planning
972
+ swap_space=4, # Swap space in GiB (CPU host memory for KV cache)
973
+ )
974
+
975
+ # Instantiate the AsyncLLMEngine directly
976
+ # If LLM is preferred, check its init signature for CPU arguments
977
+ # The error "EngineArgs.__init__() got an unexpected keyword argument 'cpu_only'"
978
+ # suggests 'cpu_only' should be passed elsewhere or is not a direct EngineArgs param
979
+ # In recent vLLM, device='cpu' or engine_args.device='cpu' is used.
980
+ # LLM(cpu_only=True) correctly sets device='cpu' in its underlying EngineArgs.
981
+ # The error might be from an older vLLM version or a conflict.
982
+ # Let's try passing device='cpu' to LLM init, which is the modern way.
983
+
984
+ llm = LLM(model=model_path_to_load,
985
+ device="cpu", # Use device='cpu' instead of cpu_only=True if available
986
+ max_model_len=MAX_MODEL_LEN_CONFIG,
987
+ enable_chunked_prefill=False,
988
+ tensor_parallel_size=1,
989
+ swap_space=4
990
+ )
991
+
992
+
993
+ engine = llm.llm_engine
994
+ engine_model_config = await engine.get_model_config()
995
+ max_model_len = engine_model_config.max_model_len
996
+
997
+ tokenizer = get_tokenizer(llm.get_tokenizer_name(),
998
+ tokenizer_mode=llm.get_tokenizer_mode(),
999
+ trust_remote_code=llm.get_tokenizer_trust_remote_code())
1000
+
1001
+ logger.info(f"LLM Engine initialized for CPU with model: {model_path_to_load}. Max model length: {max_model_len}")
1002
+
1003
+ except Exception as e:
1004
+ logger.error(f"LLM Engine initialization failed: {e}", exc_info=True)
1005
+ engine = None
1006
+ tokenizer = None
1007
+ max_model_len = MAX_MODEL_LEN_CONFIG
1008
+ raise RuntimeError(f"LLM Engine initialization failed: {e}") from e
1009
+
1010
+
1011
+ @app.on_event("startup")
1012
+ async def startup_event():
1013
+ logger.info("Application startup initiated.")
1014
+
1015
+ model_to_load_initially = DOWNLOADED_MODEL_PATH if DOWNLOADED_MODEL_PATH else MODEL_NAME
1016
+ logger.info(f"Initial model to load: {model_to_load_initially}")
1017
+
1018
+ try:
1019
+ await initialize_llm_engine(model_to_load_initially)
1020
+ except RuntimeError as e:
1021
+ logger.error(f"Failed to initialize LLM Engine during startup: {e}")
1022
+
1023
+ logger.info("Application startup complete.")
1024
+
1025
+
1026
+ @app.on_event("shutdown")
1027
+ async def shutdown_event():
1028
+ logger.info("Application shutdown initiated.")
1029
+
1030
+ global engine, tokenizer
1031
+ if engine:
1032
+ logger.info("Attempting to clean up vLLM engine resources.")
1033
+ engine = None
1034
+ tokenizer = None
1035
+ logger.info("vLLM engine and tokenizer dereferenced.")
1036
+
1037
+ logger.info("Application shutdown complete.")
1038
+
1039
+
1040
+ app.add_middleware(
1041
+ CORSMiddleware,
1042
+ allow_origins=["*"],
1043
+ allow_credentials=False,
1044
+ allow_methods=["*"],
1045
+ allow_headers=["*"],
1046
+ )
1047
+
1048
+ if __name__ == "__main__":
1049
+ uvicorn.run(app, host=HOST, port=PORT, log_level=LOG_LEVEL.lower(), timeout_keep_alive=TIMEOUT_KEEP_ALIVE)