Skip to content

Commit

Permalink
[CodeStyle][UP031] fix some `/python/paddle/distributed/auto_parallel…
Browse files Browse the repository at this point in the history
…/*` - part 11 (PaddlePaddle#65571)
  • Loading branch information
gouzil authored Jun 29, 2024
1 parent c2f4e3a commit e9a4c42
Show file tree
Hide file tree
Showing 17 changed files with 64 additions and 71 deletions.
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2088,7 +2088,7 @@ def _make_feeds(self, data_list):
if len(feed_name_list) != len(data_list):
raise ValueError(
"The input data and feed_list are not consistent."
"The model takes %s as input" % (str(feed_name_list))
f"The model takes {feed_name_list} as input"
)

def _to_lodtensor(tensor: paddle.Tensor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(ctx, *args, **kwargs):
):
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
), f"forward op [{src_op}] don't have dist attribute !"

if (
len(kwargs.get('fixed_seed_offset', [])) > 0
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def determinate_rng(
# instead of using offsets to coordinate seed across devices.
if len(process_mesh.shape) > 4:
raise NotImplementedError(
f"Auto Parallel Random Control for Mesh's rank > 4 is NOT supported! Got {str(process_mesh)}"
f"Auto Parallel Random Control for Mesh's rank > 4 is NOT supported! Got {process_mesh}"
)
global _basic_seed
seed_ = _basic_seed
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/distributed/auto_parallel/static/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def _is_print(self):

def _updates(self, logs, mode):
values = []
metrics = getattr(self, '%s_metrics' % (mode))
progbar = getattr(self, '%s_progbar' % (mode))
steps = getattr(self, '%s_step' % (mode))
metrics = getattr(self, f'{mode}_metrics')
progbar = getattr(self, f'{mode}_progbar')
steps = getattr(self, f'{mode}_step')

for k in metrics:
if k in logs:
Expand All @@ -113,8 +113,8 @@ def _updates(self, logs, mode):
for k in out_logs:
values.append((k, out_logs[k]))

if self.verbose == 3 and hasattr(self, '_%s_timer' % (mode)):
timer = getattr(self, '_%s_timer' % (mode))
if self.verbose == 3 and hasattr(self, f'_{mode}_timer'):
timer = getattr(self, f'_{mode}_timer')
cnt = timer['count'] if timer['count'] > 0 else 1.0
samples = timer['samples'] if timer['samples'] > 0 else 1.0
values.append(
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/distributed/auto_parallel/static/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,7 +1717,7 @@ def _get_op_by_id(ops, id):
continue

else:
raise ValueError(f"got unexpected op [{str(grad_op.type)}]")
raise ValueError(f"got unexpected op [{grad_op.type}]")

self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr
Expand Down Expand Up @@ -1911,7 +1911,7 @@ def infer_backward_op_partial_status(
)
else:
raise NotImplementedError(
f"Backward Partial is not adapted for {str(grad_op)}"
f"Backward Partial is not adapted for {grad_op}"
)

# resolute partial
Expand Down Expand Up @@ -2151,7 +2151,7 @@ def infer_backward_op_partial_status(
grad_op, grad_op_dist_attr
)
else:
raise ValueError(f"got unexpected op [{str(grad_op.type)}]")
raise ValueError(f"got unexpected op [{grad_op.type}]")

def complete_update_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the update phase for parallel program."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _alloc_and_fill_var(var_name):
)
)
logger.info(
f'[+] var: "{var_name}", shape={str(var_shape)}, dtype="{str(var_dtype)}".\n'
f'[+] var: "{var_name}", shape={var_shape}, dtype="{var_dtype}".\n'
) if verbose else None
np_dtype = (
convert_dtype(var_dtype)
Expand Down Expand Up @@ -264,17 +264,16 @@ def measure_program_real_op_cost(
>>> measure_program_real_op_cost(program, verbose_level=1)
'''

assert isinstance(program, Program), (
'"program" should be a instance of "paddle.base.framework.Program" but got type "%s".'
% type(program).__name__
)
assert isinstance(
program, Program
), f'"program" should be a instance of "paddle.base.framework.Program" but got type "{type(program).__name__}".'
supported_places = [
paddle.CUDAPlace,
]
assert any(
isinstance(place, supported_place)
for supported_place in supported_places
), f'Current place ({str(place)}) does not support runtime profiling. "place" should be one of the following: {str(supported_places)}.'
), f'Current place ({place}) does not support runtime profiling. "place" should be one of the following: {supported_places}.'
assert isinstance(run_iters, int) and run_iters >= 1, (
'Invalid parameter run_iters set. run_iters '
'should be an integer >= 1.'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def data_generator():
batch_size = array.shape[0]
assert (
batch_size % self.dp_world_sizes[i] == 0
), f"batch_size [{str(batch_size)}] is not divisible by dp_world_size [{str(self.dp_world_sizes[i])}]"
), f"batch_size [{batch_size}] is not divisible by dp_world_size [{self.dp_world_sizes[i]}]"
partial_data.append(
np.split(array, self.dp_world_sizes[i])[
self.dp_ranks[i]
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,7 +2064,7 @@ def _validate_batch_size(self, batch_size):
), f"DistributedBatchSampler only support one data parallel group, but got [{len(set(self._dp_world_sizes))}] different data parallel groups"
assert (
batch_size % self._dp_world_sizes[0] == 0
), f"batch_size [{str(batch_size)}] is not divisible by dp_world_size [{str(self._dp_world_sizes[0])}]"
), f"batch_size [{batch_size}] is not divisible by dp_world_size [{self._dp_world_sizes[0]}]"
return batch_size // self._dp_world_sizes[0]
else:
assert (
Expand Down Expand Up @@ -2165,7 +2165,7 @@ def _set_state_dict(self, mode, strict, state_dict, dist_attr):
continue
if param_array.dtype != state_dict[name].dtype:
self._logger.info(
f"cast {name}'s dtype from '{str(state_dict[name].dtype)}' to '{str(param_array.dtype)}'"
f"cast {name}'s dtype from '{state_dict[name].dtype}' to '{param_array.dtype}'"
)
state_dict[name] = state_dict[name].astype(param_array.dtype)
program.set_state_dict(state_dict)
Expand Down
19 changes: 8 additions & 11 deletions python/paddle/distributed/auto_parallel/static/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,11 @@ def build_program(self, mode):
# skip if we has already built program.
if self.build_info.has_cache(mode, True):
self._logger.info(
"Already build program with mode = %s, use cached program."
% mode
f"Already build program with mode = {mode}, use cached program."
)
return

self._logger.info("start to build program for mode = %s." % mode)
self._logger.info(f"start to build program for mode = {mode}.")
input_spec = [self.inputs_spec, self.labels_spec]
static_func = to_static(
self.static_func(), input_spec=input_spec, full_graph=True
Expand Down Expand Up @@ -332,14 +331,12 @@ def _verify_optimizer(self, optimizer):
assert hasattr(
optimizer, "minimize"
), "Optimizer must have minimize() method."
assert self.proxy_layer.mode == 'train', (
"Required mode == 'train', but received '%s'"
% self.proxy_layer.mode
)
assert len(self.loss_vars) == 1, (
"Required len(loss_vars) == 1, but received len(loss_vars) = %s"
% len(self.loss_vars)
)
assert (
self.proxy_layer.mode == 'train'
), f"Required mode == 'train', but received '{self.proxy_layer.mode}'"
assert (
len(self.loss_vars) == 1
), f"Required len(loss_vars) == 1, but received len(loss_vars) = {len(self.loss_vars)}"

def to(self, mode):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,11 +749,11 @@ def update_op_dims_mapping(
changed = False
if len(input_arg_names) != len(infered_input_dims_mappings):
warnings.warn(
f"dims mapping is NOT Match, infered [{len(infered_input_dims_mappings)}], original: [{len(input_arg_names)}]; dist op: [{str(dist_op)}]"
f"dims mapping is NOT Match, infered [{len(infered_input_dims_mappings)}], original: [{len(input_arg_names)}]; dist op: [{dist_op}]"
)
if len(output_arg_names) != len(infered_output_dims_mappings):
warnings.warn(
f"dims mapping is NOT Match, infered [{len(infered_output_dims_mappings)}], original: [{len(output_arg_names)}]; dist op: [{str(dist_op)}]"
f"dims mapping is NOT Match, infered [{len(infered_output_dims_mappings)}], original: [{len(output_arg_names)}]; dist op: [{dist_op}]"
)

for i in range(len(input_arg_names)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def backward(ctx, *args, **kwargs):
dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert (
dist_attr is not None
), f"backward op [{str(backward_op)}] don't have dist attribute !"
), f"backward op [{backward_op}] don't have dist attribute !"

assert rank_id in dist_attr.process_mesh.process_ids

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def forward(ctx, *args, **kwargs):
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
), f"forward op [{src_op}] don't have dist attribute !"

# check validation of inputs / outputs
assert 'Logits' in kwargs, "input [Logits] is not given"
Expand Down Expand Up @@ -230,7 +230,7 @@ def backward(ctx, *args, **kwargs):

assert (
op_dist_attr is not None
), f"backward op [{str(backward_op)}] don't have dist attribute !"
), f"backward op [{backward_op}] don't have dist attribute !"

# check validation of inputs / outputs
assert 'Softmax' in kwargs, "input [Logits] is not given"
Expand Down Expand Up @@ -287,7 +287,7 @@ def forward(ctx, *args, **kwargs):
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
), f"forward op [{src_op}] don't have dist attribute !"

# check validation of inputs / outputs
assert 'Logits' in kwargs, "input [Logits] is not given"
Expand Down Expand Up @@ -397,7 +397,7 @@ def backward(ctx, *args, **kwargs):

assert (
op_dist_attr is not None
), f"backward op [{str(backward_op)}] don't have dist attribute !"
), f"backward op [{backward_op}] don't have dist attribute !"

# check validation of inputs / outputs
assert 'Softmax' in kwargs, "input [Softmax] is not given"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def update_dims_mapping(dist_op):
for i in range(num_inputs):
assert not is_parameter_related(
input_arg_names[i], main_block
), f"input {input_arg_names[i]} of op {str(dist_op.serial_op)} is parameter, op should not use default rule."
), f"input {input_arg_names[i]} of op {dist_op.serial_op} is parameter, op should not use default rule."
input_specs.append(
get_dist_tensor_spec(dist_op, input_arg_names[i])
)
Expand All @@ -131,7 +131,7 @@ def update_dims_mapping(dist_op):
for i in range(num_outputs):
assert not is_parameter_related(
output_arg_names[i], main_block
), f"output {output_arg_names[i]} of op {str(dist_op.serial_op)} is parameter, op should not use default rule."
), f"output {output_arg_names[i]} of op {dist_op.serial_op} is parameter, op should not use default rule."
output_specs.append(
get_dist_tensor_spec(dist_op, output_arg_names[i], False)
)
Expand Down Expand Up @@ -636,7 +636,7 @@ def backward(ctx, *args, **kwargs):
dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert (
dist_attr is not None
), f"backward op [{str(backward_op)}] don't have dist attribute !"
), f"backward op [{backward_op}] don't have dist attribute !"
rank_id = dist_op_context.rank_id

# check validation of inputs / outputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def forward(ctx, *args, **kwargs):
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
), f"forward op [{src_op}] don't have dist attribute !"

if is_enable_auto_rand_ctrl() and not op_dist_attr.is_recompute:
# check validation of inputs / outputs
Expand Down Expand Up @@ -144,7 +144,7 @@ def forward(ctx, *args, **kwargs):
assert (
pre_op.type == "seed"
and len(pre_op.attr("rng_name")) == 0
), f"found exception op {str(pre_op)}"
), f"found exception op {pre_op}"

# determinate rng
X_var = main_block._var_recursive(kwargs['X'][0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def update_dims_mapping(dist_op):
input_arg_names = op_desc.input_arg_names()
assert (
len(op_desc.output_arg_names()) == 1
), f"elementwise op [{str(dist_op.serial_op)}] has [{len(op_desc.output_arg_names())}] outputs"
), f"elementwise op [{dist_op.serial_op}] has [{len(op_desc.output_arg_names())}] outputs"
output_arg_name = op_desc.output_arg_names()[0]
num_inputs = len(input_arg_names)

Expand Down
13 changes: 5 additions & 8 deletions python/paddle/distributed/auto_parallel/static/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def __init__(self, dist_context, rank_id=0):
"""
if not isinstance(dist_context, DistributedContext):
raise TypeError(
"dist_context be DistributedContext, got %s here"
% type(dist_context)
f"dist_context be DistributedContext, got {type(dist_context)} here"
)

self._dist_context = dist_context
Expand All @@ -77,8 +76,7 @@ def partition(
):
if not isinstance(serial_main_program, (Program)):
raise TypeError(
"main_program be paddle.framework.Program, got %s here"
% type(serial_main_program)
f"main_program be paddle.framework.Program, got {type(serial_main_program)} here"
)

# check if shard annotated serial program valid
Expand Down Expand Up @@ -118,8 +116,7 @@ def partition_startup_program(
):
if not isinstance(serial_startup_program, (Program)):
raise TypeError(
"dist_context be paddle.framework.Program, got %s here"
% type(serial_startup_program)
f"dist_context be paddle.framework.Program, got {type(serial_startup_program)} here"
)

partitioned_startup_prog = paddle.framework.Program()
Expand Down Expand Up @@ -147,7 +144,7 @@ def partition_startup_program(
output_vars = op.desc.output_arg_names()
assert (
len(output_vars) == 1
), f"initializer should output only ONE variable, but got [{str(op.desc)}]"
), f"initializer should output only ONE variable, but got [{op.desc}]"
assert (
temp_varname_map[output_vars[0]] in var2shape
), f"try to initialize [{output_vars[0]}] which is not a persistable var"
Expand Down Expand Up @@ -357,7 +354,7 @@ def partition_block(self, ref_block, target_block):
)
else:
raise NotImplementedError(
f"partitioner only support forward and backward, optimize ops, but got {str(op)}"
f"partitioner only support forward and backward, optimize ops, but got {op}"
)

def _is_valid_annotated_program(self, program):
Expand Down
Loading

0 comments on commit e9a4c42

Please sign in to comment.