GPT-SoVITS / GPT_SoVITS /process_ckpt.py
XXXXRT666
Init
d4d21ad
raw
history blame
2.65 kB
import os
import shutil
import traceback
from collections import OrderedDict
from time import time as ttime
from typing import Any
import torch
from GPT_SoVITS.module.models import set_serialization
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
set_serialization()
def save(fea, path): # fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = f"{ttime()}.pth"
torch.save(fea, tmp_path)
shutil.move(tmp_path, f"{dir}/{name}")
def save_ckpt(ckpt, name, epoch, steps, hps, lora_rank=None):
try:
opt = OrderedDict()
opt["weight"] = {}
for key in ckpt.keys():
if "enc_q" in key:
continue
opt["weight"][key] = ckpt[key].half()
opt["config"] = hps.to_dict()
opt["info"] = f"{epoch}epoch_{steps}iteration"
if lora_rank:
opt["lora_rank"] = lora_rank
save(opt, f"{hps.save_weight_dir}/{name}.pth")
return "Success."
except Exception:
return traceback.format_exc()
def inspect_version(
f: str,
) -> tuple[str, str, bool, Any, dict]:
"""
Returns:
tuple[model_version, lang_version, is_lora, hps, state_dict]
"""
dict_s2 = torch.load(f, map_location="cpu", mmap=True)
hps = dict_s2["config"]
version: str | None = None
if "version" in hps.keys():
version = hps["version"]
is_lora = "lora_rank" in dict_s2.keys()
if version is not None:
# V3 V4 Lora & Finetuned V2 Pro
lang_version = "v2"
model_version = version
else:
# V2 Pro Pretrain
if hps["model"]["gin_channels"] == 1024:
if hps["model"]["upsample_initial_channel"] == 768:
lang_version = "v2"
model_version = "v2ProPlus"
else:
lang_version = "v2"
model_version = "v2Pro"
return model_version, lang_version, is_lora, hps, dict_s2
# Old V1/V2
if "dec.conv_pre.weight" in dict_s2["weight"].keys():
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
lang_version = model_version = "v1"
else:
lang_version = model_version = "v2"
else: # Old Finetuned V3 & V3/V4 Pretrain
lang_version = "v2"
model_version = "v3"
if dict_s2["info"] == "pretrained_s2G_v4":
model_version = "v4"
return model_version, lang_version, is_lora, hps, dict_s2