Skip to content

Commit

Permalink
Restore skipping of non-init fields after incorrectly resolved conflict.
Browse files Browse the repository at this point in the history
  • Loading branch information
a-gardner1 committed Nov 6, 2021
1 parent 05ae33c commit 9130a26
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,24 +126,25 @@ def apply_to_collection(
memo = {}
for field in dataclasses.fields(data):
field_value = getattr(data, field.name)
fields[field.name] = field_value
fields[field.name] = (field_value, field.init)
memo[id(field_value)] = field_value
result = deepcopy(data, memo=memo)
# apply function to each field
for field_name, field_value in fields.items():
v = apply_to_collection(
field_value,
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs,
)
if include_none or v is not None:
setattr(result, field_name, v)
else: # retain old value
setattr(result, field_name, getattr(data, field_name))
for field_name, (field_value, field_init) in fields.items():
if field_init:
v = apply_to_collection(
field_value,
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs,
)
if include_none or v is not None:
setattr(result, field_name, v)
else: # retain old value
setattr(result, field_name, getattr(data, field_name))
return result

# data is neither of dtype, nor a collection
Expand Down

0 comments on commit 9130a26

Please sign in to comment.