streaming-speech / omni_speech /train /train_multiturn.py
NMCxyz's picture
Add files using upload-large-folder tool
9942354 verified
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
import torch
import transformers
import tokenizers
from omni_speech.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN
from torch.utils.data import Dataset
from omni_speech.train.omni_trainer import OmniTrainer
from audiomentations import AddBackgroundNoise, PolarityInversion
from omni_speech import conversation as conversation_lib
from omni_speech.model import *
from omni_speech.utils import *
from omni_speech.datasets.preprocess import *
import whisper
import time
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
version: Optional[str] = field(default="llama_3")
freeze_backbone: bool = field(default=False)
tune_speech_projector: bool = field(default=False)
tune_speech_encoder: bool = field(default=False)
tune_speech_generator_only: bool = field(default=False)
speech_encoder_type: Optional[str] = field(default=None)
speech_encoder: Optional[str] = field(default=None)
pretrain_speech_projector: Optional[str] = field(default=None)
speech_projector_type: Optional[str] = field(default='linear')
speech_generator_type: Optional[str] = field(default='ctc')
# ctc_decoder_config: str = "(2,4096,32,11008)" # num layers, hidden sizes, attn heads, ff dimensions of LLaMA
ctc_decoder_config: str = "(2,4096,32,22016)"
ctc_upsample_factor: int = 25
ctc_loss_weight: float = 1.0
unit_vocab_size: int = 1000
speech_encoder_ds_rate: int = 5
speech_encoder_hidden_size: int = 1280
@dataclass
class DataArguments:
data_path: str = field(default=None,
metadata={"help": "Path to the training data."})
dev_path: str = field(default=None,
metadata={"help": "Path to the dev data."})
is_multimodal: bool = False
input_type: str = field(default="mel")
speech_normalize: bool = False
mel_size: int = 128
has_tgt_units: bool = False
augment_prob: float = field(
default=0.0,
metadata={"help": "The probability of applying augmentation transform."}
)
augment_path: str = field(default=None,
metadata={"help": "Path to the augment data."})
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
freeze_speech_projector: bool = field(default=False)
model_max_length: int = field(
default=512,
metadata={
"help":
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
double_quant: bool = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
)
quant_type: str = field(
default="nf4",
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
)
bits: int = field(
default=16,
metadata={"help": "How many bits to use."}
)
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
speech_projector_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments):
super(LazySupervisedDataset, self).__init__()
list_data_dict = json.load(open(data_path, "r"))
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
if self.data_args.augment_prob != 0.0:
with open(self.data_args.augment_path, "r") as f:
augment_path_list = f.read().splitlines()
self.transform = AddBackgroundNoise(
sounds_path=augment_path_list,
min_snr_db=5.0,
max_snr_db=30.0,
noise_transform=PolarityInversion(),
p=self.data_args.augment_prob
)
def __len__(self):
return len(self.list_data_dict)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# TODO: define number of retries somewhere else
num_base_retries = 3
num_final_retries = 300
# try the current sample first
for attempt_idx in range(num_base_retries):
try:
sample = self._get_item(i)
return sample
except Exception as e:
# sleep 1s in case it is a cloud disk issue
print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
time.sleep(1)
# try other samples, in case it is file corruption issue
for attempt_idx in range(num_base_retries):
try:
next_index = min(i + 1, len(self.list_data_dict) - 1)
# sample_idx = random.choice(range(len(self)))
sample = self._get_item(next_index)
return sample
except Exception as e:
# no need to sleep
print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e)
pass
try:
sample = self._get_item(i)
return sample
except Exception as e:
raise e
def process_speech(self, speech_file):
speech = whisper.load_audio(speech_file)
if self.data_args.augment_prob != 0.0:
speech = self.transform(speech, sample_rate=16000)
if self.data_args.input_type == "raw":
speech = torch.from_numpy(speech)
if self.model_config.data_args.speech_normalize:
speech = torch.nn.functional.layer_norm(speech, speech.shape)
elif self.data_args.input_type == "mel":
speech = whisper.pad_or_trim(speech)
speech = whisper.log_mel_spectrogram(speech, n_mels=self.data_args.mel_size).permute(1, 0)
speech_lengths = torch.LongTensor([speech.shape[0]])
return speech, speech_lengths
def _get_item(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
for item in sources:
if 'tools' in item:
tools_dict = {
"from": "tools",
"value": item["tools"]
}
item["conversations"].insert(0,tools_dict)
if self.data_args.has_tgt_units:
# pad_list = [0]
# tgt_units = [e["tgt_units"] if "tgt_units" in e else pad_list for e in sources]
tgt_units = [e["tgt_units"] for e in sources]
tgt_units = torch.tensor(tgt_units, dtype=torch.long)
else:
tgt_units = None
if 'speech' in sources[0]:
import numpy as np
speech_file = self.list_data_dict[i]['speech']
if type(speech_file) is list:
speech = [self.process_speech(f) for f in speech_file]
else:
speech = [self.process_speech(speech_file)]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
has_speech=('speech' in self.list_data_dict[i]))
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
# speech exist in the data
if 'speech' in self.list_data_dict[i]:
data_dict['speech'] = speech
if tgt_units is not None:
data_dict['tgt_units'] = tgt_units[0]
data_dict["id"] = self.list_data_dict[i].get("id", i)
return data_dict
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def pad_sequence(self, input_ids, batch_first, padding_value):
if self.tokenizer.padding_side == "left":
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
if self.tokenizer.padding_side == "left":
input_ids = torch.flip(input_ids, [1])
return input_ids
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
# input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "id"))
input_ids = [_input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids]
labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels]
if self.tokenizer.pad_token_id is None:
# self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # FIXME: this could only be triggered for llama3 model.
self.tokenizer.pad_token_id = 0 # This gets the best result. Don't know why.
input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
batch = dict(input_ids=input_ids, labels=labels.long() if labels.dtype == torch.int32 else labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id))
# batch = dict(input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ids=ids)
if 'speech' in instances[0]:
speechs = [instance['speech'] for instance in instances]
speech = [sp[0] for sp_list in speechs for sp in sp_list]
speech_lengths = [sp[1] for sp_list in speechs for sp in sp_list]
batch["speech"] = speech
# print(len(speech)) # sum(len(audio) for audio in each batch)
# print(speech[0].shape) # seq_len, dim
batch['speech_lengths'] = speech_lengths
# print(speech_lengths[0].shape) # seq_len
if 'tgt_units' in instances[0]:
tgt_units = [instance['tgt_units'] for instance in instances]
tgt_units = self.pad_sequence(tgt_units, batch_first=True, padding_value=self.tokenizer.pad_token_id)
batch['tgt_units'] = tgt_units
# print(batch['tgt_units'])
# print("---------------")
# print(batch['input_ids'])
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args)
if data_args.dev_path is not None:
dev_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.dev_path,
data_args=data_args)
else:
dev_dataset = None
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset,
eval_dataset=dev_dataset,
data_collator=data_collator)
def train(attn_implementation="flash_attention_2"):
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_skip_modules=["speech_projector"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
)
))
if data_args.has_tgt_units:
if model_args.version == "llama_3":
model = OmniSpeech2SLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
elif model_args.version == "qwen":
model = OmniSpeech2SQwen2ForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
else:
raise ValueError("--currently only support llama or qwen model!")
else:
if model_args.version == "llama_3":
model = OmniSpeechLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
elif model_args.version == "qwen":
model = OmniSpeechQwen2ForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
else:
raise ValueError("--currently only support llama or qwen model!")
model.config.use_cache = False
if model_args.freeze_backbone:
model.model.requires_grad_(False)
if training_args.bits in [4, 8]:
from peft import prepare_model_for_kbit_training
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
model = get_peft_model(model, lora_config)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
model.resize_token_embeddings(len(tokenizer))
model.config.max_length = training_args.model_max_length
if model_args.version in conversation_lib.conv_templates:
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
conversation_lib.default_conversation = conversation_lib.conv_templates["llama_3"]
if model_args.speech_encoder is not None:
model.get_model().initialize_speech_modules(
model_args=model_args,
fsdp=training_args.fsdp
)
data_args.is_multimodal = True
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length
model.config.tune_speech_projector = training_args.tune_speech_projector = model_args.tune_speech_projector
model.config.speech_normalize = data_args.speech_normalize
for p in model.get_speech_encoder().parameters():
p.requires_grad = False
if model_args.tune_speech_projector:
model.requires_grad_(False)
for p in model.get_speech_projector().parameters():
p.requires_grad = True
model.config.freeze_speech_projector = training_args.freeze_speech_projector
if training_args.freeze_speech_projector:
for p in model.get_speech_projector().parameters():
p.requires_grad = False
if training_args.bits in [4, 8]:
model.get_model().speech_projector.to(dtype=compute_dtype, device=training_args.device)
model.config.speech_projector_lr = training_args.speech_projector_lr
if data_args.has_tgt_units:
model.initialize_speech_generator(model_args=model_args)
if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
data_module = make_supervised_data_module(tokenizer=tokenizer,
data_args=data_args)
print("Training Layers:")
for name, param in model.named_parameters():
if param.requires_grad:
print(name, param.grad)
trainer = OmniTrainer(model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), training_args.lora_bias
)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
model.named_parameters()
)
if training_args.local_rank == 0 or training_args.local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
else:
safe_save_model_for_hf_trainer(trainer=trainer,
output_dir=training_args.output_dir)
if __name__ == "__main__":
train()