Spaces:
Sleeping
Sleeping
New changes
Browse files- .gitignore +2 -0
- app.py +47 -23
.gitignore
CHANGED
|
@@ -1 +1,3 @@
|
|
| 1 |
.env
|
|
|
|
|
|
|
|
|
| 1 |
.env
|
| 2 |
+
app copy.py
|
| 3 |
+
new_gradio.py
|
app.py
CHANGED
|
@@ -505,6 +505,22 @@ def compute_ragbench_metrics(judge_response: dict, retrieved_sentence_keys: list
|
|
| 505 |
"Adherence": adherence
|
| 506 |
}
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
def evaluate_rag_pipeline(domain, q_indices):
|
| 510 |
import torch
|
|
@@ -597,39 +613,47 @@ def evaluate_rag_pipeline(domain, q_indices):
|
|
| 597 |
|
| 598 |
# Updated wrapper
|
| 599 |
def evaluate_rag_gradio(domain, q_indices_str):
|
| 600 |
-
# Capture logs
|
| 601 |
log_stream = io.StringIO()
|
| 602 |
sys.stdout = log_stream
|
| 603 |
|
| 604 |
try:
|
| 605 |
-
# Parse comma-separated indices
|
| 606 |
q_indices = [int(x.strip()) for x in q_indices_str.split(",") if x.strip().isdigit()]
|
| 607 |
results = evaluate_rag_pipeline(domain, q_indices)
|
| 608 |
-
|
| 609 |
logs = log_stream.getvalue()
|
| 610 |
return results, logs
|
| 611 |
-
|
| 612 |
except Exception as e:
|
| 613 |
traceback.print_exc()
|
| 614 |
return {"error": str(e)}, log_stream.getvalue()
|
| 615 |
-
|
| 616 |
finally:
|
| 617 |
-
sys.stdout = sys.__stdout__
|
| 618 |
-
|
| 619 |
-
# Gradio
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
|
| 634 |
-
# Launch
|
| 635 |
-
|
|
|
|
| 505 |
"Adherence": adherence
|
| 506 |
}
|
| 507 |
|
| 508 |
+
# --- Dataset dictionary ---
|
| 509 |
+
domain_datasets = {
|
| 510 |
+
"Legal": legal_dataset,
|
| 511 |
+
"Medical": med_dataset,
|
| 512 |
+
"GK": gk_dataset,
|
| 513 |
+
"CS": cs_dataset,
|
| 514 |
+
"Finance": fin_dataset
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
# --- Get questions for selected domain ---
|
| 518 |
+
def get_questions_for_domain(domain):
|
| 519 |
+
dataset = domain_datasets.get(domain, [])
|
| 520 |
+
if not dataset:
|
| 521 |
+
return "β οΈ No dataset found for the selected domain."
|
| 522 |
+
|
| 523 |
+
return "\n".join([f"{i}. {item['question']}" for i, item in enumerate(dataset)])
|
| 524 |
|
| 525 |
def evaluate_rag_pipeline(domain, q_indices):
|
| 526 |
import torch
|
|
|
|
| 613 |
|
| 614 |
# Updated wrapper
|
| 615 |
def evaluate_rag_gradio(domain, q_indices_str):
|
|
|
|
| 616 |
log_stream = io.StringIO()
|
| 617 |
sys.stdout = log_stream
|
| 618 |
|
| 619 |
try:
|
|
|
|
| 620 |
q_indices = [int(x.strip()) for x in q_indices_str.split(",") if x.strip().isdigit()]
|
| 621 |
results = evaluate_rag_pipeline(domain, q_indices)
|
|
|
|
| 622 |
logs = log_stream.getvalue()
|
| 623 |
return results, logs
|
|
|
|
| 624 |
except Exception as e:
|
| 625 |
traceback.print_exc()
|
| 626 |
return {"error": str(e)}, log_stream.getvalue()
|
|
|
|
| 627 |
finally:
|
| 628 |
+
sys.stdout = sys.__stdout__
|
| 629 |
+
|
| 630 |
+
# === Gradio UI using Blocks ===
|
| 631 |
+
with gr.Blocks(title="RAG Evaluation Dashboard") as demo:
|
| 632 |
+
gr.Markdown("## π RAG Evaluation Dashboard")
|
| 633 |
+
gr.Markdown("Evaluate your RAG pipeline and also browse the questions available for each domain.")
|
| 634 |
+
|
| 635 |
+
with gr.Row():
|
| 636 |
+
domain_input = gr.Dropdown(choices=list(domain_datasets.keys()), label="Select Domain")
|
| 637 |
+
q_index_input = gr.Textbox(label="Enter Query Indices (e.g., 89,121,245)", lines=1)
|
| 638 |
+
|
| 639 |
+
with gr.Row():
|
| 640 |
+
view_btn = gr.Button("π View Questions for Selected Domain")
|
| 641 |
+
questions_display = gr.Textbox(label="Domain Questions", lines=10, interactive=False)
|
| 642 |
+
|
| 643 |
+
with gr.Row():
|
| 644 |
+
run_btn = gr.Button("π Run Evaluation")
|
| 645 |
+
|
| 646 |
+
result_output = gr.JSON(label="Evaluation Metrics (RMSE & AUC-ROC)")
|
| 647 |
+
log_output = gr.Textbox(label="Execution Log", lines=10, interactive=True)
|
| 648 |
+
|
| 649 |
+
# Bindings
|
| 650 |
+
view_btn.click(fn=get_questions_for_domain, inputs=domain_input, outputs=questions_display)
|
| 651 |
+
|
| 652 |
+
run_btn.click(
|
| 653 |
+
fn=evaluate_rag_gradio,
|
| 654 |
+
inputs=[domain_input, q_index_input],
|
| 655 |
+
outputs=[result_output, log_output]
|
| 656 |
+
)
|
| 657 |
|
| 658 |
+
# === Launch ===
|
| 659 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
|