vincenzodentamaro commited on
Commit
05570bd
·
verified ·
1 Parent(s): 09ab73a

Create train_and_generate_8b.py

Browse files
Files changed (1) hide show
  1. train_and_generate_8b.py +125 -0
train_and_generate_8b.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # train_and_generate_8b.py
3
+
4
+ import torch
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ Trainer,
8
+ TrainingArguments,
9
+ DataCollatorForLanguageModeling,
10
+ )
11
+ from datasets import load_dataset
12
+ import logging
13
+
14
+ # Import the custom WERSA classes from your local package
15
+ # This assumes you have run `pip install -e .` with the corrected modeling file
16
+ from wersa import WersaConfig, WersaForCausalLM
17
+
18
+ # --- Setup Logging ---
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ def main():
23
+ # --- 1. Configuration for 8B Model ---
24
+ logger.info("Setting up 8B model and tokenizer...")
25
+ # Use a Qwen tokenizer for better compatibility with the architecture
26
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B")
27
+ if tokenizer.pad_token is None:
28
+ # GPT-2 style EOS token is a common choice for a pad token
29
+ tokenizer.pad_token = tokenizer.eos_token
30
+
31
+ # Use the 8B configuration defined in the package
32
+ config = WersaConfig(
33
+ vocab_size=len(tokenizer),
34
+ pad_token_id=tokenizer.pad_token_id,
35
+ # --- 8B Parameters ---
36
+ hidden_size=4096,
37
+ num_hidden_layers=32,
38
+ num_attention_heads=32,
39
+ intermediate_size=11008,
40
+ max_position_embeddings=4096,
41
+ # --- WERSA Parameters ---
42
+ wersa_decomp_levels=4,
43
+ wersa_random_features=256,
44
+ )
45
+
46
+ model = WersaForCausalLM(config)
47
+ logger.info(f"Model created with approximately {model.num_parameters() / 1e9:.2f}B parameters.")
48
+
49
+ # --- 2. Dataset Preparation ---
50
+ logger.info("Loading and preparing dataset...")
51
+ # Using a small slice of a large dataset for demonstration.
52
+ # For a real 8B pre-training, you would use the full dataset and train for many more steps.
53
+ raw_dataset = load_dataset("allenai/c4", "en", split="train[:100000]") # Using 100k samples
54
+ raw_dataset = raw_dataset.shuffle(seed=42)
55
+
56
+ def tokenize_function(examples):
57
+ return tokenizer(examples["text"], truncation=True, max_length=config.max_position_embeddings)
58
+
59
+ tokenized_dataset = raw_dataset.map(
60
+ tokenize_function, batched=True, num_proc=8, remove_columns=["text", "timestamp", "url"]
61
+ )
62
+
63
+ block_size = config.max_position_embeddings
64
+ def group_texts(examples):
65
+ concatenated = {k: sum(examples[k], []) for k in examples.keys()}
66
+ total_length = len(concatenated[list(examples.keys())[0]])
67
+ total_length = (total_length // block_size) * block_size
68
+ result = {k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated.items()}
69
+ result["labels"] = result["input_ids"].copy()
70
+ return result
71
+
72
+ lm_dataset = tokenized_dataset.map(group_texts, batched=True, batch_size=1000, num_proc=8)
73
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
74
+
75
+ # --- 3. Training ---
76
+ output_dir = "./wersa-qwen-style-8b-final"
77
+ logger.info("Setting up Trainer for 8B model...")
78
+
79
+ # WARNING: These settings require a high-end multi-GPU setup (e.g., A100s).
80
+ # Adjust `per_device_train_batch_size` and `gradient_accumulation_steps` for your hardware.
81
+ training_args = TrainingArguments(
82
+ output_dir=output_dir,
83
+ overwrite_output_dir=True,
84
+ num_train_epochs=1,
85
+ per_device_train_batch_size=1, # Minimal batch size per GPU
86
+ gradient_accumulation_steps=32, # Increase to simulate a larger batch size
87
+ save_steps=500,
88
+ save_total_limit=2,
89
+ logging_steps=10,
90
+ fp16=torch.cuda.is_available(), # Use fp16 for memory efficiency
91
+ )
92
+
93
+ trainer = Trainer(model=model, args=training_args, train_dataset=lm_dataset, data_collator=data_collator)
94
+
95
+ logger.info("Starting pre-training for the 8B model...")
96
+ trainer.train()
97
+ logger.info("Pre-training finished.")
98
+ trainer.save_model(output_dir)
99
+ tokenizer.save_pretrained(output_dir)
100
+ logger.info(f"Model and tokenizer saved to {output_dir}")
101
+
102
+ # --- 4. Generation Test ---
103
+ logger.info("\n" + "="*50 + "\n RUNNING 8B GENERATION TEST\n" + "="*50 + "\n")
104
+
105
+ # Load the trained model and tokenizer from the final output directory
106
+ trained_model = WersaForCausalLM.from_pretrained(output_dir)
107
+ trained_tokenizer = AutoTokenizer.from_pretrained(output_dir)
108
+
109
+ device = "cuda" if torch.cuda.is_available() else "cpu"
110
+ trained_model.to(device)
111
+
112
+ prompt = "What is the meaning of life?"
113
+ inputs = trained_tokenizer(prompt, return_tensors="pt").to(device)
114
+ logger.info(f"PROMPT: '{prompt}'")
115
+
116
+ # Generate text using the loaded model
117
+ outputs = trained_model.generate(**inputs, max_new_tokens=100, no_repeat_ngram_size=2, pad_token_id=tokenizer.eos_token_id)
118
+ generated_text = trained_tokenizer.decode(outputs[0], skip_special_tokens=True)
119
+
120
+ logger.info("\nMODEL COMPLETION:\n")
121
+ print(generated_text)
122
+ logger.info("\n" + "="*50 + "\n TEST COMPLETE\n" + "="*50)
123
+
124
+ if __name__ == "__main__":
125
+ main()