-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
vulkan : implement Stable Diffusion operators (#904)
* Fix Vulkan repeat op * Implement Vulkan concat op * Delete old Vulkan shader generator * Implement Vulkan im2col op * Implement Vulkan unary gelu_quick op * Implement Vulkan group_norm op * Implement Vulkan timestep_embedding op * Implement Vulkan upscale op * Fix Vulkan vk_context tensor extra index issue * Fix Vulkan matmul shader parameter bug * Properly fix Vulkan matmul shader parameter bug * Add Vulkan ADD f16 + f32 -> f16 operator support * Implement Vulkan tanh op * Fix Vulkan group count too large Validation error on non-Nvidia GPUs * Throw error when too much memory is requested * Fix another Vulkan group count too large Validation error on non-Nvidia GPUs * Fix matmul MMQ condition * Implement Vulkan pad op * Fix Vulkan crash when tensor is used multiple times in a compute graph * Add Vulkan CONCAT f16 + f16 -> f16 op * Add Vulkan LEAKY_RELU op
- Loading branch information
Showing
29 changed files
with
1,029 additions
and
3,412 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#version 450 | ||
|
||
#include "types.comp" | ||
#include "generic_binary_head.comp" | ||
|
||
void main() { | ||
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; | ||
const int dim = p.param3; | ||
|
||
if (idx >= p.ne) { | ||
return; | ||
} | ||
|
||
const uint i3 = idx / (p.ne22*p.ne21*p.ne20); | ||
const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20; | ||
const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20); | ||
const uint i2_offset = i2*p.ne21*p.ne20; | ||
const uint i1 = (idx - i3_offset - i2_offset) / p.ne20; | ||
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20; | ||
|
||
uint o[4] = {0, 0, 0, 0}; | ||
o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03)); | ||
|
||
const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; | ||
const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10; | ||
const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20; | ||
|
||
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; | ||
|
||
#ifndef OPTIMIZATION_ERROR_WORKAROUND | ||
data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]); | ||
#else | ||
data_d[p.d_offset + dst_idx] = is_src0 ? data_a[src0_idx] : data_b[src1_idx]; | ||
#endif | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#version 450 | ||
|
||
#include "generic_head.comp" | ||
#include "types.comp" | ||
|
||
#extension GL_EXT_control_flow_attributes : enable | ||
|
||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; | ||
|
||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; | ||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; | ||
|
||
void main() { | ||
const float GELU_QUICK_COEF = -1.702f; | ||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; | ||
|
||
if (i >= p.KX) { | ||
return; | ||
} | ||
|
||
const float x = float(data_a[i]); | ||
data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x)))); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#version 450 | ||
|
||
#include "generic_head.comp" | ||
#include "types.comp" | ||
|
||
#extension GL_EXT_control_flow_attributes : enable | ||
#define BLOCK_SIZE 512 | ||
|
||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; | ||
|
||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; | ||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; | ||
|
||
shared float tmp[BLOCK_SIZE]; | ||
|
||
void main() { | ||
const uint group_size = p.KX; | ||
const float eps = p.param1; | ||
|
||
const uint tid = gl_LocalInvocationID.x; | ||
const uint start = gl_WorkGroupID.x * group_size + tid; | ||
const uint end = start + group_size; | ||
|
||
tmp[tid] = 0.0f; | ||
|
||
// Calculate mean | ||
[[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { | ||
tmp[tid] += float(data_a[col]); | ||
} | ||
|
||
// tmp up partial tmps and write back result | ||
barrier(); | ||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { | ||
if (tid < s) { | ||
tmp[tid] += tmp[tid + s]; | ||
} | ||
barrier(); | ||
} | ||
|
||
const float mean = tmp[0] / group_size; | ||
barrier(); | ||
tmp[tid] = 0.0f; | ||
|
||
// Calculate variance | ||
[[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { | ||
const float xi = float(data_a[col]) - mean; | ||
data_d[col] = D_TYPE(xi); | ||
tmp[tid] += xi * xi; | ||
} | ||
|
||
// sum up partial sums and write back result | ||
barrier(); | ||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { | ||
if (tid < s) { | ||
tmp[tid] += tmp[tid + s]; | ||
} | ||
barrier(); | ||
} | ||
|
||
const float variance = tmp[0] / group_size; | ||
const float scale = inversesqrt(variance + eps); | ||
|
||
[[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { | ||
data_d[col] *= D_TYPE(scale); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#version 450 | ||
|
||
#extension GL_EXT_shader_16bit_storage : require | ||
|
||
layout (push_constant) uniform parameter | ||
{ | ||
uint batch_offset; uint offset_delta; | ||
uint IC; | ||
uint IW; uint IH; | ||
uint OW; uint OH; | ||
uint KW; uint KH; | ||
uint pelements; | ||
uint CHW; | ||
int s0; int s1; | ||
int p0; int p1; | ||
int d0; int d1; | ||
} p; | ||
|
||
#include "types.comp" | ||
|
||
#define BLOCK_SIZE 256 | ||
|
||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; | ||
|
||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; | ||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; | ||
|
||
void main() { | ||
const uint i = gl_GlobalInvocationID.x; | ||
if (i >= p.pelements) { | ||
return; | ||
} | ||
|
||
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); | ||
const uint kx = i / ksize; | ||
const uint kd = kx * ksize; | ||
const uint ky = (i - kd) / p.OW; | ||
const uint ix = i % p.OW; | ||
|
||
const uint oh = gl_GlobalInvocationID.y; | ||
const uint batch = gl_GlobalInvocationID.z / p.IC; | ||
const uint ic = gl_GlobalInvocationID.z % p.IC; | ||
|
||
const uint iiw = ix * p.s0 + kx * p.d0 - p.p0; | ||
const uint iih = oh * p.s1 + ky * p.d1 - p.p1; | ||
|
||
const uint offset_dst = | ||
((batch * p.OH + oh) * p.OW + ix) * p.CHW + | ||
(ic * (p.KW * p.KH) + ky * p.KW + kx); | ||
|
||
if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) { | ||
data_d[offset_dst] = D_TYPE(0.0f); | ||
} else { | ||
const uint offset_src = ic * p.offset_delta + batch * p.batch_offset; | ||
data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#version 450 | ||
|
||
#include "generic_head.comp" | ||
#include "types.comp" | ||
|
||
#extension GL_EXT_control_flow_attributes : enable | ||
|
||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; | ||
|
||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; | ||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; | ||
|
||
void main() { | ||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; | ||
|
||
if (i >= p.KX) { | ||
return; | ||
} | ||
|
||
const float val = float(data_a[i]); | ||
data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#version 450 | ||
|
||
#include "types.comp" | ||
#include "generic_unary_head.comp" | ||
|
||
void main() { | ||
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; | ||
|
||
if (idx >= p.ne) { | ||
return; | ||
} | ||
|
||
const uint i3 = idx / (p.ne12*p.ne11*p.ne10); | ||
const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; | ||
const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10); | ||
const uint i2_offset = i2*p.ne11*p.ne10; | ||
const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; | ||
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; | ||
|
||
const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; | ||
const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; | ||
|
||
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; | ||
|
||
data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.