-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add conv1d grouped convs on CPU * Add GPU support * Parallelize inside metal kernel * clenaup * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * New unfold kernel + remove unused code * Remove copy and refactor * Update vjp and reuse steel gemm * Fixed groups on cpu * Fix metal validation --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
- Loading branch information
Showing
11 changed files
with
634 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import argparse | ||
import math | ||
import os | ||
import subprocess | ||
import time | ||
|
||
import mlx.core as mx | ||
import numpy as np | ||
import torch | ||
|
||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) | ||
device_name = device_name.decode("utf-8").strip("\n") | ||
|
||
N_warmup = 10 | ||
N_iter_bench = 100 | ||
N_iter_func = 5 | ||
|
||
|
||
def bench(f, a, b): | ||
for i in range(N_warmup): | ||
f(a, b) | ||
torch.mps.synchronize() | ||
|
||
s = time.perf_counter_ns() | ||
for i in range(N_iter_bench): | ||
f(a, b) | ||
e = time.perf_counter_ns() | ||
return (e - s) * 1e-9 | ||
|
||
|
||
def make_mx_conv_1D(strides=1, padding=0, groups=1): | ||
def mx_conv_1D(a, b): | ||
ys = [] | ||
for _ in range(N_iter_func): | ||
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups) | ||
ys.append(y) | ||
mx.eval(ys) | ||
return ys | ||
|
||
return mx_conv_1D | ||
|
||
|
||
def make_pt_conv_1D(strides=1, padding=0, groups=1): | ||
@torch.no_grad() | ||
def pt_conv_1D(a, b): | ||
ys = [] | ||
for _ in range(N_iter_func): | ||
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups) | ||
ys.append(y) | ||
torch.mps.synchronize() | ||
return ys | ||
|
||
return pt_conv_1D | ||
|
||
|
||
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups): | ||
scale = 1.0 / math.sqrt(wH * C) | ||
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype) | ||
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype) | ||
|
||
a_mx = mx.array(a_np) | ||
b_mx = mx.array(b_np) | ||
|
||
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps") | ||
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps") | ||
|
||
torch.mps.synchronize() | ||
|
||
f_mx = make_mx_conv_1D(strides, padding, groups) | ||
f_pt = make_pt_conv_1D(strides, padding, groups) | ||
|
||
time_torch = bench(f_pt, a_pt, b_pt) | ||
time_mlx = bench(f_mx, a_mx, b_mx) | ||
|
||
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) | ||
out_pt = torch.conv1d( | ||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups | ||
) | ||
out_pt = torch.permute(out_pt, (0, 2, 1)) | ||
out_pt = out_pt.numpy(force=True) | ||
|
||
atol = 2e-5 if np_dtype == np.float32 else 1e-4 | ||
|
||
if not np.allclose(out_pt, out_mx, atol=atol): | ||
print( | ||
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" | ||
) | ||
|
||
return time_mlx, time_torch | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Run conv benchmarks") | ||
|
||
dtypes = ("float32",) | ||
shapes = ( | ||
(4, 32, 32, 5, 32, 1, 2, 1), | ||
(4, 32, 32, 5, 32, 1, 2, 2), | ||
(4, 32, 32, 5, 32, 1, 2, 4), | ||
(4, 32, 32, 5, 32, 1, 2, 8), | ||
(4, 32, 32, 5, 32, 1, 2, 8), | ||
(4, 32, 32, 5, 32, 1, 2, 16), | ||
(4, 32, 32, 5, 32, 1, 2, 32), | ||
(4, 32, 256, 5, 512, 1, 2, 2), | ||
(4, 32, 256, 5, 512, 1, 2, 128), | ||
(4, 32, 256, 5, 512, 1, 2, 256), | ||
) | ||
|
||
for dtype in dtypes: | ||
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%") | ||
for N, iH, C, wH, O, strides, padding, groups in shapes: | ||
np_dtype = getattr(np, dtype) | ||
time_mlx, time_torch = bench_shape( | ||
N, iH, C, wH, O, strides, padding, np_dtype, groups | ||
) | ||
diff = time_torch / time_mlx - 1.0 | ||
|
||
print( | ||
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%" | ||
) | ||
|
||
if time_mlx >= 2.0 * time_torch: | ||
print("ATTENTION ^^^^^^^") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.