Spaces:
Running
Running
Update consistencytta.py
Browse files- consistencytta.py +10 -2
consistencytta.py
CHANGED
|
@@ -75,9 +75,17 @@ class ConsistencyTTA(nn.Module):
|
|
| 75 |
[self.text_encoder, self.vae, self.fn_STFT, self.unet],
|
| 76 |
['text_encoder', 'vae', 'fn_STFT', 'unet']
|
| 77 |
):
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
for param in model.parameters():
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
@torch.no_grad()
|
|
|
|
| 75 |
[self.text_encoder, self.vae, self.fn_STFT, self.unet],
|
| 76 |
['text_encoder', 'vae', 'fn_STFT', 'unet']
|
| 77 |
):
|
| 78 |
+
try:
|
| 79 |
+
assert model.training == False, f"The {name} is not in eval mode."
|
| 80 |
+
except:
|
| 81 |
+
model.eval()
|
| 82 |
+
assert model.training == False, f"The {name} is not in eval mode."
|
| 83 |
for param in model.parameters():
|
| 84 |
+
try:
|
| 85 |
+
assert param.requires_grad == False, f"The {name} is not frozen."
|
| 86 |
+
except:
|
| 87 |
+
param.requires_grad_(False)
|
| 88 |
+
assert param.requires_grad == False, f"The {name} is not frozen."
|
| 89 |
|
| 90 |
|
| 91 |
@torch.no_grad()
|