Skip to content

Commit

Permalink
Prelim release
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Mar 1, 2025
1 parent 0887650 commit 996dca3
Show file tree
Hide file tree
Showing 12 changed files with 152 additions and 231 deletions.
19 changes: 0 additions & 19 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,6 @@
# Fixes https://github.com/unslothai/unsloth/issues/1266
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

if "CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
devices = os.environ["CUDA_VISIBLE_DEVICES"]
# Check if there are multiple cuda devices set in env
if not devices.isdigit():
first_id = devices.split(",")[0]
warnings.warn(
f"Unsloth: 'CUDA_VISIBLE_DEVICES' is currently {devices} \n"\
"Unsloth currently does not support multi GPU setups - but we are working on it!\n"\
"Multiple CUDA devices detected but we require a single device.\n"\
f"We will override CUDA_VISIBLE_DEVICES to first device: {first_id}."
)
os.environ["CUDA_VISIBLE_DEVICES"] = str(first_id)
else:
# warnings.warn("Unsloth: 'CUDA_VISIBLE_DEVICES' is not set. We shall set it ourselves.")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
pass

# Reduce VRAM usage by reducing fragmentation
# And optimize pinning of memory
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \
Expand Down
8 changes: 4 additions & 4 deletions unsloth/kernels/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def forward(ctx, X, W, b, eps):
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)

Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
mu = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
device = X.device
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
r = torch.empty(n_rows, dtype = torch.float32, device = device)
mu = torch.empty(n_rows, dtype = torch.float32, device = device)

layernorm_forward[(n_rows,)](
Y, Y.stride(0),
Expand Down
7 changes: 4 additions & 3 deletions unsloth/kernels/rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool =
BLOCK_SIZE : int
num_warps : int
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
device = X.device

Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
r = torch.empty(n_rows, dtype = torch.float32, device = device)

fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
fx[(n_rows,)](
Expand Down Expand Up @@ -180,7 +181,7 @@ def backward(ctx, dY : torch.Tensor):
n_cols : int
n_rows, n_cols = dY.shape
# dW = X
dX = torch.empty_like(dY, device = "cuda:0") if ctx.GEMMA else dY
dX = torch.empty_like(dY) if ctx.GEMMA else dY

_rms_layernorm_backward[(n_rows,)](
dY, dY.stride(0),
Expand Down
2 changes: 1 addition & 1 deletion unsloth/kernels/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
def swiglu_fg_kernel(e, g):
batch, seq_len, hd = e.shape
n_elements = e.numel()
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
return h
Expand Down
95 changes: 50 additions & 45 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,29 @@ def calculate_settings(n : int) -> (int, int,):


import bitsandbytes as bnb
import ctypes

# https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
global CUDA_STREAM
CUDA_STREAM = None
get_ptr = bnb.functional.get_ptr
import ctypes

# Get array of CUDA streams and other buffers
global CUDA_STREAMS
global WEIGHT_BUFFERS
global ABSMAX_BUFFERS

_CUDA_STREAMS = {
(index := torch.cuda.device(i).idx) : ctypes.c_void_p(torch._C._cuda_getCurrentRawStream(index))
for i in range(torch.cuda.device_count())
}
CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
ABSMAX_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
for k, v in _CUDA_STREAMS.items(): CUDA_STREAMS[k] = v
CUDA_STREAMS = tuple(CUDA_STREAMS)
del _CUDA_STREAMS

# Bitsandbytes operations
ctypes_c_int = ctypes.c_int
ctypes_c_int32 = ctypes.c_int32
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
Expand Down Expand Up @@ -118,11 +135,6 @@ def get_lora_parameters_bias(proj):
return W, QUANT_STATE(W), A, B, s, bias
pass

global WEIGHT_BUFFER
WEIGHT_BUFFER = None
global ABSMAX_BUFFER
ABSMAX_BUFFER = None

if HAS_CUDA_STREAM:
@torch.inference_mode
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
Expand All @@ -145,8 +157,10 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
global CUDA_STREAM
if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
global CUDA_STREAMS
device = W.device
device_index = device.index
CUDA_STREAM = CUDA_STREAMS[device_index]

n_elements_absmax = absmax.numel()

Expand All @@ -155,11 +169,13 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False

# Use same buffers for faster inference
size = shape[0]*shape[1]
global WEIGHT_BUFFER
global ABSMAX_BUFFER
global WEIGHT_BUFFERS
global ABSMAX_BUFFERS
WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
if WEIGHT_BUFFER is None:
WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0", requires_grad = False)
ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False)
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = device, requires_grad = False)
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)

if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size)
if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax)
Expand All @@ -168,11 +184,11 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
else:
if out is None:
out = torch.empty(shape, dtype = dtype, device = "cuda:0", requires_grad = False)
out = torch.empty(shape, dtype = dtype, device = device, requires_grad = False)
else:
assert(out.shape == shape)
assert(out.dtype == dtype)
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False)
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
pass

# NF4 dequantization of statistics
Expand Down Expand Up @@ -217,31 +233,15 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
pass

n_elements_absmax = absmax.numel()
device = W.device

# Create weight matrix
if use_global_buffer:

# Use same buffers for faster inference
size = shape[0]*shape[1]
global WEIGHT_BUFFER
global ABSMAX_BUFFER
if WEIGHT_BUFFER is None:
WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0", requires_grad = False)
ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0", requires_grad = False)

if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size)
if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax)

out = WEIGHT_BUFFER[:size].view(shape)
out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
if out is None:
out = torch.empty(shape, dtype = dtype, device = device, requires_grad = False)
else:
if out is None:
out = torch.empty(shape, dtype = dtype, device = "cuda:0", requires_grad = False)
else:
assert(out.shape == shape)
assert(out.dtype == dtype)
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False)
pass
assert(out.shape == shape)
assert(out.dtype == dtype)
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)

# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
Expand Down Expand Up @@ -288,14 +288,16 @@ def fast_gemv(X, W, quant_state, out = None):
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
global CUDA_STREAM
if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
global CUDA_STREAMS
device = W.device
device_index = device.index
CUDA_STREAM = CUDA_STREAMS[device_index]

# assert(dtype == X.dtype)
bout = shape[0]

if out is None:
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
out = torch.empty((1, 1, bout,), dtype = dtype, device = device)
# else:
# assert(out.shape == (1, 1, bout,))
# pass
Expand All @@ -313,7 +315,7 @@ def fast_gemv(X, W, quant_state, out = None):
ldb = ctypes_c_int32(ldb)
ldc = ctypes_c_int32(ldc)

df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
df = torch.empty(absmax.shape, dtype = torch.float32, device = device)
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM,
Expand Down Expand Up @@ -357,9 +359,10 @@ def fast_gemv(X, W, quant_state, out = None):
pass
# assert(dtype == X.dtype)
bout = shape[0]
device = W.device

if out is None:
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
out = torch.empty((1, 1, bout,), dtype = dtype, device = device)
# else:
# assert(out.shape == (1, 1, bout,))
# pass
Expand All @@ -377,7 +380,7 @@ def fast_gemv(X, W, quant_state, out = None):
ldb = ctypes_c_int32(ldb)
ldc = ctypes_c_int32(ldc)

df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
df = torch.empty(absmax.shape, dtype = torch.float32, device = device)
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
ctypes_c_int(blocksize2), ctypes_c_int(df.numel()),
Expand All @@ -400,6 +403,7 @@ def fast_gemv(X, W, quant_state, out = None):
torch_mm = torch.mm
torch_mv = torch.mv
torch_matmul = torch.matmul
torch_addmm = torch.addmm
def fast_linear_forward(proj, X, temp_lora = None, out = None):

W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
Expand Down Expand Up @@ -461,7 +465,8 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
if A is not None:
# LoRA is enabled
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
out = torch_addmm(X @ A.to(dtype), B.to(dtype), alpha = s, beta = 1.0, out = out)
# out += (X @ A.to(dtype)) @ (s * B.to(dtype))
pass

return out.view(batch, seq_len, -1) if reshape else out
Expand Down
64 changes: 23 additions & 41 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2025.2.15"
__version__ = "2025.3.1"

__all__ = [
"SUPPORTS_BFLOAT16",
Expand All @@ -37,7 +37,6 @@
"torch_compile_options",
"patch_linear_scaling",
"patch_llama_rope_scaling",
"check_nvidia",
"create_boolean_mask",
"torch_amp_custom_fwd",
"torch_amp_custom_bwd",
Expand Down Expand Up @@ -703,9 +702,7 @@ def get_statistics():
# =============================================
# Fixes Bitsandbytes to remove missing warnings
from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
from inspect import getsource
from accelerate.utils.dataclasses import DistributedType
BitsAndBytesConfig__init__ = getsource(BitsAndBytesConfig.__init__)
BitsAndBytesConfig__init__ = inspect.getsource(BitsAndBytesConfig.__init__)
BitsAndBytesConfig__init__ = re.sub(
r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
"",
Expand All @@ -719,27 +716,29 @@ def get_statistics():
"__init__",
"_BitsAndBytesConfig__init__",
)
exec(BitsAndBytesConfig__init__, globals())

def _prepare_backend(
self, cpu = False, sagemaker_dp = False, backend: str = None,
) -> tuple[str, DistributedType]:
return None, DistributedType.NO
if torch.cuda.device_count() == 1:
from accelerate.utils.dataclasses import DistributedType
def _prepare_backend(
self, cpu = False, sagemaker_dp = False, backend: str = None,
) -> tuple[str, DistributedType]:
return None, DistributedType.NO
pass
import accelerate.state
accelerate.state.PartialState._prepare_backend = _prepare_backend

import accelerate.accelerator
prepare = inspect.getsource(accelerate.accelerator.Accelerator.prepare)
prepare = prepare.split("\n")
spaces = prepare[0].find("def")
prepare = "\n".join(x[spaces:] for x in prepare)
x = "for obj in args:"
s = " "*spaces
prepare = prepare.replace(x, f'self.state.distributed_type = DistributedType.NO\n{s}{x}', 1)
exec(prepare, globals())
accelerate.accelerator.Accelerator.prepare = prepare
pass
import accelerate.state
accelerate.state.PartialState._prepare_backend = _prepare_backend

import accelerate.accelerator
prepare = inspect.getsource(accelerate.accelerator.Accelerator.prepare)
prepare = prepare.split("\n")
spaces = prepare[0].find("def")
prepare = "\n".join(x[spaces:] for x in prepare)
x = "for obj in args:"
s = " "*spaces
prepare = prepare.replace(x, f'self.state.distributed_type = DistributedType.NO\n{s}{x}', 1)
exec(prepare, globals())
accelerate.accelerator.Accelerator.prepare = prepare

exec(BitsAndBytesConfig__init__, globals())

import transformers.utils.quantization_config
transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
Expand Down Expand Up @@ -963,21 +962,6 @@ def patch_llama_rope_scaling(
pass


def check_nvidia():
# Unsloth doesn't work yet on AMD devices - we're working on it!
output = np.array([0,])
try:
output = subprocess.check_output("nvidia-smi --query-gpu=memory.used --format=csv", shell = True)
output = re.findall(rb'([\d]{1,})[\s]{1,}M', output)
output = np.array([int(x.decode('utf-8'))/1024 for x in output])
except:
if not torch.cuda.is_available():
raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!")
return output
pass
PRE_CHECK = check_nvidia()


def create_boolean_mask(n = 4096, sliding_window = 2048):
# Creates a boolean mask for attention
mask = torch.ones(n, n, dtype = torch.bool)
Expand Down Expand Up @@ -1122,8 +1106,6 @@ def patch_gradient_accumulation_fix(Trainer):
items_in_trainer = dir(transformers.trainer)
good_items = []
for item in items_in_trainer:
# TODO: Support Deepspeed
if item.startswith(("deepspeed", "xm", "met", "smp")): continue
if item in function: good_items.append(item)
pass
exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())
Expand Down
Loading

0 comments on commit 996dca3

Please sign in to comment.