-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using jax in ml-notebook instantly kills kernel #387
Comments
Thanks for digging into this Dhruv! @ngam - any thoughts on what might be wrong here? |
That's what I addressed in the previous PR (add cuda-nvcc). Could you print some diagnostics please? For example:
Also, it is important to know which ml-notebook image we are talking about here. |
If I the problem is what I think it is (missing cuda-nvcc), this how you know: |
The information is here: https://github.com/2i2c-org/infrastructure/blob/63a034530f30179bc80313126624cc337836e744/config/clusters/m2lines/common.values.yaml#L94-L97 It's |
Dhruv can you try this and see if it works? |
I just tried that. Now I get a new long error:
|
Dhruv, it looks like the |
I was unable to overwrite the ML image using the configurator, so I haven't gotten an answer yet. |
Thanks for reporting back the errors. These are the relevant bits:
And this is definitely related to the ptxas stuff. One thing to try (unlikely to resolve this, but might as well try it) is to get a specific cuda-nvcc, e.g.
|
This change should fix the jax problems that are mentioned here: pangeo-data/pangeo-docker-images#387 (comment)
Okay, I also wonder if this is a bad issue with K80... I hope not, but it could be. The jaxlib versions we have are strictly 11.2+, and we target these compute capabilities:
This sadly happens to be skipping I can produce a specific jaxlib build for you later this week to test if nothing else fixes this issue. If that resolves it, then we can add it here https://github.com/conda-forge/jaxlib-feedstock/blob/04564570fb408c83222d67fe948811ac5e3fb9d2/recipe/build.sh#L44 (I happen to have access to only A100 and V100, so my testing could easily miss older GPUs... the clusters I have access to only have A100 and V100) |
Alternatively, we could configure out cloud hubs to use a different GPU. What would you recommend as a good tradeoff btw cost and performance for general ML work? |
The K80 is definitely good in my previous experience (before I got involved in building jaxlib). The best of all worlds (cost, etc.) seem to be T4 and there are free versions for exploratory work (e.g. planetary computer, the never-available sagemaker free tier, and the ever-present colab all have T4s readily) and T4 supports more interesting modern optimizations (I believe) |
Actually, I now realize I have access to some weird "vis" compute node that I never used, that may still have K80s. I will test later today to confirm... |
I have tested this snippet with jax/jaxlib from conda-forge: import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0) on various clusters with A100, V100, and K80 and they all work fine. Crucially, these clusters all have system-level thorough installations of all CUDA compilers (by my request from a while ago, especially ptxas). Now I need to figure out a way to test this in isolation with the ptxas from the nvidia channel. If, for some annoying reason, we cannot get this to work correctly, I will work with the team here to rebuild the containers from scratch based on the NGC containers (where they make available all the NVIDIA proprietary stuff). More soon! |
Okay, another update. I can reproduce the same exact error as above now inside a container on a cloud service (not a cluster). This is definitely an issue with ptxas. Let me try to get to the bottom of this. |
Thank you so much ngam! |
TEMPORARY SOLUTION: the key issue here is the mismatch between the driver and the software versions from the nvidia channel (something currently beyond my understanding), but here's the solution that got this container to work on a cloud service:
@dhruvbalwada could you please test with this command? If it works, we will pin cuda-nvcc in these containers until we find a better and more sustainable solution. |
On the cluster, we follow https://cloud.google.com/kubernetes-engine/docs/how-to/gpus#installing_drivers to install the drivers. We use the 'default' version here. The m2lines hub is on k8s 1.21 and hence is running driver version R450 (whatever that means?!). We can upgrade that to 'latest' if needed. |
@ngam - I tried downgrading the cuda-nvcc to 11.6 as you suggested (after some updates to the cloud deployment, it was at 11.8.X). However that did not solve the problem - sadly. |
It will likely be difficult to downgrade cuda-nvcc enough for this to work on k80... A (safe) alternative is to simply follow the instructions in the error message: set -- Can you check if the following snippet resolves it for you?
|
Maybe helpful, I was reading this jax-ml/jax#5723 . It said have the @yuvipanda - is it possible to align these version numbers? |
Yes, but we don't have cuda-nvcc 11.0 on the nvidia channel :( lowest is 11.3 yuvipanda will need to figure out how to match those numbers to the other crazy numbers (e.g. R450)... Try |
@ngam setting xla flags as you and the error message suggested did work. It comes with the downside, as quoted :
However, since I am new to all this, it is unlikely to a bottle neck to anything I am doing for now. |
Awesome, yes I see it. Will report back on results soon. |
@yuvipanda - I tried the T4. It does not resolve the problem, maybe as the version numbers are still distinct (nvcc at 11.8 and nvidia-smi at 11.0). Maybe the version numbers need to be pinned? @ngam do you have suggestions for what to do with a working T4? |
@dhruvbalwada ok, i'm going to try changing the driver version now, so let's see. |
@dhruvbalwada Try restarting your server, you should see cuda version 11.4 now? |
And use it for the m2lines hub and leap hub, based on conversations in pangeo-data/pangeo-docker-images#387
Here's the new version:
vs the old version:
|
I see cuda version 11.4 and nvcc version 11.8.89. The XLARuntimeError still shows up. :( |
Try |
So what did work for me now, is this. *This is all for T4. |
Unless it is possible to take CUDA all the way upto 11.8 (without breaking anything else)? |
@dhruvbalwada let me poke around. |
@dhruvbalwada I'll have to upgrade the k8s master, and then upgrade the nodes to make that up I think. Should be doable without interruption, am trying that out now. Will keep you posted here. |
I think it would be very useful to have a table of cuda versions and the pangeo docker images they are compatible with. |
I don't know how to get that. I also am not sure what other things (apart from JAX) are relying on the cuda versions. |
@dhruvbalwada you're gonna see a couple minutes interruption on m2lines now, sorry about that. |
No worries, I am not doing anything of significance at the moment. Kill the instance if you have to. |
@dhruvbalwada ok, killed :) |
Now at:
@dhruvbalwada @ngam I think 11.6 is as high as GKE is gonna go for now. |
Thank you Yuvi! As predicted by ngam: On T4:
On K80
|
@ngam - should we pin nvcc to 11.6 for now (so that the user doesn't have to manually downgrade everytime)? |
Sure, here's a PR: #389 |
Thank you. I think we can close this issue when the PR is merged. |
I am trying to use jax on the ml-notebook, but run into the problem that even the most basic functions kill the kernel.
If I run the following code (first 4 lines from the jax tutorial):
I get the error that kernel is restarting. I am using the m2lines deployment.
Since, I don't even know what the error is, I am not sure who is the person to reach out to about this. I will tag @jmunroe from 2i2c, as his team is most well-versed in this hub.
The text was updated successfully, but these errors were encountered: