""" tool_trainer_m4_max.py - Optimized training for M4 Max Apple Silicon + SmolLM3-3B This script is specifically optimized for: - M4 Max 40-core GPU Apple Silicon - SmolLM3-3B (larger, more capable model) - Large training dataset (100+ examples) - Aggressive but stable hyperparameters for fast, high-quality training """ import json import torch import torch.backends.mps from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) from peft import LoraConfig, get_peft_model, TaskType from datasets import Dataset import os import time def setup_mps_optimization(): """Configure optimal settings for M4 Max.""" print("๐ŸŽ Configuring M4 Max optimizations...") # Check MPS availability if torch.backends.mps.is_available(): print("โœ… MPS (Metal Performance Shaders) is available") print(f"๐Ÿš€ Using all 40 GPU cores of M4 Max") device = torch.device("mps") else: print("โš ๏ธ MPS not available, falling back to CPU") device = torch.device("cpu") # Optimize memory allocation os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0" # Aggressive memory usage os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid fork warnings return device def load_training_data(file_path="tool_pairs_enhanced.jsonl"): """Load the comprehensive training dataset.""" pairs = [] with open(file_path, 'r') as f: for line in f: pairs.append(json.loads(line.strip())) return pairs def format_for_sft(pairs, tokenizer): """Convert pairs to SFT format optimized for function calling.""" formatted = [] for pair in pairs: # Create training example: prompt + chosen response full_text = pair["prompt"] + pair["chosen"] + tokenizer.eos_token formatted.append({"text": full_text}) return formatted def tokenize_function(examples, tokenizer, max_length=512): """Tokenize with consistent padding for variable length sequences.""" # Reduced max_length to handle variable sequences better tokenized = tokenizer( examples["text"], truncation=True, padding="max_length", # Consistent padding max_length=max_length, return_tensors=None ) # For causal LM, labels are the same as input_ids tokenized["labels"] = tokenized["input_ids"] return tokenized def main(): print("๐Ÿš€ M4 Max Optimized Training: SmolLM3-3B Function Calling") print("=" * 70) # Setup M4 Max optimizations device = setup_mps_optimization() start_time = time.time() # 1. Load SmolLM3-3B (the real deal!) print("๐Ÿ“ฅ Loading SmolLM3-3B model and tokenizer...") model_name = "HuggingFaceTB/SmolLM3-3B" # Using the actual SmolLM3-3B! tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Ensure consistent tokenizer settings tokenizer.padding_side = "right" # Load model with MPS optimization model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, # Use float32 for MPS compatibility trust_remote_code=True, attn_implementation="eager" # More stable for training ) # Move to MPS if available if str(device) == "mps": model = model.to(device) print(f"โœ… Loaded model: {model_name}") print(f"๐Ÿ”ง Model dtype: {model.dtype}") print(f"๐Ÿ’พ Model size: ~{sum(p.numel() for p in model.parameters()) / 1e9:.1f}B parameters") print(f"๐ŸŽฏ Device: {device}") # 2. Setup LoRA with optimized config for larger model print("\n๐Ÿ”ฉ Setting up LoRA adapter (rank 16 for SmolLM3-3B)...") lora_config = LoraConfig( r=16, # Higher rank for 3B model (more capacity) lora_alpha=32, # 2x rank target_modules=[ # Target more modules for better coverage "q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens", "lm_head" # Include embeddings for better learning ], lora_dropout=0.05, # Lower dropout for stability bias="none", task_type=TaskType.CAUSAL_LM, modules_to_save=["embed_tokens", "lm_head"] # Save these for better function calling ) model = get_peft_model(model, lora_config) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params = sum(p.numel() for p in model.parameters()) print(f"โœ… LoRA adapter attached") print(f"๐ŸŽฏ Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)") # 3. Load comprehensive training data print("\n๐Ÿ“Š Loading comprehensive training dataset...") pairs = load_training_data() formatted_pairs = format_for_sft(pairs, tokenizer) print(f"โœ… Loaded {len(pairs)} training pairs") print(f"๐Ÿ“ˆ Dataset is {len(pairs)/8:.1f}x larger than before!") # Create and tokenize dataset train_dataset = Dataset.from_list(formatted_pairs) tokenized_dataset = train_dataset.map( lambda x: tokenize_function(x, tokenizer), batched=True, remove_columns=train_dataset.column_names, num_proc=1 # Single process for MPS compatibility ) print(f"๐Ÿ“Š Tokenized dataset: {len(tokenized_dataset)} examples") # 4. Optimized training arguments for M4 Max print("\nโš™๏ธ Configuring M4 Max optimized training...") training_args = TrainingArguments( output_dir="./smollm3_tool_adapter", num_train_epochs=5, # More epochs with larger dataset per_device_train_batch_size=4, # Larger batch size for M4 Max gradient_accumulation_steps=2, # Effective batch size = 8 learning_rate=3e-4, # Higher LR for faster convergence weight_decay=0.01, # Regularization warmup_steps=50, # More warmup for stability logging_steps=5, save_steps=25, save_total_limit=3, remove_unused_columns=False, fp16=False, # Disable mixed precision for MPS compatibility dataloader_pin_memory=False, # Disable for MPS report_to=None, logging_dir="./logs", gradient_checkpointing=True, # Memory optimization optim="adamw_torch", # Optimized optimizer lr_scheduler_type="cosine", # Better convergence save_strategy="steps", eval_strategy="no", load_best_model_at_end=False, ) # 5. Data collator with proper padding data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, pad_to_multiple_of=8, # Efficient padding for performance ) # 6. Initialize optimized trainer print("๐Ÿ‹๏ธ Initializing M4 Max optimized trainer...") trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=data_collator, remove_unused_columns=False, ) print("โœ… Trainer ready for M4 Max acceleration") # 7. Start accelerated training print("\n๐ŸŽฏ Starting accelerated training on M4 Max...") print("โฑ๏ธ Expected time: ~3-5 minutes with 40 GPU cores") print("๐Ÿ“Š Monitoring loss for quality improvement...") # Train with progress monitoring train_result = trainer.train() end_time = time.time() training_time = end_time - start_time print("\n๐ŸŽ‰ M4 Max training completed!") print(f"๐Ÿ“Š Final training loss: {train_result.training_loss:.4f}") print(f"โฑ๏ธ Total training time: {training_time:.1f} seconds") print(f"๐Ÿš€ Training speed: {len(pairs) * 5 / training_time:.1f} examples/second") # 8. Save the optimized model print("\n๐Ÿ’พ Saving optimized model adapter...") model.save_pretrained("./smollm3_tool_adapter") tokenizer.save_pretrained("./smollm3_tool_adapter") print("โœ… Model saved to './smollm3_tool_adapter'") # 9. Enhanced functionality test print("\n๐Ÿงช Enhanced functionality test...") test_schemas = [ { "schema": { "name": "get_stock_price", "description": "Get current stock price", "parameters": { "type": "object", "properties": {"ticker": {"type": "string"}}, "required": ["ticker"] } }, "question": "What's Google stock price?", "expected_ticker": "GOOGL" }, { "schema": { "name": "process_payment", "description": "Process a payment transaction", "parameters": { "type": "object", "properties": { "amount": {"type": "number"}, "currency": {"type": "string"}, "recipient": {"type": "string"} }, "required": ["amount", "recipient"] } }, "question": "Send $150 to Alice", "expected": "process_payment" } ] model.eval() for i, test in enumerate(test_schemas, 1): test_prompt = f"""<|im_start|>system You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|> {json.dumps(test['schema'], indent=2)} <|im_start|>user {test['question']}<|im_end|> <|im_start|>assistant """ inputs = tokenizer(test_prompt, return_tensors="pt") if str(device) == "mps": inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=80, temperature=0.1, do_sample=True, pad_token_id=tokenizer.eos_token_id ) response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True) print(f"๐Ÿงช Test {i}: {test['question']}") print(f"๐Ÿค– Response: {response.strip()}") # Try to parse JSON try: json_response = json.loads(response.strip()) print(f"โœ… Valid JSON: {json_response}") except: print(f"โŒ Invalid JSON") print("-" * 50) print("\n๐Ÿ† M4 Max Optimized Training Complete!") print(f"๐Ÿ“ˆ Loss reduction with {len(pairs)} examples should be significant") print(f"๐ŸŽฏ Ready for comprehensive testing with schema_tester.py") return model, tokenizer if __name__ == "__main__": model, tokenizer = main()