Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions tmva/sofie/test/TestGelu.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "TMVA/SOFIE/RModelParser_ONNX.hxx"
#include <iostream>

int main() {

std::cout << "Running GELU parser test..." << std::endl;

TMVA::SOFIE::RModelParser_ONNX parser;

try {
auto model = parser.Parse("gelu.onnx");
std::cout << "Parsed successfully (unexpected)" << std::endl;
}
catch (...) {
std::cout << "Failed to parse GELU (expected behavior)" << std::endl;
}

return 0;
}
Binary file added tmva/sofie/test/gelu.onnx
Binary file not shown.
54 changes: 53 additions & 1 deletion tmva/sofie_parsers/src/RModelParser_PyTorch.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@


#include "TMVA/RModelParser_PyTorch.h"
#include "TMVA/ROperator_BasicBinary.hxx"
#include "TMVA/ROperator.hxx"

#include <Python.h>

Expand Down Expand Up @@ -74,6 +76,7 @@ std::unique_ptr<ROperator> MakePyTorchRelu(PyObject* fNode); // For instant
std::unique_ptr<ROperator> MakePyTorchSelu(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Selu operator
std::unique_ptr<ROperator> MakePyTorchSigmoid(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Sigmoid operator
std::unique_ptr<ROperator> MakePyTorchTranspose(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Transpose operator
std::unique_ptr<ROperator> MakePyTorchAdd(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Add operator

// For mapping PyTorch ONNX Graph's Node with the preparatory functions for ROperators
using PyTorchMethodMap = std::unordered_map<std::string, std::unique_ptr<ROperator> (*)(PyObject* fNode)>;
Expand All @@ -85,7 +88,8 @@ const PyTorchMethodMap mapPyTorchNode =
{"onnx::Relu", &MakePyTorchRelu},
{"onnx::Selu", &MakePyTorchSelu},
{"onnx::Sigmoid", &MakePyTorchSigmoid},
{"onnx::Transpose", &MakePyTorchTranspose}
{"onnx::Transpose", &MakePyTorchTranspose},
{"onnx::Add", &MakePyTorchAdd}
};


Expand Down Expand Up @@ -332,6 +336,54 @@ std::unique_ptr<ROperator> MakePyTorchConv(PyObject* fNode){
}
return op;
}


////////////////////////////////////////////////////////////////////////////////
/// \brief Prepares a ROperator_BasicBinary object for Add
///
/// \param[in] fNode Python PyTorch ONNX Graph node
/// \return Unique pointer to ROperator object
///
/// For Add operator (onnx::Add) of PyTorch's ONNX Graph, performs element-wise
/// addition of the input tensors and produces the output tensor.
std::unique_ptr<ROperator> MakePyTorchAdd(PyObject* fNode){

PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");

std::string fNodeDType =
PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));

std::string nameA = PyStringAsString(PyList_GetItem(fInputs,0));
std::string nameB = PyStringAsString(PyList_GetItem(fInputs,1));
std::string nameY = PyStringAsString(PyList_GetItem(fOutputs,0));

std::unique_ptr<ROperator> op;

switch(ConvertStringToType(fNodeDType)){

case ETensorType::FLOAT: {

op.reset(
new ROperator_BasicBinary<
float,
EBasicBinaryOperator::Add
>(nameA, nameB, nameY)
);

break;
}

default:

throw std::runtime_error(
"TMVA::SOFIE - Unsupported - Operator Add does not yet support input type "
+ fNodeDType
);
}

return op;
}
}//INTERNAL


Expand Down