from Llama3Model import Llama3Model from cfg import LLAMA32_CONFIG, LLAMA_SIZE_STR from tools import model_memory_size, generate, text_to_token_ids, token_ids_to_text import torch from huggingface import hf_hub_download from tools import load_weights_into_llama from huggingface import chat_tokenizer, tokenizer from safetensors.torch import save_file model = Llama3Model(LLAMA32_CONFIG) total_params = sum(p.numel() for p in model.parameters()) print(f"Total number of parameters: {total_params:,}") # Account for weight tying total_params_normalized = total_params - model.tok_emb.weight.numel() print(f"\nTotal number of unique parameters: {total_params_normalized:,}") print(f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB") print(f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB") if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") model.to(device) print(torch.cuda.is_available()) from safetensors.torch import load_file if LLAMA_SIZE_STR == "1B": weights_file = hf_hub_download( repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct", filename="model.safetensors", local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct" ) combined_weights = load_file(weights_file) else: combined_weights = {} for i in range(1, 3): weights_file = hf_hub_download( repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct", filename=f"model-0000{i}-of-00002.safetensors", local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct" ) current_weights = load_file(weights_file) combined_weights.update(current_weights) load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights) model.to(device) del combined_weights # free up memory print("Weight tying:", torch.equal(model.tok_emb.weight, model.out_head.weight)) import time PROMPT = "What do llamas eat?" torch.manual_seed(123) start = time.time() token_ids = generate( model=model, idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device), max_new_tokens=150, context_size=LLAMA32_CONFIG["context_length"], top_k=1, temperature=0. ) print(f"Time: {time.time() - start:.2f} sec") if torch.cuda.is_available(): max_mem_bytes = torch.cuda.max_memory_allocated() max_mem_gb = max_mem_bytes / (1024 ** 3) print(f"Max memory allocated: {max_mem_gb:.2f} GB") output_text = token_ids_to_text(token_ids, tokenizer) def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): # Find the index of the first occurrence of "<|end_header_id|>" index = text.find(header_end) if index != -1: # Return the substring starting after "<|end_header_id|>" return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace else: # If the token is not found, return the original text return text print("\n\nOutput text:\n\n", clean_text(output_text)) weights = model.state_dict() save_name = "llama32_weights.safetensors" save_file(weights, save_name) print(f"{save_name}")