Skip to content

Commit

Permalink
[CodeStyle][ruff] clean some I001 step: 15 (#63794)
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Apr 24, 2024
1 parent d50072e commit 3ecff11
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 86 deletions.
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,3 @@ known-first-party = ["paddle"]
"test/dygraph_to_static/test_loop.py" = ["C416", "F821"]
# Ignore unnecessary lambda in dy2st unittest test_lambda
"test/dygraph_to_static/test_lambda.py" = ["PLC3002"]


# temp ignore isort
"python/paddle/distributed/__init__.py" = ["I001"]
136 changes: 64 additions & 72 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,107 +13,99 @@
# limitations under the License.

import atexit # noqa: F401
from . import io
from .spawn import spawn
from .launch.main import launch
from .parallel import ( # noqa: F401
init_parallel_env,
get_rank,
get_world_size,
ParallelEnv,
DataParallel,
)
from .parallel_with_gloo import (
gloo_init_parallel_env,
gloo_barrier,
gloo_release,
)

from paddle.distributed.fleet.dataset import InMemoryDataset, QueueDataset
from paddle.base.core import Placement, ReduceType
from paddle.distributed.fleet.base.topology import ParallelMode
from paddle.distributed.fleet.dataset import InMemoryDataset, QueueDataset

from . import (
cloud_utils, # noqa: F401
io,
rpc, # noqa: F401
)
from .auto_parallel import shard_op # noqa: F401
from .auto_parallel.api import (
DistAttr,
DistModel,
ShardingStage1,
ShardingStage2,
ShardingStage3,
Strategy,
dtensor_from_fn,
reshard,
shard_dataloader,
shard_layer,
shard_optimizer,
shard_scaler,
shard_tensor,
to_static,
unshard_dtensor,
)
from .auto_parallel.placement_type import (
Partial,
Replicate,
Shard,
)
from .auto_parallel.process_mesh import ProcessMesh
from .checkpoint.load_state_dict import load_state_dict
from .checkpoint.save_state_dict import save_state_dict
from .collective import (
split,
new_group,
is_available,
new_group,
split,
)
from .communication import ( # noqa: F401
stream,
P2POp,
ReduceOp,
all_gather,
all_gather_object,
all_reduce,
alltoall,
alltoall_single,
barrier,
batch_isend_irecv,
broadcast,
broadcast_object_list,
reduce,
send,
scatter,
destroy_process_group,
gather,
scatter_object_list,
get_backend,
get_group,
irecv,
is_initialized,
isend,
recv,
irecv,
batch_isend_irecv,
P2POp,
reduce,
reduce_scatter,
is_initialized,
destroy_process_group,
get_group,
scatter,
scatter_object_list,
send,
stream,
wait,
barrier,
get_backend,
)

from .auto_parallel.process_mesh import ProcessMesh

from paddle.base.core import ReduceType, Placement
from .auto_parallel.placement_type import (
Shard,
Replicate,
Partial,
)

from .auto_parallel import shard_op # noqa: F401

from .auto_parallel.api import (
DistAttr,
shard_tensor,
dtensor_from_fn,
reshard,
shard_dataloader,
shard_layer,
shard_optimizer,
shard_scaler,
ShardingStage1,
ShardingStage2,
ShardingStage3,
to_static,
Strategy,
DistModel,
unshard_dtensor,
)

from .fleet import BoxPSDataset # noqa: F401

from .entry_attr import (
ProbabilityEntry,
CountFilterEntry,
ProbabilityEntry,
ShowClickEntry,
)

from . import cloud_utils # noqa: F401

from .fleet import BoxPSDataset # noqa: F401
from .launch.main import launch
from .parallel import ( # noqa: F401
DataParallel,
ParallelEnv,
get_rank,
get_world_size,
init_parallel_env,
)
from .parallel_with_gloo import (
gloo_barrier,
gloo_init_parallel_env,
gloo_release,
)
from .sharding import ( # noqa: F401
group_sharded_parallel,
save_group_sharded_model,
)

from . import rpc # noqa: F401

from .checkpoint.save_state_dict import save_state_dict
from .checkpoint.load_state_dict import load_state_dict
from .spawn import spawn

__all__ = [
"io",
Expand Down
17 changes: 7 additions & 10 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from paddle.distributed.auto_parallel.placement_type import (
to_placements,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.completion import (
mark_as_sharding_propagation_skip_op,
)
Expand Down Expand Up @@ -443,7 +444,7 @@ def reshard(dist_tensor, mesh, placements):

def shard_layer(
layer: nn.Layer,
process_mesh: dist.ProcessMesh,
process_mesh: ProcessMesh,
shard_fn: Callable = None,
input_fn: Callable = None,
output_fn: Callable = None,
Expand Down Expand Up @@ -523,13 +524,13 @@ def output_fn(outputs, process_mesh) -> list(paddle.Tensor)
raise ValueError("The argument `process_mesh` cannot be empty.")

# Check the legality of process_mesh
if not isinstance(process_mesh, dist.ProcessMesh):
if not isinstance(process_mesh, ProcessMesh):
raise ValueError(
"The argument `process_mesh` is not `dist.ProcessMesh` type."
)

def replicate_layer_params_and_buffers(
layer: nn.Layer, mesh: dist.ProcessMesh
layer: nn.Layer, mesh: ProcessMesh
) -> None:
for key, param in layer._parameters.items():
if param is not None and not param.is_dist():
Expand Down Expand Up @@ -2046,7 +2047,7 @@ def build_distributed_tensor(local_tensor, dist_attr):
)
else:
raise ValueError(f"dim {dim} is not supported.")
mesh = dist.ProcessMesh(
mesh = ProcessMesh(
np.array(dist_attr["process_group"]).reshape(
dist_attr["process_shape"]
)
Expand Down Expand Up @@ -2346,9 +2347,7 @@ class ShardDataloader:
def __init__(
self,
dataloader: paddle.io.DataLoader,
meshes: Union[
dist.ProcessMesh, List[dist.ProcessMesh], Tuple[dist.ProcessMesh]
],
meshes: Union[ProcessMesh, List[ProcessMesh], Tuple[ProcessMesh]],
input_keys: Union[List[str], Tuple[str]] = None,
shard_dims: Union[list, tuple, str, int] = None,
is_dataset_splitted: bool = False,
Expand Down Expand Up @@ -2597,9 +2596,7 @@ def __call__(self):

def shard_dataloader(
dataloader: paddle.io.DataLoader,
meshes: Union[
dist.ProcessMesh, List[dist.ProcessMesh], Tuple[dist.ProcessMesh]
],
meshes: Union[ProcessMesh, List[ProcessMesh], Tuple[ProcessMesh]],
input_keys: Union[List[str], Tuple[str]] = None,
shard_dims: Union[list, tuple, str, int] = None,
is_dataset_splitted: bool = False,
Expand Down

0 comments on commit 3ecff11

Please sign in to comment.