Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Functional] Adds occa::function and occa::array #442

Merged
merged 20 commits into from
Jan 17, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[Experimental] Adds initial occa::function
  • Loading branch information
dmed256 committed Jan 17, 2021
commit 8e64cf873cf87f03fc1c20287994d8ed9ff4d4cc
81 changes: 81 additions & 0 deletions include/occa/experimental/functional/function.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#ifndef OCCA_EXPERIMENTAL_FUNCTIONAL_FUNCTION_HEADER
#define OCCA_EXPERIMENTAL_FUNCTIONAL_FUNCTION_HEADER

#include <functional>

#include <occa/experimental/functional/functionDefinition.hpp>
#include <occa/experimental/functional/scope.hpp>
#include <occa/experimental/functional/utils.hpp>
#include <occa/utils/hash.hpp>
#include <occa/dtype.hpp>

namespace occa {
template <class Function>
class function;

class baseFunction {
public:
occa::scope scope;
hash_t hash_;

baseFunction(const occa::scope &scope_);

functionDefinition& definition();

virtual int argumentCount() const = 0;

hash_t hash() const;

operator hash_t () const;
};

template <class ReturnType, class ...ArgTypes>
class function<ReturnType(ArgTypes...)> : public baseFunction {
private:
std::function<ReturnType(ArgTypes...)> lambda;

public:
function(const occa::scope &scope_,
std::function<ReturnType(ArgTypes...)> lambda_,
const char *source) :
baseFunction(scope_),
lambda(lambda_) {

hash_ = functionDefinition::cache(
scope,
source,
getReturnType(),
getArgTypes()
).get()->hash;
}

hash_t getTypeHash() const {
hash_t typeHash = occa::hash(getReturnType().name());
for (auto &argType : getArgTypes()) {
typeHash ^= occa::hash(argType.name());
}
return typeHash;
}

int argumentCount() const {
return (int) sizeof...(ArgTypes);
}

dtype_t getReturnType() const {
return dtype::get<ReturnType>();
}

dtypeVector getArgTypes() const {
return dtype::getMany<ArgTypes...>();
}

ReturnType operator () (ArgTypes... args) {
return lambda(args...);
}

template <class TM>
friend class array;
};
}

#endif
52 changes: 52 additions & 0 deletions include/occa/experimental/functional/functionDefinition.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#ifndef OCCA_EXPERIMENTAL_FUNCTIONAL_FUNCTIONDEFINITION_HEADER
#define OCCA_EXPERIMENTAL_FUNCTIONAL_FUNCTIONDEFINITION_HEADER

#include <memory>
#include <string>

#include <occa/dtype.hpp>
#include <occa/utils/hash.hpp>
#include <occa/experimental/functional/scope.hpp>

namespace occa {
class functionDefinition;

typedef std::shared_ptr<functionDefinition> functionDefinitionSharedPtr;

class functionDefinition {
public:
occa::scope scope;
std::string source;
dtype_t returnType;
dtypeVector argTypes;

hash_t hash;
std::string argumentSource;
std::string bodySource;

functionDefinition();

int functionArgumentCount() const;
int totalArgumentCount() const;

std::string getFunctionSource(const std::string &functionName);

static hash_t getHash(const occa::scope &scope,
const std::string &source,
const dtype_t &returnType);

static functionDefinitionSharedPtr cache(
const occa::scope &scope,
const std::string &source,
const dtype_t &returnType,
const dtypeVector &argTypes
);

static void skipLambdaCapture(const char *&c);
static std::string getArgumentSource(const std::string &source,
const occa::scope &scope);
static std::string getBodySource(const std::string &source);
};
}

#endif
36 changes: 36 additions & 0 deletions include/occa/experimental/functional/utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef OCCA_EXPERIMENTAL_FUNCTIONAL_UTILS_HEADER
#define OCCA_EXPERIMENTAL_FUNCTIONAL_UTILS_HEADER

#define OCCA_FUNCTION(scope, lambda) \
::occa::inferFunction(scope, lambda, #lambda)

namespace occa {
template <class Function>
class function;

//---[ Magic ]------------------------
// C++ template magic for casting between types at compile-time
// lambda
// -> std::function<ret(args...)>
// -> occa::function<ret(args...)>
template <typename TM>
struct inferFunctionHelper;

template <typename ReturnType, typename ClassType, typename ...ArgTypes>
struct inferFunctionHelper<ReturnType(ClassType::*)(ArgTypes...) const> {
using occaFunctionType = occa::function<ReturnType(ArgTypes...)>;
};

template <typename LambdaType>
typename inferFunctionHelper<decltype(&LambdaType::operator())>::occaFunctionType
inferFunction(const occa::scope &scope,
LambdaType const &lambda,
const char *source) {
return typename inferFunctionHelper<decltype(&LambdaType::operator())>::occaFunctionType(
scope, lambda, source
);
}
//====================================
}

#endif
69 changes: 43 additions & 26 deletions src/core/kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#include <occa/core/device.hpp>
#include <occa/core/kernel.hpp>
#include <occa/core/memory.hpp>
#include <occa/utils/uva.hpp>
#include <occa/internal/io.hpp>
#include <occa/internal/core/device.hpp>
#include <occa/internal/core/kernel.hpp>
#include <occa/internal/lang/builtins/types.hpp>
#include <occa/internal/lang/parser.hpp>
#include <occa/internal/utils/sys.hpp>
#include <occa/utils/uva.hpp>
#include <occa/internal/experimental/functional/functionStore.hpp>

namespace occa {
//---[ kernel ]-----------------------
Expand Down Expand Up @@ -202,56 +203,72 @@ namespace occa {
//---[ Kernel Properties ]------------
// Properties:
// defines : Object
// functions : Object
// includes : Array
// headers : Array
// include_paths : Array

hash_t kernelHeaderHash(const occa::json &props) {
return (
occa::hash(props["defines"])
^ props["functions"]
^ props["includes"]
^ props["headers"]
);
}

std::string assembleKernelHeader(const occa::json &props) {
std::string header;
std::string kernelHeader;

// Add defines
const jsonObject &defines = props["defines"].object();
jsonObject::const_iterator it = defines.begin();
while (it != defines.end()) {
header += "#define ";
header += ' ';
header += it->first;
header += ' ';
header += (std::string) it->second;
header += '\n';
++it;
for (const auto &entry : props["defines"].object()) {
if (entry.second.isString()) {
kernelHeader += "#define ";
kernelHeader += ' ';
kernelHeader += entry.first;
kernelHeader += ' ';
kernelHeader += (std::string) entry.second;
kernelHeader += '\n';
}
}

// Add includes
const jsonArray &includes = props["includes"].array();
const int includeCount = (int) includes.size();
for (int i = 0; i < includeCount; ++i) {
if (includes[i].isString()) {
header += "#include \"";
header += (std::string) includes[i];
header += "\"\n";
for (const auto &include : props["includes"].array()) {
if (include.isString()) {
kernelHeader += "#include \"";
kernelHeader += (std::string) include;
kernelHeader += "\"\n";
}
}

// Add header
const jsonArray &lines = props["headers"].array();
const int lineCount = (int) lines.size();
for (int i = 0; i < lineCount; ++i) {
if (lines[i].isString()) {
header += (std::string) lines[i];
header += "\n";
for (const auto &header : props["headers"].array()) {
if (header.isString()) {
kernelHeader += (std::string) header;
kernelHeader += "\n";
}
}

return header;
// Add functions
for (const auto &entry : props["functions"].object()) {
if (entry.second.isString()) {
const std::string &functionName = entry.first;
const std::string &functionHashStr = entry.second;

functionDefinitionSharedPtr fnDefPtr = functionStore.get(
hash_t::fromString(functionHashStr)
);
if (!fnDefPtr) {
continue;
}

kernelHeader += fnDefPtr.get()->getFunctionSource(functionName);
kernelHeader += '\n';
}
}


return kernelHeader;
}
//====================================
}
20 changes: 20 additions & 0 deletions src/experimental/functional/function.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include <occa/experimental/functional/function.hpp>
#include <occa/internal/experimental/functional/functionStore.hpp>

namespace occa {
baseFunction::baseFunction(const occa::scope &scope_) :
scope(scope_) {}

functionDefinition& baseFunction::definition() {
// Should be initialized at this point
return *functionStore.get(hash_);
}

hash_t baseFunction::hash() const {
return hash_;
}

baseFunction::operator hash_t () const {
return hash_;
}
}
Loading