Skip to content

Commit

Permalink
Add type converters for sycl::item and sycl::item_base (#52)
Browse files Browse the repository at this point in the history
The runtime class of `sycl::item_base`:
```
template <int Dims> struct ItemBase<Dims, true> {
...
  range<Dims> MExtent;
  id<Dims> MIndex;
  id<Dims> MOffset;
...
}
template <int Dims> struct ItemBase<Dims, false> {
...
  range<Dims> MExtent;
  id<Dims> MIndex;
...
}
```
The runtime class of `sycl::item`:
```
template <int dimensions = 1, bool with_offset = true> class item {
...
  detail::ItemBase<dimensions, with_offset> MImpl;
...
}

```
Example of LLVM IR generated directly from clang:
```
%"class.cl::sycl::item" = type { %"struct.cl::sycl::detail::ItemBase" }
%"struct.cl::sycl::detail::ItemBase" = type { %"class.cl::sycl::range", %"class.cl::sycl::id", %"class.cl::sycl::id" }
```

Signed-off-by: Tsang, Whitney <whitney.tsang@intel.com>
  • Loading branch information
whitneywhtsang authored and etiotto committed Sep 6, 2022
1 parent c244dcf commit 7b3ac03
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 26 deletions.
87 changes: 61 additions & 26 deletions mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ template <typename SYCLType> static bool isMemRefOf(const Type &type) {
}

// Returns the element type of 'memref<?xSYCLType>'.
template <typename SYCLType>
static SYCLType getElementType(const Type &type) {
template <typename SYCLType> static SYCLType getElementType(const Type &type) {
assert(isMemRefOf<SYCLType>(type) && "Expecting memref<?xsycl::<type>>");
Type elemType = type.cast<MemRefType>().getElementType();
return elemType.cast<SYCLType>();
Expand Down Expand Up @@ -121,36 +120,74 @@ static Optional<Type> convertRangeType(sycl::RangeType type,
converter);
}

/// Create a LLVM struct type with name \p name, and the converted \p body as
/// the body.
static Optional<Type> convertBodyType(StringRef name,
llvm::ArrayRef<mlir::Type> body,
LLVMTypeConverter &converter) {
auto convertedTy =
LLVM::LLVMStructType::getIdentified(&converter.getContext(), name);
if (!convertedTy.isInitialized()) {
SmallVector<Type> convertedElemTypes;
convertedElemTypes.reserve(body.size());
if (failed(converter.convertTypes(body, convertedElemTypes)))
return llvm::None;
if (failed(convertedTy.setBody(convertedElemTypes, /*isPacked=*/false)))
return llvm::None;
}

return convertedTy;
}

/// Converts SYCL accessor implement device type to LLVM type.
static Optional<Type>
convertAccessorImplDeviceType(sycl::AccessorImplDeviceType type,
LLVMTypeConverter &converter) {
SmallVector<Type> convertedElemTypes;
convertedElemTypes.reserve(type.getBody().size());
if (failed(converter.convertTypes(type.getBody(), convertedElemTypes)))
return llvm::None;

return LLVM::LLVMStructType::getNewIdentified(
&converter.getContext(), "class.cl::sycl::detail::AccessorImplDevice",
convertedElemTypes, /*isPacked=*/false);
return convertBodyType("class.cl::sycl::detail::AccessorImplDevice" +
std::to_string(type.getDimension()),
type.getBody(), converter);
}

/// Converts SYCL accessor type to LLVM type.
static Optional<Type> convertAccessorType(sycl::AccessorType type,
LLVMTypeConverter &converter) {
SmallVector<Type> convertedElemTypes;
convertedElemTypes.reserve(type.getBody().size());
if (failed(converter.convertTypes(type.getBody(), convertedElemTypes)))
return llvm::None;
auto convertedTy = LLVM::LLVMStructType::getIdentified(
&converter.getContext(),
"class.cl::sycl::accessor" + std::to_string(type.getDimension()));
if (!convertedTy.isInitialized()) {
SmallVector<Type> convertedElemTypes;
convertedElemTypes.reserve(type.getBody().size());
if (failed(converter.convertTypes(type.getBody(), convertedElemTypes)))
return llvm::None;

auto ptrTy = LLVM::LLVMPointerType::get(type.getType(), /*addressSpace=*/1);
auto structTy =
LLVM::LLVMStructType::getLiteral(&converter.getContext(), ptrTy);
convertedElemTypes.push_back(structTy);

if (failed(convertedTy.setBody(convertedElemTypes, /*isPacked=*/false)))
return llvm::None;
}

return convertedTy;
}

auto ptrTy = LLVM::LLVMPointerType::get(type.getType(), /*addressSpace=*/1);
auto structTy =
LLVM::LLVMStructType::getLiteral(&converter.getContext(), ptrTy);
convertedElemTypes.push_back(structTy);
/// Converts SYCL item base type to LLVM type.
static Optional<Type> convertItemBaseType(sycl::ItemBaseType type,
LLVMTypeConverter &converter) {
return convertBodyType("class.cl::sycl::detail::ItemBase." +
std::to_string(type.getDimension()) +
(type.getWithOffset() ? ".true" : ".false"),
type.getBody(), converter);
}

return LLVM::LLVMStructType::getNewIdentified(
&converter.getContext(), "class.cl::sycl::accessor", convertedElemTypes,
/*isPacked=*/false);
/// Converts SYCL item type to LLVM type.
static Optional<Type> convertItemType(sycl::ItemType type,
LLVMTypeConverter &converter) {
return convertBodyType("class.cl::sycl::item." +
std::to_string(type.getDimension()) +
(type.getWithOffset() ? ".true" : ".false"),
type.getBody(), converter);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -188,7 +225,7 @@ class ConstructorPattern final
MLIRContext *context = module.getContext();

// Lookup the ctor function to use.
const auto &registry = SYCLFuncRegistry::create(module, rewriter);
const auto &registry = SYCLFuncRegistry::create(module, rewriter);
auto voidTy = LLVM::LLVMVoidType::get(context);
SYCLFuncDescriptor::FuncId funcId =
registry.getFuncId(SYCLFuncDescriptor::FuncIdKind::IdCtor, voidTy,
Expand Down Expand Up @@ -235,12 +272,10 @@ void mlir::sycl::populateSYCLToLLVMTypeConversion(
typeConverter.addConversion(
[&](sycl::IDType type) { return convertIDType(type, typeConverter); });
typeConverter.addConversion([&](sycl::ItemBaseType type) {
llvm_unreachable("SYCLToLLVM - sycl::ItemBaseType not handle (yet)");
return llvm::None;
return convertItemBaseType(type, typeConverter);
});
typeConverter.addConversion([&](sycl::ItemType type) {
llvm_unreachable("SYCLToLLVM - sycl::ItemType not handle (yet)");
return llvm::None;
return convertItemType(type, typeConverter);
});
typeConverter.addConversion([&](sycl::NdItemType type) {
llvm_unreachable("SYCLToLLVM - sycl::NdItemType not handle (yet)");
Expand Down
12 changes: 12 additions & 0 deletions mlir-sycl/test/Conversion/SYCLToLLVM/sycl-types-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
// CHECK: llvm.func @test_accessorImplDevice(%arg0: !llvm.[[ACCESSORIMPLDEVICE_1:struct<"class.cl::sycl::detail::AccessorImplDevice.*", \(]][[ID_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
// CHECK: llvm.func @test_accessor.1(%arg0: !llvm.[[ACCESSOR_1:struct<"class.cl::sycl::accessor.*", \(]][[ACCESSORIMPLDEVICE_1]][[ID_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]], struct<(ptr<i32, 1>)>)>)
// CHECK: llvm.func @test_accessor.2(%arg0: !llvm.[[ACCESSOR_2:struct<"class.cl::sycl::accessor.*", \(]][[ACCESSORIMPLDEVICE_2:struct<"class.cl::sycl::detail::AccessorImplDevice.*", \(]][[ID_2:struct<"class.cl::sycl::id.*", \(]][[ARRAY_2]][[SUFFIX]], [[RANGE_2]][[ARRAY_2]][[SUFFIX]], [[RANGE_2]][[ARRAY_2]][[SUFFIX]][[SUFFIX]], struct<(ptr<i64, 1>)>)>)
// CHECK: llvm.func @test_item_base.true(%arg0: !llvm.[[ITEM_BASE_1_TRUE:struct<"class.cl::sycl::detail::ItemBase.1.true", \(]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
// CHECK: llvm.func @test_item_base.false(%arg0: !llvm.[[ITEM_BASE_2_FALSE:struct<"class.cl::sycl::detail::ItemBase.2.false", \(]][[RANGE_2]][[ARRAY_2]][[SUFFIX]], [[ID_2]][[ARRAY_2]][[SUFFIX]][[SUFFIX]])
// CHECK: llvm.func @test_item(%arg0: !llvm.[[ITEM_1_TRUE:struct<"class.cl::sycl::item.1.true", \(]][[ITEM_BASE_1_TRUE]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]])

module {
func.func @test_array.1(%arg0: !sycl.array<[1], (memref<1xi64>)>) {
Expand All @@ -34,4 +37,13 @@ module {
func.func @test_accessor.2(%arg0: !sycl.accessor<[2, i64, write, global_buffer], (!sycl.accessor_impl_device<[2], (!sycl.id<2>, !sycl.range<2>, !sycl.range<2>)>)>) {
return
}
func.func @test_item_base.true(%arg0: !sycl.item_base<[1, true], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>) {
return
}
func.func @test_item_base.false(%arg0: !sycl.item_base<[2, false], (!sycl.range<2>, !sycl.id<2>)>) {
return
}
func.func @test_item(%arg0: !sycl.item<[1, true], (!sycl.item_base<[1, true], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>)>) {
return
}
}

0 comments on commit 7b3ac03

Please sign in to comment.