Skip to content

Add sycl examples #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# C++/CUDA Extensions in PyTorch
# C++/CUDA/SYCL Extensions in PyTorch

An example of writing a C++/CUDA extension for PyTorch. See
An example of writing a C++/CUDA/Sycl extension for PyTorch. See

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhaoqiongZ : change a title, it reads "C++/CUDA" only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

[here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial.
This repo demonstrates how to write an example `extension_cpp.ops.mymuladd`
custom op that has both custom CPU and CUDA kernels.
custom op that has both custom CPU and CUDA/Sycl kernels.

The examples in this repo work with PyTorch 2.4+.

> **Note:**
`SYCL` serves as the backend programming language for Intel GPUs (device label `xpu`). For configuration details, see:
[Getting Started on Intel GPUs](https://docs.pytorch.org/docs/main/notes/get_start_xpu.html).

The examples in this repo work with PyTorch 2.4 or later for C++/CUDA & PyTorch 2.8 or later for Sycl.

To build:
```
Expand Down
191 changes: 191 additions & 0 deletions extension_cpp/csrc/sycl/muladd.sycl
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
// Copyright (c) 2025 Intel Corporation

#include <c10/xpu/XPUStream.h>
#include <sycl/sycl.hpp>
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>

namespace extension_cpp {


// MulAdd Kernel: result = a * b + c
static void muladd_kernel(
int numel, const float* a, const float* b, float c, float* result,
const sycl::nd_item<1>& item) {
int idx = item.get_global_id(0);
if (idx < numel) {
result[idx] = a[idx] * b[idx] + c;
}
}

// Mul Kernel: result = a * b
static void mul_kernel(
int numel, const float* a, const float* b, float* result,
const sycl::nd_item<1>& item) {
int idx = item.get_global_id(0);
if (idx < numel) {
result[idx] = a[idx] * b[idx];
}
}

// Add Kernel: result = a + b
static void add_kernel(
int numel, const float* a, const float* b, float* result,
const sycl::nd_item<1>& item) {
int idx = item.get_global_id(0);
if (idx < numel) {
result[idx] = a[idx] + b[idx];
}
}


class MulAddKernelFunctor {
public:
MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result)
: numel(_numel), a(_a), b(_b), c(_c), result(_result) {}

void operator()(const sycl::nd_item<1>& item) const {
muladd_kernel(numel, a, b, c, result, item);
}

private:
int numel;
const float* a;
const float* b;
float c;
float* result;
};

class MulKernelFunctor {
public:
MulKernelFunctor(int _numel, const float* _a, const float* _b, float* _result)
: numel(_numel), a(_a), b(_b), result(_result) {}

void operator()(const sycl::nd_item<1>& item) const {
mul_kernel(numel, a, b, result, item);
}

private:
int numel;
const float* a;
const float* b;
float* result;
};

class AddKernelFunctor {
public:
AddKernelFunctor(int _numel, const float* _a, const float* _b, float* _result)
: numel(_numel), a(_a), b(_b), result(_result) {}

void operator()(const sycl::nd_item<1>& item) const {
add_kernel(numel, a, b, result, item);
}

private:
int numel;
const float* a;
const float* b;
float* result;
};


at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");

at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = at::empty_like(a_contig);

const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* res_ptr = result.data_ptr<float>();
int numel = a_contig.numel();

sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
constexpr int threads = 256;
int blocks = (numel + threads - 1) / threads;

queue.submit([&](sycl::handler& cgh) {
cgh.parallel_for<MulAddKernelFunctor>(
sycl::nd_range<1>(blocks * threads, threads),
MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast<float>(c), res_ptr)
);
});
return result;
}

at::Tensor mymul_xpu(const at::Tensor& a, const at::Tensor& b) {
TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");

at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = at::empty_like(a_contig);

const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* res_ptr = result.data_ptr<float>();
int numel = a_contig.numel();

sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
constexpr int threads = 256;
int blocks = (numel + threads - 1) / threads;

queue.submit([&](sycl::handler& cgh) {
cgh.parallel_for<MulKernelFunctor>(
sycl::nd_range<1>(blocks * threads, threads),
MulKernelFunctor(numel, a_ptr, b_ptr, res_ptr)
);
});
return result;
}

void myadd_out_xpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
TORCH_CHECK(b.sizes() == out.sizes(), "b and out must have the same shape");
TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");
TORCH_CHECK(out.device().is_xpu(), "out must be an XPU tensor");

at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();

const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* out_ptr = out.data_ptr<float>();
int numel = a_contig.numel();

sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
constexpr int threads = 256;
int blocks = (numel + threads - 1) / threads;

queue.submit([&](sycl::handler& cgh) {
cgh.parallel_for<AddKernelFunctor>(
sycl::nd_range<1>(blocks * threads, threads),
AddKernelFunctor(numel, a_ptr, b_ptr, out_ptr)
);
});
}

// ==================================================
// Register Sycl Implementations to Torch Library
// ==================================================

TORCH_LIBRARY_IMPL(extension_cpp, XPU, m) {
m.impl("mymuladd", mymuladd_xpu);
m.impl("mymul", mymul_xpu);
m.impl("myadd_out", myadd_out_xpu);
}

} // namespace extension_cpp
110 changes: 82 additions & 28 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,118 @@
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch
import glob

from setuptools import find_packages, setup

from torch.utils.cpp_extension import (
CppExtension,
CUDAExtension,
BuildExtension,
CUDA_HOME,
)
# Conditional import for SyclExtension
try:
from torch.utils.cpp_extension import SyclExtension
except ImportError:
SyclExtension = None

library_name = "extension_cpp"

# NOTE: PyTorch versions < 2.6 use torch.extension.h which depends on pybind11,
# and pybind11 requires full access to Python's C API (including internal
# structures like PyObject). This makes it incompatible with Py_LIMITED_API
# which restricts access to only stable Python C API symbols.
# For Py_LIMITED_API compatibility, use torch.library.h instead (PyTorch 2.6+).
if torch.__version__ >= "2.6.0":
py_limited_api = True
else:
py_limited_api = False


def get_extensions():
debug_mode = os.getenv("DEBUG", "0") == "1"
use_cuda = os.getenv("USE_CUDA", "1") == "1"
if debug_mode:
print("Compiling in debug mode")

use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension
# Determine backend (CUDA, SYCL, or C++)
use_cuda = os.getenv("USE_CUDA", "auto")
use_sycl = os.getenv("USE_SYCL", "auto")

# Auto-detect CUDA
if use_cuda == "auto":
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
else:
use_cuda = use_cuda.lower() == "true" or use_cuda == "1"

# Auto-detect SYCL
if use_sycl == "auto":
use_sycl = SyclExtension is not None and torch.xpu.is_available()
else:
use_sycl = use_sycl.lower() == "true" or use_sycl == "1"

if use_cuda and use_sycl:
raise RuntimeError("Cannot enable both CUDA and SYCL backends simultaneously.")

print("use cuda & use sycl",use_cuda, use_sycl)

extension = None
if use_cuda:
extension = CUDAExtension
print("Building with CUDA backend")
elif use_sycl and SyclExtension is not None:
extension = SyclExtension
print("Building with SYCL backend")
else:
extension = CppExtension
print("Building with C++ backend")

# Compilation arguments
extra_link_args = []
extra_compile_args = {
"cxx": [
extra_compile_args = {"cxx": []}
if extension == CUDAExtension:
print("CUDA is available, compile using CUDAExtension")
extra_compile_args = {
"cxx": ["-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-DPy_LIMITED_API=0x03090000"],
"nvcc": ["-O3" if not debug_mode else "-O0"]
}
elif extension == SyclExtension:
print("XPU is available, compile using SyclExtension")
extra_compile_args = {
"cxx": ["-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-DPy_LIMITED_API=0x03090000"],
"sycl": ["-O3" if not debug_mode else "-O0"]
}
else:
extra_compile_args["cxx"] = [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-DPy_LIMITED_API=0x03090000", # min CPython version 3.9
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
],
}
"-DPy_LIMITED_API=0x03090000"]

if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])

if extension == CUDAExtension:
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])
elif extension == SyclExtension:
extra_compile_args["sycl"].append("-g")
extra_link_args.extend(["-O0", "-g"])

# Source files collection
this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, library_name, "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))
backend_sources = []
if extension == CUDAExtension:
backend_dir = os.path.join(extensions_dir, "cuda")
backend_sources = glob.glob(os.path.join(backend_dir, "*.cu"))
elif extension == SyclExtension:
backend_dir = os.path.join(extensions_dir, "sycl")
backend_sources = glob.glob(os.path.join(backend_dir, "*.sycl"))

if use_cuda:
sources += cuda_sources
sources += backend_sources

# Construct extension
ext_modules = [
extension(
f"{library_name}._C",
Expand All @@ -71,17 +126,16 @@ def get_extensions():

return ext_modules


setup(
name=library_name,
version="0.0.1",
packages=find_packages(),
ext_modules=get_extensions(),
install_requires=["torch"],
description="Example of PyTorch C++ and CUDA extensions",
description="Example of PyTorch C++ and CUDA/Sycl extensions",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch/extension-cpp",
url="https://github.com/pytorch/extension-cpp",
cmdclass={"build_ext": BuildExtension},
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
)
Loading