Llama32 / main.py
hmsjwzb's picture
Upload folder using huggingface_hub
069c29f verified
from Llama3Model import Llama3Model
from cfg import LLAMA32_CONFIG, LLAMA_SIZE_STR
from tools import model_memory_size, generate, text_to_token_ids, token_ids_to_text
import torch
from huggingface import hf_hub_download
from tools import load_weights_into_llama
from huggingface import chat_tokenizer, tokenizer
from safetensors.torch import save_file
model = Llama3Model(LLAMA32_CONFIG)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")
# Account for weight tying
total_params_normalized = total_params - model.tok_emb.weight.numel()
print(f"\nTotal number of unique parameters: {total_params_normalized:,}")
print(f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB")
print(f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB")
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
model.to(device)
print(torch.cuda.is_available())
from safetensors.torch import load_file
if LLAMA_SIZE_STR == "1B":
weights_file = hf_hub_download(
repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
filename="model.safetensors",
local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
)
combined_weights = load_file(weights_file)
else:
combined_weights = {}
for i in range(1, 3):
weights_file = hf_hub_download(
repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
filename=f"model-0000{i}-of-00002.safetensors",
local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
)
current_weights = load_file(weights_file)
combined_weights.update(current_weights)
load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)
model.to(device)
del combined_weights # free up memory
print("Weight tying:", torch.equal(model.tok_emb.weight, model.out_head.weight))
import time
PROMPT = "What do llamas eat?"
torch.manual_seed(123)
start = time.time()
token_ids = generate(
model=model,
idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),
max_new_tokens=150,
context_size=LLAMA32_CONFIG["context_length"],
top_k=1,
temperature=0.
)
print(f"Time: {time.time() - start:.2f} sec")
if torch.cuda.is_available():
max_mem_bytes = torch.cuda.max_memory_allocated()
max_mem_gb = max_mem_bytes / (1024 ** 3)
print(f"Max memory allocated: {max_mem_gb:.2f} GB")
output_text = token_ids_to_text(token_ids, tokenizer)
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
# Find the index of the first occurrence of "<|end_header_id|>"
index = text.find(header_end)
if index != -1:
# Return the substring starting after "<|end_header_id|>"
return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace
else:
# If the token is not found, return the original text
return text
print("\n\nOutput text:\n\n", clean_text(output_text))
weights = model.state_dict()
save_name = "llama32_weights.safetensors"
save_file(weights, save_name)
print(f"{save_name}")