#!/usr/bin/env python3 # wersa_watcher.py - FIXED VERSION # Monitor and test WERSA checkpoints during training import argparse import os import torch import time import hashlib import shutil import tempfile import logging from contextlib import contextmanager from transformers import AutoTokenizer, LogitsProcessor # Disattiva i log di transformers per un output pulito logging.getLogger("transformers").setLevel(logging.ERROR) # Importa il tuo modello from wersa import WersaForCausalLM @contextmanager def timeout(duration): def timeout_handler(signum, frame): raise TimeoutError(f"Operation timed out after {duration} seconds") # Only works on Unix-like systems if hasattr(signal, 'SIGALRM'): import signal as _signal _signal.signal(_signal.SIGALRM, timeout_handler) _signal.alarm(duration) try: yield finally: _signal.alarm(0) else: # On Windows, just yield without timeout yield def find_latest_checkpoint_slot(output_dir): """Trova lo slot del checkpoint modificato più di recente.""" slot1_dir = os.path.join(output_dir, "slot1") slot2_dir = os.path.join(output_dir, "slot2") slot1_exists = os.path.exists(slot1_dir) and os.path.exists(os.path.join(slot1_dir, "config.json")) slot2_exists = os.path.exists(slot2_dir) and os.path.exists(os.path.join(slot2_dir, "config.json")) if not slot1_exists and not slot2_exists: return None if slot1_exists and not slot2_exists: return slot1_dir if slot2_exists and not slot1_exists: return slot2_dir # Se entrambi esistono, controlla training_info.json per trovare il più recente try: info1_path = os.path.join(slot1_dir, "training_info.json") info2_path = os.path.join(slot2_dir, "training_info.json") if os.path.exists(info1_path) and os.path.exists(info2_path): with open(info1_path) as f: info1 = json.load(f) with open(info2_path) as f: info2 = json.load(f) step1 = info1.get("completed_steps", 0) step2 = info2.get("completed_steps", 0) return slot2_dir if step2 > step1 else slot1_dir # Fallback to modification time if os.path.getmtime(slot1_dir) > os.path.getmtime(slot2_dir): return slot1_dir else: return slot2_dir except Exception: return slot2_dir if slot2_exists else slot1_dir def calculate_hash(filepath): """Calcola l'hash SHA256 di un file.""" sha256_hash = hashlib.sha256() try: with open(filepath, "rb") as f: for byte_block in iter(lambda: f.read(4096), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() except (FileNotFoundError, OSError): return None def safe_copy_checkpoint(source_dir, dest_dir, max_retries=3): """Copia in modo sicuro il checkpoint con retry in caso di errori.""" for attempt in range(max_retries): try: if os.path.exists(dest_dir): shutil.rmtree(dest_dir) # Verifica file essenziali required_files = ["config.json"] for file in required_files: if not os.path.exists(os.path.join(source_dir, file)): raise FileNotFoundError(f"Required file {file} not found in {source_dir}") shutil.copytree(source_dir, dest_dir) return True except (FileNotFoundError, OSError, shutil.Error) as e: print(f"Tentativo {attempt + 1} di copia fallito: {e}") if attempt < max_retries - 1: time.sleep(2) else: raise e return False def check_model_files(checkpoint_dir): """Verifica che tutti i file necessari per il modello esistano.""" required_files = ["config.json"] # File del modello - controlla quale formato è presente model_files = ["model.safetensors", "pytorch_model.bin"] model_found = any(os.path.exists(os.path.join(checkpoint_dir, f)) for f in model_files) # File del tokenizer tokenizer_files = [ "tokenizer_config.json", "vocab.json", "merges.txt", "tokenizer.json", "special_tokens_map.json" ] missing_required = [] missing_optional = [] for file in required_files: if not os.path.exists(os.path.join(checkpoint_dir, file)): missing_required.append(file) if not model_found: missing_required.append("model weights (safetensors or bin)") # Check for at least one tokenizer file tokenizer_found = any(os.path.exists(os.path.join(checkpoint_dir, f)) for f in tokenizer_files) if not tokenizer_found: missing_optional.append("tokenizer files") return missing_required, missing_optional class SafeLogitsProcessor(LogitsProcessor): """Ensures logits do not contain invalid values.""" def __call__(self, input_ids, scores): scores = scores.nan_to_num(nan=-1e9, posinf=1e9, neginf=-1e9) scores = torch.clamp(scores, min=-1e9, max=1e9) return scores def generate_text(model, tokenizer, prompt, device='cuda', max_new_tokens=50): """Generate text with better debugging and safety checks.""" model.eval() # Tokenize input inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) input_ids = inputs['input_ids'].to(device) print(f"Input shape: {input_ids.shape}") print(f"Input text: {tokenizer.decode(input_ids[0])}") try: with torch.no_grad(): output = model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9, top_k=50, repetition_penalty=1.2, pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, logits_processor=[SafeLogitsProcessor()], ) response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True) full_response = tokenizer.decode(output[0], skip_special_tokens=True) return response, full_response except Exception as e: print(f"Error during generation: {e}") import traceback traceback.print_exc() return None, None def main(args): last_tested_hash = None device = 'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu' print(f"--- WERSA Model Watcher ---") print(f"Output dir: {args.output_dir}") print(f"Check interval: {args.check_interval}s") print(f"Test prompt: '{args.prompt}'") print(f"Device: {device}") print("="*60) while True: try: latest_checkpoint_dir = find_latest_checkpoint_slot(args.output_dir) if latest_checkpoint_dir: config_path = os.path.join(latest_checkpoint_dir, "config.json") if os.path.exists(config_path): current_hash = calculate_hash(config_path) if current_hash != last_tested_hash: print(f"\n[{time.strftime('%Y-%m-%d %H:%M:%S')}] New checkpoint: '{os.path.basename(latest_checkpoint_dir)}'") # Check files missing_required, missing_optional = check_model_files(latest_checkpoint_dir) if missing_required: print(f"Missing required files: {missing_required}") last_tested_hash = current_hash continue if missing_optional: print(f"Missing optional files: {missing_optional}") # Load training info if available info_path = os.path.join(latest_checkpoint_dir, "training_info.json") if os.path.exists(info_path): with open(info_path) as f: info = json.load(f) print(f"Training step: {info.get('completed_steps', 'unknown')}") temp_dir = None try: # Copy checkpoint temp_dir = tempfile.mkdtemp() print("Copying checkpoint...") safe_copy_checkpoint(latest_checkpoint_dir, temp_dir) print("Loading model...") torch_dtype = torch.float32 if device == 'cuda': torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 model = WersaForCausalLM.from_pretrained( temp_dir, torch_dtype=torch_dtype, device_map=device if device == 'cuda' else None, ) if device == 'cpu': model = model.to(device) elif device == 'cuda': model = model.to(device) # Load tokenizer try: tokenizer = AutoTokenizer.from_pretrained(temp_dir) print("Loaded checkpoint tokenizer") except: print("No tokenizer in checkpoint, trying base tokenizer...") try: tokenizer = AutoTokenizer.from_pretrained(args.output_dir) except: print("Using GPT2 tokenizer as fallback") tokenizer = AutoTokenizer.from_pretrained("gpt2") # Ensure pad token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Tokenizer vocab size: {len(tokenizer)}") print(f"Model vocab size: {model.config.vocab_size}") print("Generating text...") with timeout(300): response, full_response = generate_text( model, tokenizer, args.prompt, device=device, max_new_tokens=args.max_new_tokens ) if response is not None: print("\n" + "="*60) print("PROMPT:", args.prompt) print("-"*60) print("RESPONSE:", response) print("-"*60) print("FULL TEXT:", full_response) print("="*60 + "\n") last_tested_hash = current_hash except TimeoutError as e: print(f"TIMEOUT: {e}") except torch.cuda.OutOfMemoryError: print("CUDA OOM! Try --cpu flag or reduce --max_new_tokens") torch.cuda.empty_cache() except Exception as e: print(f"Error: {e}") import traceback traceback.print_exc() finally: # Cleanup if 'model' in locals(): del model if device == 'cuda': torch.cuda.empty_cache() if temp_dir and os.path.exists(temp_dir): try: shutil.rmtree(temp_dir) except: pass else: print(f"config.json not found in {latest_checkpoint_dir}") else: print("No valid checkpoint found yet...") # Wait before next check time.sleep(args.check_interval) except KeyboardInterrupt: print("\nWatcher stopped by user") break except Exception as e: print(f"Error in main loop: {e}") import traceback traceback.print_exc() time.sleep(args.check_interval) if __name__ == "__main__": import signal import json parser = argparse.ArgumentParser(description="Monitor and test WERSA checkpoints in real-time") parser.add_argument("--output_dir", type=str, required=True, help="Training output directory with checkpoints") parser.add_argument("--prompt", type=str, default="The meaning of life is", help="Prompt for testing") parser.add_argument("--check_interval", type=int, default=60, help="Seconds between checks") parser.add_argument("--max_new_tokens", type=int, default=50, help="Max tokens to generate") parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU") args = parser.parse_args() main(args)