diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 947b00d7ab61e..ecf638c2beb90 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -126,24 +126,25 @@ def apply_to_collection( memo = {} for field in dataclasses.fields(data): field_value = getattr(data, field.name) - fields[field.name] = field_value + fields[field.name] = (field_value, field.init) memo[id(field_value)] = field_value result = deepcopy(data, memo=memo) # apply function to each field - for field_name, field_value in fields.items(): - v = apply_to_collection( - field_value, - dtype, - function, - *args, - wrong_dtype=wrong_dtype, - include_none=include_none, - **kwargs, - ) - if include_none or v is not None: - setattr(result, field_name, v) - else: # retain old value - setattr(result, field_name, getattr(data, field_name)) + for field_name, (field_value, field_init) in fields.items(): + if field_init: + v = apply_to_collection( + field_value, + dtype, + function, + *args, + wrong_dtype=wrong_dtype, + include_none=include_none, + **kwargs, + ) + if include_none or v is not None: + setattr(result, field_name, v) + else: # retain old value + setattr(result, field_name, getattr(data, field_name)) return result # data is neither of dtype, nor a collection