@@ -104,10 +104,10 @@ def forward(self, x):
104104 """x --> [K x Conv2D] --> PatchMerge
105105
106106 Args:
107- x : (B, T, H, W, C)
107+ x: (B, T, H, W, C)
108108
109109 Returns:
110- out : (B, T, H_new, W_new, C_out)
110+ out: (B, T, H_new, W_new, C_out)
111111 """
112112
113113 B , T , H , W , C = x .shape
@@ -178,10 +178,10 @@ def forward(self, x):
178178 """x --> Upsample --> [K x Conv2D]
179179
180180 Args:
181- x : (B, T, H, W, C)
181+ x: (B, T, H, W, C)
182182
183183 Returns:
184- out : (B, T, H_new, W_new, C)
184+ out: (B, T, H_new, W_new, C)
185185 """
186186
187187 x = self .upsample (x )
@@ -286,10 +286,10 @@ def forward(self, x):
286286 """x --> [K x Conv2D] --> PatchMerge --> ... --> [K x Conv2D] --> PatchMerge
287287
288288 Args:
289- x : (B, T, H, W, C)
289+ x: (B, T, H, W, C)
290290
291291 Returns:
292- out : (B, T, H_new, W_new, C_out)
292+ out: (B, T, H_new, W_new, C_out)
293293 """
294294
295295 for i , (conv_block , patch_merge ) in enumerate (
@@ -400,10 +400,10 @@ def forward(self, x):
400400 """x --> Upsample --> [K x Conv2D] --> ... --> Upsample --> [K x Conv2D]
401401
402402 Args:
403- x : Shape (B, T, H, W, C)
403+ x: Shape (B, T, H, W, C)
404404
405405 Returns:
406- out : Shape (B, T, H_new, W_new, C)
406+ out: Shape (B, T, H_new, W_new, C)
407407 """
408408 for i , (conv_block , upsample ) in enumerate (
409409 zip (self .conv_block_list , self .upsample_list )
@@ -915,10 +915,10 @@ def get_initial_z(self, final_mem, T_out):
915915 def forward (self , x , verbose = False ):
916916 """
917917 Args:
918- x : Shape (B, T, H, W, C)
919- verbos : if True, print intermediate shapes
918+ x: Shape (B, T, H, W, C)
919+ verbose : if True, print intermediate shapes
920920 Returns:
921- out : The output Shape (B, T_out, H, W, C_out)
921+ out: The output Shape (B, T_out, H, W, C_out)
922922 """
923923
924924 x = self .concat_to_tensor (x , self .input_keys )
0 commit comments