Upload folder using huggingface_hub
Browse files- generation_utils.py +1 -4
generation_utils.py
CHANGED
|
@@ -424,16 +424,13 @@ class DreamGenerationMixin:
|
|
| 424 |
# this allows user-defined logits control of the intermediate steps
|
| 425 |
logits = generation_logits_hook_func(i, x, logits)
|
| 426 |
probs = F.softmax(logits, dim=-1)
|
| 427 |
-
|
| 428 |
-
logger.warn(f"remask_schedule: {remask_schedule}")
|
| 429 |
-
|
| 430 |
#袙袨孝 孝校孝 袦袗小袣袠袧袚 袛袨袘袗袙袥携袝袦!
|
| 431 |
keep_mask = torch.ones_like(x, device=self.device, dtype=torch.bool)
|
| 432 |
if remask_schedule and remask_schedule == "tau":
|
| 433 |
tau = remask_tau[i]
|
| 434 |
token_confidence = probs.amax(-1)
|
| 435 |
keep_mask = token_confidence >= tau
|
| 436 |
-
logger.warning(f"total remask tokens: {len(x) - keep_mask.sum()}")
|
| 437 |
x = torch.where(keep_mask, x, mask_token_id)
|
| 438 |
|
| 439 |
|
|
|
|
| 424 |
# this allows user-defined logits control of the intermediate steps
|
| 425 |
logits = generation_logits_hook_func(i, x, logits)
|
| 426 |
probs = F.softmax(logits, dim=-1)
|
| 427 |
+
|
|
|
|
|
|
|
| 428 |
#袙袨孝 孝校孝 袦袗小袣袠袧袚 袛袨袘袗袙袥携袝袦!
|
| 429 |
keep_mask = torch.ones_like(x, device=self.device, dtype=torch.bool)
|
| 430 |
if remask_schedule and remask_schedule == "tau":
|
| 431 |
tau = remask_tau[i]
|
| 432 |
token_confidence = probs.amax(-1)
|
| 433 |
keep_mask = token_confidence >= tau
|
|
|
|
| 434 |
x = torch.where(keep_mask, x, mask_token_id)
|
| 435 |
|
| 436 |
|