Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions monai/networks/nets/dints.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
class CellInterface(torch.nn.Module):
"""interface for torchscriptable Cell"""

def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: # type: ignore
def forward(self, x: torch.Tensor, weight) -> torch.Tensor: # type: ignore
pass


Expand Down Expand Up @@ -170,7 +170,7 @@ def __init__(self, c: int, ops: dict, arch_code_c=None):
if arch_c > 0:
self.ops.append(ops[op_name](c))

def forward(self, x: torch.Tensor, weight: torch.Tensor):
def forward(self, x: torch.Tensor, weight: torch.Tensor | None = None):
"""
Args:
x: input tensor.
Expand All @@ -179,9 +179,10 @@ def forward(self, x: torch.Tensor, weight: torch.Tensor):
out: weighted average of the operation results.
"""
out = 0.0
weight = weight.to(x)
if weight is not None:
weight = weight.to(x)
for idx, _op in enumerate(self.ops):
out = out + _op(x) * weight[idx]
out = (out + _op(x)) if weight is None else out + _op(x) * weight[idx]
return out


Expand Down Expand Up @@ -297,7 +298,7 @@ def __init__(

self.op = MixedOp(c, self.OPS, arch_code_c)

def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, weight: torch.Tensor | None) -> torch.Tensor:
"""
Args:
x: input tensor
Expand Down Expand Up @@ -669,15 +670,13 @@ def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]:
x: input tensor.
"""
# generate path activation probability
inputs, outputs = x, [torch.tensor(0.0).to(x[0])] * self.num_depths
inputs = x
for blk_idx in range(self.num_blocks):
outputs = [torch.tensor(0.0).to(x[0])] * self.num_depths
outputs = [torch.tensor(0.0, dtype=x[0].dtype, device=x[0].device)] * self.num_depths
for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data):
if activation:
mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))]
_out = mod.forward(
x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx])
)
_out = mod.forward(x=inputs[self.arch_code2in[res_idx]], weight=None)
outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out
inputs = outputs

Expand Down Expand Up @@ -885,13 +884,13 @@ def get_ram_cost_usage(self, in_size, full: bool = False):
sizes = []
for res_idx in range(self.num_depths):
sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2**res_idx)).prod())
sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / (2 ** (int(self.use_downsample)))
sizes = torch.tensor(sizes, dtype=torch.float32, device=self.device) / (2 ** (int(self.use_downsample)))
probs_a, arch_code_prob_a = self.get_prob_a(child=False)
cell_prob = F.softmax(self.log_alpha_c, dim=-1)
if full:
arch_code_prob_a = arch_code_prob_a.detach()
arch_code_prob_a.fill_(1)
ram_cost = torch.from_numpy(self.ram_cost).to(torch.float32).to(self.device)
ram_cost = torch.from_numpy(self.ram_cost).to(dtype=torch.float32, device=self.device)
usage = 0.0
for blk_idx in range(self.num_blocks):
# node activation for input
Expand Down