Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
8389537
1
Parent(s):
a5bfa12
main.py
CHANGED
@@ -110,11 +110,13 @@ def process_frame(
|
|
110 |
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
|
111 |
"""Process a single frame through the model."""
|
112 |
timing = {}
|
113 |
-
|
114 |
# Temporal encoding
|
115 |
start = time.perf_counter()
|
|
|
116 |
output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs)
|
117 |
timing['temporal_encoder'] = time.perf_counter() - start
|
|
|
118 |
|
119 |
# UNet sampling
|
120 |
start = time.perf_counter()
|
@@ -127,13 +129,16 @@ def process_frame(
|
|
127 |
verbose=False
|
128 |
)
|
129 |
timing['unet'] = time.perf_counter() - start
|
|
|
130 |
|
131 |
# Decoding
|
132 |
start = time.perf_counter()
|
133 |
sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
|
134 |
sample = model.decode_first_stage(sample)
|
|
|
135 |
sample = sample.squeeze(0).clamp(-1, 1)
|
136 |
timing['decode'] = time.perf_counter() - start
|
|
|
137 |
|
138 |
# Convert to image
|
139 |
sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)
|
|
|
110 |
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
|
111 |
"""Process a single frame through the model."""
|
112 |
timing = {}
|
113 |
+
print ('here5')
|
114 |
# Temporal encoding
|
115 |
start = time.perf_counter()
|
116 |
+
print ('here6')
|
117 |
output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs)
|
118 |
timing['temporal_encoder'] = time.perf_counter() - start
|
119 |
+
print ('here7')
|
120 |
|
121 |
# UNet sampling
|
122 |
start = time.perf_counter()
|
|
|
129 |
verbose=False
|
130 |
)
|
131 |
timing['unet'] = time.perf_counter() - start
|
132 |
+
print ('here8')
|
133 |
|
134 |
# Decoding
|
135 |
start = time.perf_counter()
|
136 |
sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
|
137 |
sample = model.decode_first_stage(sample)
|
138 |
+
print ('here9')
|
139 |
sample = sample.squeeze(0).clamp(-1, 1)
|
140 |
timing['decode'] = time.perf_counter() - start
|
141 |
+
print ('here10')
|
142 |
|
143 |
# Convert to image
|
144 |
sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)
|