Skip to content

Commit 4cac65a

Browse files
authored
[Dev] Bump Version to dev0.8 and fix issue INT8xINT2 (apache#49)
* improve e4m3 decoding. * append fp16xint1 * Update submodule commit reference * chore: Update shared memory scope for float32 output dtype * BUGFIX: UINT8/INT8 Decoding * feat: Add rasterization options for roller module * Refactor tensorcore_legalization method to optimize tensor core usage * feat: Add function to collect variables from expression, improve for splitk * chore: Update typing import in __init__.py * chore: Refactor CPU execution of operators * Refactor matmul implementation for splitk layout * Refactor matmul implementation for splitk layout * Refactor matmul implementation for splitk layout * chore: Update version to 0.0.1.dev8 --------- Co-authored-by: LeiWang199 <leiwang199>
1 parent 99a744e commit 4cac65a

File tree

9 files changed

+22
-8
lines changed

9 files changed

+22
-8
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.0.1.dev7
1+
0.0.1.dev8

integration/BitNet/utils_quant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ def native_forward(self, input):
119119
return out
120120

121121
def forward_fp32_simulated(self, input):
122-
print("input: ", input)
123122
quant_input = self.activation_quant(input, self.input_bits).detach()
124123
quant_weight = self.weight_quant(self.weight).detach()
125124

@@ -139,6 +138,8 @@ def forward_fp32_simulated(self, input):
139138
return out
140139

141140
def forward(self, input):
141+
# return self.forward_fp32_simulated(input)
142+
142143
quant_input = self.activation_quant(input, self.input_bits).detach()
143144
fp32_out = self.bitblas_matmul(quant_input, self.weight)
144145
sw = self.sw

python/bitblas/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ def _init_logger():
8181

8282
_init_logger()
8383

84-
__version__ = "0.0.1.dev7"
84+
__version__ = "0.0.1.dev8"

python/bitblas/base/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import tempfile
2020
import itertools
2121
from tvm.ir.supply import GlobalVarSupply
22-
from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4
22+
from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
2323
import logging
2424

2525
logger = logging.getLogger(__name__)
@@ -205,6 +205,7 @@ def _build(context) -> str:
205205
def tvm_callback_cuda_postproc(code, _):
206206
code = tensor_replace_dp4a(code)
207207
code = tensor_remove_make_int4(code)
208+
code = tensor_remove_make_int2(code)
208209
return code
209210

210211
with tvm.transform.PassContext(config={"tir.use_async_copy": True, **config.pass_context}):

python/bitblas/ops/general_matmul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .impl.matmul_dequantize_impl import (
1111
select_implementation as weight_dequantize_implementation,)
1212
from .impl.matmul_impl import select_implementation as consistent_implementation
13-
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4
13+
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
1414
from bitblas.utils.target_detector import auto_detect_nvidia_target
1515
from dataclasses import dataclass
1616
from .ladder_permutate import LadderPermutate, LadderPermutateConfig
@@ -398,6 +398,7 @@ def _select_implementation(self):
398398
def post_process(self, code: str) -> str:
399399
code = tensor_replace_dp4a(code)
400400
code = tensor_remove_make_int4(code)
401+
code = tensor_remove_make_int2(code)
401402
return code
402403

403404
def retrieve_weight_shape(self):

python/bitblas/ops/matmul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import List, Union, Optional, Any, Tuple
88
from .operator import Operator, TransformKind
99
from .impl.matmul_impl import select_implementation
10-
from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4
10+
from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
1111
from dataclasses import dataclass
1212
from .ladder_permutate import LadderPermutate, LadderPermutateConfig
1313
import logging
@@ -189,6 +189,7 @@ def _select_implementation(self):
189189
def post_process(self, code: str) -> str:
190190
code = tensor_replace_dp4a(code)
191191
code = tensor_remove_make_int4(code)
192+
code = tensor_remove_make_int2(code)
192193
return code
193194

194195
def _profile_latency_with_dynamic_range(self) -> List:

python/bitblas/ops/matmul_dequantize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, List, Literal, Optional, Tuple, Union
77
from .operator import Operator, TransformKind
88
from .impl.matmul_dequantize_impl import select_implementation
9-
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4
9+
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
1010
from bitblas.utils.tensor_adapter import tvm_tensor_to_torch
1111
from dataclasses import dataclass
1212
from .ladder_permutate import LadderPermutate, LadderPermutateConfig
@@ -234,6 +234,7 @@ def _select_implementation(self):
234234
def post_process(self, code: str) -> str:
235235
code = tensor_replace_dp4a(code)
236236
code = tensor_remove_make_int4(code)
237+
code = tensor_remove_make_int2(code)
237238
return code
238239

239240
def retrieve_weight_shape(self):

python/bitblas/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4 # noqa: F401
3+
from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 # noqa: F401
44
from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401
55
from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401

python/bitblas/utils/post_process.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,12 @@ def tensor_remove_make_int4(source: str) -> str:
2727
"make_int4(0, 0, 0, 0)",
2828
)
2929
return source
30+
31+
def tensor_remove_make_int2(source: str) -> str:
32+
# remove make_int4 with 16 signed char arguments
33+
# TODO(lei): this is a stuff that should be fixed in the tvm in the future
34+
source = source.replace(
35+
"make_int2((signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0)",
36+
"make_int2(0, 0)",
37+
)
38+
return source

0 commit comments

Comments
 (0)