kenkaneki commited on
Commit
3f3cd0e
verified
1 Parent(s): 9f22d3e

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. generation_utils.py +1 -4
generation_utils.py CHANGED
@@ -424,16 +424,13 @@ class DreamGenerationMixin:
424
  # this allows user-defined logits control of the intermediate steps
425
  logits = generation_logits_hook_func(i, x, logits)
426
  probs = F.softmax(logits, dim=-1)
427
-
428
- logger.warn(f"remask_schedule: {remask_schedule}")
429
-
430
  #袙袨孝 孝校孝 袦袗小袣袠袧袚 袛袨袘袗袙袥携袝袦!
431
  keep_mask = torch.ones_like(x, device=self.device, dtype=torch.bool)
432
  if remask_schedule and remask_schedule == "tau":
433
  tau = remask_tau[i]
434
  token_confidence = probs.amax(-1)
435
  keep_mask = token_confidence >= tau
436
- logger.warning(f"total remask tokens: {len(x) - keep_mask.sum()}")
437
  x = torch.where(keep_mask, x, mask_token_id)
438
 
439
 
 
424
  # this allows user-defined logits control of the intermediate steps
425
  logits = generation_logits_hook_func(i, x, logits)
426
  probs = F.softmax(logits, dim=-1)
427
+
 
 
428
  #袙袨孝 孝校孝 袦袗小袣袠袧袚 袛袨袘袗袙袥携袝袦!
429
  keep_mask = torch.ones_like(x, device=self.device, dtype=torch.bool)
430
  if remask_schedule and remask_schedule == "tau":
431
  tau = remask_tau[i]
432
  token_confidence = probs.amax(-1)
433
  keep_mask = token_confidence >= tau
 
434
  x = torch.where(keep_mask, x, mask_token_id)
435
 
436