|
19 | 19 |
|
20 | 20 | #include "TMVA/RModelParser_Keras.h" |
21 | 21 |
|
| 22 | +#include "TMVA/Tools.h" |
| 23 | +#include "TMVA/MethodBase.h" |
| 24 | +#include "TMVA/Types.h" |
| 25 | + |
| 26 | +#include "Rtypes.h" |
| 27 | +#include "TString.h" |
| 28 | +#include <vector> |
| 29 | + |
22 | 30 | #include <Python.h> |
23 | 31 |
|
24 | 32 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION |
25 | 33 | #include <numpy/arrayobject.h> |
26 | 34 |
|
27 | 35 |
|
28 | | -namespace TMVA{ |
29 | | -namespace Experimental{ |
30 | | -namespace SOFIE{ |
31 | | -namespace PyKeras{ |
| 36 | +namespace TMVA::Experimental::SOFIE::PyKeras { |
| 37 | + |
| 38 | +namespace { |
| 39 | + |
| 40 | +// Utility functions (taken from PyMethodBase in PyMVA) |
| 41 | + |
| 42 | +void PyRunString(TString code, PyObject *globalNS, PyObject *localNS) |
| 43 | +{ |
| 44 | + PyObject *fPyReturn = PyRun_String(code, Py_single_input, globalNS, localNS); |
| 45 | + if (!fPyReturn) { |
| 46 | + std::cout << "\nPython error message:\n"; |
| 47 | + PyErr_Print(); |
| 48 | + throw std::runtime_error("\nFailed to run python code: " + code); |
| 49 | + } |
| 50 | +} |
32 | 51 |
|
33 | | -// Referencing Python utility functions present in PyMethodBase |
34 | | -static void(& PyRunString)(TString, PyObject*, PyObject*) = PyMethodBase::PyRunString; |
35 | | -static const char*(& PyStringAsString)(PyObject*) = PyMethodBase::PyStringAsString; |
36 | | -static std::vector<size_t>(& GetDataFromTuple)(PyObject*) = PyMethodBase::GetDataFromTuple; |
37 | | -static PyObject*(& GetValueFromDict)(PyObject*, const char*) = PyMethodBase::GetValueFromDict; |
| 52 | +const char *PyStringAsString(PyObject *string) |
| 53 | +{ |
| 54 | + PyObject *encodedString = PyUnicode_AsUTF8String(string); |
| 55 | + const char *cstring = PyBytes_AsString(encodedString); |
| 56 | + return cstring; |
| 57 | +} |
| 58 | + |
| 59 | +std::vector<size_t> GetDataFromTuple(PyObject *tupleObject) |
| 60 | +{ |
| 61 | + std::vector<size_t> tupleVec; |
| 62 | + for (Py_ssize_t tupleIter = 0; tupleIter < PyTuple_Size(tupleObject); ++tupleIter) { |
| 63 | + auto itemObj = PyTuple_GetItem(tupleObject, tupleIter); |
| 64 | + if (itemObj == Py_None) |
| 65 | + tupleVec.push_back(0); // case shape is for example (None,2,3) |
| 66 | + else |
| 67 | + tupleVec.push_back((size_t)PyLong_AsLong(itemObj)); |
| 68 | + } |
| 69 | + return tupleVec; |
| 70 | +} |
| 71 | + |
| 72 | +PyObject *GetValueFromDict(PyObject *dict, const char *key) |
| 73 | +{ |
| 74 | + return PyDict_GetItemWithError(dict, PyUnicode_FromString(key)); |
| 75 | +} |
| 76 | + |
| 77 | +} // namespace |
38 | 78 |
|
39 | 79 | namespace INTERNAL{ |
40 | 80 |
|
@@ -1036,7 +1076,5 @@ RModel Parse(std::string filename, int batch_size){ |
1036 | 1076 |
|
1037 | 1077 | return rmodel; |
1038 | 1078 | } |
1039 | | -}//PyKeras |
1040 | | -}//SOFIE |
1041 | | -}//Experimental |
1042 | | -}//TMVA |
| 1079 | + |
| 1080 | +} // namespace TMVA::Experimental::SOFIE::PyKeras |
0 commit comments