danf's picture
Creating app.py
a05aa9d verified
raw
history blame
14.3 kB
from __future__ import annotations
import logging
import os
from functools import lru_cache
from threading import Thread
from typing import Generator, List, Tuple
import gradio as gr
import regex
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def convert_latex_brackets_to_dollars(text: str) -> str:
"""Convert LaTeX bracket/paren sections into dollar-based math mode."""
def replace_display_latex(match):
return f"\n<bdi> $$ {match.group(1).strip()} $$ </bdi>\n"
text = regex.sub(r"(?r)\\\[\s*([^\[\]]+?)\s*\\\]", replace_display_latex, text)
def replace_paren_latex(match):
return f" <bdi> $ {match.group(1).strip()} $ </bdi> "
text = regex.sub(r"(?r)\\\(\s*(.+?)\s*\\\)", replace_paren_latex, text)
return text
MODEL_NAME = os.getenv("MODEL_NAME", "Intel/hebrew-math-tutor-v1")
@lru_cache(maxsize=1)
def load_model_and_tokenizer():
logger.info(f"Loading model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto",
)
logger.info("Model loaded successfully")
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
DEFAULT_LANG = "he"
labels = {
"he": {
"title": "מתמטיבוט 🧮",
"intro": (
"""
ברוכים הבאים לדמו! 💡 כאן תוכלו להתרשם **ממודל השפה החדש** שלנו; מודל בגודל 4 מיליארד פרמטרים שאומן לענות על שאלות מתמטיות בעברית, על המחשב שלכם, ללא חיבור לרשת.
קישור למודל, פרטים נוספים, יצירת קשר ותנאי שימוש:
https://huggingface.co/Intel/hebrew-math-tutor-v1
-----
"""
),
"select_label": "בחרו שאלה מוכנה או צרו שאלה חדשה:",
"new_question": "שאלה חדשה...",
"text_label": "שאלה:",
"placeholder": "הזינו את השאלה כאן...",
"send": "שלח",
"reset": "שיחה חדשה",
"toggle_to": "English 🇬🇧",
"predefined": [
"שאלה חדשה...",
" מהו סכום הסדרה הבאה: 1 + 1/2 + 1/4 + 1/8 + ...",
"פתח את הביטוי: (a-b)^4",
"פתרו את המשוואה הבאה: sin(2x) = 0.5",
],
"summary_text": "לחץ כדי לראות את תהליך החשיבה",
"thinking_prefix": "🤔 חושב",
"thinking_done": "🤔 *תהליך החשיבה הושלם, מכין תשובה...*",
"final_label": "📝 תשובה סופית:",
"chat_label": "צ'אט",
},
"en": {
"title": "MathBot 🧮",
"intro": (
"""
Welcome to the demo! 💡 Here you can try our **new language model** — a 4-billion-parameter model trained to answer math questions in Hebrew while maintaining its English capabilities. It runs locally on your machine without requiring an internet connection.
For the model page and more details see:
https://huggingface.co/Intel/hebrew-math-tutor-v1
-----
"""
),
"select_label": "Choose a prepared question or create a new one:",
"new_question": "New question...",
"text_label": "Question:",
"placeholder": "Type your question here...",
"send": "Send",
"reset": "New Conversation",
"toggle_to": "עברית 🇮🇱",
"predefined": [
"New question...",
"What is the sum of the series: 1 + 1/2 + 1/4 + 1/8 + ...",
"Expand the expression: (a-b)^4",
"Solve the equation: sin(2x) = 0.5",
],
"summary_text": "Click to view the thinking process",
"thinking_prefix": "🤔 Thinking",
"thinking_done": "🤔 *Thinking complete, preparing answer...*",
"final_label": "📝 Final answer:",
"chat_label": "Chat",
},
}
def dir_and_alignment(lang: str) -> Tuple[str, str]:
if lang == "he":
return "rtl", "right"
return "ltr", "left"
_details_template = (
'<details dir="{dir}" style="text-align: {align};">'
"<summary>🤔 <em>{summary}</em></summary>"
'<div style="white-space: pre-wrap; margin: 10px 0; direction: {dir}; text-align: {align};">{content}</div>'
"</details>"
)
def wrap_text_with_direction(text: str, lang: str, emphasized: bool = False) -> str:
direction, align = dir_and_alignment(lang)
weight = "font-weight: 600;" if emphasized else ""
return f'<div dir="{direction}" style="text-align: {align}; {weight}">{text}</div>'
def build_system_prompt(lang: str) -> str:
if lang == "he":
return (
"You are a helpful AI assistant specialized in mathematics and problem-solving "
"who can answer math questions with the correct answer. Answer shortly, not more than 500 "
"tokens, but outline the process step by step. Answer ONLY in Hebrew!"
)
return (
"You are a helpful AI assistant specialized in mathematics and problem-solving who can answer "
"math questions with the correct answer. Answer shortly, not more than 500 tokens, but outline "
"the process step by step."
)
def thinking_indicator(lang: str, progress_token_count: int) -> str:
direction, align = dir_and_alignment(lang)
border_side = "right" if direction == "rtl" else "left"
dots = "." * (progress_token_count % 6 or 1)
prefix = labels[lang]["thinking_prefix"]
return (
f'<div dir="{direction}" style="padding: 10px; background-color: #f0f2f6; '
f'border-radius: 10px; border-{border_side}: 4px solid #1f77b4; text-align: {align};">'
f'<p style="margin: 0; color: #1f77b4; font-style: italic;">{prefix}{dots}</p>'
"</div>"
)
def build_assistant_markdown(
lang: str,
final_answer: str,
thinking_text: str | None,
) -> str:
direction, align = dir_and_alignment(lang)
localized = labels[lang]
parts: List[str] = []
if thinking_text:
details = _details_template.format(
dir=direction,
align=align,
summary=localized["summary_text"],
content=thinking_text,
)
parts.append(details)
parts.append(wrap_text_with_direction(localized["thinking_done"], lang))
parts.append(wrap_text_with_direction(localized["final_label"], lang, emphasized=True))
converted_answer = convert_latex_brackets_to_dollars(final_answer.strip())
parts.append(wrap_text_with_direction(converted_answer or "…", lang))
return "\n\n".join(parts)
@spaces.GPU
def handle_user_message(
user_input: str,
lang: str,
chat_history: List[Tuple[str, str]] | None,
) -> Generator[tuple, None, None]:
lang = lang or DEFAULT_LANG
localized = labels[lang]
chat_history = chat_history or []
prompt = (user_input or "").strip()
if not prompt:
yield (
chat_history,
gr.Textbox.update(value=""),
gr.Dropdown.update(value=localized["new_question"]),
chat_history,
)
return
formatted_user = wrap_text_with_direction(prompt, lang)
chat_history = chat_history + [(formatted_user, "")]
dropdown_reset = gr.Dropdown.update(value=localized["new_question"])
yield chat_history, gr.Textbox.update(value=""), dropdown_reset, chat_history
system_prompt = build_system_prompt(lang)
# Format as chat template
chat_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
# Apply chat template
input_text = tokenizer.apply_chat_template(
chat_messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
thinking_buffer = ""
thinking_text: str | None = None
final_answer = ""
response_fallback = ""
in_thinking = False
try:
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=2400,
temperature=0.6,
top_p=0.95,
top_k=20,
do_sample=True,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for delta in streamer:
if not delta:
continue
response_fallback += delta
if "<think>" in delta:
in_thinking = True
if in_thinking:
thinking_buffer += delta
if "</think>" in delta:
in_thinking = False
thinking_text = (
thinking_buffer.replace("<think>", "").replace("</think>", "").strip()
)
current_answer = thinking_indicator(lang, len(thinking_buffer))
else:
final_answer += delta
current_answer = build_assistant_markdown(
lang=lang,
final_answer=final_answer,
thinking_text=thinking_text,
)
chat_history[-1] = (formatted_user, current_answer)
yield chat_history, gr.Textbox.update(value=""), dropdown_reset, chat_history
thread.join()
except Exception as exc:
error_html = wrap_text_with_direction(f"⚠️ Error generating response: {exc}", lang)
chat_history[-1] = (formatted_user, error_html)
yield chat_history, gr.Textbox.update(value=prompt), dropdown_reset, chat_history
return
if not final_answer:
final_answer = response_fallback
chat_history[-1] = (
formatted_user,
build_assistant_markdown(lang=lang, final_answer=final_answer, thinking_text=thinking_text),
)
yield chat_history, gr.Textbox.update(value=""), dropdown_reset, chat_history
def reset_conversation(lang: str):
localized = labels[lang]
return (
[],
gr.Textbox.update(
value="", label=localized["text_label"], placeholder=localized["placeholder"]
),
gr.Dropdown.update(
choices=localized["predefined"],
value=localized["new_question"],
label=localized["select_label"],
),
[],
)
def sync_question_text(selected_option: str, lang: str):
localized = labels[lang]
if selected_option == localized["new_question"]:
return gr.Textbox.update(value="")
return gr.Textbox.update(value=selected_option)
def toggle_language(lang: str):
new_lang = "en" if lang == "he" else "he"
localized = labels[new_lang]
return (
new_lang,
gr.Markdown.update(value=f"# {localized['title']}"),
gr.Markdown.update(value=localized["intro"]),
gr.Dropdown.update(
choices=localized["predefined"],
value=localized["new_question"],
label=localized["select_label"],
),
gr.Textbox.update(
label=localized["text_label"],
placeholder=localized["placeholder"],
value="",
),
gr.Button.update(value=localized["send"]),
gr.Button.update(value=localized["reset"]),
gr.Button.update(value=localized["toggle_to"]),
)
CUSTOM_CSS = """
body {
font-family: 'Rubik', 'Segoe UI', 'Helvetica Neue', Arial, sans-serif;
}
details > summary {
cursor: pointer;
}
.gradio-container .prose p {
margin-bottom: 0.5rem;
}
"""
def build_demo() -> gr.Blocks:
localized = labels[DEFAULT_LANG]
with gr.Blocks(css=CUSTOM_CSS, title="Hebrew Math Tutor") as demo:
lang_state = gr.State(DEFAULT_LANG)
chat_state = gr.State([])
title_md = gr.Markdown(f"# {localized['title']}")
intro_md = gr.Markdown(localized["intro"])
with gr.Row():
preset_dropdown = gr.Dropdown(
label=localized["select_label"],
choices=localized["predefined"],
value=localized["new_question"],
interactive=True,
)
lang_button = gr.Button(localized["toggle_to"], variant="secondary")
question_box = gr.Textbox(
label=localized["text_label"],
placeholder=localized["placeholder"],
lines=5,
)
with gr.Row():
reset_button = gr.Button(localized["reset"], variant="secondary")
send_button = gr.Button(localized["send"], variant="primary")
chatbot = gr.Chatbot(
label=localized["chat_label"],
height=520,
bubble_full_width=False,
render_markdown=True,
)
preset_dropdown.change(
fn=sync_question_text,
inputs=[preset_dropdown, lang_state],
outputs=question_box,
)
reset_button.click(
fn=reset_conversation,
inputs=[lang_state],
outputs=[chatbot, question_box, preset_dropdown, chat_state],
)
send_button.click(
fn=handle_user_message,
inputs=[question_box, lang_state, chat_state],
outputs=[chatbot, question_box, preset_dropdown, chat_state],
)
question_box.submit(
fn=handle_user_message,
inputs=[question_box, lang_state, chat_state],
outputs=[chatbot, question_box, preset_dropdown, chat_state],
)
lang_button.click(
fn=toggle_language,
inputs=[lang_state],
outputs=[
lang_state,
title_md,
intro_md,
preset_dropdown,
question_box,
send_button,
reset_button,
lang_button,
],
)
return demo
demo = build_demo()
if __name__ == "__main__":
demo.queue().launch()