Andro0s commited on
Commit
8819d2a
·
verified ·
1 Parent(s): 8a3f13e

Create Train.py

Browse files
Files changed (1) hide show
  1. Train.py +81 -0
Train.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ===============================
2
+ # AmorCoder AI - Entrenamiento LoRA Avanzado
3
+ # ===============================
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
6
+ from datasets import load_dataset
7
+ from peft import LoraConfig, get_peft_model, TaskType
8
+
9
+ # -------------------------------
10
+ # 1️⃣ Modelo base
11
+ # -------------------------------
12
+ MODEL_NAME = "codellama/CodeLlama-7b-hf"
13
+ print("Cargando modelo base...")
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ MODEL_NAME,
17
+ device_map="auto",
18
+ torch_dtype=torch.float16
19
+ )
20
+
21
+ # -------------------------------
22
+ # 2️⃣ Configuración LoRA
23
+ # -------------------------------
24
+ print("Aplicando LoRA...")
25
+ lora_config = LoraConfig(
26
+ task_type=TaskType.CAUSAL_LM,
27
+ r=16,
28
+ lora_alpha=32,
29
+ target_modules=["q_proj", "v_proj"], # módulos recomendados para LLMs
30
+ lora_dropout=0.05,
31
+ bias="none"
32
+ )
33
+ model = get_peft_model(model, lora_config)
34
+
35
+ # -------------------------------
36
+ # 3️⃣ Dataset
37
+ # -------------------------------
38
+ print("Cargando dataset...")
39
+ dataset = load_dataset("json", data_files={"train":"tu_dataset.json"}, split="train")
40
+
41
+ def preprocess(example):
42
+ prompt = f"# Instrucción:\n{example['instruction']}\n\n# Código:\n"
43
+ input_ids = tokenizer(prompt, truncation=True, max_length=512)["input_ids"]
44
+ labels = tokenizer(example['code'], truncation=True, max_length=512)["input_ids"]
45
+ return {"input_ids": input_ids, "labels": labels}
46
+
47
+ dataset = dataset.map(preprocess)
48
+
49
+ # -------------------------------
50
+ # 4️⃣ Argumentos de entrenamiento
51
+ # -------------------------------
52
+ training_args = TrainingArguments(
53
+ output_dir="./lora_codellama",
54
+ per_device_train_batch_size=1, # usar gradient accumulation para batches grandes
55
+ gradient_accumulation_steps=4,
56
+ num_train_epochs=3, # puedes subir a 5 para más precisión
57
+ learning_rate=3e-4,
58
+ fp16=True,
59
+ logging_steps=10,
60
+ save_steps=50,
61
+ save_total_limit=3,
62
+ report_to="none", # para no depender de wandb u otro tracker
63
+ )
64
+
65
+ # -------------------------------
66
+ # 5️⃣ Entrenamiento
67
+ # -------------------------------
68
+ trainer = Trainer(
69
+ model=model,
70
+ train_dataset=dataset,
71
+ args=training_args
72
+ )
73
+
74
+ print("Entrenando LoRA...")
75
+ trainer.train()
76
+
77
+ # -------------------------------
78
+ # 6️⃣ Guardar pesos
79
+ # -------------------------------
80
+ model.save_pretrained("lora_codellama")
81
+ print("✅ Entrenamiento completado. Pesos guardados en 'lora_codellama'.")