import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch ############################################################################## # LOAD MODELS ############################################################################## @st.cache_resource(ttl=7200) def load_model(): try: # Unified model loading for Microsoft Phi-3-mini-4k-instruct tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( "microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) st.write(f"Model loaded on: {device}") return tokenizer, model except Exception as e: st.error(f"Model loading failed: {str(e)}") raise # Load model try: tokenizer, model = load_model() except Exception as load_error: st.error("Model loading failed. Please ensure your environment supports the Microsoft Phi-3-mini-4k-instruct model and verify dependencies.") st.stop() ############################################################################## # GENERATE RESPONSES ############################################################################## def generate_response(prompt, max_new_tokens=None, temperature=0.7, top_p=0.85): """ Generate a more detailed response based on the provided prompt. """ try: prompt = prompt.strip() # Remove leading/trailing whitespace if not prompt: return "[Error: Empty input]" max_length = min(len(prompt.split()) * 5, 1024) # Dynamically adjust max length inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length).to(model.device) if max_new_tokens is None: max_new_tokens = max(50, min(300, len(prompt.split()) * 2)) # Added lower bound of 50 tokens outputs = model.generate( inputs["input_ids"], max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, pad_token_id=tokenizer.pad_token_id ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response.strip() except Exception as gen_error: st.error(f"Error during response generation: {str(gen_error)}") return "[Error generating response]" def explain_response(response): """ Provide an explainable AI (XAI) explanation for the generated response. """ explanation = f"### XAI Explanation\nThe response '{response}' was generated based on identified patterns and relationships within the provided query. It considers contextual relevance and statistical probability of words within the trained dataset." return explanation def summarize_conversation(conversation): """ Summarize the entire conversation to produce a cohesive and actionable plan. """ summary = "### Final Plan\n" key_points = [] for speaker, text in conversation: if speaker == "Engineer" or speaker == "Analyst": key_points.append(f"- {speaker}: {text}") summary += "\n".join(key_points[-6:]) # Include only the last 3 turns each summary += "\n\nThis collaborative plan integrates technical and analytical insights into an actionable framework." return summary ############################################################################## # STREAMLIT APP ############################################################################## st.title("Multi-Agent System with XAI Demo") if "conversation" not in st.session_state: st.session_state.conversation = [] if "user_input" not in st.session_state: st.session_state.user_input = "" st.text_area("Enter your query:", value=st.session_state.user_input, height=100, max_chars=None, key="user_input", placeholder="Type your query here...") if st.button("Generate Responses"): if st.session_state.user_input.strip(): user_text = st.session_state.user_input st.session_state.conversation.append(("User", user_text)) # Append instead of replacing engineer_prompt_base = f"Given the problem: {user_text}, provide a concise and actionable technical solution." analyst_prompt_base = "Based on the engineer's suggestion, provide complementary data-driven recommendations." # Engineer generates a single response with st.spinner("Engineer is formulating a response..."): engineer_resp = generate_response(prompt=engineer_prompt_base) st.session_state.conversation.append(("Engineer", engineer_resp)) st.markdown(f"### Engineer Response\n{engineer_resp}") st.markdown(explain_response(engineer_resp)) # Analyst generates a single response based on the engineer's output with st.spinner("Analyst is formulating a response..."): analyst_resp = generate_response( prompt=f"Engineer suggested: {engineer_resp}. {analyst_prompt_base}" ) st.session_state.conversation.append(("Analyst", analyst_resp)) st.markdown(f"### Analyst Response\n{analyst_resp}") st.markdown(explain_response(analyst_resp)) # Summarize the final plan with st.spinner("Generating the final plan..."): final_plan = summarize_conversation(st.session_state.conversation) st.session_state.conversation.append(("Summary", final_plan)) st.markdown(final_plan)