Skip to content

Commit

Permalink
Bring Your Own Datatypes (apache#5812)
Browse files Browse the repository at this point in the history
* Add ChangeDatatype pass and unittest

* [WIP] Jared's work on Fri

This was work that Jared did on my computer, trying to get Inception v3 running.

* Fix simplify inference to work over different data types.

* Formatting

* Copy setup code from other test file

* Logging in Relay

* Remove duplicate TVM_DLL

* Add Sub, Mul, Div, Max to bfloat lib

* Fix previous broken rebased commit

* Remove line

* Add LowerCustomDatatypes to build passes

* Upcast ints to custom datatypes too, as well as to floats

* Add and use convert_ndarray

* Lower Call

* Relay: create constant scalars of custom dtypes

We use the same method we use in TVM: store the value in a double.

* Custom datatype formatting in Relay

* Update unittests

* Add simpler example that's not working yet

* Add Python unittests to Makefile

* Fix bug

* Fix function name in GetPackedFunc call

* convert_ndarray makes its own executor

* Add simple test case

* Move setup() calls

* Use convert_ndarray

* Change import to make it more specific

* Fix another Registry::Get call

* Allow users to register minimum functions for custom datatypes

This commit allows users to register global functions named
`tvm.datatype.min.<type name>` which take the number of bits in the custom type
and return the corresponding minimum value (as a double).

A similar commit will need to be created for max, whenever that ends up being
needed!

* Remove check for float

* Add test

* Fix inception test

* Add MobileNet

* Lower custom datatypes before intrinsics

* Add exp and sqrt bfloat functions

* [buggy commit] Lower intrinsics like sqrt, exp

This commit has bugs in it, I'm fairly certain.

* Formatting

* Fix bug

* Add lowering for new ops in test

* Add int to bfloat

* Remove print

* Add all tests

* Correct image size

* Add TODO

* Add "notbfloat" type

This type is for testing purposes. It just stores a float in a uint32. It was
used to confirm the fact that my bfloat "implementation" is very numerically
unstable and was causing issues when running the model.

* Convert arguments

Not sure how necessary this actually is.

* Rewrite custom datatype constants in Relay

* Add test_ops

* Print constants in Relay

* Use topi.testing

* Test conv2d

* Add test_model

* Comment out model tests

* Register notbfloat

This could be unregistered at some point later

* Add commented code

Remove later

* Add posit tests

* test_ops_same_function

* [temporary] move incomplete commit to macbook

* Add more to tests

* Formatting

* Uncomment add

* Remove bad tests

* Change comments

* Change function name and docstring

* Change main function

* Restructure tests

* Fix visibility of posit functions

* YAPF

* Switching keywords around to resolve build errors on some systems

* Improve test by running smaller mobilenet

* Add test_cast

* Change datatype name; add simple test

* Rename to posit32

* Merge 3 posit types into one file

* Add a nop type

* Remove bfloat

* Refactor test comments

* Refactor conv2d test

* Add optional tolerance arguments

* Add posit8 and posit16

* Add comment about posit8

* Whoops -- actually add noptype to CMakeLists

* Add rtol, atol to run_workload

* Add noptype to tests

* Run noptype over other models, too

* Pass correct arguments to calls

* Fix line length errors

* Raise tolerances (again) to avoid flaky test

* fix style

* add test for tanh, log, sigmoid

* Remove references to bfloat, notbfloat

* Change comments

* Remove old test file

* fix min func

* refactoring unit test file

* use posits es2

* cleanup

* comment

* coment if_then_else

* support different bit widths

* use random seed to create stable tests

* update documentation

* removed nop-type and code consistency

* add batchnorm test

* rebase and update

* fix tests and format

* pylint

* change order of include

* include order

* fix style

* remove posit c linkage

* update universal

* fix style

* fix test

* fix overflow error with minfunc and posits

* style

* use change_dtype to convert params

* update universal

* fix fatal error

* fix constant repr

* minor update to posites2

* update universal

* fix rst

* fix invalid import and sqrt

* update universal

* comments

* comments and expand testing

* increase atol/rtol for custom[posites2]32

* Re-add newline

* Remove comment

* Remove opt level and comment

* Change docstring

* Add TODO

* Add file header and newline

* Update docstring

* Update file docstring

* Update docstrings

* Delete todos

* create_min_lower_func

* add better debugging message

* docs

* add BYODT tutorial

* add todo

* Reformat some of tutorial to RST, plus code fixes

* tutorial notebook runs now

* fix hyperlink

* rebase

* add to tutorial

* fix mobilenet model

* add skip tag

* black lint

* add compiler flag and add dummy float

* myfloat and posites2 test

* remove universal

* lint

* lint

* add setup

* build with USE_POSIT for CI/CD

* fix posit cmake

* add cd /

* undo docker changes

* change tutorial to use myfloat

* move files

* lint

* fix

* remove filter

* fix lint

* fix suggestions

Co-authored-by: Jared Roesch <roeschinc@gmail.com>
Co-authored-by: Andrew Liu <andrewlliu@gmail.com>
  • Loading branch information
3 people authored Sep 26, 2020
1 parent c662638 commit 5aafff9
Show file tree
Hide file tree
Showing 26 changed files with 1,889 additions and 331 deletions.
84 changes: 0 additions & 84 deletions 3rdparty/bfloat16/bfloat16.cc

This file was deleted.

5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt")
tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson")

# Contrib library options
tvm_option(USE_BYOC_POSIT "Build with BYOC software emulated posit custom datatype" OFF)
tvm_option(USE_BLAS "The blas library to be linked" none)
tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
tvm_option(USE_MKLDNN "Build with MKLDNN" OFF)
Expand Down Expand Up @@ -257,6 +258,7 @@ endif(USE_VM_PROFILER)

file(GLOB DATATYPE_SRCS src/target/datatype/*.cc)
list(APPEND COMPILER_SRCS ${DATATYPE_SRCS})
list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")

file(GLOB RUNTIME_SRCS
src/runtime/*.cc
Expand All @@ -279,8 +281,6 @@ if (INDEX_DEFAULT_I64)
add_definitions(-DTVM_INDEX_DEFAULT_I64=1)
endif()

list(APPEND RUNTIME_SRCS 3rdparty/bfloat16/bfloat16.cc)

if(USE_RPC)
message(STATUS "Build with RPC support...")
file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc)
Expand Down Expand Up @@ -334,6 +334,7 @@ include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/CODEGENC.cmake)
include(cmake/modules/contrib/DNNL.cmake)
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/Posit.cmake)
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/NNPack.cmake)
Expand Down
1 change: 0 additions & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ for text of these licenses.
Apache Software Foundation License 2.0
--------------------------------------

3rdparty/bfloat16/bfloat16.cc
3rdparty/dlpack
3rdparty/dmlc-core

Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ set(USE_LLVM OFF)
#---------------------------------------------
# Contrib libraries
#---------------------------------------------
# Whether to build with BYOC software emulated posit custom datatype
set(USE_BYOC_POSIT OFF)

# Whether use BLAS, choices: openblas, atlas, apple
set(USE_BLAS none)

Expand Down
1 change: 1 addition & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ function(add_lib_info src_file)
TVM_INFO_HIDE_PRIVATE_SYMBOLS="${HIDE_PRIVATE_SYMBOLS}"
TVM_INFO_USE_TF_TVMDSOOP="${USE_TF_TVMDSOOP}"
TVM_INFO_USE_FALLBACK_STL_MAP="${USE_FALLBACK_STL_MAP}"
TVM_INFO_USE_BYOC_POSIT="${USE_BYOC_POSIT}"
TVM_INFO_USE_BLAS="${USE_BLAS}"
TVM_INFO_USE_MKL="${USE_MKL}"
TVM_INFO_USE_MKLDNN="${USE_MKLDNN}"
Expand Down
26 changes: 26 additions & 0 deletions cmake/modules/contrib/Posit.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(USE_BYOC_POSIT)
message(STATUS "Build with contrib.posit")
if (NOT UNIVERSAL_PATH)
message(FATAL_ERROR "Fail to get Universal path")
endif(NOT UNIVERSAL_PATH)

include_directories(${UNIVERSAL_PATH}/include)
list(APPEND COMPILER_SRCS "src/target/datatype/posit/posit-wrapper.cc")
endif(USE_BYOC_POSIT)
9 changes: 0 additions & 9 deletions licenses/LICENSE.bfloat16.txt

This file was deleted.

2 changes: 2 additions & 0 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def _build_for_device(input_mod, target, target_host):
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.Simplify(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
]
)
Expand All @@ -279,6 +280,7 @@ def _build_for_device(input_mod, target, target_host):
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall(),
]
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def _tensor_value_repr(tvalue):

@tvm._ffi.register_func("relay._constant_repr")
def _tensor_constant_repr(tvalue):
dtype = tvm.runtime.DataType(tvalue.data.dtype)
if tvm.target.datatype.get_type_registered(dtype.type_code):
return "custom tensor of type " + dtype.type_code
return str(tvalue.data.asnumpy())


Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@
from .darknet import from_darknet
from .pytorch import from_pytorch
from .caffe import from_caffe
from .change_datatype import ChangeDatatype
107 changes: 107 additions & 0 deletions python/tvm/relay/frontend/change_datatype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""Change Datatype Pass"""
from ..function import Function
from ..expr_functor import ExprMutator
from ..transform.transform import function_pass
from ..expr import var, bind


@function_pass(opt_level=0)
class ChangeDatatype(ExprMutator):
"""Mutator for changing the datatype of Relay programs.
This pass should be useful for users of the Bring Your Own Datatypes
framework.
TODO(@gussmith23 @hypercubestart) Add link to documentation when it exists
Example:
.. code-block:: python
from tvm.relay.testing.inception_v3 import get_workload
mod, params = get_workload()
def change_dtype(mod, params, src, dst):
mod = ChangeDatatype(src, dst)(mod)
params = dict((p, tvm.nd.array(params[p].asnumpy().astype(dst))) for p in params)
return mod, params
mod, params = change_dtype(mod, params, "float32", "custom[posites2]32")
Parameters
----------
src : String
The source datatype name, e.g. "float" or "posites2" (but not "float32"
or "custom[posites2]32").
dst : String
The destination datatype name, in the same format.
Returns
-------
mod : tvm.IRModule
Module where all nodes of dtype `src` have been changed to have dtype
`dst`.
"""

def __init__(self, src, dst):
self.src = src
self.dst = dst
super().__init__()

def transform_function(self, func, mod, ctx):
return self.visit(func)

def visit_constant(self, const):
if const.data.dtype == self.src:
return const.astype(self.dst)
return const

def visit_function(self, fn):
new_params = []
binds = {}

for param in fn.params:
# Get the parameter's type annotation.
var_type = param.type_annotation

# See if we want to replace dtype.
if var_type.dtype == self.src:
dtype = self.dst
else:
dtype = var_type.dtype

# Generate new variable.
new_param = var(param.name_hint, shape=var_type.shape, dtype=dtype)

new_params.append(new_param)
binds[param] = new_param

new_body = self.visit(fn.body)
# Rewrite the body to use new parameters.
new_body = bind(new_body, binds)

# Construct the updated function and return.
return Function(
new_params,
new_body,
# You could change the return type, if you use None it will re-infer.
None,
type_params=fn.type_params,
attrs=fn.attrs,
)
Loading

0 comments on commit 5aafff9

Please sign in to comment.