Update patch.py
Browse files
patch.py
CHANGED
|
@@ -7,7 +7,7 @@ from einops import rearrange
|
|
| 7 |
import torch
|
| 8 |
import torch.nn.functional as F
|
| 9 |
|
| 10 |
-
from . import
|
| 11 |
from .utils import isinstance_str, init_generator, join_frame, split_frame, func_warper, join_warper, split_warper
|
| 12 |
|
| 13 |
|
|
@@ -42,7 +42,7 @@ def compute_merge(module: torch.nn.Module, x: torch.Tensor, tome_info: Dict[str,
|
|
| 42 |
|
| 43 |
# Recursive merge multi-frame tokens into one set. Such as 4->1 for 4 frames and 8->2->1 for 8 frames when target stride is 4.
|
| 44 |
while curF > 1:
|
| 45 |
-
m, u, ret_dict =
|
| 46 |
local_tokens, curF, args["local_merge_ratio"], unm, generator, args["target_stride"], args["align_batch"])
|
| 47 |
unm += ret_dict["unm_num"]
|
| 48 |
m_ls.append(m)
|
|
@@ -70,7 +70,7 @@ def compute_merge(module: torch.nn.Module, x: torch.Tensor, tome_info: Dict[str,
|
|
| 70 |
[module.global_tokens.to(local_tokens), local_tokens], dim=1)
|
| 71 |
local_chunk = 1
|
| 72 |
|
| 73 |
-
m, u, _ =
|
| 74 |
tokens, src_len, args["global_merge_ratio"], args["align_batch"], unmerge_chunk=local_chunk)
|
| 75 |
merged_tokens = m(tokens)
|
| 76 |
m_ls.append(m)
|
|
@@ -84,7 +84,7 @@ def compute_merge(module: torch.nn.Module, x: torch.Tensor, tome_info: Dict[str,
|
|
| 84 |
m = func_warper(m_ls)
|
| 85 |
u = func_warper(u_ls[::-1])
|
| 86 |
else:
|
| 87 |
-
m, u = (
|
| 88 |
merged_tokens = x
|
| 89 |
|
| 90 |
# Return merge op, unmerge op, and merged tokens.
|
|
|
|
| 7 |
import torch
|
| 8 |
import torch.nn.functional as F
|
| 9 |
|
| 10 |
+
from .merge import bipartite_soft_matching_randframe, bipartite_soft_matching_2s, do_nothing
|
| 11 |
from .utils import isinstance_str, init_generator, join_frame, split_frame, func_warper, join_warper, split_warper
|
| 12 |
|
| 13 |
|
|
|
|
| 42 |
|
| 43 |
# Recursive merge multi-frame tokens into one set. Such as 4->1 for 4 frames and 8->2->1 for 8 frames when target stride is 4.
|
| 44 |
while curF > 1:
|
| 45 |
+
m, u, ret_dict = bipartite_soft_matching_randframe(
|
| 46 |
local_tokens, curF, args["local_merge_ratio"], unm, generator, args["target_stride"], args["align_batch"])
|
| 47 |
unm += ret_dict["unm_num"]
|
| 48 |
m_ls.append(m)
|
|
|
|
| 70 |
[module.global_tokens.to(local_tokens), local_tokens], dim=1)
|
| 71 |
local_chunk = 1
|
| 72 |
|
| 73 |
+
m, u, _ = bipartite_soft_matching_2s(
|
| 74 |
tokens, src_len, args["global_merge_ratio"], args["align_batch"], unmerge_chunk=local_chunk)
|
| 75 |
merged_tokens = m(tokens)
|
| 76 |
m_ls.append(m)
|
|
|
|
| 84 |
m = func_warper(m_ls)
|
| 85 |
u = func_warper(u_ls[::-1])
|
| 86 |
else:
|
| 87 |
+
m, u = (do_nothing, do_nothing)
|
| 88 |
merged_tokens = x
|
| 89 |
|
| 90 |
# Return merge op, unmerge op, and merged tokens.
|