Skip to content

Commit 0cf2c37

Browse files
killeentsoumith
authored andcommitted
refactor nn calls in autograd convolution
1 parent e950c44 commit 0cf2c37

File tree

1 file changed

+110
-163
lines changed

1 file changed

+110
-163
lines changed

torch/csrc/autograd/functions/convolution.cpp

Lines changed: 110 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -558,82 +558,65 @@ static at::Tensor compute_output(
558558
auto dim = input.ndimension();
559559
auto dilated = params.is_dilated();
560560

561-
if (dilated) {
562-
if (params.transposed) {
563-
/* dilated && transposed */
564-
if (dim == 4) {
565-
at::SpatialFullDilatedConvolution_updateOutput(
566-
input, output, weight, bias, columns, ones,
567-
kernel_size[1], kernel_size[0],
568-
params.stride[1], params.stride[0],
569-
params.padding[1], params.padding[0],
570-
params.dilation[1], params.dilation[0],
571-
params.output_padding[1], params.output_padding[0]); goto done;
572-
} else if (dim == 5) {
573-
at::VolumetricFullDilatedConvolution_updateOutput(
574-
input, output, weight, bias, columns, ones,
575-
params.stride[0], params.stride[2], params.stride[1],
576-
params.padding[0], params.padding[2], params.padding[1],
577-
params.dilation[0], params.dilation[2], params.dilation[1],
578-
params.output_padding[0], params.output_padding[2], params.output_padding[1]); goto done;
561+
562+
if (params.transposed) {
563+
if (dim == 4) {
564+
at::SpatialFullDilatedConvolution_updateOutput(
565+
input, output, weight, bias, columns, ones,
566+
kernel_size[1], kernel_size[0],
567+
params.stride[1], params.stride[0],
568+
params.padding[1], params.padding[0],
569+
dilated ? params.dilation[1] : 1,
570+
dilated ? params.dilation[0] : 1,
571+
params.output_padding[1], params.output_padding[0]); goto done;
572+
} else if (dim == 5) {
573+
at::VolumetricFullDilatedConvolution_updateOutput(
574+
input, output, weight, bias, columns, ones,
575+
params.stride[0], params.stride[2], params.stride[1],
576+
params.padding[0], params.padding[2], params.padding[1],
577+
dilated ? params.dilation[0] : 1,
578+
dilated ? params.dilation[2] : 1,
579+
dilated ? params.dilation[1] : 1,
580+
params.output_padding[0], params.output_padding[2], params.output_padding[1]); goto done;
579581
}
580-
} else /* !transposed */ {
581-
/* dilated && !transposed */
582-
if (dim == 4) {
582+
} else { /* Not transposed */
583+
if (dim == 4) {
584+
if (dilated) {
583585
at::SpatialDilatedConvolution_updateOutput(
584-
input, output, weight, bias, columns, ones,
585-
kernel_size[1], kernel_size[0],
586-
params.stride[1], params.stride[0],
587-
params.padding[1], params.padding[0],
588-
params.dilation[1], params.dilation[0]); goto done;
589-
} else if (dim == 5) {
590-
at::VolumetricDilatedConvolution_updateOutput(
591-
input, output, weight, bias, columns, ones,
592-
kernel_size[0], kernel_size[2], kernel_size[1],
593-
params.stride[0], params.stride[2], params.stride[1],
594-
params.padding[0], params.padding[2], params.padding[1],
595-
params.dilation[0], params.dilation[2], params.dilation[1]); goto done;
596-
}
597-
}
598-
} else /* !dilated */ {
599-
if (params.transposed) {
600-
/* !dilated && transposed */
601-
if (dim == 4) {
602-
at::SpatialFullConvolution_updateOutput(
603-
input, output, weight, bias, columns, ones,
604-
kernel_size[1], kernel_size[0],
605-
params.stride[1], params.stride[0],
606-
params.padding[1], params.padding[0],
607-
params.output_padding[1], params.output_padding[0]); goto done;
608-
} else if (dim == 5) {
609-
at::VolumetricFullConvolution_updateOutput(
610-
input, output, weight, bias, columns, ones,
611-
params.stride[0], params.stride[2], params.stride[1],
612-
params.padding[0], params.padding[2], params.padding[1],
613-
params.output_padding[0], params.output_padding[2], params.output_padding[1]); goto done;
614-
}
615-
} else /* !transposed */ {
616-
/* !dilated && !transposed */
617-
if (dim == 4) {
586+
input, output, weight, bias, columns, ones,
587+
kernel_size[1], kernel_size[0],
588+
params.stride[1], params.stride[0],
589+
params.padding[1], params.padding[0],
590+
params.dilation[1], params.dilation[0]); goto done;
591+
} else {
592+
/* CPU implementation has specialized MM kernels
593+
for non-dilated case here */
618594
at::SpatialConvolutionMM_updateOutput(
619595
input, output, weight, bias, columns, ones,
620596
kernel_size[1], kernel_size[0],
621597
params.stride[1], params.stride[0],
622598
params.padding[1], params.padding[0]); goto done;
623-
} else if (dim == 5 && input.type().isCuda()) {
624-
at::VolumetricConvolution_updateOutput(
625-
input, output, weight, bias, columns, ones,
626-
params.stride[0], params.stride[2], params.stride[1],
627-
params.padding[0], params.padding[2], params.padding[1]); goto done;
628-
} else if (dim == 5) {
629-
at::VolumetricConvolutionMM_updateOutput(
630-
input, output, weight, bias, columns,
631-
kernel_size[0], kernel_size[2], kernel_size[1],
632-
params.stride[0], params.stride[2], params.stride[1],
633-
params.padding[0], params.padding[2], params.padding[1]); goto done;
634599
}
600+
} else if (dim == 5 && (input.type().isCuda() || dilated)) {
601+
at::VolumetricDilatedConvolution_updateOutput(
602+
input, output, weight, bias, columns, ones,
603+
kernel_size[0], kernel_size[2], kernel_size[1],
604+
params.stride[0], params.stride[2], params.stride[1],
605+
params.padding[0], params.padding[2], params.padding[1],
606+
dilated ? params.dilation[0] : 1,
607+
dilated ? params.dilation[2] : 1,
608+
dilated ? params.dilation[1] : 1); goto done;
609+
} else if (dim == 5) { /* dim == 5, CPU, non-dilated */
610+
/* CPU implementation has specialized MM kernels
611+
for non-dilated case here */
612+
at::VolumetricConvolutionMM_updateOutput(
613+
input, output, weight, bias, columns,
614+
kernel_size[0], kernel_size[2], kernel_size[1],
615+
params.stride[0], params.stride[2], params.stride[1],
616+
params.padding[0], params.padding[2], params.padding[1]); goto done;
635617
}
636618
}
619+
637620
throw std::runtime_error("unsupported ConvNd parameters");
638621

639622
done:
@@ -649,82 +632,64 @@ static at::Tensor compute_grad_input(
649632
auto dim = input.ndimension();
650633
auto dilated = params.is_dilated();
651634

652-
if (dilated) {
653-
if (params.transposed) {
654-
/* dilated && transposed */
655-
if (dim == 4) {
656-
at::SpatialFullDilatedConvolution_updateGradInput(
657-
input, grad_output, grad_input, weight, columns,
658-
kernel_size[1], kernel_size[0],
659-
params.stride[1], params.stride[0],
660-
params.padding[1], params.padding[0],
661-
params.dilation[1], params.dilation[0],
662-
params.output_padding[1], params.output_padding[0]); goto done;
663-
} else if (dim == 5) {
664-
at::VolumetricFullDilatedConvolution_updateGradInput(
665-
input, grad_output, grad_input, weight, columns, ones,
666-
params.stride[0], params.stride[2], params.stride[1],
667-
params.padding[0], params.padding[2], params.padding[1],
668-
params.dilation[0], params.dilation[2], params.dilation[1],
669-
params.output_padding[0], params.output_padding[2], params.output_padding[1]); goto done;
670-
}
671-
} else /* !transposed */ {
672-
/* dilated && !transposed */
673-
if (dim == 4) {
674-
at::SpatialDilatedConvolution_updateGradInput(
635+
if (params.transposed) {
636+
if (dim == 4) {
637+
at::SpatialFullDilatedConvolution_updateGradInput(
675638
input, grad_output, grad_input, weight, columns,
676639
kernel_size[1], kernel_size[0],
677640
params.stride[1], params.stride[0],
678641
params.padding[1], params.padding[0],
679-
params.dilation[1], params.dilation[0]); goto done;
680-
} else if (dim == 5) {
681-
at::VolumetricDilatedConvolution_updateGradInput(
682-
input, grad_output, grad_input, weight, columns,
683-
kernel_size[0], kernel_size[2], kernel_size[1],
642+
dilated ? params.dilation[1] : 1,
643+
dilated ? params.dilation[0] : 1,
644+
params.output_padding[1], params.output_padding[0]); goto done;
645+
} else if (dim == 5) {
646+
at::VolumetricFullDilatedConvolution_updateGradInput(
647+
input, grad_output, grad_input, weight, columns, ones,
684648
params.stride[0], params.stride[2], params.stride[1],
685649
params.padding[0], params.padding[2], params.padding[1],
686-
params.dilation[0], params.dilation[2], params.dilation[1]); goto done;
687-
}
650+
dilated ? params.dilation[0] : 1,
651+
dilated ? params.dilation[2] : 1,
652+
dilated ? params.dilation[1] : 1,
653+
params.output_padding[0], params.output_padding[2], params.output_padding[1]); goto done;
688654
}
689-
} else /* !dilated */ {
690-
if (params.transposed) {
691-
/* !dilated && transposed */
692-
if (dim == 4) {
693-
at::SpatialFullConvolution_updateGradInput(
655+
} else { /* Not transposed */
656+
if (dim == 4) {
657+
if (dilated) {
658+
at::SpatialDilatedConvolution_updateGradInput(
694659
input, grad_output, grad_input, weight, columns,
695660
kernel_size[1], kernel_size[0],
696661
params.stride[1], params.stride[0],
697662
params.padding[1], params.padding[0],
698-
params.output_padding[1], params.output_padding[0]); goto done;
699-
} else if (dim == 5) {
700-
at::VolumetricFullConvolution_updateGradInput(
701-
input, grad_output, grad_input, weight, columns, ones,
702-
params.stride[0], params.stride[2], params.stride[1],
703-
params.padding[0], params.padding[2], params.padding[1],
704-
params.output_padding[0], params.output_padding[2], params.output_padding[1]); goto done;
705-
}
706-
} else /* !transposed */ {
707-
/* !dilated && !transposed */
708-
if (dim == 4) {
663+
params.dilation[1], params.dilation[0]); goto done;
664+
} else {
665+
/* CPU implementation has specialized MM kernels
666+
for non-dilated case here */
709667
at::SpatialConvolutionMM_updateGradInput(
710668
input, grad_output, grad_input, weight, columns, ones,
711669
kernel_size[1], kernel_size[0],
712670
params.stride[1], params.stride[0],
713671
params.padding[1], params.padding[0]); goto done;
714-
} else if (dim == 5 && input.type().isCuda()) {
715-
at::VolumetricConvolution_updateGradInput(
672+
}
673+
} else if (dim == 5 && (input.type().isCuda() || dilated)) {
674+
at::VolumetricDilatedConvolution_updateGradInput(
716675
input, grad_output, grad_input, weight, columns,
676+
kernel_size[0], kernel_size[2], kernel_size[1],
717677
params.stride[0], params.stride[2], params.stride[1],
718-
params.padding[0], params.padding[2], params.padding[1]); goto done;
719-
} else if (dim == 5) {
678+
params.padding[0], params.padding[2], params.padding[1],
679+
dilated ? params.dilation[0] : 1,
680+
dilated ? params.dilation[2] : 1,
681+
dilated ? params.dilation[1] : 1); goto done;
682+
} else if (dim == 5) { /* dim == 5, CPU, non-dilated */
683+
/* CPU implementation has specialized MM kernels
684+
for non-dilated case here */
720685
at::VolumetricConvolutionMM_updateGradInput(
721686
input, grad_output, grad_input, weight, columns, ones,
722687
kernel_size[0], kernel_size[2], kernel_size[1],
723688
params.stride[0], params.stride[2], params.stride[1],
724689
params.padding[0], params.padding[2], params.padding[1]); goto done;
725-
}
726690
}
727691
}
692+
728693
throw std::runtime_error("unsupported ConvNdBackward parameters");
729694

730695
done:
@@ -748,82 +713,64 @@ static tensor_pair compute_grad_params(
748713
auto dim = input.ndimension();
749714
auto dilated = params.is_dilated();
750715

751-
if (dilated) {
752-
if (params.transposed) {
753-
/* dilated && transposed */
754-
if (dim == 4) {
755-
at::SpatialFullDilatedConvolution_accGradParameters(
716+
if (params.transposed) {
717+
if (dim == 4) {
718+
at::SpatialFullDilatedConvolution_accGradParameters(
756719
input, grad_output, grad_weight, grad_bias, columns, ones,
757720
kernel_size[1], kernel_size[0],
758721
params.stride[1], params.stride[0],
759722
params.padding[1], params.padding[0],
760-
params.dilation[1], params.dilation[0],
723+
dilated ? params.dilation[1] : 1,
724+
dilated ? params.dilation[0] : 1,
761725
params.output_padding[1], params.output_padding[0], 1.0); goto done;
762-
} else if (dim == 5) {
726+
} else if (dim == 5) {
763727
at::VolumetricFullDilatedConvolution_accGradParameters(
764728
input, grad_output, grad_weight, grad_bias, columns, ones,
765729
params.stride[0], params.stride[2], params.stride[1],
766730
params.padding[0], params.padding[2], params.padding[1],
767-
params.dilation[0], params.dilation[2], params.dilation[1],
731+
dilated ? params.dilation[0] : 1,
732+
dilated ? params.dilation[2] : 1,
733+
dilated ? params.dilation[1] : 1,
768734
params.output_padding[0], params.output_padding[2], params.output_padding[1], 1.0); goto done;
769-
}
770-
} else /* !transposed */ {
771-
/* dilated && !transposed */
772-
if (dim == 4) {
735+
}
736+
} else { /* Not transposed */
737+
if (dim == 4) {
738+
if (dilated) {
773739
at::SpatialDilatedConvolution_accGradParameters(
774740
input, grad_output, grad_weight, grad_bias, columns, ones,
775741
kernel_size[1], kernel_size[0],
776742
params.stride[1], params.stride[0],
777743
params.padding[1], params.padding[0],
778744
params.dilation[1], params.dilation[0], 1.0); goto done;
779-
} else if (dim == 5) {
780-
at::VolumetricDilatedConvolution_accGradParameters(
781-
input, grad_output, grad_weight, grad_bias, columns, ones,
782-
kernel_size[0], kernel_size[2], kernel_size[1],
783-
params.stride[0], params.stride[2], params.stride[1],
784-
params.padding[0], params.padding[2], params.padding[1],
785-
params.dilation[0], params.dilation[2], params.dilation[1], 1.0); goto done;
786-
}
787-
}
788-
} else /* !dilated */ {
789-
if (params.transposed) {
790-
/* !dilated && transposed */
791-
if (dim == 4) {
792-
at::SpatialFullConvolution_accGradParameters(
793-
input, grad_output, grad_weight, grad_bias, columns, ones,
794-
kernel_size[1], kernel_size[0],
795-
params.stride[1], params.stride[0],
796-
params.padding[1], params.padding[0],
797-
params.output_padding[1], params.output_padding[0], 1.0); goto done;
798-
} else if (dim == 5) {
799-
at::VolumetricFullConvolution_accGradParameters(
800-
input, grad_output, grad_weight, grad_bias, columns, ones,
801-
params.stride[0], params.stride[2], params.stride[1],
802-
params.padding[0], params.padding[2], params.padding[1],
803-
params.output_padding[0], params.output_padding[2], params.output_padding[1], 1.0); goto done;
804-
}
805-
} else /* !transposed */ {
806-
/* !dilated && !transposed */
807-
if (dim == 4) {
745+
} else {
746+
/* CPU implementation has specialized MM kernels
747+
for non-dilated case here */
808748
at::SpatialConvolutionMM_accGradParameters(
809749
input, grad_output, grad_weight, grad_bias, columns, ones,
810750
kernel_size[1], kernel_size[0],
811751
params.stride[1], params.stride[0],
812752
params.padding[1], params.padding[0], 1.0); goto done;
813-
} else if (dim == 5 && input.type().isCuda()) {
814-
at::VolumetricConvolution_accGradParameters(
753+
}
754+
} else if (dim == 5 && (input.type().isCuda() || dilated)) {
755+
at::VolumetricDilatedConvolution_accGradParameters(
815756
input, grad_output, grad_weight, grad_bias, columns, ones,
757+
kernel_size[0], kernel_size[2], kernel_size[1],
816758
params.stride[0], params.stride[2], params.stride[1],
817-
params.padding[0], params.padding[2], params.padding[1], 1.0); goto done;
818-
} else if (dim == 5) {
759+
params.padding[0], params.padding[2], params.padding[1],
760+
dilated ? params.dilation[0] : 1,
761+
dilated ? params.dilation[2] : 1,
762+
dilated ? params.dilation[1] : 1, 1.0); goto done;
763+
} else if (dim == 5) { /* dim == 5, CPU, non-dilated */
764+
/* CPU implementation has specialized MM kernels
765+
for non-dilated case here */
819766
at::VolumetricConvolutionMM_accGradParameters(
820767
input, grad_output, grad_weight, grad_bias, columns,
821768
kernel_size[0], kernel_size[2], kernel_size[1],
822769
params.stride[0], params.stride[2], params.stride[1],
823770
params.padding[0], params.padding[2], params.padding[1], 1.0); goto done;
824-
}
825771
}
826772
}
773+
827774
throw std::runtime_error("unsupported ConvNdBackward parameters");
828775

829776
done:

0 commit comments

Comments
 (0)