Skip to content

Commit

Permalink
Change mutable default arguments to None (#6376)
Browse files Browse the repository at this point in the history
Reading PyG code I accidently saw that class and noticed that it uses
mutable default arguments what is generally considered a bad practice
and can lead to a lot of problems. I believe it's not a desired
behaviour (am I wrong?).

Problems can occur when the class is extended in the future such that it
changes the value of any of those lists, the default value for new
instances of this class would also be affected. The other possibility is
that user changes any of the `follow_batch`, `exclude_keys` fields
(those doesn't start with an underscore, so it's kinda possible) after
the object is initialized, what's also going to change the default value
for new instances.

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
bwroblew and rusty1s authored Jan 9, 2023
1 parent 8e4f967 commit 19d5cbc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
### Changed
- Fix the default arguments of `DataParallel` class ([#6376](https://github.com/pyg-team/pytorch_geometric/pull/6376))
- Fix `ImbalancedSampler` on sliced `InMemoryDataset` ([#6374](https://github.com/pyg-team/pytorch_geometric/pull/6374))
- Breaking Change: Changed the interface and implementation of `GraphMultisetTransformer` ([#6343](https://github.com/pyg-team/pytorch_geometric/pull/6343))
- Fixed the approximate PPR variant in `transforms.GDC` to not crash on graphs with isolated nodes ([#6242](https://github.com/pyg-team/pytorch_geometric/pull/6242))
Expand Down
10 changes: 5 additions & 5 deletions torch_geometric/nn/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ class DataParallel(torch.nn.DataParallel):
output_device (int or torch.device): Device location of output.
(default: :obj:`device_ids[0]`)
follow_batch (list or tuple, optional): Creates assignment batch
vectors for each key in the list. (default: :obj:`[]`)
vectors for each key in the list. (default: :obj:`None`)
exclude_keys (list or tuple, optional): Will exclude each key in the
list. (default: :obj:`[]`)
list. (default: :obj:`None`)
"""
def __init__(self, module, device_ids=None, output_device=None,
follow_batch=[], exclude_keys=[]):
follow_batch=None, exclude_keys=None):
super().__init__(module, device_ids, output_device)
self.src_device = torch.device(f'cuda:{self.device_ids[0]}')
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
self.follow_batch = follow_batch or []
self.exclude_keys = exclude_keys or []

def forward(self, data_list):
""""""
Expand Down

0 comments on commit 19d5cbc

Please sign in to comment.