Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jan 25, 2023
1 parent f210a5a commit 51bb206
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions torch_geometric/explain/algorithm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@ def set_masks(
r"""Apply mask to every graph layer in the :obj:`model`."""
loop_mask = edge_index[0] != edge_index[1]

if not isinstance(mask, Parameter):
mask = Parameter(mask)

# Loop over layers and set masks on MessagePassing layers:
for module in model.modules():
if isinstance(module, MessagePassing):

# Convert mask to a param if it was previously registered as one.
# This is a workaround for the fact that PyTorch does not allow
# assignments of pure tensors to parameter attributes:
if (not isinstance(mask, Parameter)
and '_edge_mask' in module._parameters):
mask = Parameter(mask)

module.explain = True
module._edge_mask = mask
module._loop_mask = loop_mask
Expand All @@ -40,18 +45,12 @@ def set_hetero_masks(
for module in model.modules():
if isinstance(module, torch.nn.ModuleDict):
for edge_type in mask_dict.keys():

mask = mask_dict[edge_type]

if not isinstance(mask, Parameter):
mask = Parameter(mask)

# TODO (jinu) Use common function get `str_edge_type`.
str_edge_type = '__'.join(edge_type)
if str_edge_type in module:
set_masks(
module[str_edge_type],
mask,
mask_dict[edge_type],
edge_index_dict[edge_type],
apply_sigmoid=apply_sigmoid,
)
Expand Down

0 comments on commit 51bb206

Please sign in to comment.