Skip to content

Commit e647731

Browse files
create_min_lower_func
1 parent 68c5ec8 commit e647731

File tree

6 files changed

+49
-29
lines changed

6 files changed

+49
-29
lines changed

3rdparty/posit/posit-wrapper.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "universal/posit/math/hyperbolic.hpp"
3333
#include "universal/posit/math/logarithm.hpp"
3434
#include "universal/posit/math/sqrt.hpp"
35+
#include "universal/posit/numeric_limits.hpp"
3536

3637
TVM_DLL sw::unum::posit<8, 2> Uint8ToPosit8es2(uint8_t in) {
3738
sw::unum::bitblock<8> bb;
@@ -40,12 +41,15 @@ TVM_DLL sw::unum::posit<8, 2> Uint8ToPosit8es2(uint8_t in) {
4041
}
4142

4243
extern "C" {
43-
TVM_DLL uint8_t RawPosit8es2(uint8_t in) { return in; }
44-
4544
TVM_DLL uint8_t Posit8es2toUint8(sw::unum::posit<8, 2> in) {
4645
return static_cast<uint8_t>(in.get().to_ullong());
4746
}
4847

48+
TVM_DLL uint8_t MinPosit8es2() {
49+
auto min = std::numeric_limits<sw::unum::posit<8, 2>>::lowest();
50+
return Posit8es2toUint8(min);
51+
}
52+
4953
TVM_DLL float Posit8es2ToFloat(uint8_t in) { return Uint8ToPosit8es2(in).operator float(); }
5054

5155
TVM_DLL uint8_t FloatToPosit8es2(float in) {
@@ -104,12 +108,15 @@ TVM_DLL sw::unum::posit<16, 2> Uint16ToPosit16es2(uint16_t in) {
104108
}
105109

106110
extern "C" {
107-
TVM_DLL uint16_t RawPosit16es2(uint16_t in) { return in; }
108-
109111
TVM_DLL uint16_t Posit16es2toUint16(sw::unum::posit<16, 2> in) {
110112
return static_cast<uint16_t>(in.get().to_ullong());
111113
}
112114

115+
TVM_DLL uint8_t MinPosit16es2() {
116+
auto min = std::numeric_limits<sw::unum::posit<16, 2>>::lowest();
117+
return Posit16es2toUint16(min);
118+
}
119+
113120
TVM_DLL float Posit16es2ToFloat(uint16_t in) { return Uint16ToPosit16es2(in).operator float(); }
114121

115122
TVM_DLL uint16_t FloatToPosit16es2(float in) {
@@ -168,12 +175,15 @@ TVM_DLL sw::unum::posit<32, 2> Uint32ToPosit32es2(uint32_t in) {
168175
}
169176

170177
extern "C" {
171-
TVM_DLL uint32_t RawPosit32es2(uint32_t in) { return in; }
172-
173178
TVM_DLL uint32_t Posit32es2ToUint32(sw::unum::posit<32, 2> in) {
174179
return static_cast<uint32_t>(in.get().to_ullong());
175180
}
176181

182+
TVM_DLL uint8_t MinPosit32es2() {
183+
auto min = std::numeric_limits<sw::unum::posit<32, 2>>::lowest();
184+
return Posit32es2ToUint32(min);
185+
}
186+
177187
TVM_DLL float Posit32es2ToFloat(uint32_t in) { return Uint32ToPosit32es2(in).operator float(); }
178188

179189
TVM_DLL uint32_t FloatToPosit32es2(float in) {

python/tvm/relay/frontend/change_datatype.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
from ..transform.transform import function_pass
2222
from ..expr import var, bind
2323

24-
25-
@function_pass()
24+
@function_pass(opt_level=0)
2625
class ChangeDatatype(ExprMutator):
2726
"""Mutator for changing the datatype of Relay programs.
2827

python/tvm/target/datatype.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,27 @@ def register_min_func(func, type_name):
214214
"""
215215
_register_func("tvm.datatype.min." + type_name, func)
216216

217+
def create_min_lower_func(extern_func_map, type_name):
218+
"""Returns a function which lowers the minimum value operation to a pure extern call.
219+
220+
Parameters
221+
----------
222+
extern_func_map : map
223+
A map from bit lengths to the external function name
224+
225+
type_name : string
226+
The name of the custom datatype, e.g. posites2 (but not custom[posites2]32).
227+
"""
228+
def lower(num_bits):
229+
dtype = f'custom[{type_name}]{num_bits}'
230+
231+
if num_bits not in extern_func_map:
232+
raise RuntimeError('missing minimum function for {dtype}')
233+
234+
return call_pure_extern(dtype, extern_func_map[num_bits])
235+
236+
return lower
237+
217238
def create_lower_func(extern_func_map):
218239
"""Returns a function which lowers an operation to a function call.
219240

src/arith/rewrite_simplify.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
#include <algorithm>
3232

33+
#include "../target/datatype/registry.h"
3334
#include "const_fold.h"
3435
#include "pattern_match.h"
3536

@@ -460,6 +461,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {
460461

461462
// x / 2.0 = x * 0.5
462463
if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
464+
CHECK(op->dtype.is_float() ||
465+
datatype::Registry::Global()->GetTypeRegistered(op->dtype.code()));
463466
return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
464467
}
465468

src/tir/op/op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ PrimExpr min_value(const DataType& dtype) {
182182
if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) {
183183
auto f = datatype::GetMinFunc(dtype.code());
184184
CHECK(f) << "No minimum function registered for custom dtype " << (unsigned int)dtype.code();
185-
// TODO(@hypercubestart) Document this change (and others associated with the overflowing floatimm min bug)
185+
// TODO(@hypercubestart) Document this change (and others associated with the overflowing
186+
// floatimm min bug)
186187
return (*f)(dtype.bits());
187188
} else if (dtype.is_int()) {
188189
if (dtype.bits() == 64) {

tests/python/unittest/test_custom_datatypes.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from tvm.relay.testing.resnet import get_workload as get_resnet
2828
from tvm.relay.testing.layers import batch_norm_infer
2929
from tvm.relay.testing.mobilenet import get_workload as get_mobilenet
30-
from tvm.target.datatype import register, register_min_func, register_op, create_lower_func, lower_ite, lower_call_pure_extern
30+
from tvm.target.datatype import register, register_min_func, register_op, create_lower_func, lower_ite, lower_call_pure_extern, create_min_lower_func
3131
from tvm.tir.op import call_pure_extern
3232

3333
# we use a random seed to generate input_data
@@ -170,25 +170,11 @@ def setup():
170170
8: 'Posit8es2Tanh'
171171
}), "Call", "llvm", "posites2", intrinsic_name="tir.tanh")
172172

173-
def posit_min_func(num_bits):
174-
# the minimum representable posit is all 1's in binary,
175-
# here we encode the raw bit representation in an integer
176-
# and use the extern function to simply interpret
177-
# the integer as a posites2
178-
#
179-
# another possible way is to create a FloatImm storing the value
180-
# of the minimum as a float64 and then casting to `posites2`,
181-
# but float imprecision makes this approach susceptible to hard-to-find bugs
182-
value = np.dtype('int' + str(num_bits)).type(-1)
183-
dtype = 'custom[posites2]' + str(num_bits)
184-
func_map = {
185-
32: 'RawPosit32es2',
186-
16: 'RawPosit16es2',
187-
8: 'RawPosit8es2'
188-
}
189-
return call_pure_extern(dtype, func_map[num_bits], value)
190-
register_min_func(posit_min_func, "posites2")
191-
173+
register_min_func(create_min_lower_func({
174+
32: 'MinPosit32es2',
175+
16: 'MinPosit16es2',
176+
8: 'MinPosit8es2'
177+
}, "posites2"), "posites2")
192178

193179
def run_ops(src_dtype, dst_dtype, rtol=1e-7, atol=1e-7):
194180
"""Run the same op, but with two different datatypes"""

0 commit comments

Comments
 (0)