Skip to content

Commit d1c9406

Browse files
committed
Merge pull request #3 from wei-tan/master
move printf to debug
2 parents a98e41a + 11186cd commit d1c9406

File tree

6 files changed

+179
-146
lines changed

6 files changed

+179
-146
lines changed

als/README.md

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,37 @@ This folder contains:
66

77
(2) the JNI code to link to and accelerate the ALS.scala program in Spark MLlib.
88

9-
## Technical details
10-
By optimizing memory access and parallelism, cuMF is much faster and cost-efficient compared with state-of-art CPU based solutions.
9+
## What is matrix factorization?
1110

12-
More details can be found at:
11+
Matrix factorization (MF) factors a sparse rating matrix R (m by n, with N_z non-zero elements) into a m-by-f and a f-by-n matrices, as shown below.
12+
13+
<img src=https://github.com/wei-tan/CUDA-MLlib/raw/master/als/images/mf.png width=444 height=223 />
14+
15+
Matrix factorization (MF) is at the core of many popular algorithms, e.g., [collaborative filtering](https://en.wikipedia.org/wiki/Collaborative_filtering), word embedding, and topic model. GPU (graphics processing units) with massive cores and high intra-chip memory bandwidth sheds light on accelerating MF much further when appropriately exploiting its architectural characteristics.
16+
17+
## What is cuMF?
18+
19+
**CuMF** is a CUDA-based matrix factorization library that optimizes alternate least square (ALS) method to solve very large-scale MF. CuMF uses a set of techniques to maximize the performance on single and multiple GPUs. These techniques include smart access of sparse data leveraging GPU memory hierarchy, using data parallelism in conjunction with model parallelism, minimizing the communication overhead among GPUs, and a novel topology-aware parallel reduction scheme.
20+
21+
With only a single machine with four Nvidia GPU cards, cuMF can be 6-10 times as fast, and 33-100 times as cost-efficient, compared with the state-of-art distributed CPU solutions. Moreover, cuMF can solve the largest matrix factorization problem ever reported yet in current literature.
22+
23+
CuMF achieves excellent scalability and performance by innovatively applying the following techniques on GPUs:
24+
25+
(1) On a single GPU, MF deals with sparse matrices, which makes it difficult to utilize GPU's compute power. We optimize memory access in ALS by various techniques including reducing discontiguous memory access, retaining hotspot variables in faster memory, and aggressively using registers. By this means cuMF gets closer to the roofline performance of a single GPU.
1326

14-
1) This Nvidia GTC 2016 talk
15-
ppt:
27+
(2) On multiple GPUs, we add data parallelism to ALS's inherent model parallelism. Data parallelism needs a faster reduction operation among GPUs, leading to (3).
1628

17-
<http://www.slideshare.net/tanwei/s6211-cumf-largescale-matrix-factorization-on-just-one-machine-with-gpus>
29+
(3) We also develop an innovative topology-aware, parallel reduction method to fully leverage the bandwidth between GPUs. By this means cuMF ensures that multiple GPUs are efficiently utilized simultaneously.
1830

19-
video:
31+
## Use cuMF to accelerate Spark ALS
2032

21-
<http://on-demand.gputechconf.com/gtc/2016/video/S6211.html>
33+
CuMF can be used standalone, or to accelerate the [ALS implementation in Spark MLlib](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala).
2234

23-
2) This HPDC 2016 paper:
35+
We modified Spark's ml/recommendation/als.scala ([code](https://github.com/wei-tan/SparkGPU/blob/MLlib/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala)) to detect GPU and offload the ALS forming and solving to GPUs, while retain shuffling on Spark RDD.
2436

25-
"Faster and Cheaper: Parallelizing Large-Scale Matrix Factorization on GPUs"
26-
<http://arxiv.org/abs/1603.03820>
37+
<img src=https://github.com/wei-tan/CUDA-MLlib/raw/master/als/images/spark-gpu.png width=380 height=240 />
38+
39+
This approach has several advantages. First, existing Spark applications relying on mllib/ALS need no change. Second, we leverage the best of Spark (to scale-out to multiple nodes) and GPU (to scale-up in one node).
2740

2841
## Build
2942
There are scripts to build the program locally, run in local mode, and run in distributed mode.
@@ -52,4 +65,12 @@ We are trying to improve the usability, stability and performance. Here are some
5265

5366
(1) Out-of-memory error from GPUs, when there are many CPU threads accessing a small number of GPUs on any node. We tested Netflix data on one node, with 12 CPU cores used by the executor, and 2 Nvidia K40 GPU cards. If you have more GPU cards, you may be able to accomodate more CPU cores/threads. Otherwise you need to lessen the #cores assigned to Spark executor.
5467

55-
(2) CPU-GPU hybrid execution. We want to push as much workload to GPU as possible. If GPUs cannot accomodate all CPU threads, we want to retain the execution on CPUs.
68+
(2) CPU-GPU hybrid execution. We want to push as much workload to GPU as possible. If GPUs cannot accomodate all CPU threads, we want to retain the execution on CPUs.
69+
70+
## References
71+
72+
More details can be found at:
73+
74+
1) CuMF: Large-Scale Matrix Factorization on Just One Machine with GPUs. Nvidia GTC 2016 talk. [ppt](http://www.slideshare.net/tanwei/s6211-cumf-largescale-matrix-factorization-on-just-one-machine-with-gpus), [video](http://on-demand.gputechconf.com/gtc/2016/video/S6211.html)
75+
76+
2) Faster and Cheaper: Parallelizing Large-Scale Matrix Factorization on GPUs. Wei Tan, Liangliang Cao, Liana Fong. [HPDC 2016](http://arxiv.org/abs/1603.03820), Kyoto, Japan

als/images/mf.png

45 KB
Loading

als/images/spark-gpu.png

113 KB
Loading

als/src/CuMFJNIInterface.cpp

Lines changed: 98 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -24,107 +24,111 @@
2424

2525
JNIEXPORT jobjectArray JNICALL Java_org_apache_spark_ml_recommendation_CuMFJNIInterface_doALSWithCSR
2626
(JNIEnv * env, jobject obj, jint m, jint n, jint f, jint nnz, jdouble lambda, jobjectArray sortedSrcFactors, jintArray csrRow, jintArray csrCol, jfloatArray csrVal){
27-
//checkCudaErrors(cudaSetDevice(1));
28-
//use multiple GPUs
29-
//select a GPU for *this* specific dataset
30-
int whichGPU = get_gpu();
31-
checkCudaErrors(cudaSetDevice(whichGPU));
32-
cudaStream_t cuda_stream;
33-
cudaStreamCreate(&cuda_stream);
34-
/* check correctness
35-
int csrRowlen = env->GetArrayLength(csrRow);
36-
int csrCollen = env->GetArrayLength(csrCol);
37-
int csrVallen = env->GetArrayLength(csrVal);
38-
assert(csrRowlen == m + 1);
39-
assert(csrCollen == nnz);
40-
assert(csrVallen == nnz);
41-
*/
42-
int* csrRowIndexHostPtr;
43-
int* csrColIndexHostPtr;
44-
float* csrValHostPtr;
45-
/*
46-
printf("csrRow of len %d: ", len);
47-
for (int i = 0; i < len; i++) {
48-
printf("%d ", body[i]);
49-
}
50-
printf("\n");
51-
*/
52-
//calculate X from thetaT
53-
float* thetaTHost;
54-
cudacall(cudaMallocHost( (void** ) &thetaTHost, n * f * sizeof(thetaTHost[0])) );
55-
//to be returned
56-
float* XTHost;
57-
cudacall(cudaMallocHost( (void** ) &XTHost, m * f * sizeof(XTHost[0])) );
58-
59-
int numSrcBlocks = env->GetArrayLength(sortedSrcFactors);
60-
//WARNING: ReleaseFloatArrayElements and DeleteLocalRef are important;
61-
//Otherwise result is correct but performance is bad
62-
int index = 0;
63-
for(int i = 0; i < numSrcBlocks; i++){
64-
jobject factorsPerBlock = env->GetObjectArrayElement(sortedSrcFactors, i);
65-
int numFactors = env->GetArrayLength((jobjectArray)factorsPerBlock);
66-
for(int j = 0; j < numFactors; j++){
67-
jobject factor = env->GetObjectArrayElement((jobjectArray)factorsPerBlock, j);
68-
jfloat *factorfloat = (jfloat *) env->GetPrimitiveArrayCritical( (jfloatArray)factor, 0);
69-
memcpy(thetaTHost + index*f, factorfloat, sizeof(float)*f);
70-
index ++;
71-
env->ReleasePrimitiveArrayCritical((jfloatArray)factor, factorfloat, 0);
72-
env->DeleteLocalRef(factor);
27+
try{
28+
//checkCudaErrors(cudaSetDevice(1));
29+
//use multiple GPUs
30+
//select a GPU for *this* specific dataset
31+
int whichGPU = get_gpu();
32+
checkCudaErrors(cudaSetDevice(whichGPU));
33+
cudaStream_t cuda_stream;
34+
cudaStreamCreate(&cuda_stream);
35+
/* check correctness
36+
int csrRowlen = env->GetArrayLength(csrRow);
37+
int csrCollen = env->GetArrayLength(csrCol);
38+
int csrVallen = env->GetArrayLength(csrVal);
39+
assert(csrRowlen == m + 1);
40+
assert(csrCollen == nnz);
41+
assert(csrVallen == nnz);
42+
*/
43+
int* csrRowIndexHostPtr;
44+
int* csrColIndexHostPtr;
45+
float* csrValHostPtr;
46+
/*
47+
printf("csrRow of len %d: ", len);
48+
for (int i = 0; i < len; i++) {
49+
printf("%d ", body[i]);
7350
}
74-
env->DeleteLocalRef(factorsPerBlock);
75-
}
76-
// get a pointer to the raw input data, pinning them in memory
77-
csrRowIndexHostPtr = (jint*) env->GetPrimitiveArrayCritical(csrRow, 0);
78-
csrColIndexHostPtr = (jint*) env->GetPrimitiveArrayCritical(csrCol, 0);
79-
csrValHostPtr = (jfloat*) env->GetPrimitiveArrayCritical(csrVal, 0);
51+
printf("\n");
52+
*/
53+
//calculate X from thetaT
54+
float* thetaTHost;
55+
cudacall(cudaMallocHost( (void** ) &thetaTHost, n * f * sizeof(thetaTHost[0])) );
56+
//to be returned
57+
float* XTHost;
58+
cudacall(cudaMallocHost( (void** ) &XTHost, m * f * sizeof(XTHost[0])) );
59+
60+
int numSrcBlocks = env->GetArrayLength(sortedSrcFactors);
61+
//WARNING: ReleaseFloatArrayElements and DeleteLocalRef are important;
62+
//Otherwise result is correct but performance is bad
63+
int index = 0;
64+
for(int i = 0; i < numSrcBlocks; i++){
65+
jobject factorsPerBlock = env->GetObjectArrayElement(sortedSrcFactors, i);
66+
int numFactors = env->GetArrayLength((jobjectArray)factorsPerBlock);
67+
for(int j = 0; j < numFactors; j++){
68+
jobject factor = env->GetObjectArrayElement((jobjectArray)factorsPerBlock, j);
69+
jfloat *factorfloat = (jfloat *) env->GetPrimitiveArrayCritical( (jfloatArray)factor, 0);
70+
memcpy(thetaTHost + index*f, factorfloat, sizeof(float)*f);
71+
index ++;
72+
env->ReleasePrimitiveArrayCritical((jfloatArray)factor, factorfloat, 0);
73+
env->DeleteLocalRef(factor);
74+
}
75+
env->DeleteLocalRef(factorsPerBlock);
76+
}
77+
// get a pointer to the raw input data, pinning them in memory
78+
csrRowIndexHostPtr = (jint*) env->GetPrimitiveArrayCritical(csrRow, 0);
79+
csrColIndexHostPtr = (jint*) env->GetPrimitiveArrayCritical(csrCol, 0);
80+
csrValHostPtr = (jfloat*) env->GetPrimitiveArrayCritical(csrVal, 0);
8081

81-
/*
82-
printf("thetaTHost of len %d: \n", n*f);
83-
for (int i = 0; i < n*f; i++) {
84-
printf("%f ", thetaTHost[i]);
85-
}
86-
printf("\n");
87-
*/
88-
int * d_csrRowIndex = 0;
89-
int * d_csrColIndex = 0;
90-
float * d_csrVal = 0;
82+
/*
83+
printf("thetaTHost of len %d: \n", n*f);
84+
for (int i = 0; i < n*f; i++) {
85+
printf("%f ", thetaTHost[i]);
86+
}
87+
printf("\n");
88+
*/
89+
int * d_csrRowIndex = 0;
90+
int * d_csrColIndex = 0;
91+
float * d_csrVal = 0;
9192

92-
cudacall(cudaMalloc((void** ) &d_csrRowIndex,(m + 1) * sizeof(float)));
93-
cudacall(cudaMalloc((void** ) &d_csrColIndex, nnz * sizeof(float)));
94-
cudacall(cudaMalloc((void** ) &d_csrVal, nnz * sizeof(float)));
95-
cudacall(cudaMemcpyAsync(d_csrRowIndex, csrRowIndexHostPtr,(size_t ) ((m + 1) * sizeof(float)), cudaMemcpyHostToDevice, cuda_stream));
96-
cudacall(cudaMemcpyAsync(d_csrColIndex, csrColIndexHostPtr,(size_t ) (nnz * sizeof(float)), cudaMemcpyHostToDevice, cuda_stream));
97-
cudacall(cudaMemcpyAsync(d_csrVal, csrValHostPtr,(size_t ) (nnz * sizeof(float)),cudaMemcpyHostToDevice, cuda_stream));
98-
cudaStreamSynchronize(cuda_stream);
93+
cudacall(mallocBest((void** ) &d_csrRowIndex,(m + 1) * sizeof(float)));
94+
cudacall(mallocBest((void** ) &d_csrColIndex, nnz * sizeof(float)));
95+
cudacall(mallocBest((void** ) &d_csrVal, nnz * sizeof(float)));
96+
cudacall(cudaMemcpyAsync(d_csrRowIndex, csrRowIndexHostPtr,(size_t ) ((m + 1) * sizeof(float)), cudaMemcpyHostToDevice, cuda_stream));
97+
cudacall(cudaMemcpyAsync(d_csrColIndex, csrColIndexHostPtr,(size_t ) (nnz * sizeof(float)), cudaMemcpyHostToDevice, cuda_stream));
98+
cudacall(cudaMemcpyAsync(d_csrVal, csrValHostPtr,(size_t ) (nnz * sizeof(float)),cudaMemcpyHostToDevice, cuda_stream));
99+
cudaStreamSynchronize(cuda_stream);
99100

100-
// un-pin the host arrays, as we're done with them
101-
env->ReleasePrimitiveArrayCritical(csrRow, csrRowIndexHostPtr, 0);
102-
env->ReleasePrimitiveArrayCritical(csrCol, csrColIndexHostPtr, 0);
103-
env->ReleasePrimitiveArrayCritical(csrVal, csrValHostPtr, 0);
101+
// un-pin the host arrays, as we're done with them
102+
env->ReleasePrimitiveArrayCritical(csrRow, csrRowIndexHostPtr, 0);
103+
env->ReleasePrimitiveArrayCritical(csrCol, csrColIndexHostPtr, 0);
104+
env->ReleasePrimitiveArrayCritical(csrVal, csrValHostPtr, 0);
104105

105-
printf("\tdoALSWithCSR with m=%d,n=%d,f=%d,nnz=%d,lambda=%f \n.", m, n, f, nnz, lambda);
106-
try{
106+
#ifdef DEBUG
107+
printf("\tdoALSWithCSR with m=%d,n=%d,f=%d,nnz=%d,lambda=%f \n.", m, n, f, nnz, lambda);
108+
#endif
107109
doALSWithCSR(cuda_stream, d_csrRowIndex, d_csrColIndex, d_csrVal, thetaTHost, XTHost, m, n, f, nnz, lambda, 1);
108-
}
109-
catch (thrust::system_error &e) {
110-
printf("CUDA error during some_function: %s", e.what());
111110

112-
}
113-
jclass floatArrayClass = env->FindClass("[F");
114-
jobjectArray output = env->NewObjectArray(m, floatArrayClass, 0);
115-
for (int i = 0; i < m; i++) {
116-
jfloatArray col = env->NewFloatArray(f);
117-
env->SetFloatArrayRegion(col, 0, f, XTHost + i*f);
118-
env->SetObjectArrayElement(output, i, col);
119-
env->DeleteLocalRef(col);
120-
}
121-
cudaFreeHost(thetaTHost);
122-
cudaFreeHost(XTHost);
123-
//TODO: stream create and destroy expensive?
124-
checkCudaErrors(cudaStreamSynchronize(cuda_stream));
125-
checkCudaErrors(cudaStreamDestroy(cuda_stream));
126-
cudaCheckError();
127-
return output;
111+
jclass floatArrayClass = env->FindClass("[F");
112+
jobjectArray output = env->NewObjectArray(m, floatArrayClass, 0);
113+
for (int i = 0; i < m; i++) {
114+
jfloatArray col = env->NewFloatArray(f);
115+
env->SetFloatArrayRegion(col, 0, f, XTHost + i*f);
116+
env->SetObjectArrayElement(output, i, col);
117+
env->DeleteLocalRef(col);
118+
}
119+
cudaFreeHost(thetaTHost);
120+
cudaFreeHost(XTHost);
121+
//TODO: stream create and destroy expensive?
122+
checkCudaErrors(cudaStreamSynchronize(cuda_stream));
123+
checkCudaErrors(cudaStreamDestroy(cuda_stream));
124+
#ifdef DEBUG
125+
cudaCheckError();
126+
#endif
127+
return output;
128+
}
129+
catch (thrust::system_error &e) {
130+
printf("CUDA error during some function: %s", e.what());
131+
}
128132
}
129133

130134
JNIEXPORT void JNICALL Java_org_apache_spark_ml_recommendation_CuMFJNIInterface_testjni

0 commit comments

Comments
 (0)