Eldeeb commited on
Commit
246c304
·
verified ·
1 Parent(s): 041a4cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -20
app.py CHANGED
@@ -2,30 +2,39 @@
2
  import streamlit as st
3
  from transformers import pipeline
4
 
5
- # Caching the model pipeline
6
  @st.cache_resource
7
- def load_pipeline():
8
- return pipeline("text2text-generation", model="facebook/blenderbot_small-90M")
9
 
10
- # Initialize the model once using cache
11
- pipe = load_pipeline()
12
-
13
- # Initialize session state for conversation history and bot response
14
  if 'conversation_history' not in st.session_state:
15
  st.session_state.conversation_history = ""
16
  if 'bot_response' not in st.session_state:
17
  st.session_state.bot_response = ""
 
 
18
 
19
- def converse(user_message):
20
  # Update the conversation history
21
  st.session_state.conversation_history += f"User: {user_message}\n"
22
- result = pipe(st.session_state.conversation_history)[0]['generated_text']
23
- st.session_state.conversation_history += f"Bot: {result}\n"
 
24
  st.session_state.bot_response = result
25
  return result
26
 
27
  # Sidebar options
28
  st.sidebar.title("App Settings")
 
 
 
 
 
 
 
 
 
29
  show_history = st.sidebar.checkbox("Show conversation history", value=True)
30
  character_limit = st.sidebar.slider("Set character limit for input:", min_value=50, max_value=500, value=200)
31
 
@@ -36,21 +45,21 @@ if st.sidebar.button("Reset Conversation"):
36
  st.sidebar.success("Conversation history cleared.")
37
 
38
  # Streamlit app layout
39
- st.title("🤖 AI Chatbot")
40
- st.subheader("Chat with an AI-powered bot!")
41
 
42
  # Input field with character limit
43
  user_message = st.text_input(f"Enter your message (max {character_limit} characters):", max_chars=character_limit)
44
 
45
- # Send button to generate bot response
46
- if st.button("Send"):
47
  if user_message:
48
- # Get response from the chatbot
49
- bot_response = converse(user_message)
50
 
51
  # Display bot's response in a dedicated area
52
- st.markdown("### Bot's Response")
53
- st.success(bot_response)
54
 
55
  if show_history:
56
  # Display conversation history in a text area for better scrolling
@@ -58,12 +67,12 @@ if st.button("Send"):
58
  st.text_area("Conversation", value=st.session_state.conversation_history, height=250, max_chars=None)
59
  else:
60
  # Show a warning if no message is provided
61
- st.warning("Please enter a message before sending.")
62
 
63
  # About section
64
  st.markdown("---")
65
  st.markdown("### About this App")
66
- st.info("This chatbot is powered by a pre-trained model from the Hugging Face Transformers library. You can chat with the bot, and the conversation history will be maintained during the session.")
67
 
68
  st.sidebar.markdown("---")
69
  st.sidebar.write("Created by [Your Name](https://github.com/yourprofile)")
 
2
  import streamlit as st
3
  from transformers import pipeline
4
 
5
+ # Caching the text classification models
6
  @st.cache_resource
7
+ def load_pipeline(model_name):
8
+ return pipeline("text-classification", model=model_name)
9
 
10
+ # Initialize session state for conversation history, bot response, and selected model
 
 
 
11
  if 'conversation_history' not in st.session_state:
12
  st.session_state.conversation_history = ""
13
  if 'bot_response' not in st.session_state:
14
  st.session_state.bot_response = ""
15
+ if 'selected_model' not in st.session_state:
16
+ st.session_state.selected_model = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
17
 
18
+ def classify_text(user_message):
19
  # Update the conversation history
20
  st.session_state.conversation_history += f"User: {user_message}\n"
21
+ pipe = load_pipeline(st.session_state.selected_model)
22
+ result = pipe(user_message)[0] # pipe returns a list of results
23
+ st.session_state.conversation_history += f"Bot: {result['label']} (Score: {result['score']:.2f})\n"
24
  st.session_state.bot_response = result
25
  return result
26
 
27
  # Sidebar options
28
  st.sidebar.title("App Settings")
29
+
30
+ # Model selection
31
+ model_options = {
32
+ "DistilBERT Sentiment Analysis": "distilbert/distilbert-base-uncased-finetuned-sst-2-english",
33
+ "BERT Multilingual Sentiment Analysis": "nlptown/bert-base-multilingual-uncased-sentiment"
34
+ }
35
+ selected_model = st.sidebar.selectbox("Select model:", list(model_options.keys()))
36
+ st.session_state.selected_model = model_options[selected_model]
37
+
38
  show_history = st.sidebar.checkbox("Show conversation history", value=True)
39
  character_limit = st.sidebar.slider("Set character limit for input:", min_value=50, max_value=500, value=200)
40
 
 
45
  st.sidebar.success("Conversation history cleared.")
46
 
47
  # Streamlit app layout
48
+ st.title("🧠 Text Classification Bot")
49
+ st.subheader("Classify your text with a sentiment analysis model!")
50
 
51
  # Input field with character limit
52
  user_message = st.text_input(f"Enter your message (max {character_limit} characters):", max_chars=character_limit)
53
 
54
+ # Send button to generate classification
55
+ if st.button("Classify"):
56
  if user_message:
57
+ # Get classification from the selected model
58
+ classification_result = classify_text(user_message)
59
 
60
  # Display bot's response in a dedicated area
61
+ st.markdown("### Classification Result")
62
+ st.success(f"**Label:** {classification_result['label']}\n**Score:** {classification_result['score']:.2f}")
63
 
64
  if show_history:
65
  # Display conversation history in a text area for better scrolling
 
67
  st.text_area("Conversation", value=st.session_state.conversation_history, height=250, max_chars=None)
68
  else:
69
  # Show a warning if no message is provided
70
+ st.warning("Please enter a message before classifying.")
71
 
72
  # About section
73
  st.markdown("---")
74
  st.markdown("### About this App")
75
+ st.info("This app uses pre-trained models for sentiment analysis. You can select a model and enter text to see its classification and sentiment score.")
76
 
77
  st.sidebar.markdown("---")
78
  st.sidebar.write("Created by [Your Name](https://github.com/yourprofile)")