test_finetunning / check.py
vector2000's picture
Upload check.py
a67cce9 verified
raw
history blame
4.93 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset # , load_dataset
from tqdm import tqdm
# Завантаження моделей та токенізатора
# original_model_name = "meta-llama/Meta-Llama-3.1-8B"
original_model_name = "facebook/opt-350m" # Це відкрита модель, яку можно використовувати для тестування
fine_tuned_model_path = "./fine_tuned_model" # Шлях до вашої донавченної моделі
tokenizer = AutoTokenizer.from_pretrained(original_model_name)
original_model = AutoModelForCausalLM.from_pretrained(original_model_name)
fine_tuned_model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_path)
# Завантаження тестового набора данних
# test_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
# Завантаження данних з локального тестового файлу
with open("ilya_klimov_data.txt", "r", encoding="utf-8") as file:
text_data = file.read().strip()
# Створення датасету
test_dataset = Dataset.from_dict({"text": [text_data]})
def calculate_perplexity(model, tokenizer, dataset, max_length=1024):
model.eval()
total_loss = 0
total_length = 0
for item in tqdm(dataset, desc="Calculating perplexity"):
encodings = tokenizer(item['text'], return_tensors='pt', truncation=True, max_length=max_length)
input_ids = encodings.input_ids.to(model.device)
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
total_loss += outputs.loss.item() * input_ids.size(1)
total_length += input_ids.size(1)
avg_loss = total_loss / total_length
perplexity = torch.exp(torch.tensor(avg_loss)).item()
return perplexity
# Розрахунок реплексії для обох моделей
print("Calculating perplexity for the original model...")
original_perplexity = calculate_perplexity(original_model, tokenizer, test_dataset)
print("Calculating perplexity for the fine-tuned model...")
fine_tuned_perplexity = calculate_perplexity(fine_tuned_model, tokenizer, test_dataset)
print(f"Original model perplexity: {original_perplexity:.2f}")
print(f"Fine-tuned model perplexity: {fine_tuned_perplexity:.2f}")
# Порівняння генерації текста
def generate_text(model, tokenizer, prompt, max_length=150):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, no_repeat_ngram_size=2)
return tokenizer.decode(output[0], skip_special_tokens=True)
# prompt = "The history of artificial intelligence"
prompt = "Илья Климов - разработчик из Харькова, работающий в GitLab. Его основной язык программирования"
print("\nText generation comparison:")
print("Original model output:")
print(generate_text(original_model, tokenizer, prompt))
print("\nFine-tuned model output:")
print(generate_text(fine_tuned_model, tokenizer, prompt))
# Порівняння втрат на декількох прикладах
def compare_losses(original_model, fine_tuned_model, tokenizer, texts):
original_model.eval()
fine_tuned_model.eval()
for i, text in enumerate(texts, 1):
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
original_loss = original_model(**inputs, labels=inputs["input_ids"]).loss.item()
fine_tuned_loss = fine_tuned_model(**inputs, labels=inputs["input_ids"]).loss.item()
print(f"\nExample {i}:")
print(f"Original model loss: {original_loss:.4f}")
print(f"Fine-tuned model loss: {fine_tuned_loss:.4f}")
print("\nComparing losses on specific examples:")
#example_texts = [
# "Artificial intelligence has revolutionized many fields of science and technology.",
# "The development of machine learning algorithms has led to significant advancements in data analysis.",
# "Neural networks are a fundamental component of modern AI systems."
#]
example_texts = [
"Илья Климов работает в компании GitLab и использует JavaScript.",
"Основной фреймворк, который использует Илья Климов для работы в GitLab - это VueJS.",
"Илья Климов выступает на IT-конференциях и продает курсы по программированию.",
"У Ильи Климова есть желтый лотос, что является интересным фактом о нем."
]
compare_losses(original_model, fine_tuned_model, tokenizer, example_texts)