import os
import torch
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
LogitsProcessorList,
LogitsProcessor,
)
from peft import PeftModel
# CONFIGURATION
CHECKPOINT_PATH = "pcalhoun/ILR-Assistant-LoRA"
MODEL_NAME = "Qwen/Qwen3-4B"
LOAD_IN_4BIT = True
MAX_NEW_TOKENS = 1024
ILR_LEVELS = ['1', '1+', '2', '2+', '3', '3+']
INITIAL_USER_MESSAGE_TEMPLATE = """ILR Level 1 (Elementary):
Reads very simple texts (e.g., tourist materials) with high-frequency vocabulary. Misunderstandings common; grasps basic ideas in familiar contexts.
ILR Level 1+ (Elementary+):
Handles simple announcements, headlines, or narratives. Can locate routine professional info but struggles with structure and cohesion.
ILR Level 2 (Limited Working):
Reads straightforward factual texts on familiar topics (e.g., news, basic reports). Understands main ideas but slowly; inferences are limited.
ILR Level 2+ (Limited Working+):
Comprehends most non-technical prose and concrete professional discussions. Separates main ideas from details but misses nuance.
ILR Level 3 (General Professional):
Reads diverse authentic texts (e.g., news, reports) with near-complete comprehension. Interprets implicit meaning but struggles with complex idioms.
ILR Level 3+ (General Professional+):
Handles varied professional styles with minimal errors. Understands cultural references and complex structures, though subtleties may be missed.
Initial ILR level for this conversation: {ilr_level}
Test my comprehension of Modern Standard Arabic."""
INITIAL_ASSISTANT_SCORER = "I am administering an ILR level assessment."
IM_START = "<|im_start|>"
IM_END = "<|im_end|>"
# Global variables
model = None
tokenizer = None
class BanTokensLogitsProcessor(LogitsProcessor):
"""Custom LogitsProcessor to completely ban specific tokens with proper device handling."""
def __init__(self, tokenizer, banned_words, device):
self.banned_token_ids = set()
self.device = device
# Get all possible token IDs for banned words
for word in banned_words:
variants = [word, f" {word}", f"{word} ", f" {word} "]
for variant in variants:
try:
token_ids = tokenizer.encode(variant, add_special_tokens=False)
self.banned_token_ids.update(token_ids)
except Exception as e:
print(f"Warning: Could not encode variant '{variant}': {e}")
print(f"Banned token IDs: {self.banned_token_ids}")
print(f"LogitsProcessor device: {self.device}")
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Set logits of banned tokens to negative infinity
for token_id in self.banned_token_ids:
if token_id < scores.shape[-1]: # Safety check
scores[:, token_id] = float('-inf')
return scores
def get_banned_token_ids(tokenizer, bad_words):
"""Get token IDs for words that should be banned using bad_words_ids format."""
bad_words_ids = []
for word in bad_words:
# Try different variations to handle tokenization edge cases
variants = [
word, # exact word
f" {word}", # with leading space
f"{word} ", # with trailing space
f" {word} " # with both spaces
]
for variant in variants:
try:
token_ids = tokenizer.encode(variant, add_special_tokens=False)
if token_ids: # Only add if tokenization succeeded
bad_words_ids.append(token_ids)
except Exception as e:
print(f"Warning: Could not encode variant '{variant}': {e}")
return bad_words_ids
def load_model_and_tokenizer():
"""Load the base model with LoRA adapter."""
global model, tokenizer
if model is not None and tokenizer is not None:
return model, tokenizer
print(f"Loading model from checkpoint: {CHECKPOINT_PATH}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Load base model with quantization
if LOAD_IN_4BIT and torch.cuda.is_available():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
else:
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
# Load LoRA adapter if checkpoint exists
model = PeftModel.from_pretrained(base_model, CHECKPOINT_PATH)
model.eval()
print("β Model and LoRA adapter loaded successfully")
print(f"β Model device: {next(model.parameters()).device}")
return model, tokenizer
def debug_tokenization(tokenizer, words):
"""Debug tokenization of specific words."""
print("=== TOKENIZATION DEBUG ===")
for word in words:
variants = [word, f" {word}", f"{word} ", f" {word} "]
for variant in variants:
try:
token_ids = tokenizer.encode(variant, add_special_tokens=False)
tokens = tokenizer.tokenize(variant)
print(f"'{variant}' -> IDs: {token_ids}, Tokens: {tokens}")
except Exception as e:
print(f"Error tokenizing '{variant}': {e}")
print("=========================")
def text_completion(prompt):
"""Enhanced text completion with comprehensive token banning."""
try:
model, tokenizer = load_model_and_tokenizer()
# Print the full prompt to CLI
print("=" * 80)
print("FULL PROMPT:")
print("=" * 80)
print(prompt)
print("=" * 80)
# Get model device
model_device = next(model.parameters()).device
print(f"Model device: {model_device}")
# Method 1: bad_words_ids
banned_words = ["", ""]
bad_words_ids = get_banned_token_ids(tokenizer, banned_words)
print(f"Bad words IDs: {bad_words_ids}")
# Method 2: Custom LogitsProcessor with proper device handling
ban_processor = BanTokensLogitsProcessor(tokenizer, banned_words, model_device)
logits_processor = LogitsProcessorList([ban_processor])
# Debug tokenization (run once to see how tokens are encoded)
# debug_tokenization(tokenizer, banned_words)
inputs = tokenizer(prompt, return_tensors="pt").to(model_device)
print(f"Input device: {inputs['input_ids'].device}")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=0.6,
top_p=0.95,
top_k=20,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
bad_words_ids=bad_words_ids, # Filter out tokens
)
# Decode response
completion = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False)
# Print the raw response to CLI
print("RAW MODEL OUTPUT:")
print("=" * 80)
print(completion)
print("=" * 80)
# Clean up the response - stop at first IM_END token
if IM_END in completion:
completion = completion.split(IM_END)[0]
return completion.strip()
except Exception as e:
error_msg = f"Error generating completion: {str(e)}"
print(error_msg)
print(f"Exception type: {type(e)}")
import traceback
traceback.print_exc()
return error_msg
def format_message_for_display(content, role):
"""Format a message for display in the Gradio interface (remove chat tokens but keep scorer content)."""
if role == "user":
return content
elif role == "assistant":
# Keep the content visible but remove chat tokens
return content
return content
def build_chat_prompt(messages):
"""Build the full chat prompt with proper tokens for model generation."""
prompt = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "user":
prompt += f"{IM_START}user\n{content}{IM_END}\n"
elif role == "assistant":
if msg.get("complete", False):
# Complete message with IM_END
prompt += f"{IM_START}assistant\n{content}{IM_END}\n"
else:
# Incomplete message for generation
prompt += f"{IM_START}assistant\n{content}"
print("BUILT CHAT PROMPT:")
print("=" * 60)
print(prompt)
print("=" * 60)
return prompt
def initialize_conversation(ilr_level):
"""Initialize a new conversation with the given ILR level."""
print(f"π Initializing conversation at ILR level: {ilr_level}")
# Create initial messages
initial_user_content = INITIAL_USER_MESSAGE_TEMPLATE.format(ilr_level=ilr_level)
initial_assistant_content = f"\n{INITIAL_ASSISTANT_SCORER}\n\n"
messages = [
{"role": "user", "content": initial_user_content, "complete": True},
{"role": "assistant", "content": initial_assistant_content, "complete": False}
]
# Generate the initial assistant response
prompt = build_chat_prompt(messages)
response = text_completion(prompt)
# Update the assistant message with the complete response
messages[-1]["content"] = initial_assistant_content + response
messages[-1]["complete"] = True
# Convert to display format for Gradio
display_history = []
display_history.append([
format_message_for_display(initial_user_content, "user"),
format_message_for_display(messages[-1]["content"], "assistant")
])
# Format raw output for display
raw_output = f"RAW MODEL OUTPUT:\n{'=' * 80}\n{response}\n{'=' * 80}"
return display_history, messages, raw_output
def send_message(user_input, chat_history, messages, ilr_level):
"""Handle sending a user message and generating assistant response."""
if not user_input.strip():
return chat_history, "", messages, ""
print("π SENDING MESSAGE:")
print("=" * 60)
print(f"User Input: {repr(user_input)}")
print(f"Current Messages: {len(messages)}")
print("=" * 60)
# Add user message
messages.append({"role": "user", "content": user_input, "complete": True})
# Start assistant response with scorer tag
assistant_start = "\n"
messages.append({"role": "assistant", "content": assistant_start, "complete": False})
# Generate assistant response
prompt = build_chat_prompt(messages)
response = text_completion(prompt)
# Complete the assistant message
full_assistant_content = assistant_start + response
messages[-1]["content"] = full_assistant_content
messages[-1]["complete"] = True
# Update chat history for display
chat_history.append([
format_message_for_display(user_input, "user"),
format_message_for_display(full_assistant_content, "assistant")
])
# Format raw output for display
raw_output = f"RAW MODEL OUTPUT:\n{'=' * 80}\n{response}\n{'=' * 80}"
return chat_history, "", messages, raw_output
def reset_conversation(ilr_level):
"""Reset the conversation with a new ILR level."""
chat_history, messages, raw_output = initialize_conversation(ilr_level)
return chat_history, messages, raw_output
def create_interface():
"""Create the Gradio interface."""
with gr.Blocks(title="ILR Arabic Assistant", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πΈπ¦ ILR Arabic Assistant")
# State to store messages
messages_state = gr.State([])
with gr.Row():
with gr.Column(scale=1):
ilr_level = gr.Dropdown(
choices=ILR_LEVELS,
value="2+",
label="ILR Level",
info="Select your proficiency level"
)
reset_btn = gr.Button(
"π Reset Conversation",
variant="primary"
)
gr.Markdown("""The ILR Assistant generates Arabic reading comprehension assessments that adapt to your performance level. It presents Arabic passages with questions and automatically adjusts difficulty based on your responses - moving to easier content when you struggle or maintaining challenge when you succeed. The system was trained on authentic Arabic learning materials from the Defense Language Institute using the official ILR (Interagency Language Roundtable) proficiency scale. Try it out to see how AI can create personalized language assessments that respond to your Arabic reading comprehension skills.
### ILR Levels:
- **1**: Elementary
- **1+**: Elementary+
- **2**: Limited Working
- **2+**: Limited Working+
- **3**: General Professional
- **3+**: General Professional+
""")
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Conversation",
height=500,
show_copy_button=True,
avatar_images=("π€", "π€"),
)
with gr.Row():
msg = gr.Textbox(
label="Your message",
placeholder="Type your response in English...",
scale=4
)
send_btn = gr.Button("π€ Send", scale=1, variant="primary")
# Raw output display
raw_output_display = gr.Textbox(
label="Raw Model Output",
lines=10,
max_lines=20,
interactive=False,
show_copy_button=True,
autoscroll=True,
placeholder="Raw model output will appear here...",
)
# Event handlers
def handle_reset(level):
return reset_conversation(level)
def handle_send(user_input, chat_history, messages, level):
return send_message(user_input, chat_history, messages, level)
reset_btn.click(
handle_reset,
inputs=[ilr_level],
outputs=[chatbot, messages_state, raw_output_display]
)
send_btn.click(
handle_send,
inputs=[msg, chatbot, messages_state, ilr_level],
outputs=[chatbot, msg, messages_state, raw_output_display]
)
msg.submit(
handle_send,
inputs=[msg, chatbot, messages_state, ilr_level],
outputs=[chatbot, msg, messages_state, raw_output_display]
)
# Initialize conversation on load
def on_load(level):
chat_history, messages, raw_output = initialize_conversation(level)
return chat_history, messages, raw_output
demo.load(
on_load,
inputs=[ilr_level],
outputs=[chatbot, messages_state, raw_output_display]
)
return demo
if __name__ == "__main__":
demo = create_interface()
load_model_and_tokenizer()
demo.launch()