# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import copy from dataclasses import dataclass, field import json import logging import pathlib from typing import Dict, Optional, Sequence, List import torch import transformers import tokenizers from omni_speech.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN from torch.utils.data import Dataset from omni_speech.train.omni_trainer import OmniTrainer from omni_speech import conversation as conversation_lib from omni_speech.model import * from omni_speech.utils import * from omni_speech.datasets.preprocess import * import whisper @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") version: Optional[str] = field(default="llama_3") freeze_backbone: bool = field(default=False) tune_speech_projector: bool = field(default=False) tune_speech_encoder: bool = field(default=False) tune_speech_generator_only: bool = field(default=False) speech_encoder_type: Optional[str] = field(default=None) speech_encoder: Optional[str] = field(default=None) pretrain_speech_projector: Optional[str] = field(default=None) speech_projector_type: Optional[str] = field(default='linear') speech_generator_type: Optional[str] = field(default='ctc') ctc_decoder_config: str = "(2,4096,32,11008)" ctc_upsample_factor: int = 1 ctc_loss_weight: float = 1.0 unit_vocab_size: int = 1000 speech_encoder_ds_rate: int = 5 speech_encoder_hidden_size: int = 1280 @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) dev_path: str = field(default=None, metadata={"help": "Path to the dev data."}) is_multimodal: bool = False input_type: str = field(default="mel") speech_normalize: bool = False mel_size: int = 128 has_tgt_units: bool = False @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") freeze_speech_projector: bool = field(default=False) model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."} ) quant_type: str = field( default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} ) bits: int = field( default=16, metadata={"help": "How many bits to use."} ) lora_enable: bool = False lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" speech_projector_lr: Optional[float] = None group_by_modality_length: bool = field(default=False) class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): super(LazySupervisedDataset, self).__init__() list_data_dict = json.load(open(data_path, "r")) self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args def __len__(self): return len(self.list_data_dict) def __getitem__(self, i) -> Dict[str, torch.Tensor]: sources = self.list_data_dict[i] if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME if 'speech' in sources[0]: import numpy as np speech_file = self.list_data_dict[i]['speech'] speech = whisper.load_audio(speech_file) # speech = np.random.uniform(low=-1.0, high=1.0, size=speech.shape[0]).astype(speech.dtype) if self.data_args.input_type == "raw": speech = torch.from_numpy(speech) if self.model_config.data_args.speech_normalize: speech = torch.nn.functional.layer_norm(speech, speech.shape) elif self.data_args.input_type == "mel": speech = whisper.pad_or_trim(speech) speech = whisper.log_mel_spectrogram(speech, n_mels=self.data_args.mel_size).permute(1, 0) speech_lengths = torch.LongTensor([speech.shape[0]]) sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.data_args) else: sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_speech=('speech' in self.list_data_dict[i])) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # speech exist in the data if 'speech' in self.list_data_dict[i]: data_dict['speech'] = speech data_dict['speech_lengths'] = speech_lengths return data_dict @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :self.tokenizer.model_max_length] labels = labels[:, :self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'speech' in instances[0]: speech = [instance['speech'] for instance in instances] speech_lengths = [instance['speech_lengths'] for instance in instances] if all(x is not None and x.shape == speech[0].shape for x in speech): batch['speech'] = torch.stack(speech) batch['speech_lengths'] = torch.stack(speech_lengths) else: batch['speech'] = speech batch['speech_lengths'] = speech_lengths return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) if data_args.dev_path is not None: dev_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.dev_path, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=dev_dataset, data_collator=data_collator) def train(attn_implementation="flash_attention_2"): parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update(dict( device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, llm_int8_skip_modules=["speech_projector"], llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} ) )) if data_args.has_tgt_units: if model_args.version == "llama_3": model = OmniSpeech2SLlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **bnb_model_from_pretrained_args ) elif model_args.version == "qwen": model = OmniSpeech2SQwen2ForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **bnb_model_from_pretrained_args ) else: raise ValueError("--currently only support llama or qwen model!") else: if model_args.version == "llama_3": model = OmniSpeechLlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **bnb_model_from_pretrained_args ) elif model_args.version == "qwen": model = OmniSpeechQwen2ForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **bnb_model_from_pretrained_args ) else: raise ValueError("--currently only support llama or qwen model!") model.config.use_cache = False if model_args.freeze_backbone: model.model.requires_grad_(False) if training_args.bits in [4, 8]: from peft import prepare_model_for_kbit_training model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if training_args.lora_enable: from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, target_modules=find_all_linear_names(model), lora_dropout=training_args.lora_dropout, bias=training_args.lora_bias, task_type="CAUSAL_LM", ) if training_args.bits == 16: if training_args.bf16: model.to(torch.bfloat16) if training_args.fp16: model.to(torch.float16) model = get_peft_model(model, lora_config) tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) model.resize_token_embeddings(len(tokenizer)) model.config.max_length = training_args.model_max_length if model_args.version in conversation_lib.conv_templates: conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] else: conversation_lib.default_conversation = conversation_lib.conv_templates["llama_3"] if model_args.speech_encoder is not None: model.get_model().initialize_speech_modules( model_args=model_args, fsdp=training_args.fsdp ) data_args.is_multimodal = True model.config.tokenizer_padding_side = tokenizer.padding_side model.config.tokenizer_model_max_length = tokenizer.model_max_length model.config.tune_speech_projector = training_args.tune_speech_projector = model_args.tune_speech_projector model.config.speech_normalize = data_args.speech_normalize for p in model.get_speech_encoder().parameters(): p.requires_grad = False if model_args.tune_speech_projector: model.requires_grad_(False) for p in model.get_speech_projector().parameters(): p.requires_grad = True model.config.freeze_speech_projector = training_args.freeze_speech_projector if training_args.freeze_speech_projector: for p in model.get_speech_projector().parameters(): p.requires_grad = False if training_args.bits in [4, 8]: model.get_model().speech_projector.to(dtype=compute_dtype, device=training_args.device) model.config.speech_projector_lr = training_args.speech_projector_lr if data_args.has_tgt_units: model.initialize_speech_generator(model_args=model_args) if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): if isinstance(module, LoraLayer): if training_args.bf16: module = module.to(torch.bfloat16) if 'norm' in name: module = module.to(torch.float32) if 'lm_head' in name or 'embed_tokens' in name: if hasattr(module, 'weight'): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) print("Training Layers:") for name, param in model.named_parameters(): if param.requires_grad: print(name, param.grad) trainer = OmniTrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() model.config.use_cache = True if training_args.lora_enable: state_dict = get_peft_state_maybe_zero_3( model.named_parameters(), training_args.lora_bias ) non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( model.named_parameters() ) if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) else: safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) if __name__ == "__main__": train()