Skip to content

Commit b85cfa1

Browse files
committed
gemm: Allow custom packing functions
For complex, we'll want to use a different packing function. Add packing into the GemmKernel interface so that kernels can request a different packing function. The standard packing function is unchanged but gets its own module in the code.
1 parent 5e0aea7 commit b85cfa1

File tree

4 files changed

+118
-76
lines changed

4 files changed

+118
-76
lines changed

src/gemm.rs

Lines changed: 4 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
use core::cell::UnsafeCell;
1111
use core::cmp::min;
1212
use core::mem::size_of;
13-
use core::ptr::copy_nonoverlapping;
1413
use core::slice;
1514

1615
use crate::aligned_alloc::Alloc;
@@ -19,7 +18,6 @@ use crate::ptr::Ptr;
1918
use crate::util::range_chunk;
2019
use crate::util::round_up_to;
2120

22-
use crate::kernel::ConstNum;
2321
use crate::kernel::Element;
2422
use crate::kernel::GemmKernel;
2523
use crate::kernel::GemmSelect;
@@ -302,8 +300,8 @@ unsafe fn gemm_loop<K>(
302300
let a = a.stride_offset(csa, kkc * l4);
303301

304302
// Pack B -> B~
305-
pack::<K::NRTy, _>(kc, nc, slice::from_raw_parts_mut(bpp.ptr(), bp_size),
306-
b.ptr(), csb, rsb);
303+
K::pack_nr(kc, nc, slice::from_raw_parts_mut(bpp.ptr(), bp_size),
304+
b.ptr(), csb, rsb);
307305

308306
// First time writing to C, use user's `beta`, else accumulate
309307
let betap = if l4 == 0 { beta } else { <_>::one() };
@@ -322,8 +320,8 @@ unsafe fn gemm_loop<K>(
322320
let c = c.stride_offset(rsc, kmc * l3);
323321

324322
// Pack A -> A~
325-
pack::<K::MRTy, _>(kc, mc, slice::from_raw_parts_mut(app.ptr(), ap_size),
326-
a.ptr(), rsa, csa);
323+
K::pack_mr(kc, mc, slice::from_raw_parts_mut(app.ptr(), ap_size),
324+
a.ptr(), rsa, csa);
327325

328326
// LOOP 2 and 1
329327
gemm_packed::<K>(nc, kc, mc,
@@ -471,76 +469,6 @@ unsafe fn align_ptr<T>(align_to: usize, mut ptr: *mut T) -> *mut T {
471469
ptr
472470
}
473471

474-
/// Pack matrix into `pack`
475-
///
476-
/// + kc: length of the micropanel
477-
/// + mc: number of rows/columns in the matrix to be packed
478-
/// + pack: packing buffer
479-
/// + a: matrix,
480-
/// + rsa: row stride
481-
/// + csa: column stride
482-
///
483-
/// + MR: kernel rows/columns that we round up to
484-
// If one of pack and a is of a reference type, it gets a noalias annotation which
485-
// gives benefits to optimization. The packing buffer is contiguous so it can be passed as a slice
486-
// here.
487-
unsafe fn pack<MR, T>(kc: usize, mc: usize, pack: &mut [T],
488-
a: *const T, rsa: isize, csa: isize)
489-
where T: Element,
490-
MR: ConstNum,
491-
{
492-
let pack = pack.as_mut_ptr();
493-
let mr = MR::VALUE;
494-
let mut p = 0; // offset into pack
495-
496-
if rsa == 1 {
497-
// if the matrix is contiguous in the same direction we are packing,
498-
// copy a kernel row at a time.
499-
for ir in 0..mc/mr {
500-
let row_offset = ir * mr;
501-
for j in 0..kc {
502-
let a_row = a.stride_offset(rsa, row_offset)
503-
.stride_offset(csa, j);
504-
copy_nonoverlapping(a_row, pack.add(p), mr);
505-
p += mr;
506-
}
507-
}
508-
} else {
509-
// general layout case
510-
for ir in 0..mc/mr {
511-
let row_offset = ir * mr;
512-
for j in 0..kc {
513-
for i in 0..mr {
514-
let a_elt = a.stride_offset(rsa, i + row_offset)
515-
.stride_offset(csa, j);
516-
copy_nonoverlapping(a_elt, pack.add(p), 1);
517-
p += 1;
518-
}
519-
}
520-
}
521-
}
522-
523-
let zero = <_>::zero();
524-
525-
// Pad with zeros to multiple of kernel size (uneven mc)
526-
let rest = mc % mr;
527-
if rest > 0 {
528-
let row_offset = (mc/mr) * mr;
529-
for j in 0..kc {
530-
for i in 0..mr {
531-
if i < rest {
532-
let a_elt = a.stride_offset(rsa, i + row_offset)
533-
.stride_offset(csa, j);
534-
copy_nonoverlapping(a_elt, pack.add(p), 1);
535-
} else {
536-
*pack.add(p) = zero;
537-
}
538-
p += 1;
539-
}
540-
}
541-
}
542-
}
543-
544472
/// Call the GEMM kernel with a "masked" output C.
545473
///
546474
/// Simply redirect the MR by NR kernel output to the passed

src/kernel.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
// except according to those terms.
88

99
use crate::archparam;
10+
use crate::packing::pack;
1011

1112
/// General matrix multiply kernel
1213
pub(crate) trait GemmKernel {
@@ -35,6 +36,33 @@ pub(crate) trait GemmKernel {
3536
#[inline(always)]
3637
fn mc() -> usize { archparam::S_MC }
3738

39+
/// Pack matrix A into its packing buffer.
40+
///
41+
/// See pack for more documentation.
42+
///
43+
/// Override only if the default packing function does not
44+
/// use the right layout.
45+
#[inline]
46+
unsafe fn pack_mr(kc: usize, mc: usize, pack_buf: &mut [Self::Elem],
47+
a: *const Self::Elem, rsa: isize, csa: isize)
48+
{
49+
pack::<Self::MRTy, _>(kc, mc, pack_buf, a, rsa, csa)
50+
}
51+
52+
/// Pack matrix B into its packing buffer
53+
///
54+
/// See pack for more documentation.
55+
///
56+
/// Override only if the default packing function does not
57+
/// use the right layout.
58+
#[inline]
59+
unsafe fn pack_nr(kc: usize, mc: usize, pack_buf: &mut [Self::Elem],
60+
a: *const Self::Elem, rsa: isize, csa: isize)
61+
{
62+
pack::<Self::NRTy, _>(kc, mc, pack_buf, a, rsa, csa)
63+
}
64+
65+
3866
/// Matrix multiplication kernel
3967
///
4068
/// This does the matrix multiplication:

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ pub(crate) use archparam_defaults as archparam;
149149

150150
mod gemm;
151151
mod kernel;
152+
mod packing;
152153
mod ptr;
153154
mod threading;
154155

src/packing.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright 2016 - 2023 Ulrik Sverdrup "bluss"
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use rawpointer::PointerExt;
10+
11+
use core::ptr::copy_nonoverlapping;
12+
13+
use crate::kernel::ConstNum;
14+
use crate::kernel::Element;
15+
16+
/// Pack matrix into `pack`
17+
///
18+
/// + kc: length of the micropanel
19+
/// + mc: number of rows/columns in the matrix to be packed
20+
/// + pack: packing buffer
21+
/// + a: matrix,
22+
/// + rsa: row stride
23+
/// + csa: column stride
24+
///
25+
/// + MR: kernel rows/columns that we round up to
26+
// If one of pack and a is of a reference type, it gets a noalias annotation which
27+
// gives benefits to optimization. The packing buffer is contiguous so it can be passed as a slice
28+
// here.
29+
pub(crate) unsafe fn pack<MR, T>(kc: usize, mc: usize, pack: &mut [T],
30+
a: *const T, rsa: isize, csa: isize)
31+
where T: Element,
32+
MR: ConstNum,
33+
{
34+
let pack = pack.as_mut_ptr();
35+
let mr = MR::VALUE;
36+
let mut p = 0; // offset into pack
37+
38+
if rsa == 1 {
39+
// if the matrix is contiguous in the same direction we are packing,
40+
// copy a kernel row at a time.
41+
for ir in 0..mc/mr {
42+
let row_offset = ir * mr;
43+
for j in 0..kc {
44+
let a_row = a.stride_offset(rsa, row_offset)
45+
.stride_offset(csa, j);
46+
copy_nonoverlapping(a_row, pack.add(p), mr);
47+
p += mr;
48+
}
49+
}
50+
} else {
51+
// general layout case
52+
for ir in 0..mc/mr {
53+
let row_offset = ir * mr;
54+
for j in 0..kc {
55+
for i in 0..mr {
56+
let a_elt = a.stride_offset(rsa, i + row_offset)
57+
.stride_offset(csa, j);
58+
copy_nonoverlapping(a_elt, pack.add(p), 1);
59+
p += 1;
60+
}
61+
}
62+
}
63+
}
64+
65+
let zero = <_>::zero();
66+
67+
// Pad with zeros to multiple of kernel size (uneven mc)
68+
let rest = mc % mr;
69+
if rest > 0 {
70+
let row_offset = (mc/mr) * mr;
71+
for j in 0..kc {
72+
for i in 0..mr {
73+
if i < rest {
74+
let a_elt = a.stride_offset(rsa, i + row_offset)
75+
.stride_offset(csa, j);
76+
copy_nonoverlapping(a_elt, pack.add(p), 1);
77+
} else {
78+
*pack.add(p) = zero;
79+
}
80+
p += 1;
81+
}
82+
}
83+
}
84+
}
85+

0 commit comments

Comments
 (0)