from langchain_core.runnables import Runnable from langchain_core.callbacks import BaseCallbackHandler from fastapi import FastAPI, Request, Depends, HTTPException from sse_starlette.sse import EventSourceResponse from langserve.serialization import WellKnownLCSerializer from typing import List from sqlalchemy.orm import Session from datetime import datetime import schemas from models import Message from chains import simple_chain, formatted_chain, history_chain, rag_chain, filtered_rag_chain from prompts import format_chat_history import crud, models, schemas from database import SessionLocal, engine from callbacks import LogResponseCallback import json # temporary from database import engine import models # drop all tables and recreate models.Base.metadata.drop_all(bind=engine) models.Base.metadata.create_all(bind=engine) models.Base.metadata.create_all(bind=engine) app = FastAPI() def get_db(): db = SessionLocal() try: yield db finally: db.close() # async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]): # for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}): # data = WellKnownLCSerializer().dumps(output).decode("utf-8") # yield {'data': data, "event": "data"} # yield {"event": "end"} async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[], response_callback=None): complete_response="" if callbacks is None: callbacks=[] try: stream_iterator = runnable.stream(input_data.dict(), config={"callbacks":callbacks}) for chunk in stream_iterator: # ChatHuggingFace returns message chunks with content attribute if hasattr(chunk, 'content'): content = chunk.content else: content = str(chunk) complete_response +=content if content!="" or len(content)!=0: # Only yield non-empty content yield {'data': json.dumps({"content":content}), "event": "data"} # yield {'data': content, "event": "data"} except StopIteration: print("stream ended with StopIteration") yield {"event":"end"} # except Exception as e: # print(f"error geenrating response :{e}") if response_callback: response_callback(complete_response) yield {"event": "end"} @app.post("/simple/stream") async def simple_stream(request: Request): data = await request.json() user_question = schemas.UserQuestion(**data['input']) return EventSourceResponse(generate_stream(user_question, simple_chain)) @app.post("/formatted/stream") async def formatted_stream(request: Request): # TODO: use the formatted_chain to implement the "/formatted/stream" endpoint. try: data = await request.json() user_question = schemas.UserQuestion(**data['input']) output = EventSourceResponse( generate_stream( input_data = user_question, runnable = formatted_chain) ) # print(output.generations[0][0].text) return output except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # raise NotImplemented @app.post("/history/stream") async def history_stream(request: Request, db: Session = Depends(get_db)): # TODO: Let's implement the "/history/stream" endpoint. The endpoint should follow those steps: # - The endpoint receives the request # - The request is parsed into a user request # - The user request is used to pull the chat history of the user # - We add as part of the user history the current question by using add_message. # - We create an instance of HistoryInput by using format_chat_history. # - We use the history input within the history chain. data = await request.json() user_request = schemas.UserRequest(**data['input']) # user_data = await crud.get_or_create(db, user_request.username) # since history stream means # we have existing user's no need to check for a user chat_history = crud.get_user_chat_history(db, user_request.username) print("chat_history from the database", chat_history) history_input = schemas.HistoryInput( chat_history = format_chat_history(chat_history), question=user_request.question ) ## adding messgae to message database type = 'Human' user_data = crud.get_or_create_user(db, user_request.username) user_id = user_data.id timestamp = str(datetime.now()) add_message = schemas.MessageBase( user_id = user_id, message = user_request.question, type = type, timestamp = timestamp, user=user_request.username, ) _ = crud.add_message(db,add_message, username = user_request.username) # chat history contains: [{ message, type, timestamp}] init = LogResponseCallback(user_request = user_request, db = db) def save_full_response(complete_response): init.on_llm_end(outputs=complete_response) output = EventSourceResponse(generate_stream(history_input, history_chain, response_callback=save_full_response)) return output # raise NotImplemented @app.post("/rag/stream") async def rag_stream(request: Request, db: Session = Depends(get_db)): # TODO: Let's implement the "/rag/stream" endpoint. The endpoint should follow those steps: # - The endpoint receives the request # - The request is parsed into a user request # - The user request is used to pull the chat history of the user # - We add as part of the user history the current question by using add_message. # - We create an instance of HistoryInput by using format_chat_history. # - We use the history input within the rag chain. data = await request.json() user_request = schemas.UserRequest(**data['input']) messages = crud.get_user_chat_history(db, user_request.username) chat_history = messages history_input = schemas.HistoryInput( chat_history = format_chat_history(chat_history), question=user_request.question) ## adding messgae to message database type = 'Human' user_data = crud.get_or_create_user(db, user_request.username) user_id = user_data.id timestamp = str(datetime.now()) add_message = schemas.MessageBase( user_id = user_id, message = user_request.question, type = type, timestamp = timestamp, user=user_request.username, ) _ = crud.add_message(db,add_message, username = user_request.username) print("/rag/stream: \n: succesfully affed message to database") init = LogResponseCallback(user_request = user_request, db = db) print("succesfully intiated LogResponseCallback ") def save_full_response(complete_response): init.on_llm_end(outputs=complete_response) print("calling EventSourceResponse to generate stream............") return EventSourceResponse(generate_stream(history_input, rag_chain, response_callback=save_full_response)) # raise NotImplemented @app.post("/filtered_rag/stream") async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)): # TODO: Let's implement the "/filtered_rag/stream" endpoint. The endpoint should follow those steps: # - The endpoint receives the request # - The request is parsed into a user request # - The user request is used to pull the chat history of the user # - We add as part of the user history the current question by using add_message. # - We create an instance of HistoryInput by using format_chat_history. # - We use the history input within the filtered rag chain. data = await request.json() user_request = schemas.UserRequest(**data['input']) messages = crud.get_user_chat_history(db, user_request.username) chat_history = messages history_input = schemas.HistoryInput( chat_history = format_chat_history(chat_history), question=user_request.question) ## adding messgae to message database type = 'Human' user_data = crud.get_or_create_user(db, user_request.username) user_id = user_data.id timestamp = str(datetime.now()) add_message = schemas.MessageBase( user_id = user_id, message = user_request.question, type = type, timestamp = timestamp, user=user_request.username, ) _ = crud.add_message(db,add_message, username = user_request.username) print("/rag/stream: \n: succesfully affed message to database") init = LogResponseCallback(user_request = user_request, db = db) print("succesfully intiated LogResponseCallback ") def save_full_response(complete_response): init.on_llm_end(outputs=complete_response) print("calling EventSourceResponse to generate stream............") return EventSourceResponse(generate_stream(history_input, filtered_rag_chain, response_callback=save_full_response)) # raise NotImplemented if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", reload=False, port=7860)