File size: 5,514 Bytes
f5edfa8
06a8869
10965a9
 
8cce5d6
0f43b6d
4d80322
0f43b6d
 
 
06a8869
10965a9
 
 
 
 
 
 
 
 
 
 
 
0f43b6d
 
06a8869
10965a9
 
 
 
 
 
 
 
 
 
 
 
0f43b6d
06a8869
e79e7ca
 
 
 
 
4375e7a
0f43b6d
06a8869
0f43b6d
cab1be1
9d15a33
cab1be1
52f0fa0
cab1be1
e79e7ca
10965a9
 
 
 
 
 
 
 
 
 
e79e7ca
 
 
 
 
838db58
8cd5554
 
1df2849
8cd5554
cd48cb7
ebf362e
52f0fa0
 
ebf362e
 
945ae37
8cd5554
 
cab1be1
d80f0e9
cab1be1
 
 
838db58
 
d80f0e9
51f7bca
d753076
 
 
cd48cb7
51f7bca
cd48cb7
d753076
 
cd48cb7
d753076
945ae37
 
52f0fa0
 
 
 
 
ebf362e
9d15a33
52f0fa0
 
 
 
 
 
 
 
 
945ae37
9d15a33
52f0fa0
 
 
 
 
8a76939
cd48cb7
99de7e3
9d15a33
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import streamlit as st

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

##############################################################################
#                          LOAD MODELS
##############################################################################

@st.cache_resource
def load_model_engineer():
    try:
        # Engineer: DeepSeek-V3 loaded directly
        tokenizer = AutoTokenizer.from_pretrained("unsloth/DeepSeek-V3", trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            "unsloth/DeepSeek-V3",
            trust_remote_code=True
        )
        model = model.half().to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        return tokenizer, model
    except Exception as e:
        st.error(f"Direct model loading failed for Engineer: {str(e)}")
        raise

@st.cache_resource
def load_model_analyst():
    try:
        # Analyst: DeepSeek-V3 loaded directly
        tokenizer = AutoTokenizer.from_pretrained("unsloth/DeepSeek-V3", trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            "unsloth/DeepSeek-V3",
            trust_remote_code=True
        )
        model = model.half().to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        return tokenizer, model
    except Exception as e:
        st.error(f"Direct model loading failed for Analyst: {str(e)}")
        raise

# Load models
try:
    engineer_model = load_model_engineer()
    analyst_model = load_model_analyst()
except Exception as load_error:
    st.stop()

##############################################################################
#                     ENGINEER / ANALYST GENERATION
##############################################################################

def generate_response(prompt, model, max_sentences=2):
    """
    Generate a concise response based on the provided prompt.
    """
    try:
        tokenizer, model = model
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(
            inputs["input_ids"],
            max_new_tokens=50,
            temperature=0.6,
            top_p=0.8,
            pad_token_id=tokenizer.pad_token_id
        )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        # Limit to max_sentences by splitting and rejoining
        return " ".join(response.split(".")[:max_sentences]) + "."
    except Exception as gen_error:
        st.error(f"Error during response generation: {str(gen_error)}")
        return "[Error generating response]"

def summarize_conversation(conversation):
    """
    Summarize the entire conversation to produce a cohesive and actionable plan.
    """
    summary = "### Final Plan\n"
    key_points = []
    for speaker, text in conversation:
        if speaker == "Engineer" or speaker == "Analyst":
            key_points.append(f"- {speaker}: {text}")
    summary += "\n".join(key_points[-6:])  # Include only the last 3 turns each
    summary += "\n\nThis collaborative plan integrates technical and analytical insights into an actionable framework."
    return summary

##############################################################################
#                         STREAMLIT APP
##############################################################################

st.title("Multi-Agent System with XAI Demo")

if "conversation" not in st.session_state:
    st.session_state.conversation = []

if "user_input" not in st.session_state:
    st.session_state.user_input = ""

st.text_area("Enter your query:", value=st.session_state.user_input, height=100, max_chars=None, key="user_input")

if st.button("Generate Responses"):
    if st.session_state.user_input.strip():
        user_text = st.session_state.user_input
        st.session_state.conversation = [("User", user_text)]  # Clear and restart conversation

        engineer_prompt_base = f"Given the problem: {user_text}, provide a concise and actionable technical solution."
        analyst_prompt_base = "Based on the engineer's suggestion, provide complementary data-driven recommendations."

        for turn in range(3):
            # Engineer generates a response
            with st.spinner(f"Engineer is formulating response {turn + 1}..."):
                engineer_resp = generate_response(
                    prompt=engineer_prompt_base,
                    model=engineer_model
                )
                st.session_state.conversation.append(("Engineer", engineer_resp))

            # Display Engineer response
            st.markdown(f"### Engineer Response ({turn + 1})\n{engineer_resp}")

            # Analyst generates a response based on engineer's output
            with st.spinner(f"Analyst is formulating response {turn + 1}..."):
                analyst_resp = generate_response(
                    prompt=f"Engineer suggested: {engineer_resp}. {analyst_prompt_base}",
                    model=analyst_model
                )
                st.session_state.conversation.append(("Analyst", analyst_resp))

            # Display Analyst response
            st.markdown(f"### Analyst Response ({turn + 1})\n{analyst_resp}")

        # Summarize the final plan
        with st.spinner("Generating the final plan..."):
            final_plan = summarize_conversation(st.session_state.conversation)
            st.session_state.conversation.append(("Summary", final_plan))
            st.markdown(final_plan)