Skip to content

Commit

Permalink
[Codemod][python/main_function] caffe2: (pytorch#113357)
Browse files Browse the repository at this point in the history
Differential Revision: D51149464

Pull Request resolved: pytorch#113357
Approved by: https://github.com/huydhn
  • Loading branch information
zsol authored and pytorchmergebot committed Nov 15, 2023
1 parent 87aeb24 commit 9b736c7
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 66 deletions.
109 changes: 62 additions & 47 deletions tools/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
import os
import re
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import subprocess
import textwrap
import yaml
from collections import OrderedDict
from torchgen.code_template import CodeTemplate
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional
from typing import Any, Dict, List, Optional, Tuple

import yaml
from torchgen.code_template import CodeTemplate
from yaml.constructor import ConstructorError
from yaml.nodes import MappingNode

Expand Down Expand Up @@ -128,51 +130,63 @@ class ShaderInfo:
bias_storage_type: str = ""
register_for: Optional[Tuple[str, List[str]]] = None


def getName(filePath: str) -> str:
return os.path.basename(filePath).replace("/", "_").replace(".", "_")


def isDescriptorLine(lineStr: str) -> bool:
descriptorLineId = r"^layout\(set"
return re.search(descriptorLineId, lineStr) is not None


def isTileSizeLine(lineStr: str) -> bool:
tile_size_id = r"^ \* TILE_SIZE = \("
return re.search(tile_size_id, lineStr) is not None


def findTileSizes(lineStr: str) -> List[int]:
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
matches = re.search(tile_size_id, lineStr)
if matches is None:
raise AssertionError("matches is None in findTileSizes")
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]


def isWeightStorageTypeLine(lineStr: str) -> bool:
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
return re.search(weight_storage_id, lineStr) is not None


def getWeightStorageType(lineStr: str) -> str:
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
matches = re.search(weight_storage_id, lineStr)
if matches is None:
raise AssertionError("matches is None in getWeightStorageType")
return matches.group(1)


def isBiasStorageTypeLine(lineStr: str) -> bool:
weight_storage_id = r"^ \* BIAS_STORAGE = "
return re.search(weight_storage_id, lineStr) is not None


def getBiasStorageType(lineStr: str) -> str:
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
matches = re.search(weight_storage_id, lineStr)
if matches is None:
raise AssertionError("matches is None in getBiasStorageType")
return matches.group(1)


def isRegisterForLine(lineStr: str) -> bool:
# Check for Shader Name and a list of at least one Registry Key
register_for_id = r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
register_for_id = (
r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
)
return re.search(register_for_id, lineStr) is not None


def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
register_for_pattern = r"'([A-Za-z0-9_]+)'"
matches = re.findall(register_for_pattern, lineStr)
Expand All @@ -181,6 +195,7 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
matches_list = list(matches)
return (matches_list[0], matches_list[1:])


typeIdMapping = {
r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
Expand All @@ -189,12 +204,13 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
}

storageTypeToEnum = {
"TEXTURE_2D" : "api::StorageType::TEXTURE_2D",
"TEXTURE_3D" : "api::StorageType::TEXTURE_3D",
"BUFFER" : "api::StorageType::BUFFER",
"TEXTURE_2D": "api::StorageType::TEXTURE_2D",
"TEXTURE_3D": "api::StorageType::TEXTURE_3D",
"BUFFER": "api::StorageType::BUFFER",
"": "api::StorageType::UNKNOWN",
}


def determineDescriptorType(lineStr: str) -> str:
for identifier, typeNum in typeIdMapping.items():
if re.search(identifier, lineStr):
Expand All @@ -203,6 +219,7 @@ def determineDescriptorType(lineStr: str) -> str:
"No matching descriptor type for " + lineStr + " in determineDescriptorType"
)


def getShaderInfo(srcFilePath: str) -> ShaderInfo:
shader_info = ShaderInfo([], [], "")
with open(srcFilePath) as srcFile:
Expand All @@ -220,9 +237,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:

return shader_info


def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
template_dir_path = os.path.join(src_dir_path, "templates")
vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
vexs = glob.glob(os.path.join(template_dir_path, "**", "*.yaml"), recursive=True)
parameter_yaml_files = []
for f in vexs:
if len(f) > 1:
Expand All @@ -231,7 +249,7 @@ def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
for params_yaml in parameter_yaml_files:
generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call]

vexs = glob.glob(os.path.join(src_dir_path, '**', '*.glslt'), recursive=True)
vexs = glob.glob(os.path.join(src_dir_path, "**", "*.glslt"), recursive=True)
templateSrcPaths = []
for f in vexs:
if len(f) > 1:
Expand All @@ -258,7 +276,7 @@ def genCppH(
templateSrcPaths = []

for srcDirPath in srcDirPaths:
vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
vexs = glob.glob(os.path.join(srcDirPath, "**", "*.glsl"), recursive=True)
for f in vexs:
if len(f) > 1:
templateSrcPaths.append(f)
Expand All @@ -267,7 +285,7 @@ def genCppH(
# Now add glsl files that are generated from templates
genGLSLFromGLSLT(srcDirPath, tmpDirPath)

vexs = glob.glob(os.path.join(tmpDirPath, '**', '*.glsl'), recursive=True)
vexs = glob.glob(os.path.join(tmpDirPath, "**", "*.glsl"), recursive=True)
for f in vexs:
if len(f) > 1:
templateSrcPaths.append(f)
Expand All @@ -283,17 +301,20 @@ def genCppH(
codeTemplate = CodeTemplate.from_file(templateSrcPath)
srcPath = tmpDirPath + "/" + name + ".glsl"
content = codeTemplate.substitute(env)
with open(srcPath, 'w') as fw:
with open(srcPath, "w") as fw:
fw.write(content)

spvPath = tmpDirPath + "/" + name + ".spv"
print(f"spvPath {spvPath}")

cmd = [
glslcPath, "-fshader-stage=compute",
srcPath, "-o", spvPath,
glslcPath,
"-fshader-stage=compute",
srcPath,
"-o",
spvPath,
"--target-env=vulkan1.0",
"-Werror"
"-Werror",
] + [arg for srcDirPath in srcDirPaths for arg in ["-I", srcDirPath]]

print("\nglslc cmd:", cmd)
Expand Down Expand Up @@ -323,7 +344,9 @@ def genCppH(
h += "extern const ShaderListing shader_infos;\n"
h += "extern ShaderRegistry shader_registry;\n"
h += "inline const ShaderListing& get_shader_infos() {\n return shader_infos;\n}\n"
h += "inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
h += (
"inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
)

h += nsend

Expand All @@ -341,8 +364,8 @@ def genCppH(
name = getName(spvPath).replace("_spv", "")

print(f"spvPath:{spvPath}")
with open(spvPath, 'rb') as fr:
next_bin = array.array('I', fr.read())
with open(spvPath, "rb") as fr:
next_bin = array.array("I", fr.read())
sizeBytes = 4 * len(next_bin)
shader_info_bin_code.append(
"const uint32_t {}_bin[] = {{\n{}\n}};".format(
Expand All @@ -362,7 +385,7 @@ def genCppH(
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))

shader_info_args = [
f"\"vulkan.{name}\"",
f'"vulkan.{name}"',
f"{name}_bin",
str(sizeBytes),
shader_info_layouts,
Expand All @@ -373,7 +396,7 @@ def genCppH(

shader_info_cpp_code.append(
textwrap.indent(
"{{\"{}\",\n api::ShaderInfo(\n{})}}".format(
'{{"{}",\n api::ShaderInfo(\n{})}}'.format(
name,
textwrap.indent(",\n".join(shader_info_args), " "),
),
Expand All @@ -386,7 +409,7 @@ def genCppH(
for registry_key in registry_keys:
shader_info_registry_code.append(
textwrap.indent(
f"{{\"{op_name}\", {{{{\"{registry_key}\", \"{name}\"}}}}}}",
f'{{"{op_name}", {{{{"{registry_key}", "{name}"}}}}}}',
" ",
),
)
Expand Down Expand Up @@ -421,34 +444,20 @@ def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:


def main(argv: List[str]) -> int:
parser = argparse.ArgumentParser(description='')
parser = argparse.ArgumentParser(description="")
parser.add_argument(
'-i',
'--glsl-paths',
nargs='+',
"-i",
"--glsl-paths",
nargs="+",
help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"',
default=['.'],
default=["."],
)
parser.add_argument("-c", "--glslc-path", required=True, help="")
parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
parser.add_argument("-o", "--output-path", required=True, help="")
parser.add_argument(
'-c',
'--glslc-path',
required=True,
help='')
parser.add_argument(
'-t',
'--tmp-dir-path',
required=True,
help='/tmp')
parser.add_argument(
'-o',
'--output-path',
required=True,
help='')
parser.add_argument(
"--env",
metavar="KEY=VALUE",
nargs='*',
help="Set a number of key-value pairs")
"--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs"
)
options = parser.parse_args()
env = DEFAULT_ENV
for key, value in parse_arg_env(options.env).items():
Expand All @@ -466,9 +475,15 @@ def main(argv: List[str]) -> int:
srcDirPaths=options.glsl_paths,
glslcPath=options.glslc_path,
tmpDirPath=options.tmp_dir_path,
env=env)
env=env,
)

return 0

if __name__ == '__main__':

def invoke_main() -> None:
sys.exit(main(sys.argv))


if __name__ == "__main__":
invoke_main() # pragma: no cover
6 changes: 5 additions & 1 deletion tools/substitute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os.path


if __name__ == "__main__":
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--input-file")
parser.add_argument("--output-file")
Expand All @@ -22,3 +22,7 @@

with open(output_file, "w") as f:
f.write(contents)


if __name__ == "__main__":
main() # pragma: no cover
30 changes: 20 additions & 10 deletions torch/utils/_freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
import itertools
import marshal
import os
import types
from dataclasses import dataclass
from pathlib import Path
from typing import List
import types


PATH_MARKER = "<Generated by torch::deploy>"
Expand Down Expand Up @@ -121,10 +121,10 @@ def write_bytecode(self, install_root):
Shared frozen modules evenly across the files.
"""
bytecode_file_names = [
f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)
bytecode_file_names = [f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)]
bytecode_files = [
open(os.path.join(install_root, name), "w") for name in bytecode_file_names
]
bytecode_files = [open(os.path.join(install_root, name), "w") for name in bytecode_file_names]
it = itertools.cycle(bytecode_files)
for m in self.frozen_modules:
self.write_frozen(m, next(it))
Expand Down Expand Up @@ -202,7 +202,6 @@ def get_module_qualname(self, file_path: Path, top_package_path: Path) -> List[s
module_parent = normalized_path.parent.parts
return list(module_parent) + [module_basename]


def compile_string(self, file_content: str) -> types.CodeType:
# instead of passing in the real build time path to 'compile', we
# pass in a marker instead. This prevents the build time path being
Expand Down Expand Up @@ -239,19 +238,26 @@ def compile_file(self, path: Path, top_package_path: Path):

bytecode = marshal.dumps(co)
size = len(bytecode)
if path.name == '__init__.py':
if path.name == "__init__.py":
# Python packages are signified by negative size.
size = -size
self.frozen_modules.append(
FrozenModule(".".join(module_qualname), c_name, size, bytecode)
)

if __name__ == "__main__":

def main() -> None:
parser = argparse.ArgumentParser(description="Compile py source")
parser.add_argument("paths", nargs="*", help="Paths to freeze.")
parser.add_argument("--verbose", action="store_true", help="Print debug logs")
parser.add_argument("--install-dir", "--install_dir", help="Root directory for all output files")
parser.add_argument("--oss", action="store_true", help="If it's OSS build, add a fake _PyImport_FrozenModules")
parser.add_argument(
"--install-dir", "--install_dir", help="Root directory for all output files"
)
parser.add_argument(
"--oss",
action="store_true",
help="If it's OSS build, add a fake _PyImport_FrozenModules",
)
parser.add_argument(
"--symbol-name",
"--symbol_name",
Expand All @@ -265,7 +271,7 @@ def compile_file(self, path: Path, top_package_path: Path):

for p in args.paths:
path = Path(p)
if path.is_dir() and not Path.exists(path / '__init__.py'):
if path.is_dir() and not Path.exists(path / "__init__.py"):
# this 'top level path p' is a standard directory containing modules,
# not a module itself
# each 'mod' could be a dir containing __init__.py or .py file
Expand All @@ -277,3 +283,7 @@ def compile_file(self, path: Path, top_package_path: Path):

f.write_bytecode(args.install_dir)
f.write_main(args.install_dir, args.oss, args.symbol_name)


if __name__ == "__main__":
main() # pragma: no cover
Loading

0 comments on commit 9b736c7

Please sign in to comment.