Rashmi0801's picture
Update app.py
e252ecc verified
raw
history blame
2.27 kB
api_key = "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"
import argparse
import streamlit as st
from langchain.agents import create_tool_calling_agent, AgentExecutor
from langchain_core.runnables import Runnable
from crewai_tools import ScrapeWebsiteTool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_models import ChatLiteLLM
from litellm import completion
import importlib
# Define your custom LLM wrapper class
class CustomLLM(ChatLiteLLM):
def __init__(self):
super().__init__(model="gpt-4")
def _call(self, prompt: str, stop=None):
response = completion(model="gpt-4", messages=[{"role": "user", "content": prompt}])
return response.choices[0].message["content"]
# Define your agent class
class GaiaAgent:
def __init__(self):
self.llm = CustomLLM()
self.prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful assistant."),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
self.tools = [
ScrapeWebsiteTool()
]
self.agent: Runnable = create_tool_calling_agent(self.llm, self.tools, self.prompt)
self.agent_executor: AgentExecutor = AgentExecutor(agent=self.agent, tools=self.tools, verbose=True)
def run(self):
st.title("🧠 GAIA-compatible Agent")
user_input = st.text_input("Enter your query")
if user_input:
response = self.agent_executor.invoke({"input": user_input})
st.write("Response:", response)
# Main CLI-compatible entry point
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--import", type=str, required=False, help="Module to import (ignored for static agent)")
parser.add_argument("--class", type=str, required=False, help="Class name to instantiate (ignored for static agent)")
parser.add_argument("--device", type=str, default="cpu", help="Device type (not used in this agent)")
args = parser.parse_args()
# Directly instantiate and run the predefined GaiaAgent class
agent = GaiaAgent()
agent.run()