Skip to content

Commit cab28ca

Browse files
committed
Added volumetric convolution
1 parent 473b461 commit cab28ca

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

torch2c/wrappers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,8 @@ def __init__(self, obj, prevfns):
446446
})
447447

448448
def call_tpl(self):
449-
# NOTE: use thnn_class_name, or replicate torch/nn/_functions/conv.py:125 thnn_class_name?
450-
#print(obj.thnn_class_name(prevfns[0]))
449+
# NOTE: use thnn_class_name, or replicate torch/nn/_functions/conv.py thnn_class_name
450+
# TODO: handle dilated convolution and transposed convolution cases
451451
if self.ndim == 1:
452452
return '''
453453
TH${T}Tensor *$id = TH${T}Tensor_new();
@@ -462,6 +462,10 @@ def call_tpl(self):
462462
'''
463463
elif self.ndim == 3:
464464
return '''
465+
TH${T}Tensor *$id = TH${T}Tensor_new();
466+
TH${T}Tensor *finput_$id = TH${T}Tensor_new();
467+
TH${T}Tensor *fgradInput_$id = TH${T}Tensor_new();
468+
THNN_${T}VolumetricConvolutionMM_updateOutput(NULL,$input,$id,$weight,$bias,finput_$id,$kt,$kw,$kh,$dt,$dw,$dh,$pt,$pw,$ph);
465469
'''
466470

467471
def free_tpl(self):
@@ -477,6 +481,8 @@ def free_tpl(self):
477481
'''
478482
elif self.ndim == 3:
479483
return '''
484+
TH${T}Tensor_free($id);
485+
TH${T}Tensor_free(finput_$id);
480486
'''
481487

482488
register(ConvNd, torch.nn._functions.conv.ConvNd)

0 commit comments

Comments
 (0)