Skip to content

Commit

Permalink
Conversion of 'sycl.constructor(%0, %1) {type = @range}' (#51)
Browse files Browse the repository at this point in the history
This PR adds support for converting the `sycl.constructor(%0, %1) {type = @range}` operation (representing construction of a sycl::range<n> object) to a call to the appropriate `sycl::range<n>` constructor.

Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
  • Loading branch information
etiotto committed Sep 6, 2022
1 parent 7b3ac03 commit 5d6c920
Show file tree
Hide file tree
Showing 4 changed files with 408 additions and 128 deletions.
62 changes: 50 additions & 12 deletions mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SYCLFuncRegistry;
/// needs to be created in SYCLFuncRegistry constructor.
class SYCLFuncDescriptor {
friend class SYCLFuncRegistry;
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &, const SYCLFuncDescriptor &);

public:
/// Enumerates SYCL functions.
Expand All @@ -44,26 +45,51 @@ class SYCLFuncDescriptor {
Id1CtorSizeT, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type)
Id2CtorSizeT, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type)
Id3CtorSizeT, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type)
Id1CtorRange, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long)
Id2CtorRange, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long)
Id3CtorRange, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long)
Id1CtorItem, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long)
Id2CtorItem, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long)
Id3CtorItem, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long)
Id1Ctor2SizeT, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long)
Id2Ctor2SizeT, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long)
Id3Ctor2SizeT, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long)
Id1Ctor3SizeT, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long)
Id2Ctor3SizeT, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long)
Id3Ctor3SizeT, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long)
Id1CopyCtor, // sycl::id<1>::id(sycl::id<1> const&)
Id2CopyCtor, // sycl::id<2>::id(sycl::id<2> const&)
Id3CopyCtor, // sycl::id<3>::id(sycl::id<3> const&)

// Member functions for ..TODO..
// Member functions for the sycl::Range<n> class.
Range1CtorDefault, // sycl::Range<1>::range()
Range2CtorDefault, // sycl::range<2>::range()
Range3CtorDefault, // sycl::range<3>::range()
Range1CtorSizeT, // sycl::range<1>::range<1>(std::enable_if<(1)==(1), unsigned long>::type)
Range2CtorSizeT, // sycl::range<2>::range<2>(std::enable_if<(2)==(2), unsigned long>::type)
Range3CtorSizeT, // sycl::range<3>::range<3>(std::enable_if<(3)==(3), unsigned long>::type)
Range1Ctor2SizeT, // sycl::range<1>::range<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long)
Range2Ctor2SizeT, // sycl::range<2>::range<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long)
Range3Ctor2SizeT, // sycl::range<3>::range<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long)
Range1Ctor3SizeT, // sycl::range<1>::range<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long)
Range2Ctor3SizeT, // sycl::range<2>::range<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long)
Range3Ctor3SizeT, // sycl::range<3>::range<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long)
Range1CopyCtor, // sycl::range<1>::range(sycl::range<1> const&)
Range2CopyCtor, // sycl::range<2>::range(sycl::range<2> const&)
Range3CopyCtor, // sycl::range<3>::range(sycl::range<3> const&)
};
// clang-format on

/// Enumerates the kind of FuncId.
enum class FuncIdKind {
Unknown,
IdCtor, // any sycl::id<n> constructors
IdCtor, // any sycl::id<n> constructors.
RangeCtor // any sycl::range<n> constructors.
};

/// Returns the funcIdKind given a \p funcId.
static FuncIdKind getFuncIdKind(FuncId funcId);

/// Retuns a descriptive name for the given \p funcIdKind.
static std::string funcIdKindToName(FuncIdKind funcIdKind);

/// Retuns the FuncIdKind given a descriptive \p name.
static FuncIdKind nameToFuncIdKind(Twine name);

// Call the SYCL constructor identified by \p id with the given \p args.
static Value call(FuncId id, ValueRange args,
const SYCLFuncRegistry &registry, OpBuilder &b,
Expand All @@ -73,8 +99,11 @@ class SYCLFuncDescriptor {
/// Private constructor: only available to 'SYCLFuncRegistry'.
SYCLFuncDescriptor(FuncId id, StringRef name, Type outputTy,
ArrayRef<Type> argTys)
: id(id), name(name), outputTy(outputTy),
argTys(argTys.begin(), argTys.end()) {}
: funcId(id), funcIdKind(getFuncIdKind(id)), name(name),
outputTy(outputTy), argTys(argTys.begin(), argTys.end()) {
assert(funcId != FuncId::Unknown && "Illegal function id");
assert(funcIdKind != FuncIdKind::Unknown && "Illegal function id kind");
}

/// Inject the declaration for this function into the module.
void declareFunction(ModuleOp &module, OpBuilder &b);
Expand All @@ -83,13 +112,22 @@ class SYCLFuncDescriptor {
static bool isIdCtor(FuncId funcId);

private:
FuncId id; // unique identifier for a SYCL function
FuncId funcId = FuncId::Unknown; // SYCL function identifier
FuncIdKind funcIdKind = FuncIdKind::Unknown; // SYCL function kind
StringRef name; // SYCL function name
Type outputTy; // SYCL function output type
SmallVector<Type, 4> argTys; // SYCL function arguments types
FlatSymbolRefAttr funcRef; // Reference to the SYCL function declaration
FlatSymbolRefAttr funcRef; // Reference to the SYCL function
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const SYCLFuncDescriptor &desc) {
os << "funcId=" << (int)desc.funcId
<< ", funcIdKind=" << SYCLFuncDescriptor::funcIdKindToName(desc.funcIdKind)
<< ", name='" << desc.name.str() << "')";
return os;
}

/// \class SYCLFuncRegistry
/// Singleton class representing the set of SYCL functions callable from the
/// compiler.
Expand Down
Loading

0 comments on commit 5d6c920

Please sign in to comment.