Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
4d0cbfe
1
Parent(s):
85e8a86
tiny sampling refactor
Browse files- vampnet/modules/transformer.py +55 -54
vampnet/modules/transformer.py
CHANGED
|
@@ -724,7 +724,7 @@ class VampNet(at.ml.BaseModel):
|
|
| 724 |
|
| 725 |
logits = torch.log(probs)
|
| 726 |
|
| 727 |
-
z_inferred =
|
| 728 |
logits=logits,
|
| 729 |
top_k=top_k,
|
| 730 |
temperature=tmpt,
|
|
@@ -742,61 +742,60 @@ class VampNet(at.ml.BaseModel):
|
|
| 742 |
else:
|
| 743 |
return z
|
| 744 |
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
)
|
| 776 |
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
if sample == "multinomial":
|
| 790 |
-
probs = torch.softmax(logits, dim=-1)
|
| 791 |
-
inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
|
| 792 |
-
elif sample == "argmax":
|
| 793 |
-
inferred = torch.softmax(logits, dim=-1).argmax(dim=-1)
|
| 794 |
-
elif sample == "gumbel":
|
| 795 |
-
inferred = gumbel_sample(logits, dim=-1)
|
| 796 |
-
else:
|
| 797 |
-
raise ValueError(f"invalid sampling method: {sample}")
|
| 798 |
|
| 799 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
|
| 801 |
|
| 802 |
|
|
@@ -833,3 +832,5 @@ if __name__ == "__main__":
|
|
| 833 |
args = argbind.parse_args()
|
| 834 |
with argbind.scope(args):
|
| 835 |
try_model()
|
|
|
|
|
|
|
|
|
| 724 |
|
| 725 |
logits = torch.log(probs)
|
| 726 |
|
| 727 |
+
z_inferred = sample_from_logits(
|
| 728 |
logits=logits,
|
| 729 |
top_k=top_k,
|
| 730 |
temperature=tmpt,
|
|
|
|
| 742 |
else:
|
| 743 |
return z
|
| 744 |
|
| 745 |
+
def sample_from_logits(
|
| 746 |
+
logits,
|
| 747 |
+
top_k: int = None,
|
| 748 |
+
temperature: float = 1.0,
|
| 749 |
+
sample: str = "multinomial",
|
| 750 |
+
typical_filtering=False,
|
| 751 |
+
typical_mass=0.2,
|
| 752 |
+
typical_min_tokens=1,
|
| 753 |
+
):
|
| 754 |
+
# add temperature
|
| 755 |
+
logits = logits / temperature
|
| 756 |
+
|
| 757 |
+
# add topk
|
| 758 |
+
if top_k is not None and typical_filtering == False:
|
| 759 |
+
v, topk_idx = logits.topk(top_k)
|
| 760 |
+
logits[logits < v[..., [-1]]] = -float("inf")
|
| 761 |
+
|
| 762 |
+
if typical_filtering:
|
| 763 |
+
assert top_k is None
|
| 764 |
+
nb, nt, _ = logits.shape
|
| 765 |
+
x_flat = rearrange(logits, "b t l -> (b t ) l")
|
| 766 |
+
x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
|
| 767 |
+
x_flat_norm_p = torch.exp(x_flat_norm)
|
| 768 |
+
entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
|
| 769 |
+
|
| 770 |
+
c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
|
| 771 |
+
c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
|
| 772 |
+
x_flat_cumsum = (
|
| 773 |
+
x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
|
| 774 |
+
)
|
|
|
|
| 775 |
|
| 776 |
+
last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
|
| 777 |
+
sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
|
| 778 |
+
1, last_ind.view(-1, 1)
|
| 779 |
+
)
|
| 780 |
+
if typical_min_tokens > 1:
|
| 781 |
+
sorted_indices_to_remove[..., :typical_min_tokens] = 0
|
| 782 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 783 |
+
1, x_flat_indices, sorted_indices_to_remove
|
| 784 |
+
)
|
| 785 |
+
x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
|
| 786 |
+
logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 787 |
|
| 788 |
+
if sample == "multinomial":
|
| 789 |
+
probs = torch.softmax(logits, dim=-1)
|
| 790 |
+
inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
|
| 791 |
+
elif sample == "argmax":
|
| 792 |
+
inferred = torch.softmax(logits, dim=-1).argmax(dim=-1)
|
| 793 |
+
elif sample == "gumbel":
|
| 794 |
+
inferred = gumbel_sample(logits, dim=-1)
|
| 795 |
+
else:
|
| 796 |
+
raise ValueError(f"invalid sampling method: {sample}")
|
| 797 |
+
|
| 798 |
+
return inferred
|
| 799 |
|
| 800 |
|
| 801 |
|
|
|
|
| 832 |
args = argbind.parse_args()
|
| 833 |
with argbind.scope(args):
|
| 834 |
try_model()
|
| 835 |
+
|
| 836 |
+
|