-
Notifications
You must be signed in to change notification settings - Fork 19
Open
Description
To mark which ones we see worth doing, are doing / need to do
- iota reshape (becomes single iota)
%195 = stablehlo.iota dim = 0 : tensor<1024xi32>
%196 = stablehlo.reshape %195 : (tensor<1024xi32>) -> tensor<1x1x1024xi32>
- reshape of pad (becomes diff pad)
%175 = stablehlo.pad %174, %148, low = [0, 0, 1024, 0, 0], high = [0, 0, 0, 0, 0], interior = [0, 0, 0, 0, 0] : (tensor<1x3x1024x1x1xf32>, tensor<f32>) -> tensor<1x3x2048x1x1xf32>
%176 = stablehlo.reshape %175 : (tensor<1x3x2048x1x1xf32>) -> tensor<1x3x2048xf32>
- mul of pad with 0 (becomes pad of mul) 44026d4
%175 = stablehlo.pad %174, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
%177 = stablehlo.multiply %176, %112 : tensor<1x3x2048xf32>
- broadcast of pad (becomes pad of broadcast)
%175 = stablehlo.pad %174, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
%189 = stablehlo.broadcast_in_dim %177, dims = [0, 2, 4] : (tensor<1x3x2048xf32>) -> tensor<1x1x3x1024x2048xf32>
Metadata
Metadata
Assignees
Labels
No labels