Skip to content

torch_musa Release v1.1.0

Compare
Choose a tag to compare
@hanhaowen-mt hanhaowen-mt released this 14 Mar 04:52
· 7 commits to main since this release
1a14c97

torch_musa Release Notes

  • Highlights
  • New Features
    • AMP mixed precision training
    • MUSAExtension
    • Pinned memory
    • TensorCore computation
    • CompareTool [Experimental]
  • Supported Operators
  • Documentation
  • Dockers

Highlights

We are excited to release torch_musa v1.1.0 based on PyTorch v2.0.0. In this release, we support more import features, including AMP mixed precision training, MUSAExtension, TensorCore computation, pinned memory and CompareTool. In addition, we have adapted more than 470 operators, improved DDP module and implemented more quantization operators. With torch_musa, users can easily accelerate AI applications on Moore Threads graphics cards.

This release is due to the efforts of engineers in Moore Threads AI Team and other departments. We sincerely hope that everyone can continue to pay attention to our work and participate in it, and witness the fast iteration of torch_musa and Moore Threads graphics cards together.

New Features

AMP mixed precision training

Now we support mixed precision training of BF16 and FP16. However, it is worth noting that S80 and S3000 only support fp16, while S4000 supports both fp16 and bf16, and the interface is completely consistent with PyTorch. Users can use AMP like the following code:

# low_dtype can be torch.float16 or torch.bfloat16
def train_in_amp(low_dtype=torch.float16):
    set_seed()
    model = SimpleModel().to(DEVICE)
    criterion = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

    # create the scaler object
    scaler = torch.musa.amp.GradScaler()

    inputs = torch.randn(6, 5).to(DEVICE)  # 将数据移至GPU
    targets = torch.randn(6, 3).to(DEVICE)
    for step in range(20):
        optimizer.zero_grad()
        # create autocast environment
        with torch.musa.amp.autocast(dtype=low_dtype):
            outputs = model(inputs)
            assert outputs.dtype == low_dtype
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    return loss

MUSAExtension

MUSAExtension and CUDAExtension are basically the same, except that MUSAExtension needs to manually add a dynamic library to the dynamic library search path. For detailed usage, please refer to torch_musa/torch_musa/utils/README.md and the developer documentation. This issue will be resolved in the next version.

Pinned memory

Pinned memory now is supported by torch_musa, the following code can utilize it.

cpu_tensor = torch.rand(shape, dtype=torch.float32).pin_memory("musa")
gpu_tensor = cpu_tensor.to("musa", non_blocking=True)

TensorCore computation

The S4000 has tensorcore, therefore it supports TF32 format calculations. Users can utilize TF32 for acceleration using the following code:

with torch.backends.mudnn.flags(allow_tf32=True):
      # your train code.

CompareTool [Experimental]

CompareTool is an experimental tool aimed at automatically comparing the computation results between musa and cpu, thereby facilitating the debugging process. For detailed usage, please refer to torch_musa/utils/README.md

Supported Operators

More than 470 operators are supported in torch_musa.

Documentation

We provide developer guide for developers, which describes the development environment preparation and some development steps in detail.

Dockers

Release docker image and development docker image are available now.

[NOTE]: If you want to compile torch_musa without using the provided docker image, please download the rc2.0.0 Intel CPU_Ubuntu underlying software stack in https://developer.mthreads.com/sdk/download/musa?equipment=&os=&driverVersion=&version=

[NOTE]:

- When installing following released whl package, please remove the device name. For example,
- pip install torch-2.0.0-cp310-cp310-linux_x86_64.whl