-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathnn.py
2679 lines (2109 loc) · 89.6 KB
/
nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
"""torch.ops.aten operators under the `nn` module.
- No inplace operators.
- All functions should not have the script() decorator. This is because
we want to delay the compilation of the function.
"""
# pylint: disable=unused-argument
from __future__ import annotations
import math
from typing import Optional, Sequence, Tuple, TypeVar, Union
from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import (
IntType,
TFloat,
TFloatOrUInt8,
TInt,
TReal,
TTensor,
)
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType
_MATH_PI = math.pi
Rank = common_ops.Rank
_INT64_MAX = 9223372036854775807
_INT64_MIN = -9223372036854775808
# All float types but float32
TFloatUnlessFloat32 = TypeVar("TFloatUnlessFloat32", bound=Union[BFLOAT16, FLOAT16, DOUBLE])
# NOTE: Implementations of adaptive_average_pool are handled by torch decomp
def aten_adaptive_max_pool1d(
self: TensorType, output_size: Sequence[int]
) -> tuple[TensorType, TensorType]:
"""adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)"""
raise NotImplementedError()
def aten_adaptive_max_pool2d(
self: TensorType, output_size: Sequence[int]
) -> tuple[TensorType, TensorType]:
"""adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)"""
raise NotImplementedError()
def aten_adaptive_max_pool2d_backward(
grad_output: TensorType, self: TensorType, indices: TensorType
) -> TensorType:
"""adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor"""
raise NotImplementedError()
def aten_adaptive_max_pool3d(
self: TensorType, output_size: Sequence[int]
) -> tuple[TensorType, TensorType]:
"""adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)"""
raise NotImplementedError()
def aten_adaptive_max_pool3d_backward(
grad_output: TensorType, self: TensorType, indices: TensorType
) -> TensorType:
"""adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor"""
raise NotImplementedError()
def _adjust_attributes_of_avg_pool(
expand_size: int,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Sequence[int],
) -> Tuple[Sequence[int], Sequence[int], Sequence[int]]:
"""Adjust attributes of avg_pool to match ONNX specification."""
if isinstance(kernel_size, int):
kernel_shape = [kernel_size] * expand_size
else:
kernel_shape = kernel_size
if isinstance(padding, int):
pads = [padding] * expand_size * 2
elif len(padding) == 1:
pads = padding * expand_size * 2
elif len(padding) == 2:
pads = padding * expand_size
else:
pads = padding * 2
if isinstance(stride, int):
strides = [stride] * expand_size
elif not stride:
strides = kernel_shape
else:
strides = stride
return (kernel_shape, strides, pads)
@torch_op("aten::avg_pool1d", trace_only=True)
def aten_avg_pool1d(
self: TFloat,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0,),
ceil_mode: bool = False,
count_include_pad: bool = True,
) -> TFloat:
"""avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor"""
# Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
# But ONNX needs pair number [x,y] to specify on each side explicitly
# For pool3d, this number should be 3
expand_size = 1
kernel_shape, strides, pads = _adjust_attributes_of_avg_pool(
expand_size, kernel_size, stride, padding
)
result = op.AveragePool(
self,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
kernel_shape=kernel_shape,
pads=pads,
strides=strides,
)
return result
@torch_op("aten::avg_pool2d", trace_only=True)
def aten_avg_pool2d(
self: TFloat,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0),
ceil_mode: bool = False,
count_include_pad: bool = True,
divisor_override: Optional[int] = None,
) -> TFloat:
"""avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor"""
# Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
# But ONNX needs pair number [x,y] to specify on each side explicitly
# For pool3d, this number should be 3
expand_size = 2
kernel_shape, strides, pads = _adjust_attributes_of_avg_pool(
expand_size, kernel_size, stride, padding
)
result = op.AveragePool(
self,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
kernel_shape=kernel_shape,
pads=pads,
strides=strides,
)
# TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
# mask = [
# 1, 2, 3, S,..3, 2, 1
# 2, 4, 6, 2S, 6, 4, 2
# 3, 6, 9, 3S, 9, 6, 3
# S, 2S,3S,SS,3S,2S, S
# 3, 6, 9, 3S, 9, 6, 3
# 2, 4, 6, 2S, 6, 4, 2
# 1, 2, 3, S,..3, 2, 1
# ]
# S is stride size, in this case S=4,
# S may dup lot of times according to the image size
return result
def aten_avg_pool2d_backward(
grad_output: TensorType,
self: TensorType,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Sequence[int],
ceil_mode: bool,
count_include_pad: bool,
divisor_override: Optional[int],
) -> TensorType:
"""avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor"""
raise NotImplementedError()
@torch_op("aten::avg_pool3d", trace_only=True)
def aten_avg_pool3d(
self: TFloat,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0, 0),
ceil_mode: bool = False,
count_include_pad: bool = True,
divisor_override: Optional[int] = None,
) -> TFloat:
"""avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor"""
# Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
# But ONNX needs pair number [x,y] to specify on each side explicitly
# For pool3d, this number should be 3
expand_size = 3
kernel_shape, strides, pads = _adjust_attributes_of_avg_pool(
expand_size, kernel_size, stride, padding
)
result = op.AveragePool(
self,
kernel_shape=kernel_shape,
strides=strides,
pads=pads,
count_include_pad=count_include_pad,
ceil_mode=ceil_mode,
)
# TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
# mask = [
# 1, 2, 3, S,..3, 2, 1
# 2, 4, 6, 2S, 6, 4, 2
# 3, 6, 9, 3S, 9, 6, 3
# S, 2S,3S,SS,3S,2S, S
# 3, 6, 9, 3S, 9, 6, 3
# 2, 4, 6, 2S, 6, 4, 2
# 1, 2, 3, S,..3, 2, 1
# ]
# S is stride size, in this case S=4,
# S may dup lot of times according to the image size
return result
def aten_avg_pool3d_backward(
grad_output: TensorType,
self: TensorType,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Sequence[int],
ceil_mode: bool,
count_include_pad: bool,
divisor_override: Optional[int],
) -> TensorType:
"""avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor"""
raise NotImplementedError()
def aten_binary_cross_entropy(
self: TensorType,
target: TensorType,
weight: Optional[TensorType] = None,
reduction: int = 1,
) -> TensorType:
"""binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor"""
raise NotImplementedError()
def aten_binary_cross_entropy_backward(
grad_output: TensorType,
self: TensorType,
target: TensorType,
weight: Optional[TensorType] = None,
reduction: int = 1,
) -> TensorType:
"""binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor"""
raise NotImplementedError()
@torch_op("aten::celu", trace_only=True)
def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT:
"""celu(Tensor self, Scalar alpha=1.0) -> Tensor"""
return op.Celu(self, alpha=alpha) # op.Celu only support float32
@torch_op("aten::celu", trace_only=True)
def aten_celu_type_promoted(
self: TFloatUnlessFloat32, alpha: float = 1.0
) -> TFloatUnlessFloat32:
"""celu(Tensor self, Scalar alpha=1.0) -> Tensor"""
self_upcasted = op.Cast(self, to=FLOAT.dtype)
return op.CastLike(op.Celu(self_upcasted, alpha=alpha), self)
@torch_op("aten::col2im", trace_only=True)
def aten_col2im(
self: TReal,
output_size: INT64,
kernel_size: INT64,
dilation: Sequence[int] = (1, 1),
padding: Sequence[int] = (0, 0),
stride: Sequence[int] = (1, 1),
) -> TReal:
"""col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor"""
# assert(len(output_size)==2) for ONNX
# assert(len(kernel_size)==2) for ONNX
# assert(len(dilation)==2) for ONNX
# assert(len(stride)==2) for ONNX
# The pads should be [w, x, y, z] for ONNX
if len(padding) == 1: # [w] -> [w, w, w, w]
pads = padding * 4
elif len(padding) == 2: # [w, x] -> [w, x, w, x]
pads = padding * 2
else: # assert len(padding) == 4, already [w, x, y, z]
pads = padding
# Only one ONNX op here so didn't write a private function
return op.Col2Im(
self,
output_size,
kernel_size,
dilations=dilation,
pads=pads,
strides=stride,
)
def aten_conv_depthwise3d(
self: TensorType,
weight: TensorType,
kernel_size: Sequence[int],
bias: Optional[TensorType],
stride: Sequence[int],
padding: INT64,
dilation: Sequence[int],
) -> TensorType:
"""conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding, int[3] dilation) -> Tensor"""
raise NotImplementedError()
@torch_op("aten::cross_entropy_loss", trace_only=True)
def aten_cross_entropy_loss(
self: TFloat,
target: IntType,
weight: Optional[TFloat] = None,
reduction: int = 1, # default is 'mean'
ignore_index: int = -100,
label_smoothing: float = 0.0, # this was ignored due to ONNX not support
) -> TFloat:
"""cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor"""
if reduction == 0: # "none"
result, _ = op.SoftmaxCrossEntropyLoss(
self, target, weight, reduction="none", ignore_index=ignore_index
)
elif reduction == 2: # "sum"
result, _ = op.SoftmaxCrossEntropyLoss(
self, target, weight, reduction="sum", ignore_index=ignore_index
)
else: # "mean", default
result, _ = op.SoftmaxCrossEntropyLoss(
self, target, weight, reduction="mean", ignore_index=ignore_index
)
return result
@torch_op("aten::elu", trace_only=True)
def aten_elu(
self: TFloat,
alpha: float = 1.0,
scale: float = 1.0,
input_scale: float = 1.0,
) -> TFloat:
"""elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor"""
input_scale = op.CastLike(input_scale, self)
scale = op.CastLike(scale, self)
self = op.Mul(self, input_scale)
return op.Mul(op.Elu(self, alpha=alpha), scale)
def aten_elu_backward(
grad_output: TensorType,
alpha: float,
scale: float,
input_scale: float,
is_result: bool,
self_or_result: TensorType,
) -> TensorType:
"""elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor"""
raise NotImplementedError()
def aten_flatten_dense_tensors(tensors: Sequence[TensorType]) -> TensorType:
"""flatten_dense_tensors(Tensor[] tensors) -> Tensor"""
raise NotImplementedError()
def aten_fractional_max_pool2d(
self: TensorType,
kernel_size: Sequence[int],
output_size: Sequence[int],
random_samples: TensorType,
) -> tuple[TensorType, TensorType]:
"""fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor)"""
raise NotImplementedError()
def aten_fractional_max_pool2d_backward(
grad_output: TensorType,
self: TensorType,
kernel_size: Sequence[int],
output_size: Sequence[int],
indices: TensorType,
) -> TensorType:
"""fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor"""
raise NotImplementedError()
def aten_fractional_max_pool3d(
self: TensorType,
kernel_size: Sequence[int],
output_size: Sequence[int],
random_samples: TensorType,
) -> tuple[TensorType, TensorType]:
"""fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor)"""
raise NotImplementedError()
def aten_fractional_max_pool3d_backward(
grad_output: TensorType,
self: TensorType,
kernel_size: Sequence[int],
output_size: Sequence[int],
indices: TensorType,
) -> TensorType:
"""fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor"""
raise NotImplementedError()
@torch_op("aten::gelu", trace_only=True)
def aten_gelu(self: TReal, approximate: str = "none") -> TReal:
"""gelu(Tensor self, *, str approximate='none') -> Tensor"""
if approximate == "tanh":
result = _aten_gelu_approximate_tanh(self)
else:
result = _aten_gelu_approximate_none(self)
return result
@torch_op("aten::gelu", private=True)
def _aten_gelu_approximate_none(self: TReal) -> TReal:
"""gelu(Tensor self, *, str approximate='none') -> Tensor"""
# GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)]
inner = op.Div(self, 1.4142135623730951)
erf = op.Erf(inner)
inner = op.Add(erf, 1)
inner = op.Mul(self, inner)
result = op.Mul(0.5, inner)
return result
@torch_op("aten::gelu", private=True)
def _aten_gelu_approximate_tanh(self: TReal) -> TReal:
"""gelu(Tensor self, *, str approximate='none') -> Tensor"""
# GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]}
cubed = op.Pow(self, 3)
inner = op.Mul(0.044715, cubed)
inner = op.Add(self, inner)
# Prefer explicit graph construction over precomputed constants for clarity.
two_over_pi = op.CastLike(op.Div(2.0, _MATH_PI), self)
inner = op.Mul(op.Sqrt(two_over_pi), inner)
inner = op.Tanh(inner)
inner = op.Add(inner, 1)
inner = op.Mul(self, inner)
result = op.Mul(0.5, inner)
return result
def aten_gelu_backward(
grad_output: TensorType, self: TensorType, approximate: str = "none"
) -> TensorType:
"""gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor"""
raise NotImplementedError()
@torch_op("aten::glu")
def aten_glu(self: TFloat, dim: int = -1) -> TFloat:
"""glu(Tensor self, int dim=-1) -> Tensor"""
first, second = op.Split(self, axis=dim, num_outputs=2)
result = op.Mul(first, op.Sigmoid(second))
return result
def aten_glu_backward(grad_output: TensorType, self: TensorType, dim: int) -> TensorType:
"""glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor"""
raise NotImplementedError()
def aten_glu_backward_jvp(
grad_x: TensorType,
grad_glu: TensorType,
x: TensorType,
dgrad_glu: TensorType,
dx: TensorType,
dim: int,
) -> TensorType:
"""glu_backward_jvp(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim) -> Tensor"""
raise NotImplementedError()
@torch_op("aten::group_norm", trace_only=True)
def aten_group_norm(
input: TFloat,
num_groups: int,
weight: Optional[TFloat] = None,
bias: Optional[TFloat] = None,
eps: float = 1e-05,
cudnn_enabled: bool = True,
) -> TensorType:
"""group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor"""
# Actually we don't need N,C,HxW value because the input tensor has that information
if weight is None: # Set to 1.0 as default, the shape is Channel size
weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2))
if bias is None: # Set to 0.0 as default, the shape is Channel size
bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2))
# Because onnx.GroupNorm() need size=group for weight and bias
# But the torch's aten function's input need size=channel, the size mismatched
# So we have to use onnx.InstanceNorm() to simulate
neg_1 = op.Constant(value_ints=[-1])
# Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter
group_tensor = op.Reshape(num_groups, neg_1)
# 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1]
shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0)
input_reshaped = op.Reshape(input, shape_input)
weight_inst_norm = op.Expand(
op.CastLike(op.Constant(value_float=1.0), input), group_tensor
)
bias_inst_norm = op.Expand(op.CastLike(op.Constant(value_float=0.0), input), group_tensor)
norm = op.InstanceNormalization(
input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps
)
# Reshape back to input's shape
norm = op.Reshape(norm, op.Shape(input))
# Using the input weight and bias to do affine
# But need to unsqueeze to the target shape for broading cast easy
input_rank = Rank(input)
one = op.Constant(value_int=1)
axes_unsqueeze = op.Range(one, op.Sub(input_rank, one), one)
weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze)
bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze)
weight_full_shape = op.CastLike(weight_full_shape, norm)
norm_mul_weight = op.Mul(norm, weight_full_shape)
bias_full_shape = op.CastLike(bias_full_shape, norm_mul_weight)
norm_result = op.Add(norm_mul_weight, bias_full_shape)
return norm_result
def aten_glu_jvp(glu: TensorType, x: TensorType, dx: TensorType, dim: int) -> TensorType:
"""glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor"""
raise NotImplementedError()
@torch_op("aten::hardsigmoid", trace_only=True)
def aten_hardsigmoid(self: TFloat) -> TFloat:
"""hardsigmoid(Tensor self) -> Tensor"""
return op.HardSigmoid(self, alpha=1 / 6, beta=1 / 2)
def aten_hardsigmoid_backward(grad_output: TensorType, self: TensorType) -> TensorType:
"""hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor"""
raise NotImplementedError()
@torch_op("aten::hardswish")
def aten_hardswish(self: TFloat) -> TFloat:
"""hardswish(Tensor self) -> Tensor"""
return op.HardSwish(self)
def aten_hardswish_backward(grad_output: TensorType, self: TensorType) -> TensorType:
"""hardswish_backward(Tensor grad_output, Tensor self) -> Tensor"""
raise NotImplementedError()
@torch_op("aten::hardtanh")
def aten_hardtanh(self: TReal, min_val: float = -1.0, max_val: float = 1.0) -> TReal:
"""hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor"""
return op.Clip(self, min_val, max_val)
@torch_op("aten::hardtanh_backward", trace_only=True)
def aten_hardtanh_backward(
grad_output: TensorType, self: TensorType, min_val: float, max_val: float
) -> TensorType:
"""hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor"""
max_mask = op.Where(op.Greater(self, max_val), 0.0, 1.0)
min_mask = op.Where(op.Less(self, min_val), 0.0, 1.0)
return op.Mul(op.Mul(grad_output, max_mask), min_mask)
def aten_huber_loss(
self: TensorType, target: TensorType, reduction: int = 1, delta: float = 1.0
) -> TensorType:
"""huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor"""
raise NotImplementedError()
def aten_huber_loss_backward(
grad_output: TensorType, self: TensorType, target: TensorType, reduction: int, delta: float
) -> TensorType:
"""huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor"""
raise NotImplementedError()
def _get_im2col_indices_along_dim(
input_d: TInt,
kernel_size_d: int,
dilation_d: int,
padding_d: int,
stride_d: int,
):
# Input is always 4-D (N, C, H, W)
# Calculate indices of sliding blocks along spatial dimension
# Slide kernel over input each dim d:
# each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
# with steps = stride
blocks_d = input_d + ((padding_d * 2) - (dilation_d * (kernel_size_d - 1)))
# Stride kernel over input and find starting indices along dim d
blocks_d_indices = op.Range(0, blocks_d, stride_d)
blocks_d_indices = op.Unsqueeze(blocks_d_indices, [0])
# Apply dilation on kernel and find its indices along dim d
kernel_grid = op.Range(0, kernel_size_d * dilation_d, dilation_d)
kernel_mask = op.Unsqueeze(kernel_grid, [1])
# Broadcast and add kernel staring positions (indices) with
# kernel_grid along dim d, to get block indices along dim d
block_mask = op.Add(blocks_d_indices, kernel_mask)
return block_mask
def _get_im2col_padded_input(input, padding_h, padding_w):
# Input is always 4-D tensor (N, C, H, W)
# Padding tensor has the following format: (padding_h, padding_w)
# Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
pad = op.Concat(
op.Constant(value_ints=[0, 0]),
op.Unsqueeze(padding_h, [0]),
op.Unsqueeze(padding_w, [0]),
op.Constant(value_ints=[0, 0]),
op.Unsqueeze(padding_h, [0]),
op.Unsqueeze(padding_w, [0]),
axis=0,
)
return op.Pad(input, pad)
def _get_im2col_output_shape(input, kernel_h, kernel_w):
input_shape = op.Shape(input)
batch_dim = op.Gather(input_shape, 0, axis=0)
channel_dim = op.Gather(input_shape, 1, axis=0)
channel_unfolded = op.Mul(channel_dim, kernel_h * kernel_w)
return op.Concat(
op.Unsqueeze(batch_dim, [0]),
op.Unsqueeze(channel_unfolded, [0]),
op.Constant(value_ints=[-1]),
axis=0,
)
@torch_op("aten::im2col", trace_only=True)
def aten_im2col(
self: TReal,
kernel_size: Sequence[int],
dilation: Sequence[int] = (1, 1),
padding: Sequence[int] = (0, 0),
stride: Sequence[int] = (1, 1),
) -> TensorType:
"""im2col(Tensor self, int[2] kernel_size, int[2] dilation=1, int[2] padding=0, int[2] stride=1) -> Tensor"""
input_shape = op.Shape(self)
input_h = op.Gather(input_shape, 2, axis=0)
input_w = op.Gather(input_shape, 3, axis=0)
if not isinstance(kernel_size, Sequence):
kernel_size = (kernel_size, kernel_size)
kernel_sizes = list(kernel_size)
if not isinstance(dilation, Sequence):
dilation = (dilation, dilation)
dilations = list(dilation)
if not isinstance(padding, Sequence):
padding = (padding, padding)
pads = list(padding)
if isinstance(stride, int):
stride = (stride, stride)
strides = list(stride)
stride_h, stride_w = strides[0], strides[1]
padding_h, padding_w = pads[0], pads[1]
dilation_h, dilation_w = dilations[0], dilations[1]
kernel_h, kernel_w = kernel_sizes[0], kernel_sizes[1]
blocks_row_indices = _get_im2col_indices_along_dim(
input_h, kernel_h, dilation_h, padding_h, stride_h
)
blocks_col_indices = _get_im2col_indices_along_dim(
input_w, kernel_w, dilation_w, padding_w, stride_w
)
output_shape = _get_im2col_output_shape(self, kernel_h, kernel_w)
padded_input = _get_im2col_padded_input(self, padding_h, padding_w)
# For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
# [[[[1., 2., 3.,],
# [4., 5., 6.,],
# [7., 8., 9.,]]]]
# First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
# [[[[[1., 2., 3.],
# [4., 5., 6.]],
# [[4., 5., 6.],
# [7., 8., 9.]]]]]
# And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
# [[[[[[1., 2.],
# [4., 5.]],
# [[2., 3.],
# [5., 6]]],
# [[[4., 5.],
# [7., 8.]],
# [[5., 6.],
# [8., 9.]]]]]]
# Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
# [[[1., 2., 4., 5.],
# [2., 3., 5., 6.],
# [4., 5., 7., 8.],
# [5., 6., 8., 9.]]]
output = op.Gather(padded_input, blocks_row_indices, axis=2)
output = op.Gather(output, blocks_col_indices, axis=4)
output = op.Transpose(output, perm=[0, 1, 2, 4, 3, 5])
return op.Reshape(output, output_shape)
def aten_infinitely_differentiable_gelu_backward(
grad: TensorType, self: TensorType
) -> TensorType:
"""infinitely_differentiable_gelu_backward(Tensor grad, Tensor self) -> Tensor"""
raise NotImplementedError()
def aten_l1_loss(self: TensorType, target: TensorType, reduction: int = 1) -> TensorType:
"""l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor"""
raise NotImplementedError()
@torch_op("aten::leaky_relu", trace_only=True)
def aten_leaky_relu(self: TFloat, negative_slope: float = 0.01) -> TFloat:
"""leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor"""
return op.LeakyRelu(self, alpha=negative_slope)
def aten_leaky_relu_backward(
grad_output: TensorType, self: TensorType, negative_slope: float, self_is_result: bool
) -> TensorType:
"""leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor"""
raise NotImplementedError()
@torch_op("aten::linear", trace_only=True)
def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) -> TFloat:
"""linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"""
if len(input.shape) == 2:
# Use Gemm for the rank 2 input
return op.Gemm(input, weight, bias, transB=True)
weight_transposed = op.Transpose(weight, perm=[1, 0])
mul = op.MatMul(input, weight_transposed)
if bias is None:
return mul
return op.Add(mul, bias)
@torch_op("aten::log_sigmoid")
def aten_log_sigmoid(self: TFloat) -> TFloat:
"""log_sigmoid(Tensor self) -> Tensor"""
return op.Log(op.Sigmoid(self))
def aten_log_sigmoid_backward(
grad_output: TensorType, self: TensorType, buffer: TensorType
) -> TensorType:
"""log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor"""
raise NotImplementedError()
def aten_log_sigmoid_forward(self: TensorType) -> tuple[TensorType, TensorType]:
"""log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)"""
raise NotImplementedError()
def aten_logit_backward(
grad_output: TensorType, self: TensorType, eps: Optional[float] = None
) -> TensorType:
"""logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor"""
raise NotImplementedError()
@torch_op("aten::max_pool1d", trace_only=True)
def aten_max_pool1d(
self: TFloatOrUInt8,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0,),
dilation: Sequence[int] = (1,),
ceil_mode: bool = False,
) -> TFloatOrUInt8:
"""max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor"""
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
# But ONNX needs to specify a tuple of three ints for all sides explicitly.
expand_size = 1
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)
return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 2)
@torch_op("aten::max_pool1d_with_indices", trace_only=True)
def aten_max_pool1d_with_indices(
self: TFloatOrUInt8,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0,),
dilation: Sequence[int] = (1,),
ceil_mode: bool = False,
) -> Tuple[TFloatOrUInt8, INT64]:
"""max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
# But ONNX needs to specify a tuple of three ints for all sides explicitly.
expand_size = 1
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)
return _aten_max_pool_with_indices_onnx(
self,
kernel_shape,
strides,
pads,
dilations,
ceil_mode,
2,
([1] * expand_size),
([0] * expand_size),
([2 + i for i in range(expand_size)]),
)
def _adjust_attributes_of_max_pool(
expand_size: int,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Sequence[int],
dilation: Sequence[int],
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:
if isinstance(dilation, int):
dilations = [dilation] * expand_size
else:
dilations = dilation
if isinstance(kernel_size, int):
kernel_shape = [kernel_size] * expand_size
else:
kernel_shape = kernel_size
# NOTE: expand_size is the dimension of pooling kernel,
# ONNX needs begin and end padding so we need to double the padding
# NOTE: expand size prevents padding from being a single int in
# multiple dimension cases
if isinstance(padding, int):
pads = [padding] * expand_size * 2
elif len(padding) == 1:
pads = padding * expand_size * 2
elif len(padding) == 2:
# 2D padding
pads = padding * 2
elif len(padding) == 3:
# 3D padding
pads = padding * 2
else:
# When padding is already done for all dimensions,
# we don't need to double it
# eg: (1, 1, 1, 1, 1, 1)
pads = padding
if isinstance(stride, int):
strides = [stride] * expand_size
elif not stride:
strides = kernel_shape
else:
strides = stride
return (kernel_shape, strides, pads, dilations)
@torch_op("aten::max_pool2d", trace_only=True)
def aten_max_pool2d(
self: TFloatOrUInt8,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0),
dilation: Sequence[int] = (1, 1),
ceil_mode: bool = False,
) -> TFloatOrUInt8:
"""max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor"""
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
# But ONNX needs to specify a pair of number [x,y] on each side explicitly.
expand_size = 2
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)
return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 3)
def _aten_max_pool_onnx(
self: TFloatOrUInt8,
kernel_shape: Sequence[int],
strides: Sequence[int],
pads: Sequence[int],
dilations: Sequence[int],
ceil_mode: bool,