@@ -515,7 +515,7 @@ def convert(self, state_dict):
515515class WanVideoVAE (PreTrainedModel ):
516516 converter = WanVideoVAEStateDictConverter ()
517517
518- def __init__ (self , z_dim = 16 , parallelism : int = 1 , device : str = "cuda:0" , dtype : torch .dtype = torch .float32 ):
518+ def __init__ (self , z_dim = 16 , device : str = "cuda:0" , dtype : torch .dtype = torch .float32 ):
519519 super ().__init__ ()
520520
521521 mean = [
@@ -561,12 +561,11 @@ def __init__(self, z_dim=16, parallelism: int = 1, device: str = "cuda:0", dtype
561561 # init model
562562 self .model = VideoVAE (z_dim = z_dim ).eval ().requires_grad_ (False )
563563 self .upsampling_factor = 8
564- self .parallelism = parallelism
565564
566565 @classmethod
567- def from_state_dict (cls , state_dict , parallelism = 1 , device = "cuda:0" , dtype = torch .float32 ) -> "WanVideoVAE" :
566+ def from_state_dict (cls , state_dict , device = "cuda:0" , dtype = torch .float32 ) -> "WanVideoVAE" :
568567 with no_init_weights ():
569- model = torch .nn .utils .skip_init (cls , parallelism = parallelism , device = device , dtype = dtype )
568+ model = torch .nn .utils .skip_init (cls , device = device , dtype = dtype )
570569 model .load_state_dict (state_dict , assign = True )
571570 model .to (device = device , dtype = dtype , non_blocking = True )
572571 return model
@@ -607,7 +606,7 @@ def tiled_decode(self, hidden_states, device, tile_size, tile_stride, progress_c
607606 h_ , w_ = h + size_h , w + size_w
608607 tasks .append ((h , h_ , w , w_ ))
609608
610- data_device = device if self . parallelism > 1 else "cpu"
609+ data_device = device if dist . is_initialized () else "cpu"
611610 computation_device = device
612611
613612 out_T = T * 4 - 3
@@ -622,9 +621,9 @@ def tiled_decode(self, hidden_states, device, tile_size, tile_stride, progress_c
622621 device = data_device ,
623622 )
624623
625- hide_progress_bar = self . parallelism > 1 and dist .get_rank () != 0
626- for i , (h , h_ , w , w_ ) in enumerate (tqdm (tasks , desc = "VAE DECODING" , disable = hide_progress_bar )):
627- if self . parallelism > 1 and (i % dist .get_world_size () != dist .get_rank ()):
624+ hide_progress = dist . is_initialized () and dist .get_rank () != 0
625+ for i , (h , h_ , w , w_ ) in enumerate (tqdm (tasks , desc = "VAE DECODING" , disable = hide_progress )):
626+ if dist . is_initialized () and (i % dist .get_world_size () != dist .get_rank ()):
628627 continue
629628 hidden_states_batch = hidden_states [:, :, :, h :h_ , w :w_ ].to (computation_device )
630629 hidden_states_batch = self .model .decode (hidden_states_batch , self .scale ).to (data_device )
@@ -654,11 +653,11 @@ def tiled_decode(self, hidden_states, device, tile_size, tile_stride, progress_c
654653 target_h : target_h + hidden_states_batch .shape [3 ],
655654 target_w : target_w + hidden_states_batch .shape [4 ],
656655 ] += mask
657- if progress_callback is not None and not hide_progress_bar :
656+ if progress_callback is not None and not hide_progress :
658657 progress_callback (i + 1 , len (tasks ), "VAE DECODING" )
659- if progress_callback is not None and not hide_progress_bar :
658+ if progress_callback is not None and not hide_progress :
660659 progress_callback (len (tasks ), len (tasks ), "VAE DECODING" )
661- if self . parallelism > 1 :
660+ if dist . is_initialized () :
662661 dist .all_reduce (values )
663662 dist .all_reduce (weight )
664663 values = values / weight
@@ -681,7 +680,7 @@ def tiled_encode(self, video, device, tile_size, tile_stride, progress_callback=
681680 h_ , w_ = h + size_h , w + size_w
682681 tasks .append ((h , h_ , w , w_ ))
683682
684- data_device = device if self . parallelism > 1 else "cpu"
683+ data_device = device if dist . is_initialized () else "cpu"
685684 computation_device = device
686685
687686 out_T = (T + 3 ) // 4
@@ -696,9 +695,9 @@ def tiled_encode(self, video, device, tile_size, tile_stride, progress_callback=
696695 device = data_device ,
697696 )
698697
699- hide_progress_bar = self . parallelism > 1 and dist .get_rank () != 0
698+ hide_progress_bar = dist . is_initialized () and dist .get_rank () != 0
700699 for i , (h , h_ , w , w_ ) in enumerate (tqdm (tasks , desc = "VAE ENCODING" , disable = hide_progress_bar )):
701- if self . parallelism > 1 and (i % dist .get_world_size () != dist .get_rank ()):
700+ if dist . is_initialized () and (i % dist .get_world_size () != dist .get_rank ()):
702701 continue
703702 hidden_states_batch = video [:, :, :, h :h_ , w :w_ ].to (computation_device )
704703 hidden_states_batch = self .model .encode (hidden_states_batch , self .scale ).to (data_device )
@@ -732,7 +731,7 @@ def tiled_encode(self, video, device, tile_size, tile_stride, progress_callback=
732731 progress_callback (i + 1 , len (tasks ), "VAE ENCODING" )
733732 if progress_callback is not None and not hide_progress_bar :
734733 progress_callback (len (tasks ), len (tasks ), "VAE ENCODING" )
735- if self . parallelism > 1 :
734+ if dist . is_initialized () :
736735 dist .all_reduce (values )
737736 dist .all_reduce (weight )
738737 values = values / weight
0 commit comments