|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import torch |
|
|
import transformers |
|
|
import tokenizers |
|
|
|
|
|
from typing import Dict, Sequence |
|
|
|
|
|
from omni_speech.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX |
|
|
from omni_speech import conversation as conversation_lib |
|
|
from omni_speech.model import * |
|
|
from omni_speech.arguments import DataArguments |
|
|
from omni_speech.constants import SPEECH_TOKEN_INDEX |
|
|
import json |
|
|
|
|
|
from packaging import version |
|
|
|
|
|
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') |
|
|
|
|
|
|
|
|
def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None): |
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')] |
|
|
|
|
|
|
|
|
def insert_separator(X, sep): |
|
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] |
|
|
|
|
|
input_ids = [] |
|
|
offset = 0 |
|
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
|
|
offset = 1 |
|
|
input_ids.append(prompt_chunks[0][0]) |
|
|
|
|
|
for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)): |
|
|
input_ids.extend(x[offset:]) |
|
|
|
|
|
if return_tensors is not None: |
|
|
if return_tensors == 'pt': |
|
|
return torch.tensor(input_ids, dtype=torch.long) |
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}') |
|
|
return input_ids |
|
|
|
|
|
|
|
|
def preprocess_multimodal( |
|
|
sources: Sequence[str], |
|
|
data_args: DataArguments |
|
|
) -> Dict: |
|
|
is_multimodal = data_args.is_multimodal |
|
|
if not is_multimodal: |
|
|
return sources |
|
|
|
|
|
for source in sources: |
|
|
for sentence in source: |
|
|
if DEFAULT_SPEECH_TOKEN in sentence['value']: |
|
|
sentence['value'] = sentence['value'].replace(DEFAULT_SPEECH_TOKEN, '').strip() |
|
|
sentence['value'] = DEFAULT_SPEECH_TOKEN + '\n' + sentence['value'] |
|
|
sentence['value'] = sentence['value'].strip() |
|
|
|
|
|
return sources |
|
|
|
|
|
|
|
|
def preprocess_llama_2( |
|
|
sources, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
has_speech: bool = False |
|
|
) -> Dict: |
|
|
conv = conversation_lib.default_conversation.copy() |
|
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
|
|
|
conversations = [] |
|
|
for i, source in enumerate(sources): |
|
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
|
|
source = source[1:] |
|
|
|
|
|
conv.messages = [] |
|
|
for j, sentence in enumerate(source): |
|
|
role = roles[sentence["from"]] |
|
|
assert role == conv.roles[j % 2], f"{i}" |
|
|
conv.append_message(role, sentence["value"]) |
|
|
conversations.append(conv.get_prompt(tokenizer)) |
|
|
|
|
|
|
|
|
|
|
|
if has_speech: |
|
|
input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
|
|
else: |
|
|
input_ids = tokenizer( |
|
|
conversations, |
|
|
return_tensors="pt", |
|
|
padding="longest", |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
).input_ids |
|
|
|
|
|
targets = input_ids.clone() |
|
|
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 |
|
|
|
|
|
|
|
|
sep = "[/INST] " |
|
|
for conversation, target in zip(conversations, targets): |
|
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
|
|
rounds = conversation.split(conv.sep2) |
|
|
cur_len = 1 |
|
|
target[:cur_len] = IGNORE_INDEX |
|
|
for i, rou in enumerate(rounds): |
|
|
if rou == "": |
|
|
break |
|
|
|
|
|
parts = rou.split(sep) |
|
|
if len(parts) != 2: |
|
|
break |
|
|
parts[0] += sep |
|
|
|
|
|
if has_speech: |
|
|
round_len = len(tokenizer_speech_token(rou, tokenizer)) |
|
|
instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 |
|
|
else: |
|
|
round_len = len(tokenizer(rou).input_ids) |
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
|
|
|
cur_len += round_len |
|
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
|
|
if cur_len < tokenizer.model_max_length: |
|
|
if cur_len != total_len: |
|
|
target[:] = IGNORE_INDEX |
|
|
print( |
|
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
|
f" (ignored)" |
|
|
) |
|
|
|
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=targets, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_llama_3( |
|
|
sources, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
has_speech: bool = False, |
|
|
system_message: str = "You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language.", |
|
|
) -> Dict: |
|
|
|
|
|
roles = {"human": "user", "gpt": "assistant"} |
|
|
|
|
|
|
|
|
|
|
|
tokenizer = copy.deepcopy(tokenizer) |
|
|
|
|
|
if has_speech: |
|
|
tokenizer.add_tokens([DEFAULT_SPEECH_TOKEN], special_tokens=True) |
|
|
speech_token_index = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) |
|
|
bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>") |
|
|
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>") |
|
|
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>") |
|
|
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
|
|
|
|
unmask_tokens = ["<|begin_of_text|>", "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>", "\n\n"] |
|
|
unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens] |
|
|
|
|
|
|
|
|
|
|
|
def safe_tokenizer_llama3(text): |
|
|
input_ids = tokenizer(text).input_ids |
|
|
if input_ids[0] == bos_token_id: |
|
|
input_ids = input_ids[1:] |
|
|
return input_ids |
|
|
|
|
|
nl_tokens = tokenizer.convert_tokens_to_ids("\n\n") |
|
|
|
|
|
input_ids, targets = [], [] |
|
|
for i, source in enumerate(sources): |
|
|
if roles[source[0]["from"]] != roles["human"]: |
|
|
source = source[1:] |
|
|
|
|
|
input_id, target = [], [] |
|
|
|
|
|
|
|
|
|
|
|
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) |
|
|
target += [IGNORE_INDEX] * len(input_id) |
|
|
|
|
|
for conv in source: |
|
|
|
|
|
try: |
|
|
role = conv["role"] |
|
|
content = conv["content"] |
|
|
except: |
|
|
role = conv["from"] |
|
|
content = conv["value"] |
|
|
|
|
|
role = roles.get(role, role) |
|
|
|
|
|
conv = [{"role" : role, "content" : content}] |
|
|
|
|
|
encode_id = tokenizer.apply_chat_template(conv)[1:] |
|
|
input_id += encode_id |
|
|
if role in ["user", "system"]: |
|
|
target += [IGNORE_INDEX] * len(encode_id) |
|
|
else: |
|
|
target += encode_id |
|
|
|
|
|
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" |
|
|
for idx, encode_id in enumerate(input_id): |
|
|
if encode_id in unmask_tokens_idx: |
|
|
target[idx] = encode_id |
|
|
if encode_id == speech_token_index: |
|
|
input_id[idx] = SPEECH_TOKEN_INDEX |
|
|
input_ids.append(input_id) |
|
|
targets.append(target) |
|
|
input_ids = torch.tensor(input_ids, dtype=torch.long) |
|
|
targets = torch.tensor(targets, dtype=torch.long) |
|
|
|
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=targets, |
|
|
) |
|
|
|
|
|
|
|
|
def preprocess_v1( |
|
|
sources, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
has_speech: bool = False |
|
|
) -> Dict: |
|
|
conv = conversation_lib.default_conversation.copy() |
|
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
|
|
|
conversations = [] |
|
|
for i, source in enumerate(sources): |
|
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
|
|
source = source[1:] |
|
|
|
|
|
conv.messages = [] |
|
|
for j, sentence in enumerate(source): |
|
|
role = roles[sentence["from"]] |
|
|
assert role == conv.roles[j % 2], f"{i}" |
|
|
conv.append_message(role, sentence["value"]) |
|
|
conversations.append(conv.get_prompt(tokenizer)) |
|
|
|
|
|
|
|
|
|
|
|
if has_speech: |
|
|
input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
|
|
else: |
|
|
input_ids = tokenizer( |
|
|
conversations, |
|
|
return_tensors="pt", |
|
|
padding="longest", |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
).input_ids |
|
|
|
|
|
targets = input_ids.clone() |
|
|
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO |
|
|
|
|
|
|
|
|
sep = conv.sep + conv.roles[1] + ": " |
|
|
for conversation, target in zip(conversations, targets): |
|
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
|
|
rounds = conversation.split(conv.sep2) |
|
|
cur_len = 1 |
|
|
target[:cur_len] = IGNORE_INDEX |
|
|
for i, rou in enumerate(rounds): |
|
|
if rou == "": |
|
|
break |
|
|
|
|
|
parts = rou.split(sep) |
|
|
if len(parts) != 2: |
|
|
break |
|
|
parts[0] += sep |
|
|
|
|
|
if has_speech: |
|
|
round_len = len(tokenizer_speech_token(rou, tokenizer)) |
|
|
instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 |
|
|
else: |
|
|
round_len = len(tokenizer(rou).input_ids) |
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
|
|
|
|
|
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: |
|
|
round_len -= 1 |
|
|
instruction_len -= 1 |
|
|
|
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
|
|
|
cur_len += round_len |
|
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
|
|
if cur_len < tokenizer.model_max_length: |
|
|
if cur_len != total_len: |
|
|
target[:] = IGNORE_INDEX |
|
|
print( |
|
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
|
f" (ignored)" |
|
|
) |
|
|
|
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=targets, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_qwen( |
|
|
sources, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
has_speech: bool = False, |
|
|
system_message: str = "You are a helpful language and speech assistant. You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language." |
|
|
) -> Dict: |
|
|
roles = {"human": "user", "gpt": "assistant"} |
|
|
|
|
|
tokenizer = copy.deepcopy(tokenizer) |
|
|
|
|
|
if has_speech: |
|
|
tokenizer.add_tokens([DEFAULT_SPEECH_TOKEN], special_tokens=True) |
|
|
|
|
|
speech_token_index = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) |
|
|
unmask_tokens = ["<|im_start|>", "<|im_end|>"] |
|
|
unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(x) for x in unmask_tokens] |
|
|
|
|
|
|
|
|
input_ids, targets = [], [] |
|
|
for i, source in enumerate(sources): |
|
|
if source[0]["from"] == "tools": |
|
|
tools = source[0]["value"] |
|
|
source = source[1:] |
|
|
else: |
|
|
tools = None |
|
|
|
|
|
if roles[source[0]["from"]] != roles["human"]: |
|
|
source = source[1:] |
|
|
|
|
|
input_id, target = [], [] |
|
|
|
|
|
|
|
|
|
|
|
if tools is not None: |
|
|
json_objects = tools.split("\n\n") |
|
|
try: |
|
|
fc = [json.loads(obj) for obj in json_objects] |
|
|
except: |
|
|
if len(json_objects) > 1: |
|
|
json_objects = json_objects[:-1] |
|
|
fc = [json.loads(obj) for obj in json_objects] |
|
|
|
|
|
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}], tools = fc) |
|
|
else: |
|
|
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) |
|
|
target += [IGNORE_INDEX] * len(input_id) |
|
|
|
|
|
|
|
|
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" |
|
|
tokenizer.chat_template = chat_template |
|
|
|
|
|
for conv in source: |
|
|
try: |
|
|
role = conv["role"] |
|
|
content = conv["content"] |
|
|
except: |
|
|
role = conv["from"] |
|
|
content = conv["value"] |
|
|
|
|
|
role = roles.get(role, role) |
|
|
|
|
|
conv = [{"role" : role, "content" : content}] |
|
|
encode_id = tokenizer.apply_chat_template(conv) |
|
|
input_id += encode_id |
|
|
if role in ["user", "system"]: |
|
|
target += [IGNORE_INDEX] * len(encode_id) |
|
|
else: |
|
|
target += encode_id |
|
|
|
|
|
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" |
|
|
for idx, encode_id in enumerate(input_id): |
|
|
if encode_id in unmask_tokens_idx: |
|
|
target[idx] = encode_id |
|
|
if encode_id == speech_token_index: |
|
|
input_id[idx] = SPEECH_TOKEN_INDEX |
|
|
input_ids.append(input_id) |
|
|
targets.append(target) |
|
|
input_ids = torch.tensor(input_ids, dtype=torch.long) |
|
|
targets = torch.tensor(targets, dtype=torch.long) |
|
|
|
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=targets, |
|
|
) |
|
|
|
|
|
|
|
|
def preprocess_plain( |
|
|
sources: Sequence[str], |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
) -> Dict: |
|
|
|
|
|
conversations = [] |
|
|
for source in sources: |
|
|
assert len(source) == 2 |
|
|
assert DEFAULT_SPEECH_TOKEN in source[0]['value'] |
|
|
source[0]['value'] = DEFAULT_SPEECH_TOKEN |
|
|
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep |
|
|
conversations.append(conversation) |
|
|
|
|
|
input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] |
|
|
targets = copy.deepcopy(input_ids) |
|
|
for target, source in zip(targets, sources): |
|
|
tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer)) |
|
|
target[:tokenized_len] = IGNORE_INDEX |
|
|
|
|
|
return dict(input_ids=input_ids, labels=targets) |
|
|
|
|
|
|
|
|
def preprocess( |
|
|
sources: Sequence[str], |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
has_speech: bool = False |
|
|
) -> Dict: |
|
|
""" |
|
|
Given a list of sources, each is a conversation list. This transform: |
|
|
1. Add signal '### ' at the beginning each sentence, with end signal '\n'; |
|
|
2. Concatenate conversations together; |
|
|
3. Tokenize the concatenated conversation; |
|
|
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. |
|
|
""" |
|
|
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: |
|
|
return preprocess_plain(sources, tokenizer) |
|
|
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: |
|
|
return preprocess_llama_2(sources, tokenizer, has_speech=has_speech) |
|
|
if conversation_lib.default_conversation.version.startswith("v1"): |
|
|
return preprocess_v1(sources, tokenizer, has_speech=has_speech) |
|
|
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_3: |
|
|
return preprocess_llama_3(sources, tokenizer, has_speech=has_speech) |
|
|
if conversation_lib.default_conversation.version == "qwen": |
|
|
return preprocess_qwen(sources, tokenizer, has_speech=has_speech) |
|
|
raise NotImplementedError |