NMCxyz's picture
Add files using upload-large-folder tool
20e4eaa verified
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import torch
import transformers
import tokenizers
from typing import Dict, Sequence
from omni_speech.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX
from omni_speech import conversation as conversation_lib
from omni_speech.model import *
from omni_speech.arguments import DataArguments
from omni_speech.constants import SPEECH_TOKEN_INDEX
import json
from packaging import version
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')]
# print(f'number of turn {len(prompt_chunks)-1}')
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def preprocess_multimodal(
sources: Sequence[str],
data_args: DataArguments
) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
if DEFAULT_SPEECH_TOKEN in sentence['value']:
sentence['value'] = sentence['value'].replace(DEFAULT_SPEECH_TOKEN, '').strip()
sentence['value'] = DEFAULT_SPEECH_TOKEN + '\n' + sentence['value']
sentence['value'] = sentence['value'].strip()
return sources
def preprocess_llama_2(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_speech: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt(tokenizer))
# Tokenize conversations
if has_speech:
input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
# Mask targets
sep = "[/INST] "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_speech:
round_len = len(tokenizer_speech_token(rou, tokenizer))
instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
# def preprocess_llama_3(
# sources,
# tokenizer: transformers.PreTrainedTokenizer,
# has_speech: bool = False
# ) -> Dict:
# conv = conversation_lib.default_conversation.copy()
# roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# # Apply prompt templates
# conversations = []
# for i, source in enumerate(sources):
# if roles[source[0]["from"]] != conv.roles[0]:
# # Skip the first one if it is not from human
# source = source[1:]
# # assert len(source) == 2, "now only support single-turn conversation"
# conv.messages = []
# for j, sentence in enumerate(source):
# role = roles[sentence["from"]]
# assert role == conv.roles[j % 2], f"{i}"
# conv.append_message(role, sentence["value"])
# source_conv.append({
# "role": role,
# "content": sentence["value"],
# })
# # print(conv.get_prompt(tokenizer))
# # print("+++++++++++++++++++++++")
# conversations.append(conv.get_prompt(tokenizer))
# # Tokenize conversations
# if has_speech:
# input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
# else:
# input_ids = tokenizer(
# conversations,
# return_tensors="pt",
# padding="longest",
# max_length=tokenizer.model_max_length,
# truncation=True,
# ).input_ids
# targets = input_ids.clone()
# assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3
# # Mask targets
# sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>\n\n"
# for conversation, target in zip(conversations, targets):
# total_len = int(target.ne(tokenizer.pad_token_id).sum())
# cur_len = 1
# target[:cur_len] = IGNORE_INDEX
# parts = conversation.split(sep)
# parts[0] += sep
# if has_speech:
# conversation_len = len(tokenizer_speech_token(conversation, tokenizer))
# instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 1
# else:
# conversation_len = len(tokenizer(conversation).input_ids)
# instruction_len = len(tokenizer(parts[0]).input_ids) - 1
# target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
# cur_len += conversation_len
# target[cur_len:] = IGNORE_INDEX
# # if cur_len < tokenizer.model_max_length:
# # if cur_len != total_len:
# # target[:] = IGNORE_INDEX
# # print(
# # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
# # f" (ignored)"
# # )
# return dict(
# input_ids=input_ids,
# labels=targets,
# )
def preprocess_llama_3(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_speech: bool = False,
system_message: str = "You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language.",
) -> Dict:
# roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"}
roles = {"human": "user", "gpt": "assistant"}
# Add speech tokens to tokenizer as a special tokens
# Use a deepcopy of tokenizer so that we don't modify on the tokenizer
tokenizer = copy.deepcopy(tokenizer)
# When there is actually an image, we add the image tokens as a special token
if has_speech:
tokenizer.add_tokens([DEFAULT_SPEECH_TOKEN], special_tokens=True)
speech_token_index = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
unmask_tokens = ["<|begin_of_text|>", "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>", "\n\n"]
unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens]
# After update, calling tokenizer of llama3 will
# auto add bos id for the tokens. ヽ(`⌒´)ノ
def safe_tokenizer_llama3(text):
input_ids = tokenizer(text).input_ids
if input_ids[0] == bos_token_id:
input_ids = input_ids[1:]
return input_ids
nl_tokens = tokenizer.convert_tokens_to_ids("\n\n")
# Apply prompt templates
input_ids, targets = [], []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
# New version, use apply chat template
# Build system message for each sentence
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
target += [IGNORE_INDEX] * len(input_id)
for conv in source:
# Make sure llava data can load
try:
role = conv["role"]
content = conv["content"]
except:
role = conv["from"]
content = conv["value"]
role = roles.get(role, role)
conv = [{"role" : role, "content" : content}]
# First is bos token we don't need here
encode_id = tokenizer.apply_chat_template(conv)[1:]
input_id += encode_id
if role in ["user", "system"]:
target += [IGNORE_INDEX] * len(encode_id)
else:
target += encode_id
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
for idx, encode_id in enumerate(input_id):
if encode_id in unmask_tokens_idx:
target[idx] = encode_id
if encode_id == speech_token_index:
input_id[idx] = SPEECH_TOKEN_INDEX
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return dict(
input_ids=input_ids, # tensor(bs x seq_len)
labels=targets, # tensor(bs x seq_len)
)
def preprocess_v1(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_speech: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt(tokenizer))
# Tokenize conversations
if has_speech:
input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_speech:
round_len = len(tokenizer_speech_token(rou, tokenizer))
instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
# FIXME: tokenizer bug
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
round_len -= 1
instruction_len -= 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
# def preprocess_qwen(
# sources,
# tokenizer: transformers.PreTrainedTokenizer,
# has_speech: bool = False
# ) -> Dict:
# # Initialize conversation and roles
# conv = conversation_lib.default_conversation.copy()
# roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# # Apply prompt templates for each conversation source
# conversations = []
# for i, source in enumerate(sources):
# if roles[source[0]["from"]] != conv.roles[0]:
# # Skip the first item if not from the human role
# source = source[1:]
# # assert len(source) == 2, "Supports single-turn conversations only."
# conv.messages = []
# for j, sentence in enumerate(source):
# role = roles[sentence["from"]]
# assert role == conv.roles[j % 2], f"Role mismatch in conversation {i}."
# conv.append_message(role, sentence["value"])
# # print(conv.get_prompt(tokenizer))
# # print("--------------")
# conversations.append(conv.get_prompt(tokenizer))
# # Tokenize conversations
# if has_speech:
# # Assuming tokenizer_speech_token for Qwen when handling speech tokens
# input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
# else:
# input_ids = tokenizer(
# conversations,
# return_tensors="pt",
# padding="longest",
# max_length=tokenizer.model_max_length,
# truncation=True,
# ).input_ids
# # Clone input_ids for target labels
# targets = input_ids.clone()
# assert conv.sep_style == conversation_lib.SeparatorStyle.QWEN
# # Mask targets with Qwen-specific tokenization patterns
# sep = "<|im_start|>" + conv.roles[1] + "\n"
# for conversation, target in zip(conversations, targets):
# total_len = int(target.ne(tokenizer.pad_token_id).sum())
# cur_len = 1
# target[:cur_len] = IGNORE_INDEX
# parts = conversation.split(sep)
# parts[0] += sep
# if has_speech:
# conversation_len = len(tokenizer_speech_token(conversation, tokenizer))
# instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 1
# else:
# conversation_len = len(tokenizer(conversation).input_ids)
# instruction_len = len(tokenizer(parts[0]).input_ids) - 1
# target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
# cur_len += conversation_len
# target[cur_len:] = IGNORE_INDEX
# # Optionally handle tokenization mismatch
# # if cur_len < tokenizer.model_max_length:
# # if cur_len != total_len:
# # target[:] = IGNORE_INDEX
# # print(
# # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
# # f" (ignored)"
# # )
# return dict(
# input_ids=input_ids,
# labels=targets,
# )
def preprocess_qwen(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_speech: bool = False,
system_message: str = "You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language."
) -> Dict:
roles = {"human": "user", "gpt": "assistant"}
# Use a deepcopy of tokenizer so that we don't modify on the tokenizer
tokenizer = copy.deepcopy(tokenizer)
if has_speech:
tokenizer.add_tokens([DEFAULT_SPEECH_TOKEN], special_tokens=True)
speech_token_index = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
unmask_tokens = ["<|im_start|>", "<|im_end|>"]
unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(x) for x in unmask_tokens]
# Apply prompt templates
input_ids, targets = [], []
for i, source in enumerate(sources):
if source[0]["from"] == "tools":
tools = source[0]["value"]
source = source[1:]
else:
tools = None
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
# New version, use apply chat template
# Build system message for each sentence
if tools is not None:
json_objects = tools.split("\n\n")
try:
fc = [json.loads(obj) for obj in json_objects]
except:
if len(json_objects) > 1:
json_objects = json_objects[:-1]
fc = [json.loads(obj) for obj in json_objects]
# print(fc)
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}], tools = fc)
else:
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
target += [IGNORE_INDEX] * len(input_id)
# Reset Qwen chat templates so that it won't include system message every time we apply
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer.chat_template = chat_template
for conv in source:
try:
role = conv["role"]
content = conv["content"]
except:
role = conv["from"]
content = conv["value"]
role = roles.get(role, role)
conv = [{"role" : role, "content" : content}]
encode_id = tokenizer.apply_chat_template(conv)
input_id += encode_id
if role in ["user", "system"]:
target += [IGNORE_INDEX] * len(encode_id)
else:
target += encode_id
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
for idx, encode_id in enumerate(input_id):
if encode_id in unmask_tokens_idx:
target[idx] = encode_id
if encode_id == speech_token_index:
input_id[idx] = SPEECH_TOKEN_INDEX
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_SPEECH_TOKEN in source[0]['value']
source[0]['value'] = DEFAULT_SPEECH_TOKEN
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
conversations.append(conversation)
# tokenize conversations
input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_speech: bool = False
) -> Dict:
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
return preprocess_llama_2(sources, tokenizer, has_speech=has_speech)
if conversation_lib.default_conversation.version.startswith("v1"):
return preprocess_v1(sources, tokenizer, has_speech=has_speech)
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_3:
return preprocess_llama_3(sources, tokenizer, has_speech=has_speech)
if conversation_lib.default_conversation.version == "qwen":
return preprocess_qwen(sources, tokenizer, has_speech=has_speech)
raise NotImplementedError