da03 commited on
Commit
8389537
·
1 Parent(s): a5bfa12
Files changed (1) hide show
  1. main.py +6 -1
main.py CHANGED
@@ -110,11 +110,13 @@ def process_frame(
110
  ) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
111
  """Process a single frame through the model."""
112
  timing = {}
113
-
114
  # Temporal encoding
115
  start = time.perf_counter()
 
116
  output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs)
117
  timing['temporal_encoder'] = time.perf_counter() - start
 
118
 
119
  # UNet sampling
120
  start = time.perf_counter()
@@ -127,13 +129,16 @@ def process_frame(
127
  verbose=False
128
  )
129
  timing['unet'] = time.perf_counter() - start
 
130
 
131
  # Decoding
132
  start = time.perf_counter()
133
  sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
134
  sample = model.decode_first_stage(sample)
 
135
  sample = sample.squeeze(0).clamp(-1, 1)
136
  timing['decode'] = time.perf_counter() - start
 
137
 
138
  # Convert to image
139
  sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)
 
110
  ) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
111
  """Process a single frame through the model."""
112
  timing = {}
113
+ print ('here5')
114
  # Temporal encoding
115
  start = time.perf_counter()
116
+ print ('here6')
117
  output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs)
118
  timing['temporal_encoder'] = time.perf_counter() - start
119
+ print ('here7')
120
 
121
  # UNet sampling
122
  start = time.perf_counter()
 
129
  verbose=False
130
  )
131
  timing['unet'] = time.perf_counter() - start
132
+ print ('here8')
133
 
134
  # Decoding
135
  start = time.perf_counter()
136
  sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
137
  sample = model.decode_first_stage(sample)
138
+ print ('here9')
139
  sample = sample.squeeze(0).clamp(-1, 1)
140
  timing['decode'] = time.perf_counter() - start
141
+ print ('here10')
142
 
143
  # Convert to image
144
  sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)