Skip to content

Commit

Permalink
Multiprocessing for GAMM uses shared memory for inverse calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
JoKra1 committed Mar 14, 2024
1 parent 15fa19b commit ec7da3e
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions src/mssm/src/python/gamm_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .formula import build_sparse_matrix_from_formula,setup_cache,clear_cache,cpp_solvers,pd,Formula,mp,repeat,os,map_csc_to_eigen
from tqdm import tqdm
from functools import reduce
from multiprocessing import managers,shared_memory

CACHE_DIR = './.db'
SHOULD_CACHE = False
Expand Down Expand Up @@ -272,6 +273,19 @@ def compute_B(L,P,lTerm,n_c=10):
B = cpp_solve_tr(L,P @ lTerm.D_J_emb[:,D_start:D_end])
return B.power(2).sum()

def compute_block_linv_shared(address_dat,address_ptr,address_idx,shape_dat,shape_ptr,rows,cols,nnz,LB):
dat_shared = shared_memory.SharedMemory(name=address_dat,create=False)
ptr_shared = shared_memory.SharedMemory(name=address_ptr,create=False)
idx_shared = shared_memory.SharedMemory(name=address_idx,create=False)

data = np.ndarray(shape_dat,dtype=np.double,buffer=dat_shared.buf)
indptr = np.ndarray(shape_ptr,dtype=np.int32,buffer=ptr_shared.buf)
indices = np.ndarray(shape_dat,dtype=np.int32,buffer=idx_shared.buf)

L = cpp_solvers.solve_tr(rows, cols, nnz, data, indptr, indices, LB)

return L

def compute_Linv(L,n_c=10):
# Solves L @ inv(L) = I for Binv(L) parallelizing over column
# blocks of I if int(I.shape[1]/2000) > 1
Expand All @@ -288,10 +302,27 @@ def compute_Linv(L,n_c=10):
split = np.array_split(range(n_col),n_c)
LBs = [T[:,split[i]] for i in range(n_c)]

with mp.Pool(processes=n_c) as pool:
args = zip(repeat(L),LBs)
with managers.SharedMemoryManager() as manager, mp.Pool(processes=n_c) as pool:
# Create shared memory copies of data, indptr, and indices
rows, cols, nnz, data, indptr, indices = map_csc_to_eigen(L)
shape_dat = data.shape
shape_ptr = indptr.shape

dat_mem = manager.SharedMemory(data.nbytes)
dat_shared = np.ndarray(shape_dat, dtype=np.double, buffer=dat_mem.buf)
dat_shared[:] = data[:]

ptr_mem = manager.SharedMemory(indptr.nbytes)
ptr_shared = np.ndarray(shape_ptr, dtype=np.int32, buffer=ptr_mem.buf)
ptr_shared[:] = indptr[:]

idx_mem = manager.SharedMemory(indices.nbytes)
idx_shared = np.ndarray(shape_dat, dtype=np.int32, buffer=idx_mem.buf)
idx_shared[:] = indices[:]

args = zip(repeat(dat_mem.name),repeat(ptr_mem.name),repeat(idx_mem.name),repeat(shape_dat),repeat(shape_ptr),repeat(rows),repeat(cols),repeat(nnz),LBs)

LBinvs = pool.starmap(cpp_solve_tr,args)
LBinvs = pool.starmap(compute_block_linv_shared,args)

return scp.sparse.hstack(LBinvs)

Expand Down

0 comments on commit ec7da3e

Please sign in to comment.