Skip to content

Commit

Permalink
Implement aten.linear.default (#3594)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3594

As title.

Implementation is rather simple because the shaders just have to accumulate the `mat2` shader across the width dim rather than the height dim.

Reviewed By: yipjustin

Differential Revision: D57203869

fbshipit-source-id: 08932a75e66924a0dfb0816f8ccefa718a341dd8
  • Loading branch information
SS-JIA authored and facebook-github-bot committed May 14, 2024
1 parent e8a520c commit 0e7955d
Show file tree
Hide file tree
Showing 15 changed files with 330 additions and 131 deletions.
11 changes: 5 additions & 6 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_naive.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#define PRECISION ${PRECISION}

$if MAT2_IS_TRANSPOSED:
#define MAT2_IS_TRANSPOSED

#include "indexing_utils.h"
#include "matmul.h"

Expand Down Expand Up @@ -45,24 +48,20 @@ void main() {
}

vec4 texel = vec4(0);
ivec3 mat1_pos = ivec3(0, pos.y, pos.z);

$if MAT1_PACKING == "W_packed":
$if MAT2_PACKING == "H_packed":
ivec3 mat2_pos = ivec3(pos.x * 4, 0, pos.z);
texel = matmul_naive_W_packed_H_packed(
im_mat1,
im_mat2,
mat1_pos,
mat2_pos,
pos,
in_sizes[0]);
$elif MAT2_PACKING == "W_packed":
ivec3 mat2_pos = ivec3(pos.x, 0, pos.z);
texel = matmul_naive_W_packed_W_packed(
im_mat1,
im_mat2,
mat1_pos,
mat2_pos,
pos,
in_sizes[0]);
$else:
$raise Exception("Unsupported value for MAT2_PACKING")
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_naive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ addmm_naive:
NDIM: 3
MAT1_PACKING: W_packed
MAT2_PACKING: H_packed
MAT2_IS_TRANSPOSED: false
generate_variant_forall:
DTYPE:
- VALUE: float
Expand All @@ -18,3 +19,6 @@ addmm_naive:
- NAME: addmm_naive_W_packed_H_packed
- NAME: addmm_naive_W_packed_W_packed
MAT2_PACKING: W_packed
- NAME: linear_naive_W_packed_W_packed
MAT2_PACKING: W_packed
MAT2_IS_TRANSPOSED: true
22 changes: 10 additions & 12 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#define PRECISION ${PRECISION}

$if MAT2_IS_TRANSPOSED:
#define MAT2_IS_TRANSPOSED

#include "indexing_utils.h"
#include "matmul.h"

Expand All @@ -31,11 +34,8 @@ layout(set = 0, binding = 6) uniform PRECISION restrict SelfSizes {
ivec4 self_sizes;
};

layout(set = 0, binding = 7) uniform PRECISION restrict PackedDimMeta {
int packed_dim_size;
int packed_dim_size_padded;
int packed_dim_texel_len;
int packed_dim_padding;
layout(set = 0, binding = 7) uniform PRECISION restrict InLimits {
ivec3 in_limits;
};

layout(set = 0, binding = 8) uniform PRECISION restrict Params {
Expand All @@ -57,8 +57,7 @@ void main() {
im_mat2,
pos,
out_sizes[2],
packed_dim_texel_len,
packed_dim_padding);
in_limits[0]);

for (int idx_c = 0; idx_c < FOUR; idx_c++) {
for (int idx_r = 0; idx_r < FOUR; idx_r++) {
Expand All @@ -70,17 +69,16 @@ void main() {
out_pos,
self_sizes.x == 1,
self_sizes.y == 1);
results.data[idx_c][idx_r][0] = beta * self_texel.x + alpha * results.data[idx_c][idx_r][0];

// results is in transposed order w.r.t. the desired output
imageStore(
im_out,
out_pos,
vec4(
results.data[idx_c][idx_r][0],
results.data[idx_c][idx_r][1],
results.data[idx_c][idx_r][2],
results.data[idx_c][idx_r][3]));
beta * self_texel.x + alpha * results.data[idx_c][idx_r][0],
beta * self_texel.x + alpha * results.data[idx_c][idx_r][1],
beta * self_texel.x + alpha * results.data[idx_c][idx_r][2],
beta * self_texel.x + alpha * results.data[idx_c][idx_r][3]));
}
}
}
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ addmm_optimized:
DTYPE: float
NDIM: 3
PACKING: C_packed
MAT2_IS_TRANSPOSED: false
generate_variant_forall:
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: addmm_optimized
- NAME: linear_optimized
MAT2_IS_TRANSPOSED: true
102 changes: 61 additions & 41 deletions backends/vulkan/runtime/graph/ops/glsl/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,66 @@ struct FloatMatrix {
float data[FOUR][FOUR][FOUR];
};

#ifdef MAT2_IS_TRANSPOSED
vec4 matmul_naive_W_packed_W_packed(
#else
vec4 matmul_naive_W_packed_H_packed(
sampler3D im_mat1,
sampler3D im_mat2,
ivec3 mat1_pos,
ivec3 mat2_pos,
#endif
const sampler3D im_mat1,
const sampler3D im_mat2,
const ivec3 out_pos,
const int width) {
ivec3 mat1_pos = ivec3(0, out_pos.y, out_pos.z);
#ifdef MAT2_IS_TRANSPOSED
ivec3 mat2_pos = ivec3(0, out_pos.x * 4, 0);
#else
ivec3 mat2_pos = ivec3(out_pos.x * 4, 0, out_pos.z);
#endif

vec4 texel = vec4(0);
int K = (width + 3) / 4;
const int K = (width + 3) / 4;

for (int i = 0; i < K; ++i) {
vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0);
vec4 sums = vec4(
const vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0);
#ifdef MAT2_IS_TRANSPOSED
const vec4 sums = vec4(
dot(mat1_tex, texelFetch(im_mat2, mat2_pos, 0)),
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 1, 0), 0)),
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 2, 0), 0)),
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 3, 0), 0)));
#else
const vec4 sums = vec4(
dot(mat1_tex, texelFetch(im_mat2, mat2_pos, 0)),
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(1, 0, 0), 0)),
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(2, 0, 0), 0)),
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(3, 0, 0), 0)));
#endif

texel += sums;

mat1_pos.x++;
#ifdef MAT2_IS_TRANSPOSED
mat2_pos.x++;
#else
mat2_pos.y++;
#endif
}

return texel;
}

#ifdef MAT2_IS_TRANSPOSED
vec4 matmul_naive_W_packed_H_packed(
#else
vec4 matmul_naive_W_packed_W_packed(
sampler3D im_mat1,
sampler3D im_mat2,
ivec3 mat1_pos,
ivec3 mat2_pos,
#endif
const sampler3D im_mat1,
const sampler3D im_mat2,
const ivec3 out_pos,
const int width) {
ivec3 mat1_pos = ivec3(0, out_pos.y, out_pos.z);
ivec3 mat2_pos = ivec3(out_pos.x, 0, out_pos.z);

vec4 texel = vec4(0);
int K = divup4(width);

Expand Down Expand Up @@ -87,7 +115,7 @@ vec4 get_texel_W_packed(
else if (broadcast_at_height) {
self_texel = texelFetch(im_self, ivec3(pos.x, 0, 0), 0);
} else {
self_texel = texelFetch(im_self, pos, 0);
self_texel = texelFetch(im_self, ivec3(pos.x, pos.y, 0), 0);
}

return self_texel;
Expand All @@ -112,7 +140,7 @@ vec4 get_texel_C_packed(
else if (broadcast_at_height) {
self_texel = texelFetch(im_self, ivec3(pos.x, 0, 0), 0);
} else {
self_texel = texelFetch(im_self, pos, 0);
self_texel = texelFetch(im_self, ivec3(pos.x, pos.y, 0), 0);
}

return self_texel;
Expand All @@ -123,8 +151,7 @@ FloatMatrix matmul_partial_4x4(
sampler3D im_mat2,
const ivec3 pos,
const int batch_size,
const int K_texel_len,
const int packed_dim_padding) {
const int K_texel_len) {
FloatMatrix results;
for (int i = 0; i < FOUR; i++) {
for (int j = 0; j < FOUR; j++) {
Expand All @@ -133,43 +160,36 @@ FloatMatrix matmul_partial_4x4(
}
}
}
vec4 im_mat1_partial_rows[FOUR];
vec4 im_mat2_partial_cols[FOUR];
vec4 im_mat1_partial_load[FOUR];
vec4 im_mat2_partial_load[FOUR];

for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) {
if (FOUR * pos.z + batch_idx >= batch_size) {
break;
}
// read and cache 4x4 tile of im_mat1 (4 adjacent rows)
int mat_z = FOUR * pos.z + batch_idx;
for (int mat1_x = 0; mat1_x < K_texel_len; mat1_x++) {
for (int mat1_row = 0; mat1_row < FOUR; mat1_row++) {
const int mat1_y = (FOUR * pos.y) + mat1_row;
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, FOUR * pos.z + batch_idx);
im_mat1_partial_rows[mat1_row] = texelFetch(im_mat1, mat1_pos, 0);
// set the value out of the boundary to be 0
if (mat1_x == K_texel_len - 1 && packed_dim_padding > 0) {
for (int kk = 0; kk < packed_dim_padding; kk++) {
im_mat1_partial_rows[mat1_row][3 - kk] = 0;
}
}
}
// read and cache 4x4 tile of im_mat2 (4 adjacent columns)
for (int mat2_col = 0; mat2_col < FOUR; mat2_col++) {
const int mat2_x = (FOUR * pos.x) + mat2_col;
const ivec3 pos_rd = ivec3(mat2_x, mat1_x, FOUR * pos.z + batch_idx);
im_mat2_partial_cols[mat2_col] = texelFetch(im_mat2, pos_rd, 0);
// set the value out of the boundary to be 0
if (mat1_x == K_texel_len - 1 && packed_dim_padding > 0) {
for (int kk = 0; kk < packed_dim_padding; kk++) {
im_mat2_partial_cols[mat2_col][3 - kk] = 0;
}
}
for (int offset = 0; offset < FOUR; offset++) {
// read and cache 4x4 tile of im_mat1
const int mat1_y = (FOUR * pos.y) + offset;
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, mat_z);
im_mat1_partial_load[offset] = texelFetch(im_mat1, mat1_pos, 0);
// read and cache 4x4 tile of im_mat2
#ifdef MAT2_IS_TRANSPOSED
const int mat2_y = (FOUR * pos.x) + offset;
const ivec3 mat2_pos = ivec3(mat1_x, mat2_y, 0);
im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0);
#else
const int mat2_x = (FOUR * pos.x) + offset;
const ivec3 mat2_pos = ivec3(mat2_x, mat1_x, mat_z);
im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0);
#endif
}
// perform partial dot products and add partial result to results
for (int out_row = 0; out_row < FOUR; out_row++) {
for (int out_col = 0; out_col < FOUR; out_col++) {
results.data[out_row][out_col][batch_idx] +=
dot(im_mat1_partial_rows[out_row], im_mat2_partial_cols[out_col]);
dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]);
}
}
}
Expand Down
12 changes: 5 additions & 7 deletions backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#define PRECISION ${PRECISION}

$if MAT2_IS_TRANSPOSED:
#define MAT2_IS_TRANSPOSED

#include "indexing_utils.h"
#include "matmul.h"

Expand All @@ -35,24 +38,19 @@ void main() {
}

vec4 texel = vec4(0);
ivec3 mat1_pos = ivec3(0, pos.y, pos.z);

$if MAT1_PACKING == "W_packed":
$if MAT2_PACKING == "H_packed":
ivec3 mat2_pos = ivec3(pos.x * 4, 0, pos.z);
texel = matmul_naive_W_packed_H_packed(
im_mat1,
im_mat2,
mat1_pos,
mat2_pos,
pos,
in_sizes[0]);
$elif MAT2_PACKING == "W_packed":
ivec3 mat2_pos = ivec3(pos.x, 0, pos.z);
texel = matmul_naive_W_packed_W_packed(
im_mat1,
im_mat2,
mat1_pos,
mat2_pos,
pos,
in_sizes[0]);
$else:
$raise Exception("Unsupported value for MAT2_PACKING")
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ matmul_naive:
NDIM: 3
MAT1_PACKING: W_packed
MAT2_PACKING: H_packed
MAT2_IS_TRANSPOSED: false
generate_variant_forall:
DTYPE:
- VALUE: float
Expand All @@ -18,3 +19,6 @@ matmul_naive:
- NAME: matmul_naive_W_packed_H_packed
- NAME: matmul_naive_W_packed_W_packed
MAT2_PACKING: W_packed
- NAME: matmul_transposed_naive_W_packed_W_packed
MAT2_PACKING: W_packed
MAT2_IS_TRANSPOSED: true
13 changes: 6 additions & 7 deletions backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#define PRECISION ${PRECISION}

$if MAT2_IS_TRANSPOSED:
#define MAT2_IS_TRANSPOSED

#include "indexing_utils.h"
#include "matmul.h"

Expand All @@ -25,11 +28,8 @@ layout(set = 0, binding = 4) uniform PRECISION restrict OutSizes {
ivec4 out_sizes;
};

layout(set = 0, binding = 5) uniform PRECISION restrict PackedDimMeta {
int packed_dim_size;
int packed_dim_size_padded;
int packed_dim_texel_len;
int packed_dim_padding;
layout(set = 0, binding = 5) uniform PRECISION restrict InLimits {
ivec3 in_limits;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
Expand All @@ -46,8 +46,7 @@ void main() {
im_mat2,
pos,
out_sizes[2],
packed_dim_texel_len,
packed_dim_padding);
in_limits[0]);

for (int idx_c = 0; idx_c < FOUR; idx_c++) {
for (int idx_r = 0; idx_r < FOUR; idx_r++) {
Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ matmul_optimized:
DTYPE: float
NDIM: 3
PACKING: C_packed
MAT2_IS_TRANSPOSED: false
generate_variant_forall:
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: matmul_optimized
- NAME: matmul_transposed_optimized
MAT2_IS_TRANSPOSED: true
Loading

0 comments on commit 0e7955d

Please sign in to comment.