Skip to content

Commit 54a534c

Browse files
committed
fix cpplint and revert float64 change
1 parent df7688c commit 54a534c

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

python/tvm/topi/cuda/scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
104104
# The following algorithm performs parallel exclusive scan
105105
# Up Sweep of exclusive scan
106106
lim = tvm.tir.generic.cast(
107-
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float32"))), "int64"
107+
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64"
108108
)
109109
with ib.for_range(0, lim, dtype="int64") as l2_width:
110110
width = 2 << l2_width

python/tvm/topi/cuda/sort.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def compare(a, b):
239239

240240
# Sort the lower levels of the merge using odd-even sort, it's fast for small inputs
241241
lower_lim = tvm.tir.generic.cast(
242-
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float32"))), "int64"
242+
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64"
243243
)
244244

245245
_odd_even_sort(
@@ -255,7 +255,7 @@ def compare(a, b):
255255
)
256256

257257
upper_lim = tvm.tir.generic.cast(
258-
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float32"))), "int64"
258+
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64"
259259
)
260260

261261
def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count):

src/target/spirv/ir_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ class IRBuilder {
491491
*/
492492
Value GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index);
493493

494-
// TODO doc
494+
// TODO(masahi): doc
495495
Value DeclareUniformBuffer(const std::vector<SType>& value_types, uint32_t binding);
496496
Value GetUniform(Value ptr_ubo, const SType& v_type, uint32_t index);
497497
/*!

0 commit comments

Comments
 (0)