Skip to content

Commit

Permalink
[BYOC][Verilator] add support to dynamically load hardware library (a…
Browse files Browse the repository at this point in the history
…pache#7286)

* add files

* remove import

* remove os import

* reorder header

* fix header order cpplint

* lint fix
  • Loading branch information
vegaluisjose authored Jan 22, 2021
1 parent 6787d74 commit af9d1d2
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 21 deletions.
4 changes: 2 additions & 2 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ set(USE_TENSORRT_RUNTIME OFF)
# Whether use VITIS-AI codegen
set(USE_VITIS_AI OFF)

# Build Verilator codegen and runtime, example located in 3rdparty/vta-hw/apps/verilator
set(USE_VERILATOR_HW OFF)
# Build Verilator codegen and runtime
set(USE_VERILATOR OFF)

# Build ANTLR parser for Relay text format
# Possible values:
Expand Down
8 changes: 2 additions & 6 deletions cmake/modules/contrib/Verilator.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@
# specific language governing permissions and limitations
# under the License.

if(USE_VERILATOR_HW STREQUAL "ON")
execute_process(COMMAND make --directory ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/vta-hw/apps/verilator)
if(USE_VERILATOR STREQUAL "ON")
file(GLOB VERILATOR_RELAY_CONTRIB_SRC src/relay/backend/contrib/verilator/codegen.cc)
list(APPEND COMPILER_SRCS ${VERILATOR_RELAY_CONTRIB_SRC})
list(APPEND COMPILER_SRCS ${JSON_RELAY_CONTRIB_SRC})
find_library(EXTERN_LIBRARY_VERILATOR NAMES verilator PATHS ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/vta-hw/apps/verilator)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_VERILATOR})
file(GLOB VERILATOR_CONTRIB_SRC src/runtime/contrib/verilator/verilator_runtime.cc)
list(APPEND COMPILER_SRCS ${VERILATOR_RELAY_CONTRIB_SRC})
list(APPEND RUNTIME_SRCS ${VERILATOR_CONTRIB_SRC})
endif()

30 changes: 29 additions & 1 deletion src/relay/backend/contrib/verilator/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace contrib {

using namespace backend;

/*! \brief Verilator JSON serializer */
class VerilatorJSONSerializer : public backend::contrib::JSONSerializer {
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
Expand Down Expand Up @@ -74,6 +75,24 @@ class VerilatorJSONSerializer : public backend::contrib::JSONSerializer {
}
};

/*! \brief Attributes to store the compiler options for Verilator */
struct VerilatorCompilerConfigNode : public tvm::AttrsNode<VerilatorCompilerConfigNode> {
String lib;

TVM_DECLARE_ATTRS(VerilatorCompilerConfigNode, "ext.attrs.VerilatorCompilerConfigNode") {
TVM_ATTR_FIELD(lib).set_default("libverilator.so");
}
};

class VerilatorCompilerConfig : public Attrs {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VerilatorCompilerConfig, Attrs,
VerilatorCompilerConfigNode);
};

TVM_REGISTER_NODE_TYPE(VerilatorCompilerConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.verilator.options", VerilatorCompilerConfig);

/*!
* \brief The external compiler/codegen tool. It takes a Relay expression/module and
* compile it into a runtime module.
Expand All @@ -87,9 +106,18 @@ runtime::Module VerilatorCompiler(const ObjectRef& ref) {
std::string graph_json = serializer.GetJSON();
auto params = serializer.GetParams();

// Get Verilator compiler options
auto ctx = transform::PassContext::Current();
auto cfg = ctx->GetConfig<VerilatorCompilerConfig>("relay.ext.verilator.options");
if (!cfg.defined()) {
cfg = AttrsWithDefaultValues<VerilatorCompilerConfig>();
}

auto lib_name = cfg.value()->lib;

const auto* pf = runtime::Registry::Get("runtime.VerilatorJSONRuntimeCreate");
CHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
auto mod = (*pf)(func_name, graph_json, params);
auto mod = (*pf)(lib_name, func_name, graph_json, params);
return mod;
}

Expand Down
69 changes: 59 additions & 10 deletions src/runtime/contrib/verilator/verilator_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
* \brief A simple JSON runtime for Verilator.
*/

#include <dlfcn.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>

#include <cstddef>
#include <string>
#include <vector>

#include "../../library_module.h"
#include "../json/json_node.h"
#include "../json/json_runtime.h"
#include "verilator_device.h"
Expand All @@ -38,9 +40,40 @@ namespace tvm {
namespace runtime {
namespace contrib {

typedef VerilatorHandle (*VerilatorAllocFunc)();
typedef void (*VerilatorResetFunc)(VerilatorHandle, int);
typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);

using namespace tvm::runtime;
using namespace tvm::runtime::json;

class VerilatorLibrary : public Library {
public:
~VerilatorLibrary() {
if (lib_handle_) Unload();
}
void Init(const std::string& name) { Load(name); }

void* GetSymbol(const char* name) final { return GetSymbol_(name); }

private:
// Library handle
void* lib_handle_{nullptr};
// load the library
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
ICHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name << " " << dlerror();
}

void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); }

void Unload() {
dlclose(lib_handle_);
lib_handle_ = nullptr;
}
};

class VerilatorJSONRuntime : public JSONRuntimeBase {
public:
VerilatorJSONRuntime(const std::string& symbol_name, const std::string& graph_json,
Expand All @@ -49,8 +82,25 @@ class VerilatorJSONRuntime : public JSONRuntimeBase {

const char* type_key() const { return "verilator_json"; }

void LoadLibrary(const std::string& lib_name) {
lib_ = new VerilatorLibrary();
lib_->Init(lib_name);
}

void Init(const Array<NDArray>& consts) override {
BuildEngine();
// get symbols
auto alloc_func = reinterpret_cast<VerilatorAllocFunc>(lib_->GetSymbol("VerilatorAlloc"));
ICHECK(alloc_func != nullptr);
auto reset_func = reinterpret_cast<VerilatorResetFunc>(lib_->GetSymbol("VerilatorReset"));
ICHECK(reset_func != nullptr);
vadd_func_ = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));
ICHECK(vadd_func_ != nullptr);

// alloc device
device_ = (*alloc_func)();

// reset for 10 cycles
(*reset_func)(device_, 10);

CHECK_EQ(consts.size(), const_idx_.size())
<< "The number of input constants must match the number of required.";
Expand Down Expand Up @@ -80,7 +130,7 @@ class VerilatorJSONRuntime : public JSONRuntimeBase {
if ("add" == op_name) {
auto entry = node.GetInputs()[0];
auto shape = nodes_[entry.id_].GetOpShape()[entry.index_];
verilator_add(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
(*vadd_func_)(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
Expand All @@ -89,19 +139,18 @@ class VerilatorJSONRuntime : public JSONRuntimeBase {
}

private:
void BuildEngine() {
device_ = VerilatorAlloc();
// reset for 10 cycles
VerilatorReset(device_, 10);
}

/* The verilator handle. */
/* The verilator device handle. */
VerilatorHandle device_{nullptr};
/* The verilator library handle. */
VerilatorLibrary* lib_{nullptr};
/* The verilator add function handle */
VerilatorAddFunc vadd_func_{nullptr};
};

runtime::Module VerilatorJSONRuntimeCreate(String symbol_name, String graph_json,
runtime::Module VerilatorJSONRuntimeCreate(String lib_name, String symbol_name, String graph_json,
const Array<String>& const_names) {
auto n = make_object<VerilatorJSONRuntime>(symbol_name, graph_json, const_names);
n->LoadLibrary(lib_name);
return runtime::Module(n);
}

Expand Down
39 changes: 37 additions & 2 deletions tests/python/contrib/test_verilator/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
# under the License.
"""Verilator utility functions"""

import os
import sys
import subprocess as sp

import tvm
from tvm import relay
Expand Down Expand Up @@ -66,10 +68,43 @@ def offload(mod):
return mod


def verilator_app_path():
"""Find verilator hardware app path"""

cur_dir = os.path.dirname(os.path.realpath(__file__))
return os.path.join(
cur_dir,
"..",
"..",
"..",
"..",
"3rdparty",
"vta-hw",
"apps",
"verilator",
)


def compile_hardware():
"""Compile hardware into shared library"""

cmd = []
cmd.append("make")
cmd.append("--directory")
cmd.append(verilator_app_path())
sp.run(cmd, check=True)


def compile_module(mod):
"""Compile Relay module"""
"""Compile Relay module and hardware library"""

lib = os.path.join(verilator_app_path(), "libverilator.so")
if not os.path.isfile(lib):
compile_hardware()

with relay.build_config(opt_level=3):
with tvm.transform.PassContext(
opt_level=3, config={"relay.ext.verilator.options": {"lib": lib}}
):
exe = relay.vm.compile(mod, target="llvm", params=None)
code, lib = exe.save()
return runtime.vm.Executable.load_exec(code, lib)
Expand Down

0 comments on commit af9d1d2

Please sign in to comment.