File size: 4,627 Bytes
f5edfa8
06a8869
d984bee
8cce5d6
0f43b6d
4d80322
0f43b6d
 
 
06a8869
f6eb965
 
 
06a8869
0f43b6d
 
06a8869
f6eb965
 
 
06a8869
0f43b6d
06a8869
 
 
4375e7a
0f43b6d
06a8869
0f43b6d
cab1be1
7ab699e
cab1be1
d753076
cab1be1
0f43b6d
d753076
 
 
 
0f43b6d
 
cab1be1
abd0aa0
cab1be1
 
 
d80f0e9
cab1be1
 
0f43b6d
cab1be1
7ab699e
0f43b6d
d753076
0f43b6d
 
8357c9c
d80f0e9
8357c9c
d753076
0f43b6d
 
 
 
abd0aa0
0f43b6d
 
 
d80f0e9
0f43b6d
 
 
838db58
cab1be1
d80f0e9
cab1be1
 
 
838db58
 
d80f0e9
51f7bca
d753076
 
 
 
51f7bca
838db58
d753076
 
 
 
 
06a8869
d753076
06a8869
 
cab1be1
06a8869
0f43b6d
d753076
06a8869
d753076
06a8869
 
 
cab1be1
06a8869
51f7bca
d753076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838db58
d753076
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import streamlit as st

from transformers import AutoTokenizer, AutoModelForCausalLM

##############################################################################
#                          LOAD MODELS
##############################################################################

@st.cache_resource
def load_model_engineer():
    # Engineer: DistilGPT-2
    tokenizerE = AutoTokenizer.from_pretrained("distilgpt2")
    modelE = AutoModelForCausalLM.from_pretrained("distilgpt2")
    return tokenizerE, modelE

@st.cache_resource
def load_model_analyst():
    # Analyst: GPT-Neo-125M
    tokenizerA = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
    modelA = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
    return tokenizerA, modelA

# Load models
tokenizerE, modelE = load_model_engineer()
tokenizerA, modelA = load_model_analyst()

##############################################################################
#                     ENGINEER / ANALYST GENERATION
##############################################################################

def generate_engineer_response(user_text, tokenizer, model):
    """
    As an Engineer, generate a concise approach or solution based on user input.
    """
    prompt = f"""
    User text: {user_text}
    Provide a technical approach or solution.
    
    """
    inputs = tokenizer.encode(prompt, return_tensors="pt")
    outputs = model.generate(
        inputs,
        max_length=256,  # Extend length for detailed outputs
        temperature=0.7,
        do_sample=True,
        top_p=0.9,
        repetition_penalty=1.3,
        no_repeat_ngram_size=2
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def generate_analyst_response(user_text, engineer_output, tokenizer, model):
    """
    As an Analyst, provide an approach or solution based on user input and engineer's output.
    """
    prompt = f"""
Engineer provided the following: {engineer_output}

Provide an approach or solution from a data-centric perspective.

"""
    inputs = tokenizer.encode(prompt, return_tensors="pt")
    outputs = model.generate(
        inputs,
        max_length=256,  # Extend length for detailed outputs
        temperature=0.7,
        do_sample=True,
        top_p=0.9,
        repetition_penalty=1.3,
        no_repeat_ngram_size=2
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

##############################################################################
#                         STREAMLIT APP
##############################################################################

st.title("Multi-Agent System with XAI Demo")

if "conversation" not in st.session_state:
    st.session_state.conversation = []

if "user_input" not in st.session_state:
    st.session_state.user_input = ""

st.text_area("User Input:", value=st.session_state.user_input, height=100, max_chars=None, key="user_input")

if st.button("Start/Continue Conversation"):
    if st.session_state.user_input.strip():
        user_text = st.session_state.user_input
        st.session_state.conversation.append(("User", user_text))

        # Engineer generates a response
        engineer_resp = generate_engineer_response(
            user_text=user_text,
            tokenizer=tokenizerE,
            model=modelE
        )
        st.session_state.conversation.append(("Engineer", engineer_resp))

        # Analyst generates a response based on engineer's output
        analyst_resp = generate_analyst_response(
            user_text=user_text,
            engineer_output=engineer_resp,
            tokenizer=tokenizerA,
            model=modelA
        )
        st.session_state.conversation.append(("Analyst", analyst_resp))

        # Limit the conversation to 3 exchanges between Engineer and Analyst
        for _ in range(2):
            engineer_resp = generate_engineer_response(
                user_text=analyst_resp,
                tokenizer=tokenizerE,
                model=modelE
            )
            st.session_state.conversation.append(("Engineer", engineer_resp))

            analyst_resp = generate_analyst_response(
                user_text=engineer_resp,
                engineer_output=engineer_resp,
                tokenizer=tokenizerA,
                model=modelA
            )
            st.session_state.conversation.append(("Analyst", analyst_resp))

for speaker, text in st.session_state.conversation:
    if speaker == "User":
        st.markdown(f"**{speaker}:** {text}")
    else:
        st.markdown(f"<div style='display:none'>{speaker}: {text}</div>", unsafe_allow_html=True)