import os from glob import glob import torch from safetensors import safe_open from torch import nn from flashcosyvoice.config import CosyVoice2LLMConfig def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor): param.data.copy_(loaded_weight) def load_text_llm(model: nn.Module, path: str): packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) for file in glob(os.path.join(path, "*.safetensors")): with safe_open(file, "pt", "cpu") as f: for weight_name in f.keys(): for k in packed_modules_mapping: if k in weight_name: v, shard_id = packed_modules_mapping[k] param_name = weight_name.replace(k, v) param = model.get_parameter(param_name) weight_loader = param.weight_loader weight_loader(param, f.get_tensor(weight_name), shard_id) break else: param = model.get_parameter(weight_name) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, f.get_tensor(weight_name)) def load_speech_llm(model: nn.Module, path: str, hf_config: CosyVoice2LLMConfig): packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) # NOTE(xcsong): 1. load speech embedding + sos/taskid embedding + lm head embedding_weights = {} tmp_weights = torch.load(f"{path}/llm.pt", map_location="cpu", weights_only=True) missed, missed_names = 0, [] for k, v in tmp_weights.items(): if k == "speech_embedding.weight": # torch.Size([6564, 896]) speech_embedding_size = hf_config.speech_vocab_size # 6562 # NOTE(xcsong): padding to 6592 for vllm tensor parallel if speech_embedding_size != v.shape[0]: # [6564, 896] -> [6562, 896] assert speech_embedding_size <= v.shape[0], f"speech_embedding_size should be less than or equal to {v.shape[0]}, but got {speech_embedding_size}" v = v[:speech_embedding_size, :] embedding_weights["speech_embedding.weight"] = v elif k == "llm_embedding.weight": # torch.Size([2, 896]), eos and task_id assert v.shape[0] == 2, f"llm_embedding.weight should be of shape [2, 896], but got {v.shape}" embedding_weights["llm_embedding.weight"] = v elif k == "llm.model.model.embed_tokens.weight": # torch.Size([151936, 896]) embedding_weights["model.embed_tokens.weight"] = v elif k == "llm_decoder.weight": # torch.Size([6564, 896]) lm_head_size = hf_config.speech_vocab_size # 6562 if lm_head_size != v.shape[0]: # [6564, 896] -> [6562, 896] assert lm_head_size <= v.shape[0], f"lm_head_size should be less than or equal to {v.shape[0]}, but got {lm_head_size}" v = v[:lm_head_size, :] param = model.get_parameter("lm_head.weight") weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, v) elif k == "llm_decoder.bias": # torch.Size([6564]) lm_head_size = hf_config.speech_vocab_size # 6562 if lm_head_size != v.shape[0]: # [6564] -> [6562] assert lm_head_size <= v.shape[0], f"lm_head_size should be less than or equal to {v.shape[0]}, but got {lm_head_size}" v = v[:lm_head_size] param = model.get_parameter("lm_head.bias") weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, v) elif "llm.model." in k: weight_name = k.replace("llm.model.", "") for kk in packed_modules_mapping: if kk in weight_name: vv, shard_id = packed_modules_mapping[kk] param_name = weight_name.replace(kk, vv) try: param = model.get_parameter(param_name) weight_loader = param.weight_loader weight_loader(param, v, shard_id) break except Exception as e: print(e) print(f"skip parameter (1): {weight_name}") continue else: try: param = model.get_parameter(weight_name) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, v) except Exception as e: print(e) print(f"skip parameter (2): {weight_name}") continue else: missed += 1 missed_names.append(weight_name) continue print(f"missed {missed} parameters: {missed_names}") # NOTE(xcsong): 2. merge text embedding, sos/taskid embedding, and speech embedding text_embedding_weight = embedding_weights["model.embed_tokens.weight"].cpu() # [151936, 896] sos_taskid_embedding_weight = embedding_weights["llm_embedding.weight"].cpu() # [2, 896] speech_embedding_weight = embedding_weights["speech_embedding.weight"].cpu() # [6562, 896] final_embedding_weight = torch.cat([speech_embedding_weight, sos_taskid_embedding_weight, text_embedding_weight], dim=0) # [158500, 896] param = model.get_parameter("model.embed_tokens.weight") weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, final_embedding_weight) def load_model(model: nn.Module, path: str, hf_config: CosyVoice2LLMConfig | None = None): if model.model_type == "speech_llm": load_speech_llm(model, path, hf_config) elif model.model_type == "text_llm": load_text_llm(model, path) else: raise ValueError(f"Unsupported model type: {model.model_type}")