@@ -492,7 +492,7 @@ def forward(self, x: torch.Tensor):
492492 inputs = []
493493 for d in range (self .num_depths ):
494494 # allow multi-resolution input
495- _mod_w : StemInterface = self .stem_down [str (d )]
495+ _mod_w : StemInterface = self .stem_down [str (d )] # type: ignore[assignment]
496496 x_out = _mod_w .forward (x )
497497 if self .node_a [0 ][d ]:
498498 inputs .append (x_out )
@@ -505,7 +505,7 @@ def forward(self, x: torch.Tensor):
505505 start = False
506506 _temp : torch .Tensor = torch .empty (0 )
507507 for res_idx in range (self .num_depths - 1 , - 1 , - 1 ):
508- _mod_up : StemInterface = self .stem_up [str (res_idx )]
508+ _mod_up : StemInterface = self .stem_up [str (res_idx )] # type: ignore[assignment]
509509 if start :
510510 _temp = _mod_up .forward (outputs [res_idx ] + _temp )
511511 elif self .node_a [blk_idx + 1 ][res_idx ]:
@@ -680,7 +680,7 @@ def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]:
680680 outputs = [torch .tensor (0.0 , dtype = x [0 ].dtype , device = x [0 ].device )] * self .num_depths
681681 for res_idx , activation in enumerate (self .arch_code_a [blk_idx ].data ):
682682 if activation :
683- mod : CellInterface = self .cell_tree [str ((blk_idx , res_idx ))]
683+ mod : CellInterface = self .cell_tree [str ((blk_idx , res_idx ))] # type: ignore[assignment]
684684 _out = mod .forward (x = inputs [self .arch_code2in [res_idx ]], weight = None )
685685 outputs [self .arch_code2out [res_idx ]] = outputs [self .arch_code2out [res_idx ]] + _out
686686 inputs = outputs
@@ -782,12 +782,10 @@ def __init__(
782782 for blk_idx in range (self .num_blocks ):
783783 for res_idx in range (len (self .arch_code2out )):
784784 if self .arch_code_a [blk_idx , res_idx ] == 1 :
785+ cell_inter : Cell = self .cell_tree [str ((blk_idx , res_idx ))] # type: ignore
785786 self .ram_cost [blk_idx , res_idx ] = np .array (
786- [
787- op .ram_cost + self .cell_tree [str ((blk_idx , res_idx ))].preprocess .ram_cost
788- for op in self .cell_tree [str ((blk_idx , res_idx ))].op .ops [: self .num_cell_ops ]
789- ]
790- )
787+ [op .ram_cost + cell_inter .preprocess .ram_cost for op in cell_inter .op .ops [: self .num_cell_ops ]]
788+ ) # type: ignore
791789
792790 # define cell and macro architecture probabilities
793791 self .log_alpha_c = nn .Parameter (
0 commit comments