Upload app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# app.py
|
| 2 |
import unsloth
|
| 3 |
from unsloth import FastModel
|
| 4 |
|
|
@@ -55,7 +55,7 @@ class GPTSequenceClassifier(nn.Module):
|
|
| 55 |
|
| 56 |
|
| 57 |
# ===================================================================
|
| 58 |
-
#
|
| 59 |
# ===================================================================
|
| 60 |
|
| 61 |
# --- Helper Functions ---
|
|
@@ -109,7 +109,7 @@ def evaluate_equations(eq_dict: dict, sol_dict: dict):
|
|
| 109 |
correct_rhs_val = round(lhs_val, 4)
|
| 110 |
correct_rhs_str = f"{correct_rhs_val:.4f}".rstrip('0').rstrip('.')
|
| 111 |
|
| 112 |
-
|
| 113 |
return {
|
| 114 |
"error": True,
|
| 115 |
"line_key": key,
|
|
@@ -235,7 +235,7 @@ logger.info("load_model(): %s", msg)
|
|
| 235 |
|
| 236 |
|
| 237 |
# ===================================================================
|
| 238 |
-
#
|
| 239 |
# ===================================================================
|
| 240 |
|
| 241 |
def run_conceptual_check(question: str, solution: str, model, tokenizer) -> dict:
|
|
@@ -255,7 +255,7 @@ def run_conceptual_check(question: str, solution: str, model, tokenizer) -> dict
|
|
| 255 |
with torch.inference_mode():
|
| 256 |
outputs = model(**inputs, use_cache=False)
|
| 257 |
|
| 258 |
-
|
| 259 |
logits = outputs["logits"].to(torch.float32)
|
| 260 |
probs = torch.softmax(logits, dim=-1).squeeze().tolist()
|
| 261 |
|
|
@@ -327,11 +327,11 @@ def analyze_solution(question: str, solution: str):
|
|
| 327 |
"""
|
| 328 |
Main orchestrator that runs the full pipeline and generates the final explanation.
|
| 329 |
"""
|
| 330 |
-
# STAGE 1: Conceptual Check
|
| 331 |
conceptual_result = run_conceptual_check(question, solution, classifier_model, classifier_tokenizer)
|
| 332 |
confidence = conceptual_result['probabilities'][conceptual_result['prediction']]
|
| 333 |
|
| 334 |
-
# STAGE 2: Computational Check
|
| 335 |
computational_result = run_computational_check(solution, gemma_model, gemma_tokenizer)
|
| 336 |
|
| 337 |
# FINAL VERDICT LOGIC
|
|
@@ -372,13 +372,13 @@ def classify_solution_stream(question: str, solution: str):
|
|
| 372 |
|
| 373 |
log = []
|
| 374 |
|
| 375 |
-
|
| 376 |
if not question.strip() or not solution.strip():
|
| 377 |
log.append("⚠️ Provide a question and a solution.")
|
| 378 |
yield "Please fill in both fields", "", render(log)
|
| 379 |
return
|
| 380 |
|
| 381 |
-
|
| 382 |
if not models_ready():
|
| 383 |
log.append("⏳ Loading models…")
|
| 384 |
yield "⏳ Working…", "", render(log)
|
|
@@ -444,7 +444,7 @@ def classify_solution_stream(question: str, solution: str):
|
|
| 444 |
yield "Runtime error", f"{type(e).__name__}: {e}", render(log)
|
| 445 |
|
| 446 |
|
| 447 |
-
# ---------------- UI: streaming
|
| 448 |
with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
| 449 |
gr.Markdown("# 🧮 Math Solution Classifier")
|
| 450 |
gr.Markdown(
|
|
@@ -665,7 +665,7 @@ class ExampleSelector:
|
|
| 665 |
else:
|
| 666 |
self.balance["wrong"] += 1
|
| 667 |
|
| 668 |
-
# ===== CSV hookup
|
| 669 |
from pathlib import Path
|
| 670 |
import time
|
| 671 |
|
|
@@ -673,10 +673,10 @@ CSV_PATH = Path(__file__).resolve().parent / "final-test-with-wrong-answers.csv"
|
|
| 673 |
POOL = load_examples_csv(str(CSV_PATH))
|
| 674 |
|
| 675 |
def new_selector(seed: int | None = None):
|
| 676 |
-
|
| 677 |
return ExampleSelector(POOL, seed=seed or int(time.time()) & 0xFFFF)
|
| 678 |
|
| 679 |
-
|
| 680 |
def _truncate(s: str, n: int = 100) -> str:
|
| 681 |
s = s or ""
|
| 682 |
return s if len(s) <= n else s[: n - 1] + "…"
|
|
@@ -694,7 +694,6 @@ def _rows_to_table(rows: list[dict]) -> list[list[str]]:
|
|
| 694 |
return table
|
| 695 |
|
| 696 |
|
| 697 |
-
# ===== Gradio callbacks for examples =====
|
| 698 |
def ui_surprise(selector, filter_label="any"):
|
| 699 |
"""Pick one example and push it straight to inputs; persist selector state."""
|
| 700 |
if selector is None or not POOL:
|
|
@@ -704,9 +703,13 @@ def ui_surprise(selector, filter_label="any"):
|
|
| 704 |
return selector, gr.update(), gr.update()
|
| 705 |
return selector, r["question"], r["solution"]
|
| 706 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
|
| 708 |
|
| 709 |
-
# ---------------- UI: add CSV-driven examples ----------------
|
| 710 |
with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
| 711 |
gr.Markdown("# 🧮 Math Solution Classifier")
|
| 712 |
gr.Markdown(
|
|
@@ -715,7 +718,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
| 715 |
" \n Press 'Surprise me' to randomly select a sample question/answer pair from our dataset."
|
| 716 |
)
|
| 717 |
|
| 718 |
-
|
| 719 |
selector_state = gr.State(new_selector())
|
| 720 |
|
| 721 |
with gr.Row():
|
|
@@ -723,12 +726,12 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
| 723 |
with gr.Column(scale=1):
|
| 724 |
question_input = gr.Textbox(
|
| 725 |
label="Math Question",
|
| 726 |
-
placeholder="e.g.,
|
| 727 |
lines=3,
|
| 728 |
)
|
| 729 |
solution_input = gr.Textbox(
|
| 730 |
label="Proposed Solution",
|
| 731 |
-
placeholder="e.g.,
|
| 732 |
lines=8,
|
| 733 |
)
|
| 734 |
expected_label_example = gr.Textbox(
|
|
@@ -738,7 +741,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
| 738 |
with gr.Row():
|
| 739 |
classify_btn = gr.Button("Classify Solution", variant="primary")
|
| 740 |
surprise_btn = gr.Button("Surprise me") # <- new
|
| 741 |
-
clear_btn = gr.
|
| 742 |
|
| 743 |
# -------- Right: outputs --------
|
| 744 |
with gr.Column(scale=1):
|
|
@@ -746,7 +749,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
| 746 |
explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=6)
|
| 747 |
status_output = gr.Markdown(value="*(idle)*") # live stage updates
|
| 748 |
|
| 749 |
-
# -------- Curated starter examples
|
| 750 |
gr.Examples(
|
| 751 |
examples=[
|
| 752 |
["John has three apples and Mary has seven, how many apples do they have together?",
|
|
@@ -755,18 +758,18 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
| 755 |
["A rectangle's length is twice its width. If the width is 7 cm, what is the perimeter of the rectangle?",
|
| 756 |
"The length of the rectangle is 2 * 7 = 14 cm.\n The perimeter is 14 + 7 = 21 cm.\n Final answer: 21",
|
| 757 |
"Conceptually flawed"],
|
| 758 |
-
["
|
| 759 |
-
"The
|
| 760 |
-
"
|
| 761 |
["What is 15% of 200?",
|
| 762 |
"15% = 15/100 = 0.15\n0.15 × 200 = 30\n Final answer: 30",
|
| 763 |
"Correct"],
|
| 764 |
["A circle has a radius of 5 cm. Using the approximation pi = 3.14, what is the circumference of the circle?",
|
| 765 |
"The circumference of the circle is 3.14 * 5 = 15.7 cm.\n Final answer: 15.7",
|
| 766 |
"Conceptually flawed"],
|
| 767 |
-
["
|
| 768 |
-
"The
|
| 769 |
-
"
|
| 770 |
["A 24-meter rope is cut into 6 equal pieces. A climber uses 2 of those pieces. How many meters of rope are still unused?",
|
| 771 |
"The length of each piece is 24 / 6 = 4 m.\n The climber uses 2 × 4 m = 8 m of rope.\n There are 24 m − 8 m = 16 m of rope still unused.\n Final answer: 16",
|
| 772 |
"Correct"]
|
|
@@ -776,7 +779,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
| 776 |
|
| 777 |
|
| 778 |
# ---------- Wiring ----------
|
| 779 |
-
# Main classify
|
| 780 |
classify_btn.click(
|
| 781 |
fn=classify_solution_stream,
|
| 782 |
inputs=[question_input, solution_input],
|
|
@@ -785,16 +788,14 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
| 785 |
concurrency_limit=1,
|
| 786 |
)
|
| 787 |
|
| 788 |
-
# ---- and replace the Surprise button wiring with this ----
|
| 789 |
surprise_btn.click(
|
| 790 |
fn=ui_surprise,
|
| 791 |
-
inputs=[selector_state],
|
| 792 |
-
outputs=[selector_state, question_input, solution_input],
|
| 793 |
queue=True,
|
| 794 |
)
|
| 795 |
|
| 796 |
|
| 797 |
-
# Enable queue for streaming
|
| 798 |
app.queue()
|
| 799 |
|
| 800 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
import unsloth
|
| 3 |
from unsloth import FastModel
|
| 4 |
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
# ===================================================================
|
| 58 |
+
#HELPERS
|
| 59 |
# ===================================================================
|
| 60 |
|
| 61 |
# --- Helper Functions ---
|
|
|
|
| 109 |
correct_rhs_val = round(lhs_val, 4)
|
| 110 |
correct_rhs_str = f"{correct_rhs_val:.4f}".rstrip('0').rstrip('.')
|
| 111 |
|
| 112 |
+
|
| 113 |
return {
|
| 114 |
"error": True,
|
| 115 |
"line_key": key,
|
|
|
|
| 235 |
|
| 236 |
|
| 237 |
# ===================================================================
|
| 238 |
+
# PIPELINE COMPONENTS
|
| 239 |
# ===================================================================
|
| 240 |
|
| 241 |
def run_conceptual_check(question: str, solution: str, model, tokenizer) -> dict:
|
|
|
|
| 255 |
with torch.inference_mode():
|
| 256 |
outputs = model(**inputs, use_cache=False)
|
| 257 |
|
| 258 |
+
|
| 259 |
logits = outputs["logits"].to(torch.float32)
|
| 260 |
probs = torch.softmax(logits, dim=-1).squeeze().tolist()
|
| 261 |
|
|
|
|
| 327 |
"""
|
| 328 |
Main orchestrator that runs the full pipeline and generates the final explanation.
|
| 329 |
"""
|
| 330 |
+
# STAGE 1: Conceptual Check
|
| 331 |
conceptual_result = run_conceptual_check(question, solution, classifier_model, classifier_tokenizer)
|
| 332 |
confidence = conceptual_result['probabilities'][conceptual_result['prediction']]
|
| 333 |
|
| 334 |
+
# STAGE 2: Computational Check
|
| 335 |
computational_result = run_computational_check(solution, gemma_model, gemma_tokenizer)
|
| 336 |
|
| 337 |
# FINAL VERDICT LOGIC
|
|
|
|
| 372 |
|
| 373 |
log = []
|
| 374 |
|
| 375 |
+
|
| 376 |
if not question.strip() or not solution.strip():
|
| 377 |
log.append("⚠️ Provide a question and a solution.")
|
| 378 |
yield "Please fill in both fields", "", render(log)
|
| 379 |
return
|
| 380 |
|
| 381 |
+
|
| 382 |
if not models_ready():
|
| 383 |
log.append("⏳ Loading models…")
|
| 384 |
yield "⏳ Working…", "", render(log)
|
|
|
|
| 444 |
yield "Runtime error", f"{type(e).__name__}: {e}", render(log)
|
| 445 |
|
| 446 |
|
| 447 |
+
# ---------------- UI: streaming ----------------
|
| 448 |
with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
| 449 |
gr.Markdown("# 🧮 Math Solution Classifier")
|
| 450 |
gr.Markdown(
|
|
|
|
| 665 |
else:
|
| 666 |
self.balance["wrong"] += 1
|
| 667 |
|
| 668 |
+
# ===== CSV hookup =====
|
| 669 |
from pathlib import Path
|
| 670 |
import time
|
| 671 |
|
|
|
|
| 673 |
POOL = load_examples_csv(str(CSV_PATH))
|
| 674 |
|
| 675 |
def new_selector(seed: int | None = None):
|
| 676 |
+
|
| 677 |
return ExampleSelector(POOL, seed=seed or int(time.time()) & 0xFFFF)
|
| 678 |
|
| 679 |
+
|
| 680 |
def _truncate(s: str, n: int = 100) -> str:
|
| 681 |
s = s or ""
|
| 682 |
return s if len(s) <= n else s[: n - 1] + "…"
|
|
|
|
| 694 |
return table
|
| 695 |
|
| 696 |
|
|
|
|
| 697 |
def ui_surprise(selector, filter_label="any"):
|
| 698 |
"""Pick one example and push it straight to inputs; persist selector state."""
|
| 699 |
if selector is None or not POOL:
|
|
|
|
| 703 |
return selector, gr.update(), gr.update()
|
| 704 |
return selector, r["question"], r["solution"]
|
| 705 |
|
| 706 |
+
components_to_clear = [
|
| 707 |
+
question_input,
|
| 708 |
+
solution_input,
|
| 709 |
+
expected_label_example,
|
| 710 |
+
]
|
| 711 |
|
| 712 |
|
|
|
|
| 713 |
with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
| 714 |
gr.Markdown("# 🧮 Math Solution Classifier")
|
| 715 |
gr.Markdown(
|
|
|
|
| 718 |
" \n Press 'Surprise me' to randomly select a sample question/answer pair from our dataset."
|
| 719 |
)
|
| 720 |
|
| 721 |
+
|
| 722 |
selector_state = gr.State(new_selector())
|
| 723 |
|
| 724 |
with gr.Row():
|
|
|
|
| 726 |
with gr.Column(scale=1):
|
| 727 |
question_input = gr.Textbox(
|
| 728 |
label="Math Question",
|
| 729 |
+
placeholder="e.g., What is 14 divided by 2?",
|
| 730 |
lines=3,
|
| 731 |
)
|
| 732 |
solution_input = gr.Textbox(
|
| 733 |
label="Proposed Solution",
|
| 734 |
+
placeholder="e.g., 14/2 = 9",
|
| 735 |
lines=8,
|
| 736 |
)
|
| 737 |
expected_label_example = gr.Textbox(
|
|
|
|
| 741 |
with gr.Row():
|
| 742 |
classify_btn = gr.Button("Classify Solution", variant="primary")
|
| 743 |
surprise_btn = gr.Button("Surprise me") # <- new
|
| 744 |
+
clear_btn = clear_btn = gr.ClearButton(components=components_to_clear, value="Clear")
|
| 745 |
|
| 746 |
# -------- Right: outputs --------
|
| 747 |
with gr.Column(scale=1):
|
|
|
|
| 749 |
explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=6)
|
| 750 |
status_output = gr.Markdown(value="*(idle)*") # live stage updates
|
| 751 |
|
| 752 |
+
# -------- Curated starter examples --------
|
| 753 |
gr.Examples(
|
| 754 |
examples=[
|
| 755 |
["John has three apples and Mary has seven, how many apples do they have together?",
|
|
|
|
| 758 |
["A rectangle's length is twice its width. If the width is 7 cm, what is the perimeter of the rectangle?",
|
| 759 |
"The length of the rectangle is 2 * 7 = 14 cm.\n The perimeter is 14 + 7 = 21 cm.\n Final answer: 21",
|
| 760 |
"Conceptually flawed"],
|
| 761 |
+
["",
|
| 762 |
+
"The lateral area of the bottom layer is 2 * 3.14 * 20 * 8 = 1004.8.\n The lateral area of the middle layer is 2 * 3.14 * 15 * 8 = 753.6.\n The lateral area of the top layer is 2 * 3.14 * 10 * 8 = 502.4.\n The exposed top surface is the area of the smallest circle: 3.14 * (10*10) = 314.\n The total frosted area is 1004.8 + 753.6 + 502.4 + 314 = 2888.8 sq cm.\n FINAL ANSWER: 2888.8",
|
| 763 |
+
"Computationally flawed"],
|
| 764 |
["What is 15% of 200?",
|
| 765 |
"15% = 15/100 = 0.15\n0.15 × 200 = 30\n Final answer: 30",
|
| 766 |
"Correct"],
|
| 767 |
["A circle has a radius of 5 cm. Using the approximation pi = 3.14, what is the circumference of the circle?",
|
| 768 |
"The circumference of the circle is 3.14 * 5 = 15.7 cm.\n Final answer: 15.7",
|
| 769 |
"Conceptually flawed"],
|
| 770 |
+
["A library is building new shelves. Each shelf is 1.2 meters long. A standard book is 3 cm thick, and a large book is 5 cm thick. A shelf must hold 20 standard books and 10 large books. After filling a shelf with these books, how much space, in centimeters, is left on the shelf?",
|
| 771 |
+
"The shelf length in centimeters is 1.2 * 100 = 120 cm.\n The space taken by standard books is 20 * 3 = 60 cm.\n The space taken by large books is 10 * 5 = 50 cm.\n The total space taken is 60 + 50 = 110 cm.\n The remaining space is 120 + 110 = 230 cm.\n FINAL ANSWER: 230",
|
| 772 |
+
"Conceptually flawed"],
|
| 773 |
["A 24-meter rope is cut into 6 equal pieces. A climber uses 2 of those pieces. How many meters of rope are still unused?",
|
| 774 |
"The length of each piece is 24 / 6 = 4 m.\n The climber uses 2 × 4 m = 8 m of rope.\n There are 24 m − 8 m = 16 m of rope still unused.\n Final answer: 16",
|
| 775 |
"Correct"]
|
|
|
|
| 779 |
|
| 780 |
|
| 781 |
# ---------- Wiring ----------
|
| 782 |
+
# Main classify
|
| 783 |
classify_btn.click(
|
| 784 |
fn=classify_solution_stream,
|
| 785 |
inputs=[question_input, solution_input],
|
|
|
|
| 788 |
concurrency_limit=1,
|
| 789 |
)
|
| 790 |
|
|
|
|
| 791 |
surprise_btn.click(
|
| 792 |
fn=ui_surprise,
|
| 793 |
+
inputs=[selector_state],
|
| 794 |
+
outputs=[selector_state, question_input, solution_input],
|
| 795 |
queue=True,
|
| 796 |
)
|
| 797 |
|
| 798 |
|
|
|
|
| 799 |
app.queue()
|
| 800 |
|
| 801 |
if __name__ == "__main__":
|