aarvis commited on
Commit
98ca186
·
1 Parent(s): 07256a3

read-me-update-6

Browse files
Files changed (2) hide show
  1. train_llama.py +383 -0
  2. train_orpheus.py +428 -0
train_llama.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # * This script was not rigorously tested, so it may not work as expected. We would suggest to
2
+ # * edit the script to follow Orpheus training script.
3
+
4
+ # * Install unsloth, PEFT, Weights & Biases, SNAC, pandas, soundfile and loguru.
5
+ # !pip install unsloth peft==0.15.2 wandb snac pandas soundfile loguru
6
+
7
+ # * Login to Weights & Biases.
8
+ # !wandb login
9
+
10
+ # Import necessary libraries.
11
+ # * unsloth import should always be at the top.
12
+ from unsloth import FastLanguageModel
13
+
14
+ import os
15
+
16
+ from datasets import load_dataset
17
+ from huggingface_hub import login
18
+ from loguru import logger
19
+ from snac import SNAC
20
+ from trl import SFTConfig, SFTTrainer
21
+ import soundfile as sf
22
+ import torch
23
+ import wandb
24
+
25
+
26
+ # Set up constants and configurations.
27
+ HUGGINGFACE_USERNAME = "" # ! Fill.
28
+ BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
29
+ TRAIN_CSV_PATH = "data/data_stage_1.csv"
30
+ VALID_CSV_PATH = "data/data_eval.csv"
31
+ TRAIN_NUM_SAMPLES = None
32
+ EVAL_NUM_SAMPLES = None
33
+ MAX_SEQ_LENGTH = 2048
34
+ N_CODEBOOKS, CODEBOOK_SIZE = 3, 4096
35
+ FIELDS = [
36
+ "user",
37
+ "gender",
38
+ "age",
39
+ "language",
40
+ "utterance",
41
+ "audio",
42
+ ]
43
+ START_OF_SPECIAL_TOKENS = {field: f"<|start_of_{field}|>" for field in FIELDS}
44
+ END_OF_SPECIAL_TOKENS = {field: f"<|end_of_{field}|>" for field in FIELDS}
45
+ SNAC_TOKENS = [
46
+ f"<|snac_{i}_{j}|>" for i in range(N_CODEBOOKS) for j in range(CODEBOOK_SIZE)
47
+ ]
48
+ PER_DEVICE_TRAIN_BATCH_SIZE = 8
49
+ GRADIENT_ACCUMULATION_STEPS = 4
50
+ FULL_FINETUNING = True # Set to False for LoRA training.
51
+ MODEL_NAME = "indic-tts-lora-training"
52
+ WANDB_USERNAME = "" # ! Fill.
53
+ WANDB_PROJECT = "indic-tts-lora-training"
54
+ WANDB_LOG_MODEL = "checkpoint"
55
+ WANDB_RUN_NAME = None
56
+ WANDB_RUN_ID = None
57
+ SEED = 3407
58
+ HUGGINGFACE_TOKEN = "" # ! Fill.
59
+ WANDB_TOKEN = "" # ! Fill.
60
+
61
+ # * Use the following command to start the training: python train_llama.py
62
+
63
+ # Login to Hugging Face.
64
+ login(token=HUGGINGFACE_TOKEN)
65
+
66
+ # Login to Weights & Biases.
67
+ wandb.login(key=WANDB_TOKEN)
68
+
69
+ # Set up environment variables for Weights & Biases.
70
+ os.environ["WANDB_PROJECT"] = WANDB_PROJECT
71
+ os.environ["WANDB_LOG_MODEL"] = WANDB_LOG_MODEL
72
+
73
+ # Load the model and tokenizer.
74
+ model, tokenizer = FastLanguageModel.from_pretrained(
75
+ model_name=BASE_MODEL,
76
+ load_in_4bit=not FULL_FINETUNING,
77
+ max_seq_length=MAX_SEQ_LENGTH,
78
+ full_finetuning=FULL_FINETUNING,
79
+ )
80
+ logger.success(f"Loaded model: {BASE_MODEL}")
81
+
82
+ # Set the end of sequence token.
83
+ EOS_TOKEN = tokenizer.eos_token
84
+
85
+ # Add new special tokens to the tokenizer.
86
+ new_special_tokens = (
87
+ list(START_OF_SPECIAL_TOKENS.values())
88
+ + list(END_OF_SPECIAL_TOKENS.values())
89
+ + SNAC_TOKENS
90
+ )
91
+ tokenizer.add_tokens(new_special_tokens, special_tokens=True)
92
+ model.resize_token_embeddings(len(tokenizer))
93
+ snac_offset = len(tokenizer.get_vocab()) - len(SNAC_TOKENS)
94
+ logger.success("Added new special tokens to the tokenizer.")
95
+
96
+ if not FULL_FINETUNING:
97
+ # Get parameter efficient fine-tuning model.
98
+ model = FastLanguageModel.get_peft_model(
99
+ model,
100
+ r=192,
101
+ target_modules=[
102
+ "q_proj",
103
+ "k_proj",
104
+ "v_proj",
105
+ "o_proj",
106
+ "up_proj",
107
+ "down_proj",
108
+ "gate_proj",
109
+ "lm_head",
110
+ "embed_tokens",
111
+ ],
112
+ lora_alpha=384,
113
+ random_state=SEED,
114
+ )
115
+ logger.success("Initialized parameter efficient fine-tuning model.")
116
+
117
+ # Load training and validation datasets.
118
+ # The dataset should be in CSV format with columns user (str), language (str), utterance (str), and snac_codes (list).
119
+ train_dataset = load_dataset("csv", data_files=TRAIN_CSV_PATH)["train"]
120
+ eval_dataset = load_dataset("csv", data_files=VALID_CSV_PATH)["train"]
121
+
122
+ if TRAIN_NUM_SAMPLES:
123
+ train_dataset = train_dataset.shuffle(seed=SEED).select(
124
+ range(min(TRAIN_NUM_SAMPLES, len(train_dataset)))
125
+ )
126
+
127
+ if EVAL_NUM_SAMPLES:
128
+ eval_dataset = eval_dataset.shuffle(seed=SEED).select(
129
+ range(min(EVAL_NUM_SAMPLES, len(eval_dataset)))
130
+ )
131
+
132
+ logger.success(
133
+ f"Loaded datasets: {len(train_dataset)} training samples, {len(eval_dataset)} evaluation samples."
134
+ )
135
+
136
+
137
+ # Format SNAC audio codes.
138
+ def format_snac_audio_codes(row):
139
+ audio_codes = row["snac_codes"]
140
+ if isinstance(audio_codes, str):
141
+ audio_codes = eval(audio_codes)
142
+ snac_tokens = [[], [], []]
143
+ for i, layer in enumerate(audio_codes):
144
+ for code in layer:
145
+ snac_tokens[i].append(f"<|snac_{i}_{code}|>")
146
+ row["snac_tokens"] = snac_tokens
147
+ return row
148
+
149
+
150
+ train_dataset = train_dataset.map(format_snac_audio_codes)
151
+ eval_dataset = eval_dataset.map(format_snac_audio_codes)
152
+ logger.success("Formatted SNAC audio codes.")
153
+
154
+
155
+ # Flatten SNAC audio codes.
156
+ def flatten_audio_codes(row):
157
+ audio_codes = row["snac_tokens"]
158
+ flattened_codes = []
159
+ for i in range(len(audio_codes[0])):
160
+ flattened_codes.append(audio_codes[0][i])
161
+ flattened_codes.append(audio_codes[1][2 * i])
162
+ flattened_codes.append(audio_codes[2][4 * i])
163
+ flattened_codes.append(audio_codes[2][(4 * i) + 1])
164
+ flattened_codes.append(audio_codes[1][(2 * i) + 1])
165
+ flattened_codes.append(audio_codes[2][(4 * i) + 2])
166
+ flattened_codes.append(audio_codes[2][(4 * i) + 3])
167
+ row["snac_tokens_list"] = flattened_codes
168
+ return row
169
+
170
+
171
+ train_dataset = train_dataset.map(flatten_audio_codes)
172
+ eval_dataset = eval_dataset.map(flatten_audio_codes)
173
+ logger.success("Flattened SNAC audio codes.")
174
+
175
+
176
+ # Remove duplicate frames from the audio codes.
177
+ def remove_duplicate_frames(row):
178
+ vals = row["snac_tokens_list"]
179
+ if len(vals) % 7 != 0:
180
+ raise ValueError("Input list length must be divisible by 7")
181
+ result = vals[:7]
182
+ for i in range(7, len(vals), 7):
183
+ current_first = vals[i]
184
+ previous_first = result[-7]
185
+ if current_first != previous_first:
186
+ result.extend(vals[i : i + 7])
187
+ row["snac_tokens_list"] = result
188
+ return row
189
+
190
+
191
+ train_dataset = train_dataset.map(remove_duplicate_frames)
192
+ eval_dataset = eval_dataset.map(remove_duplicate_frames)
193
+ logger.success("Removed duplicate frames from audio codes.")
194
+
195
+
196
+ # Define a function to format the prompt for each row in the dataset.
197
+ def format_text(row):
198
+ input_parts = ""
199
+ output_part = ""
200
+ for field in FIELDS:
201
+ if field != "audio":
202
+ part = f"{START_OF_SPECIAL_TOKENS[field]} {row[field]} {END_OF_SPECIAL_TOKENS[field]}"
203
+ input_parts += part + " "
204
+ else:
205
+ output_part = f"{START_OF_SPECIAL_TOKENS[field]} {' '.join(row['snac_tokens_list'])} {END_OF_SPECIAL_TOKENS[field]}"
206
+ text = f"{input_parts.strip()} {output_part} {EOS_TOKEN}"
207
+ eval_text = f"{input_parts.strip()} {START_OF_SPECIAL_TOKENS['audio']} "
208
+ row["text"] = text
209
+ row["eval_text"] = eval_text
210
+ return row
211
+
212
+
213
+ train_dataset = train_dataset.map(format_text)
214
+ eval_dataset = eval_dataset.map(format_text)
215
+ logger.success("Formatted text for training and evaluation datasets.")
216
+
217
+ # Set training arguments.
218
+ training_args = SFTConfig(
219
+ num_train_epochs=2,
220
+ per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
221
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
222
+ optim="adamw_8bit",
223
+ learning_rate=5e-5 if FULL_FINETUNING else 2e-4,
224
+ lr_scheduler_type="cosine",
225
+ warmup_ratio=0.02,
226
+ do_eval=True,
227
+ eval_strategy="steps",
228
+ eval_steps=50,
229
+ logging_strategy="steps",
230
+ logging_steps=1,
231
+ save_strategy="steps",
232
+ save_only_model=True,
233
+ save_steps=1250,
234
+ output_dir="outputs",
235
+ report_to="wandb",
236
+ run_name=WANDB_RUN_NAME,
237
+ seed=SEED,
238
+ )
239
+
240
+ # Initialize the SFTTrainer.
241
+ trainer = SFTTrainer(
242
+ model=model,
243
+ tokenizer=tokenizer,
244
+ train_dataset=train_dataset,
245
+ eval_dataset=eval_dataset,
246
+ dataset_text_field="text",
247
+ max_seq_length=MAX_SEQ_LENGTH,
248
+ dataset_num_proc=2,
249
+ packing=True,
250
+ args=training_args,
251
+ )
252
+
253
+ logger.success("Initialized SFTTrainer with the specified configuration.")
254
+
255
+ # Start the training process.
256
+ logger.info("Starting the training process...")
257
+
258
+ run = wandb.init()
259
+
260
+ if WANDB_RUN_ID:
261
+ logger.info(f"Resuming from Weights & Biases run ID: {WANDB_RUN_ID}")
262
+
263
+ artifact = run.use_artifact(
264
+ f"{WANDB_USERNAME}/{WANDB_PROJECT}/{WANDB_RUN_ID}", type="model"
265
+ )
266
+
267
+ artifact_dir = artifact.download()
268
+
269
+ trainer.train(resume_from_checkpoint=artifact_dir)
270
+ else:
271
+ try:
272
+ logger.info("Attempting to resume training from the last checkpoint...")
273
+
274
+ trainer.train(resume_from_checkpoint=True)
275
+ except Exception as err:
276
+ trainer.train()
277
+
278
+ # Finish the Weights & Biases run.
279
+ wandb.finish()
280
+
281
+ logger.success("Training completed successfully.")
282
+
283
+ # ! Saving and loading model doesn't work.
284
+ # # Save the model and tokenizer.
285
+ # model.save_pretrained_merged(
286
+ # f"{HUGGINGFACE_USERNAME}/{MODEL_NAME}",
287
+ # tokenizer,
288
+ # save_method="merged_16bit",
289
+ # )
290
+ # logger.success("Saved the model and tokenizer locally.")
291
+
292
+ # model.push_to_hub_merged(
293
+ # f"{HUGGINGFACE_USERNAME}/{MODEL_NAME}",
294
+ # tokenizer,
295
+ # save_method="merged_16bit",
296
+ # token=HUGGINGFACE_TOKEN,
297
+ # )
298
+ # logger.success("Pushed the model and tokenizer to the Hugging Face Hub.")
299
+
300
+ # del trainer, model, tokenizer
301
+
302
+ # # Inference with the trained model.
303
+ # # Load the model and tokenizer.
304
+ # model, tokenizer = FastLanguageModel.from_pretrained(
305
+ # model_name=f"{HUGGINGFACE_USERNAME}/{MODEL_NAME}",
306
+ # load_in_4bit=True,
307
+ # max_seq_length=MAX_SEQ_LENGTH,
308
+ # )
309
+
310
+ FastLanguageModel.for_inference(model)
311
+
312
+ logger.success(f"Loaded model for inference: {HUGGINGFACE_USERNAME}/{MODEL_NAME}")
313
+
314
+ # Load the SNAC model for audio decoding.
315
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
316
+ logger.success("Loaded SNAC model for audio decoding.")
317
+
318
+
319
+ # Function to generate audio from a dataset row.
320
+ def generate_audio(
321
+ row, model, tokenizer, temperature=0.4, top_p=0.9, repetition_penalty=1.05
322
+ ):
323
+ prompt = row["eval_text"]
324
+ inputs = tokenizer(prompt, return_tensors="pt")
325
+ max_tokens = MAX_SEQ_LENGTH - inputs.input_ids.shape[1]
326
+ output = model.generate(
327
+ input_ids=inputs.input_ids.to("cuda"),
328
+ attention_mask=inputs.attention_mask.to("cuda"),
329
+ max_new_tokens=max_tokens,
330
+ temperature=temperature,
331
+ top_p=top_p,
332
+ repetition_penalty=repetition_penalty,
333
+ )
334
+ audio_ids = []
335
+ for id in output[0]:
336
+ if id >= snac_offset:
337
+ audio_ids.append(id.item())
338
+ clean_audio_ids = []
339
+ for i in range((len(audio_ids) + 1) // 7):
340
+ for j in range(7):
341
+ clean_audio_ids += [audio_ids[7 * i + j], 220]
342
+ audio_tokens = tokenizer.decode(clean_audio_ids).strip().split(" ")
343
+ codes = [[], [], []]
344
+ for i in range((len(audio_tokens) + 1) // 7):
345
+ frame = []
346
+ for j in range(7):
347
+ _, _, code = audio_tokens[7 * i + j].split("_")
348
+ code = int(code[:-2])
349
+ frame.append(code)
350
+ codes[0].append(frame[0])
351
+ codes[1].append(frame[1])
352
+ codes[2].append(frame[2])
353
+ codes[2].append(frame[3])
354
+ codes[1].append(frame[4])
355
+ codes[2].append(frame[5])
356
+ codes[2].append(frame[6])
357
+ codes = [
358
+ torch.tensor(codes[0]).unsqueeze(0),
359
+ torch.tensor(codes[1]).unsqueeze(0),
360
+ torch.tensor(codes[2]).unsqueeze(0),
361
+ ]
362
+ try:
363
+ audio = snac_model.decode(codes)
364
+ except Exception as e:
365
+ logger.error(f"Error decoding audio: {e}")
366
+ return None
367
+ return audio.detach().squeeze().to("cpu").numpy()
368
+
369
+
370
+ # Generate and save some examples.
371
+ train_sample = generate_audio(train_dataset[0], model, tokenizer)
372
+ if train_sample is None:
373
+ logger.error("Failed to generate audio for training sample.")
374
+ else:
375
+ sf.write("train.wav", train_sample, 24000)
376
+ logger.success("Generated and saved training sample audio.")
377
+
378
+ eval_sample = generate_audio(eval_dataset[0], model, tokenizer)
379
+ if eval_sample is None:
380
+ logger.error("Failed to generate audio for evaluation sample.")
381
+ else:
382
+ sf.write("eval.wav", eval_sample, 24000)
383
+ logger.success("Generated and saved evaluation sample audio.")
train_orpheus.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # * Install unsloth, PEFT, Weights & Biases, SNAC, pandas, soundfile and loguru.
2
+ # !pip install unsloth peft==0.15.2 wandb snac pandas soundfile loguru
3
+
4
+ # Import necessary libraries.
5
+ # * unsloth import should always be at the top.
6
+ from unsloth import FastLanguageModel
7
+
8
+ import os
9
+
10
+ from datasets import load_dataset
11
+ from huggingface_hub import login
12
+ from loguru import logger
13
+ from snac import SNAC
14
+ from trl import SFTConfig, SFTTrainer
15
+ import soundfile as sf
16
+ import torch
17
+ import wandb
18
+
19
+
20
+ # Set up constants and configurations.
21
+ STAGE = 1
22
+ HUGGINGFACE_USERNAME = "" # ! Fill.
23
+ if STAGE == 1:
24
+ # * You need to request access to the model at https://huggingface.co/canopylabs/3b-hi-pretrain-research_release.
25
+ BASE_MODEL = "canopylabs/3b-hi-pretrain-research_release"
26
+ TARGET_MODULES = [
27
+ "q_proj",
28
+ "k_proj",
29
+ "v_proj",
30
+ "o_proj",
31
+ "up_proj",
32
+ "down_proj",
33
+ "gate_proj",
34
+ "lm_head",
35
+ "embed_tokens",
36
+ ]
37
+ TRAIN_CSV_PATH = "data/data_stage_1.csv"
38
+ VALID_CSV_PATH = "data/data_eval.csv"
39
+ LR = 2e-4
40
+ EPOCHS = 1
41
+ MODEL_NAME = f"snorTTS-indicv0-stage-{STAGE}"
42
+ elif STAGE == 2:
43
+ BASE_MODEL = f"{HUGGINGFACE_USERNAME}/snorTTS-indicv0-stage-1"
44
+ TARGET_MODULES = [
45
+ "q_proj",
46
+ "k_proj",
47
+ "v_proj",
48
+ "o_proj",
49
+ "up_proj",
50
+ "down_proj",
51
+ "gate_proj",
52
+ "lm_head",
53
+ "embed_tokens",
54
+ ]
55
+ TRAIN_CSV_PATH = "data/data_stage_2.csv"
56
+ VALID_CSV_PATH = "data/data_eval.csv"
57
+ LR = 2e-4
58
+ EPOCHS = 2
59
+ MODEL_NAME = f"snorTTS-indicv0-stage-{STAGE}"
60
+ else:
61
+ BASE_MODEL = f"{HUGGINGFACE_USERNAME}/snorTTS-indicv0-stage-2"
62
+ TARGET_MODULES = [
63
+ "q_proj",
64
+ "k_proj",
65
+ "v_proj",
66
+ "o_proj",
67
+ "up_proj",
68
+ "down_proj",
69
+ "gate_proj",
70
+ ]
71
+ TRAIN_CSV_PATH = "data/data_train_tamil.csv"
72
+ VALID_CSV_PATH = "data/data_eval_tamil.csv"
73
+ LR = 2e-4
74
+ EPOCHS = 2
75
+ MODEL_NAME = f"snorTTS-tamilv0-stage-{STAGE}"
76
+ TRAIN_NUM_SAMPLES = None
77
+ EVAL_NUM_SAMPLES = 250
78
+ MAX_SEQ_LENGTH = 2048
79
+ PER_DEVICE_TRAIN_BATCH_SIZE = 8
80
+ GRADIENT_ACCUMULATION_STEPS = 4
81
+ WANDB_USERNAME = "" # ! Fill.
82
+ WANDB_PROJECT = MODEL_NAME
83
+ WANDB_LOG_MODEL = "checkpoint"
84
+ WANDB_RUN_NAME = f"{MODEL_NAME}-training"
85
+ WANDB_RUN_ID = None
86
+ SEED = 3407
87
+ HUGGINGFACE_TOKEN = "" # ! Fill.
88
+ WANDB_TOKEN = "" # ! Fill.
89
+
90
+ # * Use the following command to start the training: python train_orpheus.py
91
+
92
+ # Login to Hugging Face.
93
+ login(token=HUGGINGFACE_TOKEN)
94
+
95
+ # Login to Weights & Biases.
96
+ wandb.login(key=WANDB_TOKEN)
97
+
98
+ # Set up environment variables for Weights & Biases.
99
+ os.environ["WANDB_PROJECT"] = WANDB_PROJECT
100
+ os.environ["WANDB_LOG_MODEL"] = WANDB_LOG_MODEL
101
+
102
+ # Load the model and tokenizer.
103
+ model, tokenizer = FastLanguageModel.from_pretrained(
104
+ model_name=BASE_MODEL,
105
+ load_in_4bit=True,
106
+ max_seq_length=MAX_SEQ_LENGTH,
107
+ token=HUGGINGFACE_TOKEN,
108
+ )
109
+ logger.success(f"Loaded model: {BASE_MODEL}")
110
+
111
+ # Load the special tokens for the tokenizer.
112
+ tokeniser_length = 128256
113
+
114
+ start_of_text_id = 128000
115
+ end_of_text_id = 128009
116
+ start_of_speech_id = tokeniser_length + 1
117
+ end_of_speech_id = tokeniser_length + 2
118
+ start_of_human_id = tokeniser_length + 3
119
+ end_of_human_id = tokeniser_length + 4
120
+ start_of_ai_id = tokeniser_length + 5
121
+ end_of_ai_id = tokeniser_length + 6
122
+ pad_token_id = tokeniser_length + 7
123
+ audio_start_id = tokeniser_length + 10
124
+
125
+ start_of_text_token = tokenizer.decode([start_of_text_id])
126
+ end_of_text_token = tokenizer.decode([end_of_text_id])
127
+ start_of_speech_token = tokenizer.decode([start_of_speech_id])
128
+ end_of_speech_token = tokenizer.decode([end_of_speech_id])
129
+ start_of_human_token = tokenizer.decode([start_of_human_id])
130
+ end_of_human_token = tokenizer.decode([end_of_human_id])
131
+ start_of_ai_token = tokenizer.decode([start_of_ai_id])
132
+ end_of_ai_token = tokenizer.decode([end_of_ai_id])
133
+ pad_token = tokenizer.decode([pad_token_id])
134
+ audio_start_token = tokenizer.decode([audio_start_id])
135
+
136
+ logger.success("Load special tokens for the tokenizer.")
137
+
138
+ # Set the padding token and padding side.
139
+ tokenizer.pad_token = pad_token
140
+ tokenizer.padding_side = "left"
141
+ logger.success("Set padding token and padding side for the tokenizer.")
142
+
143
+ # Get parameter efficient fine-tuning model.
144
+ model = FastLanguageModel.get_peft_model(
145
+ model,
146
+ r=192,
147
+ target_modules=TARGET_MODULES,
148
+ lora_alpha=384,
149
+ random_state=SEED,
150
+ )
151
+ logger.success("Initialized parameter efficient fine-tuning model.")
152
+
153
+ # Load training and validation datasets.
154
+ # The dataset should be in CSV format with columns user (str), language (str), utterance (str), and snac_codes (list of lists).
155
+ train_dataset = load_dataset("csv", data_files=TRAIN_CSV_PATH)["train"]
156
+ eval_dataset = load_dataset("csv", data_files=VALID_CSV_PATH)["train"]
157
+
158
+ if TRAIN_NUM_SAMPLES:
159
+ train_dataset = train_dataset.shuffle(seed=SEED).select(
160
+ range(min(TRAIN_NUM_SAMPLES, len(train_dataset)))
161
+ )
162
+
163
+ if EVAL_NUM_SAMPLES:
164
+ eval_dataset = eval_dataset.shuffle(seed=SEED).select(
165
+ range(min(EVAL_NUM_SAMPLES, len(eval_dataset)))
166
+ )
167
+
168
+ logger.success(
169
+ f"Loaded datasets: {len(train_dataset)} training samples, {len(eval_dataset)} evaluation samples."
170
+ )
171
+
172
+
173
+ # Flatten and get SNAC token IDs from the audio codes.
174
+ def flatten_and_get_audio_input_ids(row):
175
+ audio_codes = row["snac_codes"]
176
+ if isinstance(audio_codes, str):
177
+ audio_codes = eval(audio_codes)
178
+ snac_token_ids = []
179
+ for i in range(len(audio_codes[0])):
180
+ snac_token_ids.append(audio_codes[0][i] + 128266)
181
+ snac_token_ids.append(audio_codes[1][2 * i] + 128266 + 4096)
182
+ snac_token_ids.append(audio_codes[2][4 * i] + 128266 + (2 * 4096))
183
+ snac_token_ids.append(audio_codes[2][(4 * i) + 1] + 128266 + (3 * 4096))
184
+ snac_token_ids.append(audio_codes[1][(2 * i) + 1] + 128266 + (4 * 4096))
185
+ snac_token_ids.append(audio_codes[2][(4 * i) + 2] + 128266 + (5 * 4096))
186
+ snac_token_ids.append(audio_codes[2][(4 * i) + 3] + 128266 + (6 * 4096))
187
+ row["snac_token_ids"] = snac_token_ids
188
+ return row
189
+
190
+
191
+ train_dataset = train_dataset.map(flatten_and_get_audio_input_ids)
192
+ eval_dataset = eval_dataset.map(flatten_and_get_audio_input_ids)
193
+ logger.success("Flattened and extracted SNAC token IDs from audio codes.")
194
+
195
+ # Filter out rows with empty or None audio codes.
196
+ train_dataset = train_dataset.filter(
197
+ lambda x: x["snac_token_ids"] is not None and len(x["snac_token_ids"]) > 0
198
+ )
199
+ eval_dataset = eval_dataset.filter(
200
+ lambda x: x["snac_token_ids"] is not None and len(x["snac_token_ids"]) > 0
201
+ )
202
+ logger.success("Filtered datasets to remove rows with empty or None audio codes.")
203
+
204
+
205
+ # Remove duplicate frames from the audio codes.
206
+ def remove_duplicate_frames(row):
207
+ vals = row["snac_token_ids"]
208
+ if len(vals) % 7 != 0:
209
+ raise ValueError("Input list length must be divisible by 7")
210
+ result = vals[:7]
211
+ for i in range(7, len(vals), 7):
212
+ current_first = vals[i]
213
+ previous_first = result[-7]
214
+ if current_first != previous_first:
215
+ result.extend(vals[i : i + 7])
216
+ row["snac_token_ids"] = result
217
+ return row
218
+
219
+
220
+ train_dataset = train_dataset.map(remove_duplicate_frames)
221
+ eval_dataset = eval_dataset.map(remove_duplicate_frames)
222
+ logger.success("Removed duplicate frames from audio codes.")
223
+
224
+
225
+ # Define a function to format the prompt for each row in the dataset.
226
+ def format_text(row):
227
+ text = (
228
+ f"{start_of_human_token}{start_of_text_token}{row['language']}{row['user']}: {row['utterance']}{end_of_text_token}"
229
+ f"{end_of_human_token}{start_of_ai_token}{start_of_speech_token}"
230
+ f"{tokenizer.decode(row['snac_token_ids'])}{end_of_speech_token}{end_of_ai_token}"
231
+ )
232
+ eval_text_user = (
233
+ f"{start_of_human_token}{start_of_text_token}{row['language']}{row['user']}: {row['utterance']}{end_of_text_token}"
234
+ f"{end_of_human_token}{start_of_ai_token}{start_of_speech_token}"
235
+ )
236
+ eval_text_no_user = (
237
+ f"{start_of_human_token}{start_of_text_token}{row['utterance']}{end_of_text_token}"
238
+ f"{end_of_human_token}{start_of_ai_token}{start_of_speech_token}"
239
+ )
240
+ row["text"] = text
241
+ row["eval_text_user"] = eval_text_user
242
+ row["eval_text_no_user"] = eval_text_no_user
243
+ return row
244
+
245
+
246
+ train_dataset = train_dataset.map(format_text)
247
+ eval_dataset = eval_dataset.map(format_text)
248
+ logger.success("Formatted text for training and evaluation datasets.")
249
+
250
+
251
+ # Tokenize the text in the datasets without adding special tokens.
252
+ def tokenize_function(example):
253
+ return tokenizer(
254
+ example["text"],
255
+ add_special_tokens=False,
256
+ truncation=True,
257
+ max_length=MAX_SEQ_LENGTH,
258
+ )
259
+
260
+
261
+ train_dataset = train_dataset.map(tokenize_function)
262
+ eval_dataset = eval_dataset.map(tokenize_function)
263
+ logger.success("Tokenized text in the datasets without adding special tokens.")
264
+
265
+ # Set training arguments.
266
+ training_args = SFTConfig(
267
+ num_train_epochs=EPOCHS,
268
+ per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
269
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
270
+ optim="adamw_8bit",
271
+ learning_rate=LR,
272
+ lr_scheduler_type="cosine",
273
+ warmup_ratio=0.02,
274
+ do_eval=True,
275
+ eval_strategy="steps",
276
+ eval_steps=50,
277
+ logging_strategy="steps",
278
+ logging_steps=1,
279
+ save_strategy="steps",
280
+ save_only_model=True,
281
+ save_steps=1250,
282
+ output_dir="outputs",
283
+ report_to="wandb",
284
+ run_name=WANDB_RUN_NAME,
285
+ seed=SEED,
286
+ )
287
+
288
+ # Initialize the SFTTrainer.
289
+ trainer = SFTTrainer(
290
+ model=model,
291
+ tokenizer=tokenizer,
292
+ train_dataset=train_dataset,
293
+ eval_dataset=eval_dataset,
294
+ max_seq_length=MAX_SEQ_LENGTH,
295
+ dataset_num_proc=2,
296
+ packing=True,
297
+ args=training_args,
298
+ )
299
+
300
+ logger.success("Initialized SFTTrainer with the specified configuration.")
301
+
302
+ # Start the training process.
303
+ logger.info("Starting the training process...")
304
+
305
+ run = wandb.init()
306
+
307
+ if WANDB_RUN_ID:
308
+ logger.info(f"Resuming from Weights & Biases run ID: {WANDB_RUN_ID}")
309
+
310
+ artifact = run.use_artifact(
311
+ f"{WANDB_USERNAME}/{WANDB_PROJECT}/{WANDB_RUN_ID}", type="model"
312
+ )
313
+
314
+ artifact_dir = artifact.download()
315
+
316
+ trainer.train(resume_from_checkpoint=artifact_dir)
317
+ else:
318
+ try:
319
+ logger.info("Attempting to resume training from the last checkpoint...")
320
+
321
+ trainer.train(resume_from_checkpoint=True)
322
+ except Exception as err:
323
+ trainer.train()
324
+
325
+ # Finish the Weights & Biases run.
326
+ wandb.finish()
327
+
328
+ logger.success("Training completed successfully.")
329
+
330
+ # Inference with the trained model.
331
+ FastLanguageModel.for_inference(model)
332
+ logger.success(f"Model {MODEL_NAME} is ready for inference.")
333
+
334
+ # Load the SNAC model for audio decoding.
335
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
336
+ logger.success("Loaded SNAC model for audio decoding.")
337
+
338
+
339
+ # Function to generate audio from a dataset row.
340
+ def generate_audio(
341
+ row,
342
+ model,
343
+ tokenizer,
344
+ user=False,
345
+ temperature=0.4,
346
+ top_p=0.9,
347
+ repetition_penalty=1.05,
348
+ ):
349
+ try:
350
+ if user:
351
+ prompt = row["eval_text_user"]
352
+ else:
353
+ prompt = row["eval_text_no_user"]
354
+ inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
355
+ max_tokens = MAX_SEQ_LENGTH - inputs.input_ids.shape[1]
356
+ output = model.generate(
357
+ input_ids=inputs.input_ids.to("cuda"),
358
+ attention_mask=inputs.attention_mask.to("cuda"),
359
+ max_new_tokens=max_tokens,
360
+ temperature=temperature,
361
+ top_p=top_p,
362
+ repetition_penalty=repetition_penalty,
363
+ eos_token_id=end_of_speech_id,
364
+ )
365
+ audio_ids = []
366
+ for id in output[0]:
367
+ if id >= audio_start_id:
368
+ audio_ids.append(id.item())
369
+ clean_audio_ids = []
370
+ for i in range((len(audio_ids) + 1) // 7):
371
+ for j in range(7):
372
+ clean_audio_ids += [audio_ids[7 * i + j] - audio_start_id]
373
+ codes = [[], [], []]
374
+ for i in range((len(clean_audio_ids) + 1) // 7):
375
+ codes[0].append(clean_audio_ids[7 * i])
376
+ codes[1].append(clean_audio_ids[7 * i + 1] - 4096)
377
+ codes[2].append(clean_audio_ids[7 * i + 2] - (2 * 4096))
378
+ codes[2].append(clean_audio_ids[7 * i + 3] - (3 * 4096))
379
+ codes[1].append(clean_audio_ids[7 * i + 4] - (4 * 4096))
380
+ codes[2].append(clean_audio_ids[7 * i + 5] - (5 * 4096))
381
+ codes[2].append(clean_audio_ids[7 * i + 6] - (6 * 4096))
382
+ codes = [
383
+ torch.tensor(codes[0]).unsqueeze(0),
384
+ torch.tensor(codes[1]).unsqueeze(0),
385
+ torch.tensor(codes[2]).unsqueeze(0),
386
+ ]
387
+ audio = snac_model.decode(codes)
388
+ return audio.detach().squeeze().to("cpu").numpy()
389
+ except Exception as e:
390
+ logger.error(f"Error decoding audio: {e}")
391
+ return None
392
+
393
+
394
+ # Generate and save some examples.
395
+ train_sample = generate_audio(train_dataset[0], model, tokenizer, True)
396
+ if train_sample is None:
397
+ logger.error("Failed to generate audio for training sample.")
398
+ else:
399
+ sf.write(f"train_{STAGE}.wav", train_sample, 24000)
400
+ logger.success("Generated and saved training sample audio.")
401
+
402
+
403
+ dir_ = f"eval_{STAGE}/"
404
+ os.makedirs(dir_, exist_ok=True)
405
+ for i in range(10):
406
+ eval_sample = generate_audio(eval_dataset[i], model, tokenizer, True)
407
+ if eval_sample is None:
408
+ logger.error(f"Failed to generate audio for evaluation sample {i}.")
409
+ else:
410
+ filename = dir_ + f"eval_{i}.wav"
411
+ sf.write(filename, eval_sample, 24000)
412
+ logger.success(f"Generated and saved evaluation sample audio as {filename}.")
413
+
414
+ # Save the model and tokenizer.
415
+ model.save_pretrained_merged(
416
+ f"{HUGGINGFACE_USERNAME}/{MODEL_NAME}",
417
+ tokenizer,
418
+ save_method="merged_16bit",
419
+ )
420
+ logger.success("Saved the model and tokenizer locally.")
421
+
422
+ model.push_to_hub_merged(
423
+ f"{HUGGINGFACE_USERNAME}/{MODEL_NAME}",
424
+ tokenizer,
425
+ save_method="merged_16bit",
426
+ token=HUGGINGFACE_TOKEN,
427
+ )
428
+ logger.success("Pushed the model and tokenizer to the Hugging Face Hub.")