Russian
NekitAI commited on
Commit
ab3dd03
·
verified ·
1 Parent(s): 0df6ff8

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +78 -0
train.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ GPT2Config,
3
+ GPT2LMHeadModel,
4
+ GPT2TokenizerFast,
5
+ Trainer,
6
+ TrainingArguments,
7
+ TextDataset,
8
+ DataCollatorForLanguageModeling
9
+ )
10
+
11
+ from pathlib import Path
12
+
13
+ # === Параметры ===
14
+ model_name = "NekitAI"
15
+ data_path = "my_texts.txt"
16
+ block_size = 128
17
+ batch_size = 4
18
+ epochs = 3
19
+
20
+ # === Токенизатор ===
21
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
22
+ tokenizer.pad_token = tokenizer.eos_token # обязательно для обучения
23
+
24
+ # === Конфигурация модели ===
25
+ config = GPT2Config(
26
+ vocab_size=tokenizer.vocab_size,
27
+ n_positions=block_size,
28
+ n_embd=256,
29
+ n_layer=4,
30
+ n_head=4,
31
+ bos_token_id=tokenizer.bos_token_id,
32
+ eos_token_id=tokenizer.eos_token_id
33
+ )
34
+
35
+ # === Создание модели ===
36
+ model = GPT2LMHeadModel(config)
37
+
38
+ # === Подготовка датасета ===
39
+ dataset = TextDataset(
40
+ tokenizer=tokenizer,
41
+ file_path=data_path,
42
+ block_size=block_size
43
+ )
44
+
45
+ data_collator = DataCollatorForLanguageModeling(
46
+ tokenizer=tokenizer, mlm=False
47
+ )
48
+
49
+ # === Аргументы обучения ===
50
+ training_args = TrainingArguments(
51
+ output_dir=model_name,
52
+ overwrite_output_dir=True,
53
+ per_device_train_batch_size=batch_size,
54
+ num_train_epochs=epochs,
55
+ save_steps=500,
56
+ logging_steps=50,
57
+ save_total_limit=1,
58
+ prediction_loss_only=True,
59
+ fp16=True, # включай, если у тебя есть GPU с поддержкой fp16
60
+ )
61
+
62
+ # === Trainer ===
63
+ trainer = Trainer(
64
+ model=model,
65
+ args=training_args,
66
+ data_collator=data_collator,
67
+ train_dataset=dataset,
68
+ )
69
+
70
+ # === Обучение ===
71
+ trainer.train()
72
+
73
+ # === Сохранение модели и токенизатора ===
74
+ Path(model_name).mkdir(parents=True, exist_ok=True)
75
+ model.save_pretrained(model_name)
76
+ tokenizer.save_pretrained(model_name)
77
+
78
+ print(f"\n✅ Готово! Модель сохранена в: {model_name}")