@@ -558,82 +558,65 @@ static at::Tensor compute_output(
558
558
auto dim = input.ndimension ();
559
559
auto dilated = params.is_dilated ();
560
560
561
- if (dilated) {
562
- if (params.transposed ) {
563
- /* dilated && transposed */
564
- if (dim == 4 ) {
565
- at::SpatialFullDilatedConvolution_updateOutput (
566
- input, output, weight, bias, columns, ones,
567
- kernel_size[1 ], kernel_size[0 ],
568
- params.stride [1 ], params.stride [0 ],
569
- params.padding [1 ], params.padding [0 ],
570
- params.dilation [1 ], params.dilation [0 ],
571
- params.output_padding [1 ], params.output_padding [0 ]); goto done;
572
- } else if (dim == 5 ) {
573
- at::VolumetricFullDilatedConvolution_updateOutput (
574
- input, output, weight, bias, columns, ones,
575
- params.stride [0 ], params.stride [2 ], params.stride [1 ],
576
- params.padding [0 ], params.padding [2 ], params.padding [1 ],
577
- params.dilation [0 ], params.dilation [2 ], params.dilation [1 ],
578
- params.output_padding [0 ], params.output_padding [2 ], params.output_padding [1 ]); goto done;
561
+
562
+ if (params.transposed ) {
563
+ if (dim == 4 ) {
564
+ at::SpatialFullDilatedConvolution_updateOutput (
565
+ input, output, weight, bias, columns, ones,
566
+ kernel_size[1 ], kernel_size[0 ],
567
+ params.stride [1 ], params.stride [0 ],
568
+ params.padding [1 ], params.padding [0 ],
569
+ dilated ? params.dilation [1 ] : 1 ,
570
+ dilated ? params.dilation [0 ] : 1 ,
571
+ params.output_padding [1 ], params.output_padding [0 ]); goto done;
572
+ } else if (dim == 5 ) {
573
+ at::VolumetricFullDilatedConvolution_updateOutput (
574
+ input, output, weight, bias, columns, ones,
575
+ params.stride [0 ], params.stride [2 ], params.stride [1 ],
576
+ params.padding [0 ], params.padding [2 ], params.padding [1 ],
577
+ dilated ? params.dilation [0 ] : 1 ,
578
+ dilated ? params.dilation [2 ] : 1 ,
579
+ dilated ? params.dilation [1 ] : 1 ,
580
+ params.output_padding [0 ], params.output_padding [2 ], params.output_padding [1 ]); goto done;
579
581
}
580
- } else /* ! transposed */ {
581
- /* dilated && !transposed */
582
- if (dim == 4 ) {
582
+ } else { /* Not transposed */
583
+ if (dim == 4 ) {
584
+ if (dilated ) {
583
585
at::SpatialDilatedConvolution_updateOutput (
584
- input, output, weight, bias, columns, ones,
585
- kernel_size[1 ], kernel_size[0 ],
586
- params.stride [1 ], params.stride [0 ],
587
- params.padding [1 ], params.padding [0 ],
588
- params.dilation [1 ], params.dilation [0 ]); goto done;
589
- } else if (dim == 5 ) {
590
- at::VolumetricDilatedConvolution_updateOutput (
591
- input, output, weight, bias, columns, ones,
592
- kernel_size[0 ], kernel_size[2 ], kernel_size[1 ],
593
- params.stride [0 ], params.stride [2 ], params.stride [1 ],
594
- params.padding [0 ], params.padding [2 ], params.padding [1 ],
595
- params.dilation [0 ], params.dilation [2 ], params.dilation [1 ]); goto done;
596
- }
597
- }
598
- } else /* !dilated */ {
599
- if (params.transposed ) {
600
- /* !dilated && transposed */
601
- if (dim == 4 ) {
602
- at::SpatialFullConvolution_updateOutput (
603
- input, output, weight, bias, columns, ones,
604
- kernel_size[1 ], kernel_size[0 ],
605
- params.stride [1 ], params.stride [0 ],
606
- params.padding [1 ], params.padding [0 ],
607
- params.output_padding [1 ], params.output_padding [0 ]); goto done;
608
- } else if (dim == 5 ) {
609
- at::VolumetricFullConvolution_updateOutput (
610
- input, output, weight, bias, columns, ones,
611
- params.stride [0 ], params.stride [2 ], params.stride [1 ],
612
- params.padding [0 ], params.padding [2 ], params.padding [1 ],
613
- params.output_padding [0 ], params.output_padding [2 ], params.output_padding [1 ]); goto done;
614
- }
615
- } else /* !transposed */ {
616
- /* !dilated && !transposed */
617
- if (dim == 4 ) {
586
+ input, output, weight, bias, columns, ones,
587
+ kernel_size[1 ], kernel_size[0 ],
588
+ params.stride [1 ], params.stride [0 ],
589
+ params.padding [1 ], params.padding [0 ],
590
+ params.dilation [1 ], params.dilation [0 ]); goto done;
591
+ } else {
592
+ /* CPU implementation has specialized MM kernels
593
+ for non-dilated case here */
618
594
at::SpatialConvolutionMM_updateOutput (
619
595
input, output, weight, bias, columns, ones,
620
596
kernel_size[1 ], kernel_size[0 ],
621
597
params.stride [1 ], params.stride [0 ],
622
598
params.padding [1 ], params.padding [0 ]); goto done;
623
- } else if (dim == 5 && input.type ().isCuda ()) {
624
- at::VolumetricConvolution_updateOutput (
625
- input, output, weight, bias, columns, ones,
626
- params.stride [0 ], params.stride [2 ], params.stride [1 ],
627
- params.padding [0 ], params.padding [2 ], params.padding [1 ]); goto done;
628
- } else if (dim == 5 ) {
629
- at::VolumetricConvolutionMM_updateOutput (
630
- input, output, weight, bias, columns,
631
- kernel_size[0 ], kernel_size[2 ], kernel_size[1 ],
632
- params.stride [0 ], params.stride [2 ], params.stride [1 ],
633
- params.padding [0 ], params.padding [2 ], params.padding [1 ]); goto done;
634
599
}
600
+ } else if (dim == 5 && (input.type ().isCuda () || dilated)) {
601
+ at::VolumetricDilatedConvolution_updateOutput (
602
+ input, output, weight, bias, columns, ones,
603
+ kernel_size[0 ], kernel_size[2 ], kernel_size[1 ],
604
+ params.stride [0 ], params.stride [2 ], params.stride [1 ],
605
+ params.padding [0 ], params.padding [2 ], params.padding [1 ],
606
+ dilated ? params.dilation [0 ] : 1 ,
607
+ dilated ? params.dilation [2 ] : 1 ,
608
+ dilated ? params.dilation [1 ] : 1 ); goto done;
609
+ } else if (dim == 5 ) { /* dim == 5, CPU, non-dilated */
610
+ /* CPU implementation has specialized MM kernels
611
+ for non-dilated case here */
612
+ at::VolumetricConvolutionMM_updateOutput (
613
+ input, output, weight, bias, columns,
614
+ kernel_size[0 ], kernel_size[2 ], kernel_size[1 ],
615
+ params.stride [0 ], params.stride [2 ], params.stride [1 ],
616
+ params.padding [0 ], params.padding [2 ], params.padding [1 ]); goto done;
635
617
}
636
618
}
619
+
637
620
throw std::runtime_error (" unsupported ConvNd parameters" );
638
621
639
622
done:
@@ -649,82 +632,64 @@ static at::Tensor compute_grad_input(
649
632
auto dim = input.ndimension ();
650
633
auto dilated = params.is_dilated ();
651
634
652
- if (dilated) {
653
- if (params.transposed ) {
654
- /* dilated && transposed */
655
- if (dim == 4 ) {
656
- at::SpatialFullDilatedConvolution_updateGradInput (
657
- input, grad_output, grad_input, weight, columns,
658
- kernel_size[1 ], kernel_size[0 ],
659
- params.stride [1 ], params.stride [0 ],
660
- params.padding [1 ], params.padding [0 ],
661
- params.dilation [1 ], params.dilation [0 ],
662
- params.output_padding [1 ], params.output_padding [0 ]); goto done;
663
- } else if (dim == 5 ) {
664
- at::VolumetricFullDilatedConvolution_updateGradInput (
665
- input, grad_output, grad_input, weight, columns, ones,
666
- params.stride [0 ], params.stride [2 ], params.stride [1 ],
667
- params.padding [0 ], params.padding [2 ], params.padding [1 ],
668
- params.dilation [0 ], params.dilation [2 ], params.dilation [1 ],
669
- params.output_padding [0 ], params.output_padding [2 ], params.output_padding [1 ]); goto done;
670
- }
671
- } else /* !transposed */ {
672
- /* dilated && !transposed */
673
- if (dim == 4 ) {
674
- at::SpatialDilatedConvolution_updateGradInput (
635
+ if (params.transposed ) {
636
+ if (dim == 4 ) {
637
+ at::SpatialFullDilatedConvolution_updateGradInput (
675
638
input, grad_output, grad_input, weight, columns,
676
639
kernel_size[1 ], kernel_size[0 ],
677
640
params.stride [1 ], params.stride [0 ],
678
641
params.padding [1 ], params.padding [0 ],
679
- params.dilation [1 ], params.dilation [0 ]); goto done;
680
- } else if (dim == 5 ) {
681
- at::VolumetricDilatedConvolution_updateGradInput (
682
- input, grad_output, grad_input, weight, columns,
683
- kernel_size[0 ], kernel_size[2 ], kernel_size[1 ],
642
+ dilated ? params.dilation [1 ] : 1 ,
643
+ dilated ? params.dilation [0 ] : 1 ,
644
+ params.output_padding [1 ], params.output_padding [0 ]); goto done;
645
+ } else if (dim == 5 ) {
646
+ at::VolumetricFullDilatedConvolution_updateGradInput (
647
+ input, grad_output, grad_input, weight, columns, ones,
684
648
params.stride [0 ], params.stride [2 ], params.stride [1 ],
685
649
params.padding [0 ], params.padding [2 ], params.padding [1 ],
686
- params.dilation [0 ], params.dilation [2 ], params.dilation [1 ]); goto done;
687
- }
650
+ dilated ? params.dilation [0 ] : 1 ,
651
+ dilated ? params.dilation [2 ] : 1 ,
652
+ dilated ? params.dilation [1 ] : 1 ,
653
+ params.output_padding [0 ], params.output_padding [2 ], params.output_padding [1 ]); goto done;
688
654
}
689
- } else /* !dilated */ {
690
- if (params.transposed ) {
691
- /* !dilated && transposed */
692
- if (dim == 4 ) {
693
- at::SpatialFullConvolution_updateGradInput (
655
+ } else { /* Not transposed */
656
+ if (dim == 4 ) {
657
+ if (dilated) {
658
+ at::SpatialDilatedConvolution_updateGradInput (
694
659
input, grad_output, grad_input, weight, columns,
695
660
kernel_size[1 ], kernel_size[0 ],
696
661
params.stride [1 ], params.stride [0 ],
697
662
params.padding [1 ], params.padding [0 ],
698
- params.output_padding [1 ], params.output_padding [0 ]); goto done;
699
- } else if (dim == 5 ) {
700
- at::VolumetricFullConvolution_updateGradInput (
701
- input, grad_output, grad_input, weight, columns, ones,
702
- params.stride [0 ], params.stride [2 ], params.stride [1 ],
703
- params.padding [0 ], params.padding [2 ], params.padding [1 ],
704
- params.output_padding [0 ], params.output_padding [2 ], params.output_padding [1 ]); goto done;
705
- }
706
- } else /* !transposed */ {
707
- /* !dilated && !transposed */
708
- if (dim == 4 ) {
663
+ params.dilation [1 ], params.dilation [0 ]); goto done;
664
+ } else {
665
+ /* CPU implementation has specialized MM kernels
666
+ for non-dilated case here */
709
667
at::SpatialConvolutionMM_updateGradInput (
710
668
input, grad_output, grad_input, weight, columns, ones,
711
669
kernel_size[1 ], kernel_size[0 ],
712
670
params.stride [1 ], params.stride [0 ],
713
671
params.padding [1 ], params.padding [0 ]); goto done;
714
- } else if (dim == 5 && input.type ().isCuda ()) {
715
- at::VolumetricConvolution_updateGradInput (
672
+ }
673
+ } else if (dim == 5 && (input.type ().isCuda () || dilated)) {
674
+ at::VolumetricDilatedConvolution_updateGradInput (
716
675
input, grad_output, grad_input, weight, columns,
676
+ kernel_size[0 ], kernel_size[2 ], kernel_size[1 ],
717
677
params.stride [0 ], params.stride [2 ], params.stride [1 ],
718
- params.padding [0 ], params.padding [2 ], params.padding [1 ]); goto done;
719
- } else if (dim == 5 ) {
678
+ params.padding [0 ], params.padding [2 ], params.padding [1 ],
679
+ dilated ? params.dilation [0 ] : 1 ,
680
+ dilated ? params.dilation [2 ] : 1 ,
681
+ dilated ? params.dilation [1 ] : 1 ); goto done;
682
+ } else if (dim == 5 ) { /* dim == 5, CPU, non-dilated */
683
+ /* CPU implementation has specialized MM kernels
684
+ for non-dilated case here */
720
685
at::VolumetricConvolutionMM_updateGradInput (
721
686
input, grad_output, grad_input, weight, columns, ones,
722
687
kernel_size[0 ], kernel_size[2 ], kernel_size[1 ],
723
688
params.stride [0 ], params.stride [2 ], params.stride [1 ],
724
689
params.padding [0 ], params.padding [2 ], params.padding [1 ]); goto done;
725
- }
726
690
}
727
691
}
692
+
728
693
throw std::runtime_error (" unsupported ConvNdBackward parameters" );
729
694
730
695
done:
@@ -748,82 +713,64 @@ static tensor_pair compute_grad_params(
748
713
auto dim = input.ndimension ();
749
714
auto dilated = params.is_dilated ();
750
715
751
- if (dilated) {
752
- if (params.transposed ) {
753
- /* dilated && transposed */
754
- if (dim == 4 ) {
755
- at::SpatialFullDilatedConvolution_accGradParameters (
716
+ if (params.transposed ) {
717
+ if (dim == 4 ) {
718
+ at::SpatialFullDilatedConvolution_accGradParameters (
756
719
input, grad_output, grad_weight, grad_bias, columns, ones,
757
720
kernel_size[1 ], kernel_size[0 ],
758
721
params.stride [1 ], params.stride [0 ],
759
722
params.padding [1 ], params.padding [0 ],
760
- params.dilation [1 ], params.dilation [0 ],
723
+ dilated ? params.dilation [1 ] : 1 ,
724
+ dilated ? params.dilation [0 ] : 1 ,
761
725
params.output_padding [1 ], params.output_padding [0 ], 1.0 ); goto done;
762
- } else if (dim == 5 ) {
726
+ } else if (dim == 5 ) {
763
727
at::VolumetricFullDilatedConvolution_accGradParameters (
764
728
input, grad_output, grad_weight, grad_bias, columns, ones,
765
729
params.stride [0 ], params.stride [2 ], params.stride [1 ],
766
730
params.padding [0 ], params.padding [2 ], params.padding [1 ],
767
- params.dilation [0 ], params.dilation [2 ], params.dilation [1 ],
731
+ dilated ? params.dilation [0 ] : 1 ,
732
+ dilated ? params.dilation [2 ] : 1 ,
733
+ dilated ? params.dilation [1 ] : 1 ,
768
734
params.output_padding [0 ], params.output_padding [2 ], params.output_padding [1 ], 1.0 ); goto done;
769
- }
770
- } else /* ! transposed */ {
771
- /* dilated && !transposed */
772
- if (dim == 4 ) {
735
+ }
736
+ } else { /* Not transposed */
737
+ if (dim == 4 ) {
738
+ if (dilated ) {
773
739
at::SpatialDilatedConvolution_accGradParameters (
774
740
input, grad_output, grad_weight, grad_bias, columns, ones,
775
741
kernel_size[1 ], kernel_size[0 ],
776
742
params.stride [1 ], params.stride [0 ],
777
743
params.padding [1 ], params.padding [0 ],
778
744
params.dilation [1 ], params.dilation [0 ], 1.0 ); goto done;
779
- } else if (dim == 5 ) {
780
- at::VolumetricDilatedConvolution_accGradParameters (
781
- input, grad_output, grad_weight, grad_bias, columns, ones,
782
- kernel_size[0 ], kernel_size[2 ], kernel_size[1 ],
783
- params.stride [0 ], params.stride [2 ], params.stride [1 ],
784
- params.padding [0 ], params.padding [2 ], params.padding [1 ],
785
- params.dilation [0 ], params.dilation [2 ], params.dilation [1 ], 1.0 ); goto done;
786
- }
787
- }
788
- } else /* !dilated */ {
789
- if (params.transposed ) {
790
- /* !dilated && transposed */
791
- if (dim == 4 ) {
792
- at::SpatialFullConvolution_accGradParameters (
793
- input, grad_output, grad_weight, grad_bias, columns, ones,
794
- kernel_size[1 ], kernel_size[0 ],
795
- params.stride [1 ], params.stride [0 ],
796
- params.padding [1 ], params.padding [0 ],
797
- params.output_padding [1 ], params.output_padding [0 ], 1.0 ); goto done;
798
- } else if (dim == 5 ) {
799
- at::VolumetricFullConvolution_accGradParameters (
800
- input, grad_output, grad_weight, grad_bias, columns, ones,
801
- params.stride [0 ], params.stride [2 ], params.stride [1 ],
802
- params.padding [0 ], params.padding [2 ], params.padding [1 ],
803
- params.output_padding [0 ], params.output_padding [2 ], params.output_padding [1 ], 1.0 ); goto done;
804
- }
805
- } else /* !transposed */ {
806
- /* !dilated && !transposed */
807
- if (dim == 4 ) {
745
+ } else {
746
+ /* CPU implementation has specialized MM kernels
747
+ for non-dilated case here */
808
748
at::SpatialConvolutionMM_accGradParameters (
809
749
input, grad_output, grad_weight, grad_bias, columns, ones,
810
750
kernel_size[1 ], kernel_size[0 ],
811
751
params.stride [1 ], params.stride [0 ],
812
752
params.padding [1 ], params.padding [0 ], 1.0 ); goto done;
813
- } else if (dim == 5 && input.type ().isCuda ()) {
814
- at::VolumetricConvolution_accGradParameters (
753
+ }
754
+ } else if (dim == 5 && (input.type ().isCuda () || dilated)) {
755
+ at::VolumetricDilatedConvolution_accGradParameters (
815
756
input, grad_output, grad_weight, grad_bias, columns, ones,
757
+ kernel_size[0 ], kernel_size[2 ], kernel_size[1 ],
816
758
params.stride [0 ], params.stride [2 ], params.stride [1 ],
817
- params.padding [0 ], params.padding [2 ], params.padding [1 ], 1.0 ); goto done;
818
- } else if (dim == 5 ) {
759
+ params.padding [0 ], params.padding [2 ], params.padding [1 ],
760
+ dilated ? params.dilation [0 ] : 1 ,
761
+ dilated ? params.dilation [2 ] : 1 ,
762
+ dilated ? params.dilation [1 ] : 1 , 1.0 ); goto done;
763
+ } else if (dim == 5 ) { /* dim == 5, CPU, non-dilated */
764
+ /* CPU implementation has specialized MM kernels
765
+ for non-dilated case here */
819
766
at::VolumetricConvolutionMM_accGradParameters (
820
767
input, grad_output, grad_weight, grad_bias, columns,
821
768
kernel_size[0 ], kernel_size[2 ], kernel_size[1 ],
822
769
params.stride [0 ], params.stride [2 ], params.stride [1 ],
823
770
params.padding [0 ], params.padding [2 ], params.padding [1 ], 1.0 ); goto done;
824
- }
825
771
}
826
772
}
773
+
827
774
throw std::runtime_error (" unsupported ConvNdBackward parameters" );
828
775
829
776
done:
0 commit comments