kenkaneki commited on
Commit
c33a83c
verified
1 Parent(s): 4e2edb8

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. generation_utils.py +7 -8
generation_utils.py CHANGED
@@ -425,7 +425,13 @@ class DreamGenerationMixin:
425
  logits = generation_logits_hook_func(i, x, logits)
426
  probs = F.softmax(logits, dim=-1)
427
 
428
-
 
 
 
 
 
 
429
 
430
 
431
  mask_logits = logits[mask_index]
@@ -463,13 +469,6 @@ class DreamGenerationMixin:
463
  row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
464
  x[row_indices,transfer_index] = x_[row_indices,transfer_index]
465
 
466
- #袙袨孝 孝校孝 袦袗小袣袠袧袚 袛袨袘袗袙袥携袝袦!
467
- keep_mask = torch.ones_like(x, device=self.device, dtype=torch.bool)
468
- if remask_schedule and remask_schedule == "tau":
469
- tau = remask_tau[i]
470
- token_confidence = probs.amax(-1)
471
- keep_mask = token_confidence >= tau
472
- x = torch.where(keep_mask, x, mask_token_id)
473
  # this allows user-defined token control of the intermediate steps
474
  x = generation_tokens_hook_func(i, x, logits)
475
 
 
425
  logits = generation_logits_hook_func(i, x, logits)
426
  probs = F.softmax(logits, dim=-1)
427
 
428
+ #袙袨孝 孝校孝 袦袗小袣袠袧袚 袛袨袘袗袙袥携袝袦!
429
+ keep_mask = torch.ones_like(x, device=self.device, dtype=torch.bool)
430
+ if remask_schedule and remask_schedule == "tau":
431
+ tau = remask_tau[i]
432
+ token_confidence = probs.amax(-1)
433
+ keep_mask = token_confidence >= tau
434
+ x = torch.where(keep_mask, x, mask_token_id)
435
 
436
 
437
  mask_logits = logits[mask_index]
 
469
  row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
470
  x[row_indices,transfer_index] = x_[row_indices,transfer_index]
471
 
 
 
 
 
 
 
 
472
  # this allows user-defined token control of the intermediate steps
473
  x = generation_tokens_hook_func(i, x, logits)
474