Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,196 +2,157 @@
|
|
2 |
from fastapi import FastAPI, HTTPException
|
3 |
from pydantic import BaseModel
|
4 |
from typing import Optional
|
5 |
-
import torch
|
6 |
import uvicorn
|
7 |
-
from
|
8 |
import os
|
9 |
-
from
|
10 |
-
import sys # Import sys for sys.exit()
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
|
15 |
-
#
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
20 |
|
21 |
-
#
|
|
|
|
|
|
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
"""
|
27 |
-
Handles startup and shutdown events for the FastAPI application.
|
28 |
-
Loads the model on startup and can optionally clean up on shutdown.
|
29 |
-
"""
|
30 |
-
global generator
|
31 |
-
try:
|
32 |
-
# --- Optional: Login to Hugging Face Hub for gated models ---
|
33 |
-
# If you are using a gated model (e.g., meta-llama/Llama-3-8B-Instruct),
|
34 |
-
# uncomment the following lines and ensure HF_TOKEN is set as a Space Secret.
|
35 |
-
# hf_token = os.getenv("HF_TOKEN")
|
36 |
-
# if hf_token:
|
37 |
-
# login(token=hf_token)
|
38 |
-
# print("Logged into Hugging Face Hub.")
|
39 |
-
# else:
|
40 |
-
# print("HF_TOKEN not found. Make sure it's set as a Space Secret if using a gated model.")
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
# For larger models, use device_map="auto" and torch_dtype
|
47 |
-
# device_map = "auto"
|
48 |
-
# torch_dtype = torch.bfloat16 # or torch.float16 for GPUs that support it
|
49 |
-
else:
|
50 |
-
print("CUDA not available, using CPU. Inference will be very slow for this model size.")
|
51 |
-
device = -1 # Use CPU
|
52 |
-
# device_map = None
|
53 |
-
# torch_dtype = torch.float32 # Default for CPU
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
model=MODEL_NAME,
|
62 |
-
device=device,
|
63 |
-
# Pass your HF token to the model loading for gated models
|
64 |
-
# token=os.getenv("HF_TOKEN"), # Uncomment if using a gated model
|
65 |
-
# For 7B models on 16GB GPU, float16 is usually enough, but bfloat16 is better if supported
|
66 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
67 |
-
# For more fine-grained control and auto device mapping for multiple GPUs:
|
68 |
-
# model_kwargs={"device_map": "auto", "torch_dtype": torch.float16}
|
69 |
-
)
|
70 |
-
print("Model loaded successfully!")
|
71 |
|
72 |
-
#
|
73 |
-
|
74 |
-
|
|
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
except Exception as e:
|
77 |
-
print(f"
|
78 |
-
|
79 |
-
sys.exit(1)
|
80 |
-
|
81 |
-
finally:
|
82 |
-
# --- Shutdown Code (Optional): Clean up resources ---
|
83 |
-
print("Application shutting down. Any cleanup can go here.")
|
84 |
-
|
85 |
-
|
86 |
-
# --- Initialize FastAPI application with the lifespan handler ---
|
87 |
-
app = FastAPI(lifespan=lifespan, # Use the lifespan context manager
|
88 |
-
title="Text Generation API",
|
89 |
-
description="A simple text generation API using Hugging Face transformers",
|
90 |
-
version="1.0.0"
|
91 |
-
)
|
92 |
-
|
93 |
-
# Request model
|
94 |
-
class TextGenerationRequest(BaseModel):
|
95 |
-
prompt: str
|
96 |
-
max_new_tokens: Optional[int] = 250 # Changed from max_length for better control
|
97 |
-
num_return_sequences: Optional[int] = 1
|
98 |
-
temperature: Optional[float] = 0.7 # Recommend lower temp for more coherent output
|
99 |
-
do_sample: Optional[bool] = True
|
100 |
-
top_p: Optional[float] = 0.9 # Added top_p for more control
|
101 |
-
|
102 |
-
# Response model
|
103 |
-
class TextGenerationResponse(BaseModel):
|
104 |
-
generated_text: str
|
105 |
-
prompt: str
|
106 |
-
model_name: str
|
107 |
|
108 |
@app.get("/")
|
109 |
async def root():
|
|
|
110 |
return {
|
111 |
-
"message": "
|
112 |
-
"status": "running",
|
113 |
"endpoints": {
|
114 |
-
"
|
115 |
-
"
|
116 |
-
"
|
117 |
-
|
118 |
-
},
|
119 |
-
"current_model": MODEL_NAME
|
120 |
}
|
121 |
|
122 |
@app.get("/health")
|
123 |
async def health_check():
|
|
|
124 |
return {
|
125 |
-
"status": "healthy"
|
126 |
-
"
|
127 |
-
"cuda_available": torch.cuda.is_available(),
|
128 |
-
"model_name": MODEL_NAME
|
129 |
}
|
130 |
|
131 |
-
@app.post("/
|
132 |
-
async def
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
136 |
try:
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
temperature=request.temperature,
|
143 |
-
do_sample=request.do_sample,
|
144 |
-
top_p=request.top_p, # Pass top_p
|
145 |
-
pad_token_id=generator.tokenizer.eos_token_id,
|
146 |
-
eos_token_id=generator.tokenizer.eos_token_id,
|
147 |
-
# Add stop sequences relevant to your instruction-tuned model format
|
148 |
-
# stop_sequences=["\nUser:", "\n###", "\n\n"]
|
149 |
-
)
|
150 |
|
151 |
-
|
152 |
|
153 |
-
return
|
154 |
-
|
155 |
-
|
156 |
-
model_name=MODEL_NAME
|
157 |
)
|
158 |
-
|
159 |
except Exception as e:
|
160 |
-
print(f"
|
161 |
-
raise HTTPException(status_code=500, detail=f"
|
162 |
|
163 |
-
@app.
|
164 |
-
async def
|
165 |
-
|
166 |
-
max_new_tokens: int = 250, # Changed from max_length
|
167 |
-
temperature: float = 0.7
|
168 |
-
):
|
169 |
-
"""GET endpoint for simple text generation"""
|
170 |
-
if generator is None:
|
171 |
-
raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
|
172 |
-
|
173 |
try:
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
|
|
|
|
|
|
184 |
|
185 |
return {
|
186 |
-
"
|
187 |
-
"
|
188 |
-
"
|
189 |
}
|
190 |
-
|
191 |
except Exception as e:
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
if __name__ == "__main__":
|
196 |
-
|
197 |
-
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
2 |
from fastapi import FastAPI, HTTPException
|
3 |
from pydantic import BaseModel
|
4 |
from typing import Optional
|
|
|
5 |
import uvicorn
|
6 |
+
from src.RAGSample import setup_retriever, setup_rag_chain, RAGApplication
|
7 |
import os
|
8 |
+
from dotenv import load_dotenv
|
|
|
9 |
|
10 |
+
# Load environment variables
|
11 |
+
load_dotenv()
|
12 |
|
13 |
+
# Create FastAPI app
|
14 |
+
app = FastAPI(
|
15 |
+
title="RAG API",
|
16 |
+
description="A REST API for Retrieval-Augmented Generation using local vector database",
|
17 |
+
version="1.0.0"
|
18 |
+
)
|
19 |
|
20 |
+
# Initialize RAG components (this will be done once when the server starts)
|
21 |
+
retriever = None
|
22 |
+
rag_chain = None
|
23 |
+
rag_application = None
|
24 |
|
25 |
+
# Pydantic model for request
|
26 |
+
class QuestionRequest(BaseModel):
|
27 |
+
question: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
# Pydantic model for response
|
30 |
+
class QuestionResponse(BaseModel):
|
31 |
+
question: str
|
32 |
+
answer: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
@app.on_event("startup")
|
35 |
+
async def startup_event():
|
36 |
+
"""Initialize RAG components when the server starts."""
|
37 |
+
global retriever, rag_chain, rag_application
|
38 |
+
try:
|
39 |
+
print("Initializing RAG components...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
# Check if Kaggle credentials are provided via environment variables
|
42 |
+
kaggle_username = os.getenv("KAGGLE_USERNAME")
|
43 |
+
kaggle_key = os.getenv("KAGGLE_KEY")
|
44 |
+
kaggle_dataset = os.getenv("KAGGLE_DATASET")
|
45 |
|
46 |
+
# If no environment variables, try to load from kaggle.json
|
47 |
+
if not (kaggle_username and kaggle_key):
|
48 |
+
try:
|
49 |
+
from src.kaggle_loader import KaggleDataLoader
|
50 |
+
# Test if we can create a loader (this will auto-load from kaggle.json)
|
51 |
+
test_loader = KaggleDataLoader()
|
52 |
+
if test_loader.kaggle_username and test_loader.kaggle_key:
|
53 |
+
kaggle_username = test_loader.kaggle_username
|
54 |
+
kaggle_key = test_loader.kaggle_key
|
55 |
+
print(f"Loaded Kaggle credentials from kaggle.json: {kaggle_username}")
|
56 |
+
except Exception as e:
|
57 |
+
print(f"Could not load Kaggle credentials from kaggle.json: {e}")
|
58 |
+
|
59 |
+
if kaggle_username and kaggle_key and kaggle_dataset:
|
60 |
+
print(f"Loading Kaggle dataset: {kaggle_dataset}")
|
61 |
+
retriever = setup_retriever(
|
62 |
+
use_kaggle_data=True,
|
63 |
+
kaggle_dataset=kaggle_dataset,
|
64 |
+
kaggle_username=kaggle_username,
|
65 |
+
kaggle_key=kaggle_key
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
print("Loading mental health FAQ data from local file...")
|
69 |
+
# Load mental health FAQ data from local file (default behavior)
|
70 |
+
retriever = setup_retriever()
|
71 |
+
|
72 |
+
rag_chain = setup_rag_chain()
|
73 |
+
rag_application = RAGApplication(retriever, rag_chain)
|
74 |
+
print("RAG components initialized successfully!")
|
75 |
except Exception as e:
|
76 |
+
print(f"Error initializing RAG components: {e}")
|
77 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
@app.get("/")
|
80 |
async def root():
|
81 |
+
"""Root endpoint with API information."""
|
82 |
return {
|
83 |
+
"message": "RAG API is running",
|
|
|
84 |
"endpoints": {
|
85 |
+
"ask_question": "/ask",
|
86 |
+
"health_check": "/health",
|
87 |
+
"load_kaggle_dataset": "/load-kaggle-dataset"
|
88 |
+
}
|
|
|
|
|
89 |
}
|
90 |
|
91 |
@app.get("/health")
|
92 |
async def health_check():
|
93 |
+
"""Health check endpoint."""
|
94 |
return {
|
95 |
+
"status": "healthy",
|
96 |
+
"rag_initialized": rag_application is not None
|
|
|
|
|
97 |
}
|
98 |
|
99 |
+
@app.post("/ask", response_model=QuestionResponse)
|
100 |
+
async def ask_question(request: QuestionRequest):
|
101 |
+
"""Ask a question and get an answer using RAG."""
|
102 |
+
if rag_application is None:
|
103 |
+
raise HTTPException(status_code=500, detail="RAG application not initialized")
|
104 |
+
|
105 |
try:
|
106 |
+
print(f"Processing question: {request.question}")
|
107 |
+
|
108 |
+
# Debug: Check what retriever we're using
|
109 |
+
retriever_type = type(rag_application.retriever).__name__
|
110 |
+
print(f"DEBUG: Using retriever type: {retriever_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
+
answer = rag_application.run(request.question)
|
113 |
|
114 |
+
return QuestionResponse(
|
115 |
+
question=request.question,
|
116 |
+
answer=answer
|
|
|
117 |
)
|
|
|
118 |
except Exception as e:
|
119 |
+
print(f"Error processing question: {e}")
|
120 |
+
raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
|
121 |
|
122 |
+
@app.post("/load-kaggle-dataset")
|
123 |
+
async def load_kaggle_dataset(dataset_name: str):
|
124 |
+
"""Load a Kaggle dataset for RAG."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
try:
|
126 |
+
from src.kaggle_loader import KaggleDataLoader
|
127 |
+
|
128 |
+
# Create loader without parameters - it will auto-load from kaggle.json
|
129 |
+
loader = KaggleDataLoader()
|
130 |
+
|
131 |
+
# Download the dataset
|
132 |
+
dataset_path = loader.download_dataset(dataset_name)
|
133 |
+
|
134 |
+
# Reload the retriever with the new dataset
|
135 |
+
global rag_application
|
136 |
+
retriever = setup_retriever(use_kaggle_data=True, kaggle_dataset=dataset_name)
|
137 |
+
rag_chain = setup_rag_chain()
|
138 |
+
rag_application = RAGApplication(retriever, rag_chain)
|
139 |
|
140 |
return {
|
141 |
+
"status": "success",
|
142 |
+
"message": f"Dataset {dataset_name} loaded successfully",
|
143 |
+
"dataset_path": dataset_path
|
144 |
}
|
|
|
145 |
except Exception as e:
|
146 |
+
return {"status": "error", "message": str(e)}
|
147 |
+
|
148 |
+
@app.get("/models")
|
149 |
+
async def get_models():
|
150 |
+
"""Get information about available models."""
|
151 |
+
return {
|
152 |
+
"llm_model": "dolphin-llama3:8b",
|
153 |
+
"embedding_model": "TF-IDF embeddings",
|
154 |
+
"vector_database": "ChromaDB (local)"
|
155 |
+
}
|
156 |
|
157 |
if __name__ == "__main__":
|
158 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|