-
Notifications
You must be signed in to change notification settings - Fork 622
float16 support for GPU als model #661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
both training and inference times are slightly faster with fp16 - but not drastically so:
This is as expected, since we're computing results in float32 - just storing in float16. |
Running some quick experiments with cross-validation, and I got equivalent results with both fp16 and fp32 factors. This indicates that there isn't an accuracy hit to using fp16 factors in the learned model. Running a simple experiment on the lastfm dataset: from implicit.evaluation import precision_at_k, train_test_split
from implicit.datasets.lastfm import get_lastfm
from implicit.gpu.als import AlternatingLeastSquares
_, _, ratings = get_lastfm()
train, test = train_test_split(ratings.T.tocsr())
fp_16_model = AlternatingLeastSquares(factors=128, dtype="float16")
fp_16_model.fit(train)
p = precision_at_k(fp_16_model, train, test, K=10)
print("precision@10, fp16", p)
fp_32_model = AlternatingLeastSquares(factors=128, dtype="float32")
fp_32_model.fit(train)
p = precision_at_k(fp_32_model, train, test, K=10)
print("precision@10, fp32", p) Prints out
(note this was with just default hyper-parameters - the goal here is to show if the results are equivalent between fp16/fp32 or not, rather than to be the best possible results for the lastfm dataset). |
This adds support for using float16 factors in the GPU version of the ALS model. This reduces the memory needed for the ALS model embeddings by half - while providing a small speedup in training time, and virtually no difference in the accuracy of the learned model.
All computations are still performed using float32 - including both training and inference. This is done with using mixed precision matrix multiplications during inference : the fp16 factors are multiplied together with results accumulated as fp32. During training, the factors are converted from fp16 to fp32 - and updates are calculated in 32-bit before being stored back as fp16.