Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
from utils import compute_token_num, load_audio, log_mel_spectrogram, padding_mels | |
class StepAudio2Base: | |
def __init__(self, model_path: str): | |
self.llm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="right") | |
self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda() | |
self.eos_token_id = self.llm_tokenizer.eos_token_id | |
def __call__(self, messages: list, **kwargs): | |
messages, mels = self.apply_chat_template(messages) | |
# Tokenize prompts | |
prompt_ids = [] | |
for msg in messages: | |
if isinstance(msg, str): | |
prompt_ids.append(self.llm_tokenizer(text=msg, return_tensors="pt", padding=True)["input_ids"]) | |
elif isinstance(msg, list): | |
prompt_ids.append(torch.tensor([msg], dtype=torch.int32)) | |
else: | |
raise ValueError(f"Unsupported content type: {type(msg)}") | |
prompt_ids = torch.cat(prompt_ids, dim=-1).cuda() | |
attention_mask = torch.ones_like(prompt_ids) | |
#mels = None if len(mels) == 0 else torch.stack(mels).cuda() | |
#mel_lengths = None if mels is None else torch.tensor([mel.shape[1] - 2 for mel in mels], dtype=torch.int32, device='cuda') | |
if len(mels)==0: | |
mels = None | |
mel_lengths = None | |
else: | |
mels, mel_lengths = padding_mels(mels) | |
mels = mels.cuda() | |
mel_lengths = mel_lengths.cuda() | |
generate_inputs = { | |
"input_ids": prompt_ids, | |
"wavs": mels, | |
"wav_lens": mel_lengths, | |
"attention_mask":attention_mask | |
} | |
generation_config = dict(max_new_tokens=2048, | |
pad_token_id=self.llm_tokenizer.pad_token_id, | |
eos_token_id=self.eos_token_id, | |
) | |
generation_config.update(kwargs) | |
generation_config = GenerationConfig(**generation_config) | |
outputs = self.llm.generate(**generate_inputs, generation_config=generation_config) | |
output_token_ids = outputs[0, prompt_ids.shape[-1] : -1].tolist() | |
output_text_tokens = [i for i in output_token_ids if i < 151688] | |
output_audio_tokens = [i - 151696 for i in output_token_ids if i > 151695] | |
output_text = self.llm_tokenizer.decode(output_text_tokens) | |
return output_token_ids, output_text, output_audio_tokens | |
def apply_chat_template(self, messages: list): | |
results = [] | |
mels = [] | |
for msg in messages: | |
content = msg | |
if isinstance(content, str): | |
text_with_audio = content | |
results.append(text_with_audio) | |
elif isinstance(content, dict): | |
if content["type"] == "text": | |
results.append(f"{content['text']}") | |
elif content["type"] == "audio": | |
audio = load_audio(content['audio']) | |
for i in range(0, audio.shape[0], 16000 * 25): | |
mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479) | |
mels.append(mel) | |
audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1]) | |
results.append(f"<audio_start>{audio_tokens}<audio_end>") | |
elif content["type"] == "token": | |
results.append(content["token"]) | |
else: | |
raise ValueError(f"Unsupported content type: {type(content)}") | |
# print(results) | |
return results, mels | |
class StepAudio2(StepAudio2Base): | |
def __init__(self, model_path: str): | |
super().__init__(model_path) | |
self.llm_tokenizer.eos_token = "<|EOT|>" | |
self.llm.config.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>") | |
self.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>") | |
def apply_chat_template(self, messages: list): | |
results = [] | |
mels = [] | |
for msg in messages: | |
role = msg["role"] | |
content = msg["content"] | |
if role == "user": | |
role = "human" | |
if isinstance(content, str): | |
text_with_audio = f"<|BOT|>{role}\n{content}" | |
text_with_audio += '<|EOT|>' if msg.get('eot', True) else '' | |
results.append(text_with_audio) | |
elif isinstance(content, list): | |
results.append(f"<|BOT|>{role}\n") | |
for item in content: | |
if item["type"] == "text": | |
results.append(f"{item['text']}") | |
elif item["type"] == "audio": | |
audio = load_audio(item['audio']) | |
for i in range(0, audio.shape[0], 16000 * 25): | |
mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479) | |
mels.append(mel) | |
audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1]) | |
results.append(f"<audio_start>{audio_tokens}<audio_end>") | |
elif item["type"] == "token": | |
results.append(item["token"]) | |
if msg.get('eot', True): | |
results.append('<|EOT|>') | |
elif content is None: | |
results.append(f"<|BOT|>{role}\n") | |
else: | |
raise ValueError(f"Unsupported content type: {type(content)}") | |
# print(results) | |
return results, mels | |
if __name__ == '__main__': | |
from token2wav import Token2wav | |
model = StepAudio2('Step-Audio-2-mini') | |
token2wav = Token2wav('Step-Audio-2-mini/token2wav') | |
# Text-to-text conversation | |
print() | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "human", "content": "Give me a brief introduction to the Great Wall."}, | |
{"role": "assistant", "content": None} | |
] | |
tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) | |
print(text) | |
# Text-to-speech conversation | |
print() | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "human", "content": "Give me a brief introduction to the Great Wall."}, | |
{"role": "assistant", "content": "<tts_start>", "eot": False}, # Insert <tts_start> for speech response | |
] | |
tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) | |
print(text) | |
print(tokens) | |
audio = token2wav(audio, prompt_wav='assets/default_male.wav') | |
with open('output-male.wav', 'wb') as f: | |
f.write(audio) | |
# Speech-to-text conversation | |
print() | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]}, | |
{"role": "assistant", "content": None} | |
] | |
tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) | |
print(text) | |
# Speech-to-speech conversation | |
print() | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]}, | |
{"role": "assistant", "content": "<tts_start>", "eot": False}, # Insert <tts_start> for speech response | |
] | |
tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) | |
print(text) | |
print(tokens) | |
audio = token2wav(audio, prompt_wav='assets/default_female.wav') | |
with open('output-female.wav', 'wb') as f: | |
f.write(audio) | |
# Multi-turn conversation | |
print() | |
messages.pop(-1) | |
messages += [ | |
{"role": "assistant", "content": [{"type": "text", "text": "<tts_start>"}, | |
{"type": "token", "token": tokens}]}, | |
{"role": "human", "content": "Now write a 4-line poem about it."}, | |
{"role": "assistant", "content": None} | |
] | |
tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) | |
print(text) | |
# Multi-modal inputs | |
print() | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "human", "content": [{"type": "text", "text": "Translate the speech into Chinese."}, | |
{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]}, | |
{"role": "assistant", "content": None} | |
] | |
tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) | |
print(text) | |