Skip to content

Commit

Permalink
[PyTorch][Vulkan] Add template based codegen for shader generation (p…
Browse files Browse the repository at this point in the history
…ytorch#88323)

We would like to be able to parameterize kernels such that a parameterized
algorithm can be implemented via templates. We can then profile performance of
a kernel with different parameter values. This enables us to determine what
parameters may work the best for a given kernel or a given device.

In this diff one such kernel added in 1x1 conv which parameters across size of
the tile being produced by each invocation.

Few other options for parameters can be:
- One can imagine dtype can also be a parameter such that we can do compute in
fp16 or int8/int16.
- Register blocking for input channels

Differential Revision: [D40280336](https://our.internmc.facebook.com/intern/diff/D40280336/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D40280336/)!
Pull Request resolved: pytorch#88323
Approved by: https://github.com/jmdetloff
  • Loading branch information
kimishpatel authored and pytorchmergebot committed Nov 3, 2022
1 parent 60925fc commit 893f8e3
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

/*
* TILE_SIZE = (2, 2, 1)
* WEIGHT_STORAGE = TEXTURE_3D
* TILE_SIZE = ($TILE_SIZE_X, $TILE_SIZE_Y, 1)
* WEIGHT_STORAGE = TEXTURE_2D
* WEIGHT_STORAGE_LAYOUT = OC4,IC4,4ic,4oc
*/

layout(std430) buffer;
Expand Down Expand Up @@ -54,17 +51,19 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
void main() {
const ivec3 gpos = ivec3(gl_GlobalInvocationID);

// Determine the output positions that will be written to.
// Output position for TILE_SIZE_X, TILE_SIZE_Y = 2, 2
// +--------+--------+
// | pos[0] | pos[1] |
// +--------+--------+
// | pos[2] | pos[3] |
// +--------+--------+
ivec3 pos[4];
pos[0] = ivec3(gpos.x * 2, gpos.y * 2, gpos.z);
pos[1] = ivec3(gpos.x * 2 + 1, gpos.y * 2, gpos.z);
pos[2] = ivec3(gpos.x * 2, gpos.y * 2 + 1, gpos.z);
pos[3] = ivec3(gpos.x * 2 + 1, gpos.y * 2 + 1, gpos.z);
ivec3 pos[$TILE_SIZE_X * $TILE_SIZE_Y];
for (int y = 0, i = 0; y < $TILE_SIZE_Y; ++y) {
for (int x = 0; x < $TILE_SIZE_X; ++x) {
pos[i] = ivec3(gpos.x * $TILE_SIZE_X + x, gpos.y * $TILE_SIZE_Y + y, gpos.z);
i++;
}
}

// If the top left position is out of bounds, then this invocation will have
// no work to do.
Expand All @@ -75,14 +74,14 @@ void main() {
// Compute the index of the input texture that needs to be loaded for each
// output position. Note that negative indices can be produced indicating that
// the top-left element is in a region added by padding.
ivec2 ipos[4];
for (int i = 0; i < 4; ++i) {
ivec2 ipos[$TILE_SIZE_X * $TILE_SIZE_Y];
for (int i = 0; i < $TILE_SIZE_X * $TILE_SIZE_Y; ++i) {
ipos[i] = pos[i].xy * uBlock.stride - uBlock.padding;
}

vec4 sum[4];
vec4 sum[$TILE_SIZE_X * $TILE_SIZE_Y];
sum[0] = texelFetch(uBias, ivec2(gpos.z, 0), 0);
for (int i = 1; i < 4; ++i) {
for (int i = 1; i < $TILE_SIZE_X * $TILE_SIZE_Y; ++i) {
sum[i] = sum[0];
}

Expand All @@ -92,13 +91,18 @@ void main() {
// During prepacking, the weight tensor has been permuted so that the
// channel (IC) dim is along the x axis, and the batch (OC) dim is along
// the z axis.
vec4 in_tex[$TILE_SIZE_X * $TILE_SIZE_Y];
const vec4 ktex_0 = texelFetch(uKernel, ivec2(z + 0, gpos.z), 0);
const vec4 ktex_1 = texelFetch(uKernel, ivec2(z + 1, gpos.z), 0);
const vec4 ktex_2 = texelFetch(uKernel, ivec2(z + 2, gpos.z), 0);
const vec4 ktex_3 = texelFetch(uKernel, ivec2(z + 3, gpos.z), 0);

for (int i = 0; i < 4; ++i) {
const vec4 in_tex = texelFetch(uInput, ivec3(ipos[i], z4), 0);
for (int i = 0; i < $TILE_SIZE_Y * $TILE_SIZE_X; ++i) {
in_tex[i] = texelFetch(uInput, ivec3(ipos[i], z4), 0);
}

for (int i = 0; i < $TILE_SIZE_Y * $TILE_SIZE_X; ++i) {
// For 2x2 tile size algorithm works as follows.
// To explain the calculations below, the contents one in_tex and the
// group of 4 texels loaded from uKernel are shown:
//
Expand Down Expand Up @@ -131,15 +135,14 @@ void main() {
//
// which is what is expressed in the following calculations. This is done
// for each output position.

sum[i] = fma(in_tex.xxxx, ktex_0, sum[i]);
sum[i] = fma(in_tex.yyyy, ktex_1, sum[i]);
sum[i] = fma(in_tex.zzzz, ktex_2, sum[i]);
sum[i] = fma(in_tex.wwww, ktex_3, sum[i]);
sum[i] = fma(in_tex[i].xxxx, ktex_0, sum[i]);
sum[i] = fma(in_tex[i].yyyy, ktex_1, sum[i]);
sum[i] = fma(in_tex[i].zzzz, ktex_2, sum[i]);
sum[i] = fma(in_tex[i].wwww, ktex_3, sum[i]);
}
}

for (int i = 0; i < 4; ++i) {
for (int i = 0; i < $TILE_SIZE_Y * $TILE_SIZE_X; ++i) {
if (all(lessThan(pos[i], uBlock.out_extents.xyz))) {
imageStore(
uOutput,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
conv2d_pw:
parameter_names_with_default_values:
TILE_SIZE_X: 2
TILE_SIZE_Y: 2
parameter_values:
- TILE_SIZE_X: 1
TILE_SIZE_Y: 1
26 changes: 26 additions & 0 deletions tools/BUCK.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,18 @@ def define_tools_targets(
"gen_vulkan_spv.py",
],
base_module = "",
deps = [
torchgen_deps,
":gen_aten_vulkan_glsl_lib",
],
)

python_library(
name = "gen_aten_vulkan_glsl_lib",
srcs = [
"gen_vulkan_glsl.py",
],
base_module = "tools",
deps = [
torchgen_deps,
],
Expand All @@ -223,6 +235,20 @@ def define_tools_targets(
"PUBLIC",
],
deps = [
":gen_aten_vulkan_glsl_lib",
":gen_aten_vulkan_spv_lib",
],
)

python_test(
name = "vulkan_codegen_test",
srcs = [
"test/test_vulkan_codegen.py",
],
contacts = contacts,
visibility = ["PUBLIC"],
deps = [
":gen_aten_vulkan_glsl_lib",
":gen_aten_vulkan_spv_lib",
],
)
Expand Down
111 changes: 111 additions & 0 deletions tools/gen_vulkan_glsl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import copy
import os

import yaml

from torchgen.code_template import CodeTemplate
from yaml.constructor import ConstructorError
from yaml.nodes import MappingNode

try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader # type: ignore[misc]

# https://gist.github.com/pypt/94d747fe5180851196eb
class UniqueKeyLoader(Loader):
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
if not isinstance(node, MappingNode):
raise ConstructorError(
None,
None,
"expected a mapping node, but found %s" % node.id,
node.start_mark,
)
mapping = {}
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
try:
hash(key)
except TypeError:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found unacceptable key ",
key_node.start_mark,
)
# check for duplicate keys
if key in mapping:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found duplicate key",
key_node.start_mark,
)
value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call]
mapping[key] = value
return mapping


class GLSLGenerator(object):
standard_header = """
#version 450 core
#define PRECISION $precision
#define FORMAT $format
"""

def __init__(self): # type: ignore[no-untyped-def]
self.ops_template_params = {}

def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def]
all_template_params = {}
with open(parameters_yaml_file, "r") as f:
contents = yaml.load(f, Loader=UniqueKeyLoader)
for key in contents:
all_template_params[key] = contents[key]
self.validate_and_construct_op_params(all_template_params) # type: ignore[no-untyped-call]

def validate_and_construct_op_params(self, all_template_params): # type: ignore[no-untyped-def]
for op in all_template_params:
if op in self.ops_template_params:
raise KeyError(f"{op} params file has already been parsed")
op_params_default_vals = all_template_params[op][
"parameter_names_with_default_values"
]
template_params_set = set(op_params_default_vals.keys())
self.ops_template_params[op] = []
self.ops_template_params[op].append(op_params_default_vals)
op_template_params_values = all_template_params[op]["parameter_values"]
for param_vals in op_template_params_values:
param_vals_set = set(param_vals.keys())
missing_keys = template_params_set - param_vals_set
invalid_keys = param_vals_set - template_params_set
if (len(invalid_keys)) > 0:
raise KeyError(f"Invalid keys {invalid_keys} are found")
param_vals_copy = copy.deepcopy(param_vals)
for key in missing_keys:
param_vals_copy[key] = op_params_default_vals[key]
self.ops_template_params[op].append(param_vals_copy)

def generate(self, glsl_template_in, out_dir): # type: ignore[no-untyped-def]
glsl_template_name = os.path.basename(glsl_template_in)
op_name, extension_name = glsl_template_name.split(".")
if extension_name != "glslt":
raise TypeError(f"invalid file type for glsl template {extension_name}")
if op_name not in self.ops_template_params:
raise KeyError(f"{op_name} params have not been populated")
code_template = CodeTemplate.from_file(glsl_template_in)
for template_params in self.ops_template_params[op_name]:
content = GLSLGenerator.standard_header
param_vals_string = "x".join([str(i) for i in template_params.values()])
output_file_name = op_name + "_" + param_vals_string + ".glsl"
content += code_template.substitute(template_params)
output_file = os.path.join(out_dir, output_file_name)
with open(output_file, "w") as f:
f.write(content)


# Remove this
if __name__ == "__main__":
pass
30 changes: 30 additions & 0 deletions tools/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from dataclasses import dataclass
from typing import List

from tools.gen_vulkan_glsl import GLSLGenerator

H_NAME = "spv.h"
CPP_NAME = "spv.cpp"
DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
Expand Down Expand Up @@ -78,12 +80,40 @@ def getShaderInfo(srcFilePath):

return shader_info

def genGLSLFromGLSLT(src_dir_path, tmp_dir_path):
template_dir_path = os.path.join(src_dir_path, "templates")
vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
parameter_yaml_files = []
for f in vexs:
if len(f) > 1:
parameter_yaml_files.append(f)
generator = GLSLGenerator()
for params_yaml in parameter_yaml_files:
generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call]

vexs = glob.glob(os.path.join(src_dir_path, '**', '*.glslt'), recursive=True)
templateSrcPaths = []
for f in vexs:
if len(f) > 1:
templateSrcPaths.append(f)
templateSrcPaths.sort()
for glslt in templateSrcPaths:
generator.generate(glslt, tmp_dir_path) # type: ignore[no-untyped-call]

def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
print("hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format(
hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath))

vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
templateSrcPaths = []
for f in vexs:
if len(f) > 1:
templateSrcPaths.append(f)
templateSrcPaths.sort()

# Now add glsl files that are generated from templates
genGLSLFromGLSLT(srcDirPath, tmpDirPath)
vexs = glob.glob(os.path.join(tmpDirPath, '**', '*.glsl'), recursive=True)
for f in vexs:
if len(f) > 1:
templateSrcPaths.append(f)
Expand Down
Loading

0 comments on commit 893f8e3

Please sign in to comment.