Skip to content

Commit d3c1895

Browse files
committed
pytorch converter
1 parent 5310e8d commit d3c1895

File tree

16 files changed

+2052
-0
lines changed

16 files changed

+2052
-0
lines changed

nnvm/include/nnvm/top/nn.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,39 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
394394
}
395395
};
396396

397+
398+
struct AdaptiveMaxPool2DParam : public dmlc::Parameter<AdaptiveMaxPool2DParam> {
399+
TShape output_size;
400+
std::string layout;
401+
402+
DMLC_DECLARE_PARAMETER(AdaptiveMaxPool2DParam) {
403+
DMLC_DECLARE_FIELD(output_size)
404+
.describe("Array of output height and width");
405+
DMLC_DECLARE_FIELD(layout).set_default("NCHW")
406+
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
407+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
408+
"dimensions respectively. Convolution is applied on the 'H' and"
409+
"'W' dimensions.");
410+
}
411+
};
412+
413+
414+
struct AdaptiveAvgPool2DParam : public dmlc::Parameter<AdaptiveAvgPool2DParam> {
415+
TShape output_size;
416+
std::string layout;
417+
418+
DMLC_DECLARE_PARAMETER(AdaptiveAvgPool2DParam) {
419+
DMLC_DECLARE_FIELD(output_size)
420+
.describe("Array of output height and width");
421+
DMLC_DECLARE_FIELD(layout).set_default("NCHW")
422+
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
423+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
424+
"dimensions respectively. Convolution is applied on the 'H' and"
425+
"'W' dimensions.");
426+
}
427+
};
428+
429+
397430
struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
398431
int scale;
399432
std::string layout;

nnvm/python/nnvm/frontend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
from .darknet import from_darknet
88
from .tensorflow import from_tensorflow
99
from .caffe2 import from_caffe2
10+
from .pytorch import from_pytorch
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
r'''PyTorch->NNVM converter'''
2+
from .converter import from_pytorch

0 commit comments

Comments
 (0)