|
from Llama3Model import Llama3Model |
|
from cfg import LLAMA32_CONFIG, LLAMA_SIZE_STR |
|
import torch |
|
import time |
|
from tools import model_memory_size, generate, text_to_token_ids, token_ids_to_text |
|
from huggingface import chat_tokenizer, tokenizer |
|
|
|
model = Llama3Model(LLAMA32_CONFIG) |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
elif torch.backends.mps.is_available(): |
|
device = torch.device("mps") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
model.to(device) |
|
|
|
from safetensors.torch import load_file |
|
|
|
weights = load_file("llama32_weights.safetensors") |
|
model.load_state_dict(weights) |
|
|
|
print("Weights loaded successfully!") |
|
|
|
start = time.time() |
|
|
|
PROMPT = "What do llamas eat?" |
|
token_ids = generate( |
|
model=model, |
|
idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device), |
|
max_new_tokens=150, |
|
context_size=LLAMA32_CONFIG["context_length"], |
|
top_k=1, |
|
temperature=0. |
|
) |
|
|
|
print(f"Time: {time.time() - start:.2f} sec") |
|
|
|
if torch.cuda.is_available(): |
|
max_mem_bytes = torch.cuda.max_memory_allocated() |
|
max_mem_gb = max_mem_bytes / (1024 ** 3) |
|
print(f"Max memory allocated: {max_mem_gb:.2f} GB") |
|
|
|
output_text = token_ids_to_text(token_ids, tokenizer) |
|
|
|
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): |
|
|
|
index = text.find(header_end) |
|
|
|
if index != -1: |
|
|
|
return text[index + len(header_end):].strip() |
|
else: |
|
|
|
return text |
|
|
|
print("\n\nOutput text:\n\n", clean_text(output_text)) |