Skip to content

Commit

Permalink
Diagnose Concat, Compress, ArgMax, ArgMin axis attribute out-of-bouds (
Browse files Browse the repository at this point in the history
…llvm#1291)

* Diagnose Concat, Compress, ArgMax, ArgMin axis attribute when it exceeds limit

Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>

* Diagnose Concat, Compress, ArgMax, ArgMin axis attribute when it exceeds limit

Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>

* Address code review comments

Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>

* Create onnx_mlir::Diagnostic class, use it in verify() and shapeInference()

Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>

* Create onnx_mlir::Diagnostic class, use it in verify() and shapeInference()

Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>

* Create onnx_mlir::Diagnostic class, use it in verify() and shapeInference()

Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>

* Diagnose ArgMax & Compress axis value dusing shape inference.

Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>

* Fix Compress code generation.

Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>

* Fix Windows build

Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>

Co-authored-by: Tung D. Le <tungld@gmail.com>
  • Loading branch information
Ettore Tiotto and tungld authored Apr 5, 2022
1 parent baf6e83 commit 8af8f42
Show file tree
Hide file tree
Showing 16 changed files with 424 additions and 153 deletions.
45 changes: 45 additions & 0 deletions docs/Dialects/krnl.md
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,51 @@ Traits: MemRefsNormalizable
| `scale` | floating-point
| `seed` | floating-point

### `krnl.seqextract` (::mlir::KrnlSeqExtractOp)

Krnl load from a seq

sequence is represented with memref<memref<>>.
This op loads a tensor for the sequence 'seq' at position 'index',
and return the tensor, which will be freed by Bufferization::Deallocation.
The element in the sequence will become null.

Interfaces: AllocationOpInterface, MemoryEffectOpInterface

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `seq` | memref of any type values
| `index` | index

#### Results:

| Result | Description |
| :----: | ----------- |
| `output` | any type

### `krnl.seqstore` (::mlir::KrnlSeqStoreOp)

Krnl store into a seq

sequence is represented with memref<memref<>>.
This op will copy the tensor to be stored, and cast the type if needed.
The motivation to introduce this Op is to help bufferization::deallocation
The experiment showed that memref will be freed after the memref is stored.
However, the store of memref only write out the pointer for the memref,
and its memory cannot be freed.

Traits: MemRefsNormalizable

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `input` | any type
| `seq` | memref of any type values
| `index` | index

### `krnl.shape` (::mlir::KrnlShapeOp)

Krnl operation to retreieve the shape of a MemRef.
Expand Down
118 changes: 61 additions & 57 deletions src/Compiler/CompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,17 @@

#define DEBUG_TYPE "compiler_utils"

using namespace std;
using namespace mlir;
using namespace onnx_mlir;

const string OnnxMlirEnvOptionName = "ONNX_MLIR_FLAGS";
const std::string OnnxMlirEnvOptionName = "ONNX_MLIR_FLAGS";
#if defined(ONNX_MLIR_REPOSITORY) && defined(ONNX_MLIR_REVISION) && \
defined(LLVM_REPOSITORY) && defined(LLVM_REVISION)
static const string OnnxMlirVersion =
static const std::string OnnxMlirVersion =
"onnx-mlir version 1.0.0 (" ONNX_MLIR_REPOSITORY " " ONNX_MLIR_REVISION
" " LLVM_REPOSITORY " " LLVM_REVISION ")";
#else
const string OnnxMlirVersion = "onnx-mlir version 1.0.0";
const std::string OnnxMlirVersion = "onnx-mlir version 1.0.0";
#endif

namespace {
Expand Down Expand Up @@ -119,9 +118,9 @@ static std::string getRuntimeDir() {
if (envDir && llvm::sys::fs::exists(envDir.getValue()))
return envDir.getValue();

string execDir = llvm::sys::path::parent_path(getExecPath()).str();
std::string execDir = llvm::sys::path::parent_path(getExecPath()).str();
if (llvm::sys::path::stem(execDir).str().compare("bin") == 0) {
string p = execDir.substr(0, execDir.size() - 3);
std::string p = execDir.substr(0, execDir.size() - 3);
if (llvm::sys::fs::exists(p + "lib"))
return p + "lib";
}
Expand All @@ -148,15 +147,15 @@ static std::string getRuntimeDir() {
// installed system wide but to different places and their sources have been
// removed. So we force CMAKE_INSTALL_PREFIX to be the same as that of
// llvm-project.
static std::string getToolPath(string tool) {
string execDir = llvm::sys::path::parent_path(getExecPath()).str();
static std::string getToolPath(std::string tool) {
std::string execDir = llvm::sys::path::parent_path(getExecPath()).str();
llvm::SmallString<8> toolPath(execDir);
llvm::sys::path::append(toolPath, tool);
string p = llvm::StringRef(toolPath).str();
std::string p = llvm::StringRef(toolPath).str();
if (llvm::sys::fs::can_execute(p))
return p;
else
return string();
return std::string();
}

// Helper struct to make command construction and execution easy & readable.
Expand Down Expand Up @@ -243,12 +242,12 @@ struct Command {
// =============================================================================
// Methods for compiling and file processing.

void loadMLIR(string inputFilename, mlir::MLIRContext &context,
void loadMLIR(std::string inputFilename, mlir::MLIRContext &context,
mlir::OwningOpRef<ModuleOp> &module) {
// Handle '.mlir' input to the ONNX-MLIR frontend.
// The mlir format indicates that one or more of the supported
// representations are used in the file.
string errorMessage;
std::string errorMessage;
auto input = openInputFile(inputFilename, &errorMessage);
if (!input) {
llvm::errs() << errorMessage << "\n";
Expand All @@ -257,8 +256,9 @@ void loadMLIR(string inputFilename, mlir::MLIRContext &context,

// Parse the input mlir.
llvm::SourceMgr sourceMgr;
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
module = mlir::parseSourceFile(sourceMgr, &context);
module = mlir::parseSourceFile<ModuleOp>(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
exit(1);
Expand All @@ -267,11 +267,11 @@ void loadMLIR(string inputFilename, mlir::MLIRContext &context,

// Write LLVM optimized bitcode.
static void genLLVMBitcode(const mlir::OwningOpRef<ModuleOp> &module,
string optimizedBitcodePath, string outputBaseName) {
error_code error;
std::string optimizedBitcodePath, std::string outputBaseName) {
std::error_code error;

// Write bitcode to a file.
string unoptimizedBitcodePath = outputBaseName + ".unoptimized.bc";
std::string unoptimizedBitcodePath = outputBaseName + ".unoptimized.bc";
llvm::FileRemover unoptimizedBitcodeRemover(
unoptimizedBitcodePath, !keepFiles(KeepFilesOfType::Bitcode));

Expand Down Expand Up @@ -312,7 +312,7 @@ static void genLLVMBitcode(const mlir::OwningOpRef<ModuleOp> &module,
moduleBitcodeStream.flush();

// Use the LLVM's 'opt' command to optimize the bitcode.
string optPath = getToolPath("opt");
std::string optPath = getToolPath("opt");
Command optBitcode(/*exePath=*/!optPath.empty() ? optPath : kOptPath);
optBitcode.appendStr(getOptimizationLevelOption())
.appendStr(getTargetTripleOption())
Expand All @@ -326,15 +326,16 @@ static void genLLVMBitcode(const mlir::OwningOpRef<ModuleOp> &module,
}

// Compile LLVM bitcode to object file.
static std::string genModelObject(string bitcodePath, string outputBaseName) {
static std::string genModelObject(
std::string bitcodePath, std::string outputBaseName) {

#ifdef _WIN32
string modelObjPath = outputBaseName + ".obj";
std::string modelObjPath = outputBaseName + ".obj";
#else
string modelObjPath = outputBaseName + ".o";
std::string modelObjPath = outputBaseName + ".o";
#endif

string llcPath = getToolPath("llc");
std::string llcPath = getToolPath("llc");
Command llvmToObj(/*exePath=*/!llcPath.empty() ? llcPath : kLlcPath);
llvmToObj.appendStr(getOptimizationLevelOption())
.appendStr(getTargetTripleOption())
Expand All @@ -351,7 +352,7 @@ static std::string genModelObject(string bitcodePath, string outputBaseName) {
}

static void genJniObject(const mlir::OwningOpRef<ModuleOp> &module,
string jniSharedLibPath, string jniObjPath) {
std::string jniSharedLibPath, std::string jniObjPath) {
Command ar(/*exePath=*/kArPath);
ar.appendStr("x")
// old version of ar does not support --output so comment out
Expand All @@ -365,27 +366,27 @@ static void genJniObject(const mlir::OwningOpRef<ModuleOp> &module,
}

// Link everything into a shared object.
static std::string genSharedLib(string outputBaseName, std::vector<string> opts,
std::vector<string> objs, std::vector<string> libs,
std::vector<string> libDirs) {
static std::string genSharedLib(std::string outputBaseName,
std::vector<std::string> opts, std::vector<std::string> objs,
std::vector<std::string> libs, std::vector<std::string> libDirs) {

#ifdef _WIN32
string sharedLibPath = outputBaseName + ".dll";
std::vector<string> outputOpt = {"/Fe:" + sharedLibPath};
std::string sharedLibPath = outputBaseName + ".dll";
std::vector<std::string> outputOpt = {"/Fe:" + sharedLibPath};
// link has to be before def and libpath since they need to be passed through
// to the linker
std::vector<string> sharedLibOpts = {
std::vector<std::string> sharedLibOpts = {
"/LD", "/link", "/NOLOGO", "/def:" + outputBaseName + ".def"};

llvm::for_each(libs, [](string &lib) { lib = lib + ".lib"; });
llvm::for_each(
libDirs, [](string &libDir) { libDir = "/libpath:\"" + libDir + "\""; });
llvm::for_each(libs, [](std::string &lib) { lib = lib + ".lib"; });
llvm::for_each(libDirs,
[](std::string &libDir) { libDir = "/libpath:\"" + libDir + "\""; });
#else
string sharedLibPath = outputBaseName + ".so";
std::vector<string> outputOpt = {"-o", sharedLibPath};
std::vector<string> sharedLibOpts = {"-shared", "-fPIC"};
llvm::for_each(libs, [](string &lib) { lib = "-l" + lib; });
llvm::for_each(libDirs, [](string &libDir) { libDir = "-L" + libDir; });
std::string sharedLibPath = outputBaseName + ".so";
std::vector<std::string> outputOpt = {"-o", sharedLibPath};
std::vector<std::string> sharedLibOpts = {"-shared", "-fPIC"};
llvm::for_each(libs, [](std::string &lib) { lib = "-l" + lib; });
llvm::for_each(libDirs, [](std::string &libDir) { libDir = "-L" + libDir; });
#endif

Command link(kCxxPath);
Expand All @@ -403,10 +404,10 @@ static std::string genSharedLib(string outputBaseName, std::vector<string> opts,
// Create jar containing java runtime and model shared library (which includes
// jni runtime).
static void genJniJar(const mlir::OwningOpRef<ModuleOp> &module,
string modelSharedLibPath, string modelJniJarPath) {
std::string modelSharedLibPath, std::string modelJniJarPath) {
llvm::SmallString<8> runtimeDir(getRuntimeDir());
llvm::sys::path::append(runtimeDir, "javaruntime.jar");
string javaRuntimeJarPath = llvm::StringRef(runtimeDir).str();
std::string javaRuntimeJarPath = llvm::StringRef(runtimeDir).str();

// Copy javaruntime.jar to model jar.
llvm::sys::fs::copy_file(javaRuntimeJarPath, modelJniJarPath);
Expand All @@ -423,7 +424,7 @@ static void genJniJar(const mlir::OwningOpRef<ModuleOp> &module,

std::string compileModuleToObject(
const mlir::OwningOpRef<ModuleOp> &module, std::string outputBaseName) {
string bitcodePath = outputBaseName + ".bc";
std::string bitcodePath = outputBaseName + ".bc";
genLLVMBitcode(module, bitcodePath, outputBaseName);
llvm::FileRemover bitcodeRemover(
bitcodePath, !keepFiles(KeepFilesOfType::Bitcode));
Expand All @@ -433,7 +434,7 @@ std::string compileModuleToObject(

std::string compileModuleToSharedLibrary(
const mlir::OwningOpRef<ModuleOp> &module, std::string outputBaseName) {
string modelObjPath = compileModuleToObject(module, outputBaseName);
std::string modelObjPath = compileModuleToObject(module, outputBaseName);
llvm::FileRemover modelObjRemover(
modelObjPath, !keepFiles(KeepFilesOfType::Object));

Expand All @@ -443,35 +444,35 @@ std::string compileModuleToSharedLibrary(

void compileModuleToJniJar(
const mlir::OwningOpRef<ModuleOp> &module, std::string outputBaseName) {
string modelObjPath = compileModuleToObject(module, outputBaseName);
std::string modelObjPath = compileModuleToObject(module, outputBaseName);
llvm::FileRemover modelObjRemover(
modelObjPath, !keepFiles(KeepFilesOfType::Object));

StringRef outputDir = llvm::sys::path::parent_path(outputBaseName);
if (outputDir.empty())
outputDir = StringRef(".");

string jniSharedLibPath = getRuntimeDir() + "/libjniruntime.a";
std::string jniSharedLibPath = getRuntimeDir() + "/libjniruntime.a";

llvm::SmallString<8> jniObjDir(outputDir);
llvm::sys::path::append(jniObjDir, "jnidummy.c.o");
string jniObjPath = llvm::StringRef(jniObjDir).str();
std::string jniObjPath = llvm::StringRef(jniObjDir).str();

genJniObject(module, jniSharedLibPath, jniObjPath);
llvm::FileRemover jniObjRemover(
jniObjPath, !keepFiles(KeepFilesOfType::Object));

llvm::SmallString<8> jniLibDir(outputDir);
llvm::sys::path::append(jniLibDir, "libmodel");
string jniLibBase = llvm::StringRef(jniLibDir).str();
std::string jniLibBase = llvm::StringRef(jniLibDir).str();

string modelSharedLibPath = genSharedLib(jniLibBase, {"-z", "noexecstack"},
{modelObjPath, jniObjPath}, {"jniruntime", "cruntime"},
{getRuntimeDir()});
std::string modelSharedLibPath = genSharedLib(jniLibBase,
{"-z", "noexecstack"}, {modelObjPath, jniObjPath},
{"jniruntime", "cruntime"}, {getRuntimeDir()});
llvm::FileRemover modelSharedLibRemover(
modelSharedLibPath, !keepFiles(KeepFilesOfType::Object));

string modelJniJarPath = outputBaseName + ".jar";
std::string modelJniJarPath = outputBaseName + ".jar";
genJniJar(module, modelSharedLibPath, modelJniJarPath);
}

Expand All @@ -489,11 +490,12 @@ void registerDialects(mlir::MLIRContext &context) {
context.getOrLoadDialect<mlir::KrnlOpsDialect>();
}

void processInputFile(string inputFilename, mlir::MLIRContext &context,
void processInputFile(std::string inputFilename, mlir::MLIRContext &context,
mlir::OwningOpRef<ModuleOp> &module, std::string *errorMessage) {
// Decide if the input file is an ONNX model or a model specified
// in MLIR. The extension of the file is the decider.
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
std::string extension =
inputFilename.substr(inputFilename.find_last_of(".") + 1);
bool inputIsONNX = (extension == "onnx");
bool inputIsMLIR = (extension == "mlir");

Expand Down Expand Up @@ -524,13 +526,13 @@ void processInputArray(const void *onnxBuffer, int bufferSize,
ImportFrontendModelArray(onnxBuffer, bufferSize, context, module, options);
}

void outputCode(
mlir::OwningOpRef<ModuleOp> &module, string filename, string extension) {
void outputCode(mlir::OwningOpRef<ModuleOp> &module, std::string filename,
std::string extension) {
mlir::OpPrintingFlags flags;
if (preserveLocations)
flags.enableDebugInfo();

string errorMessage;
std::string errorMessage;
auto output = openOutputFile(filename + extension, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
Expand All @@ -541,8 +543,9 @@ void outputCode(
output->keep();
}

void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
mlir::MLIRContext &context, mlir::OwningOpRef<ModuleOp> &module) {
void emitOutputFiles(std::string outputBaseName,
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
mlir::OwningOpRef<ModuleOp> &module) {
// For EmitONNXIR and EmitMLIR the constant value are embedded in the code
// thus making the code hard to read. These values can be elided by emitting
// two versions of the same source code:
Expand All @@ -562,15 +565,16 @@ void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
// necessary when emitting the .bc file.
switch (emissionTarget) {
case EmitObj: {
string modelObjPath = compileModuleToObject(module, outputBaseName);
std::string modelObjPath = compileModuleToObject(module, outputBaseName);
if (keepFiles(KeepFilesOfType::MLIR))
outputCode(module, outputBaseName, ".llvm.mlir");

if (VerboseOutput)
printf("Object file %s.o has been compiled.\n", outputBaseName.c_str());
} break;
case EmitLib: {
string sharedLib = compileModuleToSharedLibrary(module, outputBaseName);
std::string sharedLib =
compileModuleToSharedLibrary(module, outputBaseName);
if (keepFiles(KeepFilesOfType::MLIR))
outputCode(module, outputBaseName, ".llvm.mlir");
if (VerboseOutput)
Expand Down
Loading

0 comments on commit 8af8f42

Please sign in to comment.