import streamlit as st import networkx as nx import matplotlib.pyplot as plt import nltk from nltk import sent_tokenize from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.cluster import KMeans import numpy as np import plotly.graph_objs as go # Download the punkt tokenizer nltk.download('punkt_tab') # Helper function to split text into topics using KMeans clustering and extract top words def split_text_into_topics(text, n_topics): sentences = sent_tokenize(text) vectorizer = TfidfVectorizer(stop_words='english') X = vectorizer.fit_transform(sentences) kmeans = KMeans(n_clusters=n_topics, random_state=42) kmeans.fit(X) clusters = kmeans.labels_.tolist() topic_sentences = {i: [] for i in range(n_topics)} # Store the top word for each cluster top_words = [] for i in range(n_topics): cluster_center = kmeans.cluster_centers_[i] sorted_indices = np.argsort(cluster_center)[::-1] top_word_index = sorted_indices[0] top_word = vectorizer.get_feature_names_out()[top_word_index] top_words.append(top_word) for i, sentence in enumerate(sentences): topic_sentences[clusters[i]].append(sentence) return topic_sentences, top_words # Recursive function to split subtopics def recursive_split(topic_dict, depth, max_depth, subtopics): if depth >= max_depth: return new_topic_dict = {} for topic, sentences in topic_dict.items(): if len(sentences) <= 1: new_topic_dict[topic] = sentences else: sub_topics, _ = split_text_into_topics(' '.join(sentences), subtopics) new_topic_dict[topic] = sub_topics return new_topic_dict # Function to convert the tree into edge data for Plotly visualization def get_edges(tree, parent=None, level=0, top_words=None): edges = [] labels = [] hover_texts = [] pos = {} for key, value in tree.items(): node_label = f'Topic {key}' if parent is None else f'Subtopic {key}' pos[node_label] = (level, len(labels)) top_word = top_words[key] if top_words and key < len(top_words) else "N/A" labels.append(node_label) hover_texts.append(f"Top Word: {top_word}") if parent: edges.append((parent, node_label)) if isinstance(value, dict): new_edges, new_labels, new_hover_texts, new_pos = get_edges(value, node_label, level+1, top_words) edges += new_edges labels += new_labels hover_texts += new_hover_texts pos.update(new_pos) else: for i, sentence in enumerate(value): sentence_label = f"{node_label} - Sentence {i+1}" pos[sentence_label] = (level+1, len(labels)) labels.append(sentence_label) hover_texts.append(sentence) edges.append((node_label, sentence_label)) return edges, labels, hover_texts, pos # Streamlit App layout # st.title('Interactive Text Topic Tree Generator') # Upload file uploaded_file = st.file_uploader("Upload a text file", type="txt") if uploaded_file is not None: text = uploaded_file.read().decode('utf-8') # Select number of main topics and depth of subtopics n_topics = st.slider('Select number of main topics', 2, 10, 5) max_depth = st.slider('Select maximum depth of subtopics', 1, 5, 2) subtopics_per_topic = st.slider('Select number of subtopics per topic', 2, 5, 3) # Split text into main topics and extract top words topic_dict, top_words = split_text_into_topics(text, n_topics) # Recursively split the topics into subtopics full_tree = recursive_split(topic_dict, 0, max_depth, subtopics_per_topic) # Get edges, labels, hover texts, and positions for the plot edges, labels, hover_texts, pos = get_edges(full_tree, top_words=top_words) # Plot the tree graph using Plotly edge_x = [] edge_y = [] for edge in edges: x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x += [x0, x1, None] edge_y += [y0, y1, None] node_x = [pos[label][0] for label in labels] node_y = [pos[label][1] for label in labels] # Create edge trace edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=2, color='Gray'), hoverinfo='none', mode='lines' ) # Create node trace with hover text showing top words node_trace = go.Scatter( x=node_x, y=node_y, mode='markers+text', text=labels, hoverinfo='text', hovertext=hover_texts, # Adding hover text marker=dict( showscale=True, colorscale='YlGnBu', size=20, colorbar=dict( thickness=15, title='Depth', xanchor='left', titleside='right' ), line_width=2 ) ) # Plot the figure fig = go.Figure(data=[edge_trace, node_trace], layout=go.Layout( showlegend=False, hovermode='closest', margin=dict(b=0, l=0, r=0, t=0), xaxis=dict(showgrid=False, zeroline=False), yaxis=dict(showgrid=False, zeroline=False) )) st.plotly_chart(fig)