diff --git a/bagua/torch_api/__init__.py b/bagua/torch_api/__init__.py index 84e1a7570..e4205b76f 100644 --- a/bagua/torch_api/__init__.py +++ b/bagua/torch_api/__init__.py @@ -2,9 +2,7 @@ """ The Bagua communication library PyTorch interface. """ -from distutils.errors import ( - DistutilsPlatformError, -) +from distutils.errors import DistutilsPlatformError try: import torch @@ -29,32 +27,33 @@ init_process_group, send, recv, - broadcast, - reduce, - reduce_inplace, - gather, - gather_inplace, - scatter, - scatter_inplace, - allreduce, - allreduce_inplace, allgather, allgather_inplace, + allreduce, + allreduce_inplace, alltoall, alltoall_inplace, alltoall_v, alltoall_v_inplace, + barrier, + broadcast, + gather, + gather_inplace, + reduce, + reduce_inplace, reduce_scatter, reduce_scatter_inplace, + scatter, + scatter_inplace, ReduceOp, ) from .distributed import BaguaModule # noqa: F401 from .tensor import BaguaTensor # noqa: F401 from .env import ( # noqa: F401 - get_rank, - get_world_size, get_local_rank, get_local_size, + get_rank, + get_world_size, ) from . import contrib # noqa: F401 from . import communication # noqa: F401