|
from Llama3Model import Llama3Model |
|
from cfg import LLAMA32_CONFIG, LLAMA_SIZE_STR |
|
from tools import model_memory_size, generate, text_to_token_ids, token_ids_to_text |
|
import torch |
|
from huggingface import hf_hub_download |
|
from tools import load_weights_into_llama |
|
from huggingface import chat_tokenizer, tokenizer |
|
from safetensors.torch import save_file |
|
|
|
model = Llama3Model(LLAMA32_CONFIG) |
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
print(f"Total number of parameters: {total_params:,}") |
|
|
|
|
|
total_params_normalized = total_params - model.tok_emb.weight.numel() |
|
print(f"\nTotal number of unique parameters: {total_params_normalized:,}") |
|
|
|
print(f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB") |
|
print(f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB") |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
elif torch.backends.mps.is_available(): |
|
device = torch.device("mps") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
model.to(device) |
|
|
|
print(torch.cuda.is_available()) |
|
|
|
from safetensors.torch import load_file |
|
|
|
|
|
if LLAMA_SIZE_STR == "1B": |
|
weights_file = hf_hub_download( |
|
repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct", |
|
filename="model.safetensors", |
|
local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct" |
|
) |
|
combined_weights = load_file(weights_file) |
|
|
|
|
|
else: |
|
combined_weights = {} |
|
for i in range(1, 3): |
|
weights_file = hf_hub_download( |
|
repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct", |
|
filename=f"model-0000{i}-of-00002.safetensors", |
|
local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct" |
|
) |
|
current_weights = load_file(weights_file) |
|
combined_weights.update(current_weights) |
|
|
|
|
|
load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights) |
|
model.to(device) |
|
del combined_weights |
|
print("Weight tying:", torch.equal(model.tok_emb.weight, model.out_head.weight)) |
|
|
|
import time |
|
|
|
|
|
PROMPT = "What do llamas eat?" |
|
|
|
torch.manual_seed(123) |
|
|
|
start = time.time() |
|
|
|
token_ids = generate( |
|
model=model, |
|
idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device), |
|
max_new_tokens=150, |
|
context_size=LLAMA32_CONFIG["context_length"], |
|
top_k=1, |
|
temperature=0. |
|
) |
|
|
|
print(f"Time: {time.time() - start:.2f} sec") |
|
|
|
if torch.cuda.is_available(): |
|
max_mem_bytes = torch.cuda.max_memory_allocated() |
|
max_mem_gb = max_mem_bytes / (1024 ** 3) |
|
print(f"Max memory allocated: {max_mem_gb:.2f} GB") |
|
|
|
output_text = token_ids_to_text(token_ids, tokenizer) |
|
|
|
|
|
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): |
|
|
|
index = text.find(header_end) |
|
|
|
if index != -1: |
|
|
|
return text[index + len(header_end):].strip() |
|
else: |
|
|
|
return text |
|
|
|
print("\n\nOutput text:\n\n", clean_text(output_text)) |
|
|
|
weights = model.state_dict() |
|
|
|
save_name = "llama32_weights.safetensors" |
|
save_file(weights, save_name) |
|
print(f"{save_name}") |