6Genix commited on
Commit
42fdd8c
·
1 Parent(s): 4e4ed5e

Addressed quantization issue by enforcing fp16 precision for DeepSeek-V3 model loading. Updated error handling and improved compatibility for Multi-Agent XAI Demo.

Browse files
Files changed (1) hide show
  1. app.py +3 -134
app.py CHANGED
@@ -1,134 +1,3 @@
1
- import streamlit as st
2
-
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import torch
5
-
6
- ##############################################################################
7
- # LOAD MODELS
8
- ##############################################################################
9
-
10
- @st.cache_resource
11
- def load_model_engineer():
12
- try:
13
- # Engineer: DeepSeek-V3 loaded directly
14
- tokenizer = AutoTokenizer.from_pretrained("unsloth/DeepSeek-V3", trust_remote_code=True)
15
- model = AutoModelForCausalLM.from_pretrained(
16
- "unsloth/DeepSeek-V3",
17
- trust_remote_code=True
18
- )
19
- model = model.half().to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
20
- return tokenizer, model
21
- except Exception as e:
22
- st.error(f"Direct model loading failed for Engineer: {str(e)}")
23
- raise
24
-
25
- @st.cache_resource
26
- def load_model_analyst():
27
- try:
28
- # Analyst: DeepSeek-V3 loaded directly
29
- tokenizer = AutoTokenizer.from_pretrained("unsloth/DeepSeek-V3", trust_remote_code=True)
30
- model = AutoModelForCausalLM.from_pretrained(
31
- "unsloth/DeepSeek-V3",
32
- trust_remote_code=True
33
- )
34
- model = model.half().to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
35
- return tokenizer, model
36
- except Exception as e:
37
- st.error(f"Direct model loading failed for Analyst: {str(e)}")
38
- raise
39
-
40
- # Load models
41
- try:
42
- engineer_model = load_model_engineer()
43
- analyst_model = load_model_analyst()
44
- except Exception as load_error:
45
- st.stop()
46
-
47
- ##############################################################################
48
- # ENGINEER / ANALYST GENERATION
49
- ##############################################################################
50
-
51
- def generate_response(prompt, model, max_sentences=2):
52
- """
53
- Generate a concise response based on the provided prompt.
54
- """
55
- try:
56
- tokenizer, model = model
57
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
58
- outputs = model.generate(
59
- inputs["input_ids"],
60
- max_new_tokens=50,
61
- temperature=0.6,
62
- top_p=0.8,
63
- pad_token_id=tokenizer.pad_token_id
64
- )
65
- response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
66
- # Limit to max_sentences by splitting and rejoining
67
- return " ".join(response.split(".")[:max_sentences]) + "."
68
- except Exception as gen_error:
69
- st.error(f"Error during response generation: {str(gen_error)}")
70
- return "[Error generating response]"
71
-
72
- def summarize_conversation(conversation):
73
- """
74
- Summarize the entire conversation to produce a cohesive and actionable plan.
75
- """
76
- summary = "### Final Plan\n"
77
- key_points = []
78
- for speaker, text in conversation:
79
- if speaker == "Engineer" or speaker == "Analyst":
80
- key_points.append(f"- {speaker}: {text}")
81
- summary += "\n".join(key_points[-6:]) # Include only the last 3 turns each
82
- summary += "\n\nThis collaborative plan integrates technical and analytical insights into an actionable framework."
83
- return summary
84
-
85
- ##############################################################################
86
- # STREAMLIT APP
87
- ##############################################################################
88
-
89
- st.title("Multi-Agent System with XAI Demo")
90
-
91
- if "conversation" not in st.session_state:
92
- st.session_state.conversation = []
93
-
94
- if "user_input" not in st.session_state:
95
- st.session_state.user_input = ""
96
-
97
- st.text_area("Enter your query:", value=st.session_state.user_input, height=100, max_chars=None, key="user_input")
98
-
99
- if st.button("Generate Responses"):
100
- if st.session_state.user_input.strip():
101
- user_text = st.session_state.user_input
102
- st.session_state.conversation = [("User", user_text)] # Clear and restart conversation
103
-
104
- engineer_prompt_base = f"Given the problem: {user_text}, provide a concise and actionable technical solution."
105
- analyst_prompt_base = "Based on the engineer's suggestion, provide complementary data-driven recommendations."
106
-
107
- for turn in range(3):
108
- # Engineer generates a response
109
- with st.spinner(f"Engineer is formulating response {turn + 1}..."):
110
- engineer_resp = generate_response(
111
- prompt=engineer_prompt_base,
112
- model=engineer_model
113
- )
114
- st.session_state.conversation.append(("Engineer", engineer_resp))
115
-
116
- # Display Engineer response
117
- st.markdown(f"### Engineer Response ({turn + 1})\n{engineer_resp}")
118
-
119
- # Analyst generates a response based on engineer's output
120
- with st.spinner(f"Analyst is formulating response {turn + 1}..."):
121
- analyst_resp = generate_response(
122
- prompt=f"Engineer suggested: {engineer_resp}. {analyst_prompt_base}",
123
- model=analyst_model
124
- )
125
- st.session_state.conversation.append(("Analyst", analyst_resp))
126
-
127
- # Display Analyst response
128
- st.markdown(f"### Analyst Response ({turn + 1})\n{analyst_resp}")
129
-
130
- # Summarize the final plan
131
- with st.spinner("Generating the final plan..."):
132
- final_plan = summarize_conversation(st.session_state.conversation)
133
- st.session_state.conversation.append(("Summary", final_plan))
134
- st.markdown(final_plan)
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33a11bfc0175eafaba57084965d74562e60c395888489cfcd3b5af0a34a80a51
3
+ size 5647