Spaces:
Runtime error
Runtime error
Update audioldm/pipeline.py
Browse files- audioldm/pipeline.py +18 -4
audioldm/pipeline.py
CHANGED
|
@@ -30,7 +30,23 @@ def make_batch_for_text_to_audio(text, batchsize=1):
|
|
| 30 |
)
|
| 31 |
return batch
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
if(torch.cuda.is_available()):
|
| 35 |
device = torch.device("cuda:0")
|
| 36 |
else:
|
|
@@ -40,7 +56,7 @@ def build_model(config=None):
|
|
| 40 |
assert type(config) is str
|
| 41 |
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
|
| 42 |
else:
|
| 43 |
-
config = default_audioldm_config()
|
| 44 |
|
| 45 |
# Use text as condition instead of using waveform during training
|
| 46 |
config["model"]["params"]["device"] = device
|
|
@@ -49,8 +65,6 @@ def build_model(config=None):
|
|
| 49 |
# No normalization here
|
| 50 |
latent_diffusion = LatentDiffusion(**config["model"]["params"])
|
| 51 |
|
| 52 |
-
resume_from_checkpoint = "./ckpt/ldm_trimmed.ckpt"
|
| 53 |
-
|
| 54 |
checkpoint = torch.load(resume_from_checkpoint, map_location=device)
|
| 55 |
latent_diffusion.load_state_dict(checkpoint["state_dict"])
|
| 56 |
|
|
|
|
| 30 |
)
|
| 31 |
return batch
|
| 32 |
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def build_model(
|
| 36 |
+
ckpt_path=None,
|
| 37 |
+
config=None,
|
| 38 |
+
model_name="audioldm-s-full"
|
| 39 |
+
):
|
| 40 |
+
print("Load AudioLDM: %s" % model_name)
|
| 41 |
+
|
| 42 |
+
resume_from_checkpoint = "ckpt/%s.ckpt" % model_name
|
| 43 |
+
|
| 44 |
+
# if(ckpt_path is None):
|
| 45 |
+
# ckpt_path = get_metadata()[model_name]["path"]
|
| 46 |
+
|
| 47 |
+
# if(not os.path.exists(ckpt_path)):
|
| 48 |
+
# download_checkpoint(model_name)
|
| 49 |
+
|
| 50 |
if(torch.cuda.is_available()):
|
| 51 |
device = torch.device("cuda:0")
|
| 52 |
else:
|
|
|
|
| 56 |
assert type(config) is str
|
| 57 |
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
|
| 58 |
else:
|
| 59 |
+
config = default_audioldm_config(model_name)
|
| 60 |
|
| 61 |
# Use text as condition instead of using waveform during training
|
| 62 |
config["model"]["params"]["device"] = device
|
|
|
|
| 65 |
# No normalization here
|
| 66 |
latent_diffusion = LatentDiffusion(**config["model"]["params"])
|
| 67 |
|
|
|
|
|
|
|
| 68 |
checkpoint = torch.load(resume_from_checkpoint, map_location=device)
|
| 69 |
latent_diffusion.load_state_dict(checkpoint["state_dict"])
|
| 70 |
|