import streamlit as st import pandas as pd import plotly.graph_objects as go from src.predict import predict_demo from src.front import render_html from results.output import training_log, report_dict, report_dict_2, model_compare, data_compare st.set_page_config(page_title="Vietnamese NER", layout="wide") # ===== Main Title ===== st.title("🔍 Vietnamese Named Entity Recognition Demo") # Tabs tab1, tab2, tab3 = st.tabs(["📊 Data Analysis", "📈 Training Results", "🧪 Model Demo"]) # --- Tab 1: DATA ANALYSIS --- with tab1: col1, col2 = st.columns(2) # ==== Distribution of NER Label Frequency ==== with col1: st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/ner_freq.png") # ==== Distribution of NER Label Frequency (Add crawled data) ==== with col2: st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/ner_freq_add.png") # ==== Distribution of the Number of Entities per Sentence (0 to 15+) ==== with col1: st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/ent_dis.png") # ==== Distribution of Sentence Lengths ==== with col2: st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/sent_len.png") # ==== Distribution of Token Lengths ==== with col1: st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/token_len.png") # --- Tab 2: TRAINING RESULTS --- with tab2: st.set_page_config( page_title="Vietnamese NER", layout="wide", initial_sidebar_state="expanded" ) # ==== CREATE FIGURES ==== # 1️⃣ Loss fig_loss = go.Figure() fig_loss.add_trace(go.Scatter(x=training_log["epoch"], y=training_log["train_loss"], mode='lines+markers', name='Train Loss')) fig_loss.add_trace(go.Scatter(x=training_log["epoch"], y=training_log["val_loss"], mode='lines+markers', name='Validation Loss')) fig_loss.update_layout(title="Loss Curve", xaxis_title="Epoch", yaxis_title="Loss") # 2️⃣ F1-Score fig_f1 = go.Figure() fig_f1.add_trace(go.Scatter(x=training_log["epoch"], y=training_log["train_f1"], mode='lines+markers', name='Train F1')) fig_f1.add_trace(go.Scatter(x=training_log["epoch"], y=training_log["val_f1"], mode='lines+markers', name='Validation F1')) fig_f1.update_layout(title="F1-Score Curve", xaxis_title="Epoch", yaxis_title="F1-Score") # 3️⃣ Classification Report Table & Bar labels = [k for k in report_dict.keys() if k not in ["accuracy", "macro avg", "weighted avg"]] report_data = [[lbl, report_dict[lbl]["precision"], report_dict[lbl]["recall"], report_dict[lbl]["f1-score"]] for lbl in labels] df_report = pd.DataFrame(report_data, columns=["Label", "Precision", "Recall", "F1-Score"]) fig_report = go.Figure() for col in ["Precision", "Recall", "F1-Score"]: fig_report.add_trace(go.Bar(x=df_report["Label"], y=df_report[col], name=col)) fig_report.update_layout(barmode='group', title="Class Metrics: PhoBERT + CRF", xaxis_title="Label", yaxis_title="Score", yaxis=dict(range=[0, 1.0])) labels2 = [k for k in report_dict_2.keys() if k not in ["accuracy", "macro avg", "weighted avg"]] report_data2 = [[lbl, report_dict_2[lbl]["precision"], report_dict_2[lbl]["recall"], report_dict_2[lbl]["f1-score"]] for lbl in labels2] df_report2 = pd.DataFrame(report_data2, columns=["Label", "Precision", "Recall", "F1-Score"]) fig_report2 = go.Figure() for col in ["Precision", "Recall", "F1-Score"]: fig_report2.add_trace(go.Bar(x=df_report2["Label"], y=df_report2[col], name=col)) fig_report2.update_layout(barmode='group', title="Class Metrics: PhoBERT + Softmax", xaxis_title="Label", yaxis_title="Score", yaxis=dict(range=[0, 1.0])) # 4️⃣ Model & Data Comparison Tables df_model = pd.DataFrame( [[m, v["F1"], v["Accuracy"]] for m, v in model_compare["Data"].items()], columns=["Model", "F1-Score", "Accuracy"] ) df_data = pd.DataFrame( [[s, f1] for s, f1 in data_compare["Data"].items()], columns=["Preprocessing", "F1-Score"] ) # ==== CLEAN LAYOUT WITH COLUMNS ==== # Row 1: Loss | F1 col1, col2 = st.columns(2) with col1: st.plotly_chart(fig_loss, use_container_width=True) with col2: st.plotly_chart(fig_f1, use_container_width=True) # Row 2: Class Report Table | Bar Chart col3, col4 = st.columns(2) with col3: st.plotly_chart(fig_report2, use_container_width=True) with col4: st.plotly_chart(fig_report, use_container_width=True) # Row 3: Model Compare | Data Compare col5, col6 = st.columns(2) with col5: st.markdown("**Model Comparison**") st.dataframe(df_model, use_container_width=True) with col6: st.markdown("**Data Preprocessing Comparison**") st.dataframe(df_data, use_container_width=True) # --- Tab 3: MODEL DEMO --- with tab3: text = st.text_input("Enter Vietnamese text:", "Nguyễn Văn A đang làm việc tại Hà Nội") if st.button("Analyze"): if not text.strip(): st.warning("Please enter some text!") else: tokens, labels = predict_demo(text) st.subheader("Detected Entities") entities = [(tok, lab) for tok, lab in zip(tokens, labels) if lab != "O"] if entities: for tok, lab in entities: st.markdown(f"🔹 **{tok}** — *{lab}*") else: st.info("No named entities detected.") st.subheader("Highlighted Text") st.markdown(render_html(tokens, labels), unsafe_allow_html=True)