# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import torch import transformers import tokenizers from typing import Dict, Sequence from omni_speech.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX from omni_speech import conversation as conversation_lib from omni_speech.model import * from omni_speech.arguments import DataArguments from omni_speech.constants import SPEECH_TOKEN_INDEX import json from packaging import version IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] # print(f'number of turn {len(prompt_chunks)-1}') def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids def preprocess_multimodal( sources: Sequence[str], data_args: DataArguments ) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_SPEECH_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace(DEFAULT_SPEECH_TOKEN, '').strip() sentence['value'] = DEFAULT_SPEECH_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() return sources def preprocess_llama_2( sources, tokenizer: transformers.PreTrainedTokenizer, has_speech: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt(tokenizer)) # Tokenize conversations if has_speech: input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 # Mask targets sep = "[/INST] " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_speech: round_len = len(tokenizer_speech_token(rou, tokenizer)) instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) # def preprocess_llama_3( # sources, # tokenizer: transformers.PreTrainedTokenizer, # has_speech: bool = False # ) -> Dict: # conv = conversation_lib.default_conversation.copy() # roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # # Apply prompt templates # conversations = [] # for i, source in enumerate(sources): # if roles[source[0]["from"]] != conv.roles[0]: # # Skip the first one if it is not from human # source = source[1:] # # assert len(source) == 2, "now only support single-turn conversation" # conv.messages = [] # for j, sentence in enumerate(source): # role = roles[sentence["from"]] # assert role == conv.roles[j % 2], f"{i}" # conv.append_message(role, sentence["value"]) # source_conv.append({ # "role": role, # "content": sentence["value"], # }) # # print(conv.get_prompt(tokenizer)) # # print("+++++++++++++++++++++++") # conversations.append(conv.get_prompt(tokenizer)) # # Tokenize conversations # if has_speech: # input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) # else: # input_ids = tokenizer( # conversations, # return_tensors="pt", # padding="longest", # max_length=tokenizer.model_max_length, # truncation=True, # ).input_ids # targets = input_ids.clone() # assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3 # # Mask targets # sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>\n\n" # for conversation, target in zip(conversations, targets): # total_len = int(target.ne(tokenizer.pad_token_id).sum()) # cur_len = 1 # target[:cur_len] = IGNORE_INDEX # parts = conversation.split(sep) # parts[0] += sep # if has_speech: # conversation_len = len(tokenizer_speech_token(conversation, tokenizer)) # instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 1 # else: # conversation_len = len(tokenizer(conversation).input_ids) # instruction_len = len(tokenizer(parts[0]).input_ids) - 1 # target[cur_len : cur_len + instruction_len] = IGNORE_INDEX # cur_len += conversation_len # target[cur_len:] = IGNORE_INDEX # # if cur_len < tokenizer.model_max_length: # # if cur_len != total_len: # # target[:] = IGNORE_INDEX # # print( # # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." # # f" (ignored)" # # ) # return dict( # input_ids=input_ids, # labels=targets, # ) def preprocess_llama_3( sources, tokenizer: transformers.PreTrainedTokenizer, has_speech: bool = False, system_message: str = "You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language.", ) -> Dict: # roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"} roles = {"human": "user", "gpt": "assistant"} # Add speech tokens to tokenizer as a special tokens # Use a deepcopy of tokenizer so that we don't modify on the tokenizer tokenizer = copy.deepcopy(tokenizer) # When there is actually an image, we add the image tokens as a special token if has_speech: tokenizer.add_tokens([DEFAULT_SPEECH_TOKEN], special_tokens=True) speech_token_index = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>") start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>") end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>") eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") unmask_tokens = ["<|begin_of_text|>", "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>", "\n\n"] unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens] # After update, calling tokenizer of llama3 will # auto add bos id for the tokens. ヽ(`⌒´)ノ def safe_tokenizer_llama3(text): input_ids = tokenizer(text).input_ids if input_ids[0] == bos_token_id: input_ids = input_ids[1:] return input_ids nl_tokens = tokenizer.convert_tokens_to_ids("\n\n") # Apply prompt templates input_ids, targets = [], [] for i, source in enumerate(sources): if roles[source[0]["from"]] != roles["human"]: source = source[1:] input_id, target = [], [] # New version, use apply chat template # Build system message for each sentence input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) target += [IGNORE_INDEX] * len(input_id) for conv in source: # Make sure llava data can load try: role = conv["role"] content = conv["content"] except: role = conv["from"] content = conv["value"] role = roles.get(role, role) conv = [{"role" : role, "content" : content}] # First is bos token we don't need here encode_id = tokenizer.apply_chat_template(conv)[1:] input_id += encode_id if role in ["user", "system"]: target += [IGNORE_INDEX] * len(encode_id) else: target += encode_id assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" for idx, encode_id in enumerate(input_id): if encode_id in unmask_tokens_idx: target[idx] = encode_id if encode_id == speech_token_index: input_id[idx] = SPEECH_TOKEN_INDEX input_ids.append(input_id) targets.append(target) input_ids = torch.tensor(input_ids, dtype=torch.long) targets = torch.tensor(targets, dtype=torch.long) return dict( input_ids=input_ids, # tensor(bs x seq_len) labels=targets, # tensor(bs x seq_len) ) def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_speech: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt(tokenizer)) # Tokenize conversations if has_speech: input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_speech: round_len = len(tokenizer_speech_token(rou, tokenizer)) instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # FIXME: tokenizer bug if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: round_len -= 1 instruction_len -= 1 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) # def preprocess_qwen( # sources, # tokenizer: transformers.PreTrainedTokenizer, # has_speech: bool = False # ) -> Dict: # # Initialize conversation and roles # conv = conversation_lib.default_conversation.copy() # roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # # Apply prompt templates for each conversation source # conversations = [] # for i, source in enumerate(sources): # if roles[source[0]["from"]] != conv.roles[0]: # # Skip the first item if not from the human role # source = source[1:] # # assert len(source) == 2, "Supports single-turn conversations only." # conv.messages = [] # for j, sentence in enumerate(source): # role = roles[sentence["from"]] # assert role == conv.roles[j % 2], f"Role mismatch in conversation {i}." # conv.append_message(role, sentence["value"]) # # print(conv.get_prompt(tokenizer)) # # print("--------------") # conversations.append(conv.get_prompt(tokenizer)) # # Tokenize conversations # if has_speech: # # Assuming tokenizer_speech_token for Qwen when handling speech tokens # input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) # else: # input_ids = tokenizer( # conversations, # return_tensors="pt", # padding="longest", # max_length=tokenizer.model_max_length, # truncation=True, # ).input_ids # # Clone input_ids for target labels # targets = input_ids.clone() # assert conv.sep_style == conversation_lib.SeparatorStyle.QWEN # # Mask targets with Qwen-specific tokenization patterns # sep = "<|im_start|>" + conv.roles[1] + "\n" # for conversation, target in zip(conversations, targets): # total_len = int(target.ne(tokenizer.pad_token_id).sum()) # cur_len = 1 # target[:cur_len] = IGNORE_INDEX # parts = conversation.split(sep) # parts[0] += sep # if has_speech: # conversation_len = len(tokenizer_speech_token(conversation, tokenizer)) # instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 1 # else: # conversation_len = len(tokenizer(conversation).input_ids) # instruction_len = len(tokenizer(parts[0]).input_ids) - 1 # target[cur_len : cur_len + instruction_len] = IGNORE_INDEX # cur_len += conversation_len # target[cur_len:] = IGNORE_INDEX # # Optionally handle tokenization mismatch # # if cur_len < tokenizer.model_max_length: # # if cur_len != total_len: # # target[:] = IGNORE_INDEX # # print( # # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." # # f" (ignored)" # # ) # return dict( # input_ids=input_ids, # labels=targets, # ) def preprocess_qwen( sources, tokenizer: transformers.PreTrainedTokenizer, has_speech: bool = False, system_message: str = "You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language." ) -> Dict: roles = {"human": "user", "gpt": "assistant"} # Use a deepcopy of tokenizer so that we don't modify on the tokenizer tokenizer = copy.deepcopy(tokenizer) if has_speech: tokenizer.add_tokens([DEFAULT_SPEECH_TOKEN], special_tokens=True) speech_token_index = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) unmask_tokens = ["<|im_start|>", "<|im_end|>"] unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(x) for x in unmask_tokens] # Apply prompt templates input_ids, targets = [], [] for i, source in enumerate(sources): if source[0]["from"] == "tools": tools = source[0]["value"] source = source[1:] else: tools = None if roles[source[0]["from"]] != roles["human"]: source = source[1:] input_id, target = [], [] # New version, use apply chat template # Build system message for each sentence if tools is not None: json_objects = tools.split("\n\n") try: fc = [json.loads(obj) for obj in json_objects] except: if len(json_objects) > 1: json_objects = json_objects[:-1] fc = [json.loads(obj) for obj in json_objects] # print(fc) input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}], tools = fc) else: input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) target += [IGNORE_INDEX] * len(input_id) # Reset Qwen chat templates so that it won't include system message every time we apply chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" tokenizer.chat_template = chat_template for conv in source: try: role = conv["role"] content = conv["content"] except: role = conv["from"] content = conv["value"] role = roles.get(role, role) conv = [{"role" : role, "content" : content}] encode_id = tokenizer.apply_chat_template(conv) input_id += encode_id if role in ["user", "system"]: target += [IGNORE_INDEX] * len(encode_id) else: target += encode_id assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" for idx, encode_id in enumerate(input_id): if encode_id in unmask_tokens_idx: target[idx] = encode_id if encode_id == speech_token_index: input_id[idx] = SPEECH_TOKEN_INDEX input_ids.append(input_id) targets.append(target) input_ids = torch.tensor(input_ids, dtype=torch.long) targets = torch.tensor(targets, dtype=torch.long) return dict( input_ids=input_ids, labels=targets, ) def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: # add end signal and concatenate together conversations = [] for source in sources: assert len(source) == 2 assert DEFAULT_SPEECH_TOKEN in source[0]['value'] source[0]['value'] = DEFAULT_SPEECH_TOKEN conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep conversations.append(conversation) # tokenize conversations input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer)) target[:tokenized_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=targets) def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_speech: bool = False ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: return preprocess_plain(sources, tokenizer) if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: return preprocess_llama_2(sources, tokenizer, has_speech=has_speech) if conversation_lib.default_conversation.version.startswith("v1"): return preprocess_v1(sources, tokenizer, has_speech=has_speech) if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_3: return preprocess_llama_3(sources, tokenizer, has_speech=has_speech) if conversation_lib.default_conversation.version == "qwen": return preprocess_qwen(sources, tokenizer, has_speech=has_speech) raise NotImplementedError