import os import torch import gradio as gr from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, LogitsProcessorList, LogitsProcessor, ) from peft import PeftModel # CONFIGURATION CHECKPOINT_PATH = "pcalhoun/ILR-Assistant-LoRA" MODEL_NAME = "Qwen/Qwen3-4B" LOAD_IN_4BIT = True MAX_NEW_TOKENS = 1024 ILR_LEVELS = ['1', '1+', '2', '2+', '3', '3+'] INITIAL_USER_MESSAGE_TEMPLATE = """ILR Level 1 (Elementary): Reads very simple texts (e.g., tourist materials) with high-frequency vocabulary. Misunderstandings common; grasps basic ideas in familiar contexts. ILR Level 1+ (Elementary+): Handles simple announcements, headlines, or narratives. Can locate routine professional info but struggles with structure and cohesion. ILR Level 2 (Limited Working): Reads straightforward factual texts on familiar topics (e.g., news, basic reports). Understands main ideas but slowly; inferences are limited. ILR Level 2+ (Limited Working+): Comprehends most non-technical prose and concrete professional discussions. Separates main ideas from details but misses nuance. ILR Level 3 (General Professional): Reads diverse authentic texts (e.g., news, reports) with near-complete comprehension. Interprets implicit meaning but struggles with complex idioms. ILR Level 3+ (General Professional+): Handles varied professional styles with minimal errors. Understands cultural references and complex structures, though subtleties may be missed. Initial ILR level for this conversation: {ilr_level} Test my comprehension of Modern Standard Arabic.""" INITIAL_ASSISTANT_SCORER = "I am administering an ILR level assessment." IM_START = "<|im_start|>" IM_END = "<|im_end|>" # Global variables model = None tokenizer = None class BanTokensLogitsProcessor(LogitsProcessor): """Custom LogitsProcessor to completely ban specific tokens with proper device handling.""" def __init__(self, tokenizer, banned_words, device): self.banned_token_ids = set() self.device = device # Get all possible token IDs for banned words for word in banned_words: variants = [word, f" {word}", f"{word} ", f" {word} "] for variant in variants: try: token_ids = tokenizer.encode(variant, add_special_tokens=False) self.banned_token_ids.update(token_ids) except Exception as e: print(f"Warning: Could not encode variant '{variant}': {e}") print(f"Banned token IDs: {self.banned_token_ids}") print(f"LogitsProcessor device: {self.device}") def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # Set logits of banned tokens to negative infinity for token_id in self.banned_token_ids: if token_id < scores.shape[-1]: # Safety check scores[:, token_id] = float('-inf') return scores def get_banned_token_ids(tokenizer, bad_words): """Get token IDs for words that should be banned using bad_words_ids format.""" bad_words_ids = [] for word in bad_words: # Try different variations to handle tokenization edge cases variants = [ word, # exact word f" {word}", # with leading space f"{word} ", # with trailing space f" {word} " # with both spaces ] for variant in variants: try: token_ids = tokenizer.encode(variant, add_special_tokens=False) if token_ids: # Only add if tokenization succeeded bad_words_ids.append(token_ids) except Exception as e: print(f"Warning: Could not encode variant '{variant}': {e}") return bad_words_ids def load_model_and_tokenizer(): """Load the base model with LoRA adapter.""" global model, tokenizer if model is not None and tokenizer is not None: return model, tokenizer print(f"Loading model from checkpoint: {CHECKPOINT_PATH}") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token # Load base model with quantization if LOAD_IN_4BIT and torch.cuda.is_available(): bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) base_model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, ) else: base_model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) # Load LoRA adapter if checkpoint exists model = PeftModel.from_pretrained(base_model, CHECKPOINT_PATH) model.eval() print("βœ“ Model and LoRA adapter loaded successfully") print(f"βœ“ Model device: {next(model.parameters()).device}") return model, tokenizer def debug_tokenization(tokenizer, words): """Debug tokenization of specific words.""" print("=== TOKENIZATION DEBUG ===") for word in words: variants = [word, f" {word}", f"{word} ", f" {word} "] for variant in variants: try: token_ids = tokenizer.encode(variant, add_special_tokens=False) tokens = tokenizer.tokenize(variant) print(f"'{variant}' -> IDs: {token_ids}, Tokens: {tokens}") except Exception as e: print(f"Error tokenizing '{variant}': {e}") print("=========================") def text_completion(prompt): """Enhanced text completion with comprehensive token banning.""" try: model, tokenizer = load_model_and_tokenizer() # Print the full prompt to CLI print("=" * 80) print("FULL PROMPT:") print("=" * 80) print(prompt) print("=" * 80) # Get model device model_device = next(model.parameters()).device print(f"Model device: {model_device}") # Method 1: bad_words_ids banned_words = ["", ""] bad_words_ids = get_banned_token_ids(tokenizer, banned_words) print(f"Bad words IDs: {bad_words_ids}") # Method 2: Custom LogitsProcessor with proper device handling ban_processor = BanTokensLogitsProcessor(tokenizer, banned_words, model_device) logits_processor = LogitsProcessorList([ban_processor]) # Debug tokenization (run once to see how tokens are encoded) # debug_tokenization(tokenizer, banned_words) inputs = tokenizer(prompt, return_tensors="pt").to(model_device) print(f"Input device: {inputs['input_ids'].device}") with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=0.6, top_p=0.95, top_k=20, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, bad_words_ids=bad_words_ids, # Filter out tokens ) # Decode response completion = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False) # Print the raw response to CLI print("RAW MODEL OUTPUT:") print("=" * 80) print(completion) print("=" * 80) # Clean up the response - stop at first IM_END token if IM_END in completion: completion = completion.split(IM_END)[0] return completion.strip() except Exception as e: error_msg = f"Error generating completion: {str(e)}" print(error_msg) print(f"Exception type: {type(e)}") import traceback traceback.print_exc() return error_msg def format_message_for_display(content, role): """Format a message for display in the Gradio interface (remove chat tokens but keep scorer content).""" if role == "user": return content elif role == "assistant": # Keep the content visible but remove chat tokens return content return content def build_chat_prompt(messages): """Build the full chat prompt with proper tokens for model generation.""" prompt = "" for msg in messages: role = msg["role"] content = msg["content"] if role == "user": prompt += f"{IM_START}user\n{content}{IM_END}\n" elif role == "assistant": if msg.get("complete", False): # Complete message with IM_END prompt += f"{IM_START}assistant\n{content}{IM_END}\n" else: # Incomplete message for generation prompt += f"{IM_START}assistant\n{content}" print("BUILT CHAT PROMPT:") print("=" * 60) print(prompt) print("=" * 60) return prompt def initialize_conversation(ilr_level): """Initialize a new conversation with the given ILR level.""" print(f"πŸ”„ Initializing conversation at ILR level: {ilr_level}") # Create initial messages initial_user_content = INITIAL_USER_MESSAGE_TEMPLATE.format(ilr_level=ilr_level) initial_assistant_content = f"\n{INITIAL_ASSISTANT_SCORER}\n\n" messages = [ {"role": "user", "content": initial_user_content, "complete": True}, {"role": "assistant", "content": initial_assistant_content, "complete": False} ] # Generate the initial assistant response prompt = build_chat_prompt(messages) response = text_completion(prompt) # Update the assistant message with the complete response messages[-1]["content"] = initial_assistant_content + response messages[-1]["complete"] = True # Convert to display format for Gradio display_history = [] display_history.append([ format_message_for_display(initial_user_content, "user"), format_message_for_display(messages[-1]["content"], "assistant") ]) # Format raw output for display raw_output = f"RAW MODEL OUTPUT:\n{'=' * 80}\n{response}\n{'=' * 80}" return display_history, messages, raw_output def send_message(user_input, chat_history, messages, ilr_level): """Handle sending a user message and generating assistant response.""" if not user_input.strip(): return chat_history, "", messages, "" print("πŸ“ SENDING MESSAGE:") print("=" * 60) print(f"User Input: {repr(user_input)}") print(f"Current Messages: {len(messages)}") print("=" * 60) # Add user message messages.append({"role": "user", "content": user_input, "complete": True}) # Start assistant response with scorer tag assistant_start = "\n" messages.append({"role": "assistant", "content": assistant_start, "complete": False}) # Generate assistant response prompt = build_chat_prompt(messages) response = text_completion(prompt) # Complete the assistant message full_assistant_content = assistant_start + response messages[-1]["content"] = full_assistant_content messages[-1]["complete"] = True # Update chat history for display chat_history.append([ format_message_for_display(user_input, "user"), format_message_for_display(full_assistant_content, "assistant") ]) # Format raw output for display raw_output = f"RAW MODEL OUTPUT:\n{'=' * 80}\n{response}\n{'=' * 80}" return chat_history, "", messages, raw_output def reset_conversation(ilr_level): """Reset the conversation with a new ILR level.""" chat_history, messages, raw_output = initialize_conversation(ilr_level) return chat_history, messages, raw_output def create_interface(): """Create the Gradio interface.""" with gr.Blocks(title="ILR Arabic Assistant", theme=gr.themes.Soft()) as demo: gr.Markdown("# πŸ‡ΈπŸ‡¦ ILR Arabic Assistant") # State to store messages messages_state = gr.State([]) with gr.Row(): with gr.Column(scale=1): ilr_level = gr.Dropdown( choices=ILR_LEVELS, value="2+", label="ILR Level", info="Select your proficiency level" ) reset_btn = gr.Button( "πŸ”„ Reset Conversation", variant="primary" ) gr.Markdown("""The ILR Assistant generates Arabic reading comprehension assessments that adapt to your performance level. It presents Arabic passages with questions and automatically adjusts difficulty based on your responses - moving to easier content when you struggle or maintaining challenge when you succeed. The system was trained on authentic Arabic learning materials from the Defense Language Institute using the official ILR (Interagency Language Roundtable) proficiency scale. Try it out to see how AI can create personalized language assessments that respond to your Arabic reading comprehension skills. ### ILR Levels: - **1**: Elementary - **1+**: Elementary+ - **2**: Limited Working - **2+**: Limited Working+ - **3**: General Professional - **3+**: General Professional+ """) with gr.Column(scale=3): chatbot = gr.Chatbot( label="Conversation", height=500, show_copy_button=True, avatar_images=("πŸ‘€", "πŸ€–"), ) with gr.Row(): msg = gr.Textbox( label="Your message", placeholder="Type your response in English...", scale=4 ) send_btn = gr.Button("πŸ“€ Send", scale=1, variant="primary") # Raw output display raw_output_display = gr.Textbox( label="Raw Model Output", lines=10, max_lines=20, interactive=False, show_copy_button=True, autoscroll=True, placeholder="Raw model output will appear here...", ) # Event handlers def handle_reset(level): return reset_conversation(level) def handle_send(user_input, chat_history, messages, level): return send_message(user_input, chat_history, messages, level) reset_btn.click( handle_reset, inputs=[ilr_level], outputs=[chatbot, messages_state, raw_output_display] ) send_btn.click( handle_send, inputs=[msg, chatbot, messages_state, ilr_level], outputs=[chatbot, msg, messages_state, raw_output_display] ) msg.submit( handle_send, inputs=[msg, chatbot, messages_state, ilr_level], outputs=[chatbot, msg, messages_state, raw_output_display] ) # Initialize conversation on load def on_load(level): chat_history, messages, raw_output = initialize_conversation(level) return chat_history, messages, raw_output demo.load( on_load, inputs=[ilr_level], outputs=[chatbot, messages_state, raw_output_display] ) return demo if __name__ == "__main__": demo = create_interface() load_model_and_tokenizer() demo.launch()