Spaces:
Sleeping
Sleeping
File size: 4,627 Bytes
f5edfa8 06a8869 d984bee 8cce5d6 0f43b6d 4d80322 0f43b6d 06a8869 f6eb965 06a8869 0f43b6d 06a8869 f6eb965 06a8869 0f43b6d 06a8869 4375e7a 0f43b6d 06a8869 0f43b6d cab1be1 7ab699e cab1be1 d753076 cab1be1 0f43b6d d753076 0f43b6d cab1be1 abd0aa0 cab1be1 d80f0e9 cab1be1 0f43b6d cab1be1 7ab699e 0f43b6d d753076 0f43b6d 8357c9c d80f0e9 8357c9c d753076 0f43b6d abd0aa0 0f43b6d d80f0e9 0f43b6d 838db58 cab1be1 d80f0e9 cab1be1 838db58 d80f0e9 51f7bca d753076 51f7bca 838db58 d753076 06a8869 d753076 06a8869 cab1be1 06a8869 0f43b6d d753076 06a8869 d753076 06a8869 cab1be1 06a8869 51f7bca d753076 838db58 d753076 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
##############################################################################
# LOAD MODELS
##############################################################################
@st.cache_resource
def load_model_engineer():
# Engineer: DistilGPT-2
tokenizerE = AutoTokenizer.from_pretrained("distilgpt2")
modelE = AutoModelForCausalLM.from_pretrained("distilgpt2")
return tokenizerE, modelE
@st.cache_resource
def load_model_analyst():
# Analyst: GPT-Neo-125M
tokenizerA = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
modelA = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
return tokenizerA, modelA
# Load models
tokenizerE, modelE = load_model_engineer()
tokenizerA, modelA = load_model_analyst()
##############################################################################
# ENGINEER / ANALYST GENERATION
##############################################################################
def generate_engineer_response(user_text, tokenizer, model):
"""
As an Engineer, generate a concise approach or solution based on user input.
"""
prompt = f"""
User text: {user_text}
Provide a technical approach or solution.
"""
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs,
max_length=256, # Extend length for detailed outputs
temperature=0.7,
do_sample=True,
top_p=0.9,
repetition_penalty=1.3,
no_repeat_ngram_size=2
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def generate_analyst_response(user_text, engineer_output, tokenizer, model):
"""
As an Analyst, provide an approach or solution based on user input and engineer's output.
"""
prompt = f"""
Engineer provided the following: {engineer_output}
Provide an approach or solution from a data-centric perspective.
"""
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs,
max_length=256, # Extend length for detailed outputs
temperature=0.7,
do_sample=True,
top_p=0.9,
repetition_penalty=1.3,
no_repeat_ngram_size=2
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
##############################################################################
# STREAMLIT APP
##############################################################################
st.title("Multi-Agent System with XAI Demo")
if "conversation" not in st.session_state:
st.session_state.conversation = []
if "user_input" not in st.session_state:
st.session_state.user_input = ""
st.text_area("User Input:", value=st.session_state.user_input, height=100, max_chars=None, key="user_input")
if st.button("Start/Continue Conversation"):
if st.session_state.user_input.strip():
user_text = st.session_state.user_input
st.session_state.conversation.append(("User", user_text))
# Engineer generates a response
engineer_resp = generate_engineer_response(
user_text=user_text,
tokenizer=tokenizerE,
model=modelE
)
st.session_state.conversation.append(("Engineer", engineer_resp))
# Analyst generates a response based on engineer's output
analyst_resp = generate_analyst_response(
user_text=user_text,
engineer_output=engineer_resp,
tokenizer=tokenizerA,
model=modelA
)
st.session_state.conversation.append(("Analyst", analyst_resp))
# Limit the conversation to 3 exchanges between Engineer and Analyst
for _ in range(2):
engineer_resp = generate_engineer_response(
user_text=analyst_resp,
tokenizer=tokenizerE,
model=modelE
)
st.session_state.conversation.append(("Engineer", engineer_resp))
analyst_resp = generate_analyst_response(
user_text=engineer_resp,
engineer_output=engineer_resp,
tokenizer=tokenizerA,
model=modelA
)
st.session_state.conversation.append(("Analyst", analyst_resp))
for speaker, text in st.session_state.conversation:
if speaker == "User":
st.markdown(f"**{speaker}:** {text}")
else:
st.markdown(f"<div style='display:none'>{speaker}: {text}</div>", unsafe_allow_html=True)
|