Skip to content

Commit 6b8c578

Browse files
gussmith23jroeschhypercubestart
authored andcommitted
Bring Your Own Datatypes (apache#5812)
* Add ChangeDatatype pass and unittest * [WIP] Jared's work on Fri This was work that Jared did on my computer, trying to get Inception v3 running. * Fix simplify inference to work over different data types. * Formatting * Copy setup code from other test file * Logging in Relay * Remove duplicate TVM_DLL * Add Sub, Mul, Div, Max to bfloat lib * Fix previous broken rebased commit * Remove line * Add LowerCustomDatatypes to build passes * Upcast ints to custom datatypes too, as well as to floats * Add and use convert_ndarray * Lower Call * Relay: create constant scalars of custom dtypes We use the same method we use in TVM: store the value in a double. * Custom datatype formatting in Relay * Update unittests * Add simpler example that's not working yet * Add Python unittests to Makefile * Fix bug * Fix function name in GetPackedFunc call * convert_ndarray makes its own executor * Add simple test case * Move setup() calls * Use convert_ndarray * Change import to make it more specific * Fix another Registry::Get call * Allow users to register minimum functions for custom datatypes This commit allows users to register global functions named `tvm.datatype.min.<type name>` which take the number of bits in the custom type and return the corresponding minimum value (as a double). A similar commit will need to be created for max, whenever that ends up being needed! * Remove check for float * Add test * Fix inception test * Add MobileNet * Lower custom datatypes before intrinsics * Add exp and sqrt bfloat functions * [buggy commit] Lower intrinsics like sqrt, exp This commit has bugs in it, I'm fairly certain. * Formatting * Fix bug * Add lowering for new ops in test * Add int to bfloat * Remove print * Add all tests * Correct image size * Add TODO * Add "notbfloat" type This type is for testing purposes. It just stores a float in a uint32. It was used to confirm the fact that my bfloat "implementation" is very numerically unstable and was causing issues when running the model. * Convert arguments Not sure how necessary this actually is. * Rewrite custom datatype constants in Relay * Add test_ops * Print constants in Relay * Use topi.testing * Test conv2d * Add test_model * Comment out model tests * Register notbfloat This could be unregistered at some point later * Add commented code Remove later * Add posit tests * test_ops_same_function * [temporary] move incomplete commit to macbook * Add more to tests * Formatting * Uncomment add * Remove bad tests * Change comments * Change function name and docstring * Change main function * Restructure tests * Fix visibility of posit functions * YAPF * Switching keywords around to resolve build errors on some systems * Improve test by running smaller mobilenet * Add test_cast * Change datatype name; add simple test * Rename to posit32 * Merge 3 posit types into one file * Add a nop type * Remove bfloat * Refactor test comments * Refactor conv2d test * Add optional tolerance arguments * Add posit8 and posit16 * Add comment about posit8 * Whoops -- actually add noptype to CMakeLists * Add rtol, atol to run_workload * Add noptype to tests * Run noptype over other models, too * Pass correct arguments to calls * Fix line length errors * Raise tolerances (again) to avoid flaky test * fix style * add test for tanh, log, sigmoid * Remove references to bfloat, notbfloat * Change comments * Remove old test file * fix min func * refactoring unit test file * use posits es2 * cleanup * comment * coment if_then_else * support different bit widths * use random seed to create stable tests * update documentation * removed nop-type and code consistency * add batchnorm test * rebase and update * fix tests and format * pylint * change order of include * include order * fix style * remove posit c linkage * update universal * fix style * fix test * fix overflow error with minfunc and posits * style * use change_dtype to convert params * update universal * fix fatal error * fix constant repr * minor update to posites2 * update universal * fix rst * fix invalid import and sqrt * update universal * comments * comments and expand testing * increase atol/rtol for custom[posites2]32 * Re-add newline * Remove comment * Remove opt level and comment * Change docstring * Add TODO * Add file header and newline * Update docstring * Update file docstring * Update docstrings * Delete todos * create_min_lower_func * add better debugging message * docs * add BYODT tutorial * add todo * Reformat some of tutorial to RST, plus code fixes * tutorial notebook runs now * fix hyperlink * rebase * add to tutorial * fix mobilenet model * add skip tag * black lint * add compiler flag and add dummy float * myfloat and posites2 test * remove universal * lint * lint * add setup * build with USE_POSIT for CI/CD * fix posit cmake * add cd / * undo docker changes * change tutorial to use myfloat * move files * lint * fix * remove filter * fix lint * fix suggestions Co-authored-by: Jared Roesch <roeschinc@gmail.com> Co-authored-by: Andrew Liu <andrewlliu@gmail.com>
1 parent 9e93e3c commit 6b8c578

File tree

26 files changed

+1889
-331
lines changed

26 files changed

+1889
-331
lines changed

3rdparty/bfloat16/bfloat16.cc

Lines changed: 0 additions & 84 deletions
This file was deleted.

CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt")
5757
tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson")
5858

5959
# Contrib library options
60+
tvm_option(USE_BYOC_POSIT "Build with BYOC software emulated posit custom datatype" OFF)
6061
tvm_option(USE_BLAS "The blas library to be linked" none)
6162
tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
6263
tvm_option(USE_MKLDNN "Build with MKLDNN" OFF)
@@ -272,6 +273,7 @@ endif(USE_VM_PROFILER)
272273

273274
file(GLOB DATATYPE_SRCS src/target/datatype/*.cc)
274275
list(APPEND COMPILER_SRCS ${DATATYPE_SRCS})
276+
list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")
275277

276278
file(GLOB RUNTIME_SRCS
277279
src/runtime/*.cc
@@ -294,8 +296,6 @@ if (INDEX_DEFAULT_I64)
294296
add_definitions(-DTVM_INDEX_DEFAULT_I64=1)
295297
endif()
296298

297-
list(APPEND RUNTIME_SRCS 3rdparty/bfloat16/bfloat16.cc)
298-
299299
if(USE_RPC)
300300
message(STATUS "Build with RPC support...")
301301
file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc)
@@ -349,6 +349,7 @@ include(cmake/modules/contrib/BLAS.cmake)
349349
include(cmake/modules/contrib/CODEGENC.cmake)
350350
include(cmake/modules/contrib/DNNL.cmake)
351351
include(cmake/modules/contrib/Random.cmake)
352+
include(cmake/modules/contrib/Posit.cmake)
352353
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
353354
include(cmake/modules/contrib/Sort.cmake)
354355
include(cmake/modules/contrib/NNPack.cmake)

LICENSE

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ for text of these licenses.
209209
Apache Software Foundation License 2.0
210210
--------------------------------------
211211

212-
3rdparty/bfloat16/bfloat16.cc
213212
3rdparty/dlpack
214213
3rdparty/dmlc-core
215214

cmake/config.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ set(USE_LLVM OFF)
120120
#---------------------------------------------
121121
# Contrib libraries
122122
#---------------------------------------------
123+
# Whether to build with BYOC software emulated posit custom datatype
124+
set(USE_BYOC_POSIT OFF)
125+
123126
# Whether use BLAS, choices: openblas, atlas, apple
124127
set(USE_BLAS none)
125128

cmake/modules/LibInfo.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ function(add_lib_info src_file)
5353
TVM_INFO_HIDE_PRIVATE_SYMBOLS="${HIDE_PRIVATE_SYMBOLS}"
5454
TVM_INFO_USE_TF_TVMDSOOP="${USE_TF_TVMDSOOP}"
5555
TVM_INFO_USE_FALLBACK_STL_MAP="${USE_FALLBACK_STL_MAP}"
56+
TVM_INFO_USE_BYOC_POSIT="${USE_BYOC_POSIT}"
5657
TVM_INFO_USE_BLAS="${USE_BLAS}"
5758
TVM_INFO_USE_MKL="${USE_MKL}"
5859
TVM_INFO_USE_MKLDNN="${USE_MKLDNN}"

cmake/modules/contrib/Posit.cmake

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
if(USE_BYOC_POSIT)
19+
message(STATUS "Build with contrib.posit")
20+
if (NOT UNIVERSAL_PATH)
21+
message(FATAL_ERROR "Fail to get Universal path")
22+
endif(NOT UNIVERSAL_PATH)
23+
24+
include_directories(${UNIVERSAL_PATH}/include)
25+
list(APPEND COMPILER_SRCS "src/target/datatype/posit/posit-wrapper.cc")
26+
endif(USE_BYOC_POSIT)

licenses/LICENSE.bfloat16.txt

Lines changed: 0 additions & 9 deletions
This file was deleted.

python/tvm/driver/build_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def _build_for_device(input_mod, target, target_host):
264264
tvm.tir.transform.LowerWarpMemory(),
265265
tvm.tir.transform.Simplify(),
266266
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
267+
tvm.tir.transform.LowerCustomDatatypes(),
267268
tvm.tir.transform.LowerIntrin(),
268269
]
269270
)
@@ -279,6 +280,7 @@ def _build_for_device(input_mod, target, target_host):
279280
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
280281
tvm.tir.transform.LowerTVMBuiltin(),
281282
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
283+
tvm.tir.transform.LowerCustomDatatypes(),
282284
tvm.tir.transform.LowerIntrin(),
283285
tvm.tir.transform.CombineContextCall(),
284286
]

python/tvm/relay/backend/_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def _tensor_value_repr(tvalue):
9090

9191
@tvm._ffi.register_func("relay._constant_repr")
9292
def _tensor_constant_repr(tvalue):
93+
dtype = tvm.runtime.DataType(tvalue.data.dtype)
94+
if tvm.target.datatype.get_type_registered(dtype.type_code):
95+
return "custom tensor of type " + dtype.type_code
9396
return str(tvalue.data.asnumpy())
9497

9598

python/tvm/relay/frontend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@
3636
from .caffe import from_caffe
3737
from .sklearn import from_sklearn
3838
from .sklearn import from_auto_ml
39+
from .change_datatype import ChangeDatatype

0 commit comments

Comments
 (0)