Skip to content

Commit

Permalink
Add SYCLAccessorFuncDescriptor class (#55)
Browse files Browse the repository at this point in the history
This PR add the `SYCLAccessorFuncDescriptor ` class. This class represents the ctors and member functions descriptors for the `sycl::accessor` class.

One function is added to the registry (`cl::sycl::accessor<int, 1, (cl::sycl::access::mode)1026, (cl::sycl::access::target)2014, (cl::sycl::access::placeholder)0, cl::sycl::ext::oneapi::accessor_property_list<> >::__init(int AS1*, cl::sycl::range<1>, cl::sycl::range<1>, cl::sycl::id<1>`)

Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
  • Loading branch information
etiotto committed Sep 6, 2022
1 parent 9656ebf commit 6eac75e
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 79 deletions.
62 changes: 47 additions & 15 deletions mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/IR/Types.h"

namespace mlir {
class LLVMTypeConverter;
namespace sycl {
class SYCLFuncRegistry;

Expand All @@ -30,14 +31,19 @@ class SYCLFuncRegistry;
/// needs to be created in SYCLFuncRegistry constructor.
class SYCLFuncDescriptor {
friend class SYCLFuncRegistry;
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &, const SYCLFuncDescriptor &);
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &,
const SYCLFuncDescriptor &);

public:
/// Enumerates SYCL functions.
// clang-format off
enum class FuncId {
Unknown,

// Member functions for the sycl:accessor class.
AccessorInt1ReadWriteGlobalBufferFalseInit, // sycl::accessor<int, 1, read_write, global_buffer, (placeholder)0>::
// __init(int AS1*, sycl::range<1>, sycl::range<1>, sycl::id<1>)

// Member functions for the sycl:id<n> class.
Id1CtorDefault, // sycl::id<1>::id()
Id2CtorDefault, // sycl::id<2>::id()
Expand Down Expand Up @@ -77,8 +83,9 @@ class SYCLFuncDescriptor {
/// Enumerates the kind of FuncId.
enum class FuncKind {
Unknown,
IdCtor, // any sycl::id<n> constructors.
RangeCtor // any sycl::range<n> constructors.
Accessor, // sycl::accessor class
Id, // sycl::id<n> class
Range, // sycl::range<n> class
};

/// Each descriptor is uniquely identified by the pair {FuncId, FuncKind}.
Expand Down Expand Up @@ -125,7 +132,7 @@ class SYCLFuncDescriptor {
StringRef name; // SYCL function name
Type outputTy; // SYCL function output type
SmallVector<Type, 4> argTys; // SYCL function arguments types
FlatSymbolRefAttr funcRef; // Reference to the SYCL function
FlatSymbolRefAttr funcRef; // Reference to the SYCL function
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
Expand All @@ -141,7 +148,11 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
return os;
}

#define DEFINE_CTOR_CLASS(ClassName, ClassKind) \
//===----------------------------------------------------------------------===//
// Derived classes specializing the generic SYCLFuncDescriptor.
//===----------------------------------------------------------------------===//

#define DEFINE_CLASS(ClassName, ClassKind) \
class ClassName : public SYCLFuncDescriptor { \
public: \
friend class SYCLFuncRegistry; \
Expand All @@ -156,36 +167,57 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
} \
bool isValid(FuncId) const override; \
};
DEFINE_CTOR_CLASS(SYCLIdCtorDescriptor, FuncKind::IdCtor)
DEFINE_CTOR_CLASS(SYCLRangeCtorDescriptor, FuncKind::RangeCtor)
#undef DEFINE_CTOR_CLASS
DEFINE_CLASS(SYCLAccessorFuncDescriptor, FuncKind::Accessor)
DEFINE_CLASS(SYCLIdFuncDescriptor, FuncKind::Id)
DEFINE_CLASS(SYCLRangeFuncDescriptor, FuncKind::Range)
#undef DEFINE_CLASS

/// \class SYCLFuncRegistry
/// Singleton class representing the set of SYCL functions callable from the
/// compiler.
class SYCLFuncRegistry {
using FuncId = SYCLFuncDescriptor::FuncId;
using FuncKind = SYCLFuncDescriptor::FuncKind;
using Registry = std::map<FuncId, SYCLFuncDescriptor>;

public:
~SYCLFuncRegistry() { instance = nullptr; }

/// Populate the registry.
static const SYCLFuncRegistry create(ModuleOp &module, OpBuilder &builder);

const SYCLFuncDescriptor &
getFuncDesc(SYCLFuncDescriptor::FuncId funcId) const {
/// Return the function descriptor corresponding to the given \p funcId.
const SYCLFuncDescriptor &getFuncDesc(FuncId funcId) const {
assert((registry.find(funcId) != registry.end()) &&
"function identified by 'funcId' not found in the SYCL function "
"registry");
return registry.at(funcId);
}

// Returns the SYCLFuncDescriptor::Id::FuncId corresponding to the function
// descriptor that matches the given \p funcKind and signature.
SYCLFuncDescriptor::FuncId getFuncId(SYCLFuncDescriptor::FuncKind funcKind,
Type retType, TypeRange argTypes) const;
/// Returns the SYCLFuncDescriptor::Id::FuncId corresponding to the function
/// descriptor that matches the given \p funcKind and signature.
FuncId getFuncId(FuncKind funcKind, Type retType, TypeRange argTypes) const;

private:
SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder);

using Registry = std::map<SYCLFuncDescriptor::FuncId, SYCLFuncDescriptor>;
/// Declare sycl::accessor<n> function descriptors and add them to the
/// registry.
void declareAccessorFuncDescriptors(LLVMTypeConverter &converter,
ModuleOp &module, OpBuilder &builder);

/// Declare sycl::id<n> function descriptors and add them to the registry.
void declareIdFuncDescriptors(LLVMTypeConverter &converter, ModuleOp &module,
OpBuilder &builder);

/// Declare sycl::range<n> function descriptors and add them to the registry.
void declareRangeFuncDescriptors(LLVMTypeConverter &converter,
ModuleOp &module, OpBuilder &builder);

/// Declare function descriptors and add them to the registry.
void declareFuncDescriptors(std::vector<SYCLFuncDescriptor> &descriptors,
ModuleOp &module, OpBuilder &builder);

static SYCLFuncRegistry *instance;
Registry registry;
};
Expand Down
Loading

0 comments on commit 6eac75e

Please sign in to comment.