Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using jax in ml-notebook instantly kills kernel #387

Closed
dhruvbalwada opened this issue Oct 10, 2022 · 53 comments · Fixed by #389
Closed

Using jax in ml-notebook instantly kills kernel #387

dhruvbalwada opened this issue Oct 10, 2022 · 53 comments · Fixed by #389

Comments

@dhruvbalwada
Copy link
Member

dhruvbalwada commented Oct 10, 2022

I am trying to use jax on the ml-notebook, but run into the problem that even the most basic functions kill the kernel.

If I run the following code (first 4 lines from the jax tutorial):

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)

I get the error that kernel is restarting. I am using the m2lines deployment.

Since, I don't even know what the error is, I am not sure who is the person to reach out to about this. I will tag @jmunroe from 2i2c, as his team is most well-versed in this hub.

@dhruvbalwada
Copy link
Member Author

I was able to figure out that this issue has something to do with ptxas, but not sure how to address this. Here is a screenshot:
Screen Shot 2022-10-12 at 9 20 29 AM

@rabernat
Copy link
Member

Thanks for digging into this Dhruv! @ngam - any thoughts on what might be wrong here?

@ngam
Copy link
Contributor

ngam commented Oct 12, 2022

That's what I addressed in the previous PR (add cuda-nvcc).

Could you print some diagnostics please? For example:

! conda list | grep jaxlib
! conda list | grep cuda-nvcc
! conda list | grep tensorflow

Also, it is important to know which ml-notebook image we are talking about here.

@ngam
Copy link
Contributor

ngam commented Oct 12, 2022

If I the problem is what I think it is (missing cuda-nvcc), this how you know: ! conda list | grep cuda-nvcc will return nothing. To fix it, simply install cuda-nvcc (from the nvidia channel): conda install cuda-nvcc -c nvidia (cuda-nvcc version doesn't matter, hopefully, but if it does, we can think about constraining it)

@dhruvbalwada
Copy link
Member Author

Thank you for looking through this so promptly.

I am not sure how to check which ml-notebook is being used. I use the m2lines 2i2c deployment, and am using the Tensorflow option (how would I check for notebook image version?).
Screen Shot 2022-10-12 at 11 07 11 AM

As far as the diagnostics go:
Screen Shot 2022-10-12 at 11 09 45 AM

@rabernat
Copy link
Member

@rabernat
Copy link
Member

To fix it, simply install cuda-nvcc (from the nvidia channel): conda install cuda-nvcc -c nvidia

Dhruv can you try this and see if it works?

@dhruvbalwada
Copy link
Member Author

I just tried that.

Now I get a new long error:

IPython 8.5.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax.numpy as jnp
   ...: from jax import grad, jit, vmap
   ...: from jax import random
   ...: 

In [2]: key = random.PRNGKey(0)
2022-10-12 15:16:32.114395: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc:57] cuLinkAddData fails. This is usually caused by stale driver version.
2022-10-12 15:16:32.114450: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:1325] The CUDA linking API did not work. Please use XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 to bypass it, but expect to get longer compilation time due to the lack of multi-threading.
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In [2], line 1
----> 1 key = random.PRNGKey(0)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/random.py:128, in PRNGKey(seed)
    114 """Create a pseudo-random number generator (PRNG) key given an integer seed.
    115 
    116 The resulting key carries the default PRNG implementation, as
   (...)
    125 
    126 """
    127 impl = default_prng_impl()
--> 128 key = prng.seed_with_impl(impl, seed)
    129 return _return_prng_keys(True, key)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/prng.py:262, in seed_with_impl(impl, seed)
    261 def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
--> 262   return random_seed(seed, impl=impl)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/prng.py:555, in random_seed(seeds, impl)
    553 else:
    554   seeds_arr = jnp.asarray(seeds)
--> 555 return random_seed_p.bind(seeds_arr, impl=impl)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/core.py:325, in Primitive.bind(self, *args, **params)
    322 def bind(self, *args, **params):
    323   assert (not config.jax_enable_checks or
    324           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 325   return self.bind_with_trace(find_top_trace(args), args, params)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/core.py:328, in Primitive.bind_with_trace(self, trace, args, params)
    327 def bind_with_trace(self, trace, args, params):
--> 328   out = trace.process_primitive(self, map(trace.full_raise, args), params)
    329   return map(full_lower, out) if self.multiple_results else full_lower(out)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/core.py:686, in EvalTrace.process_primitive(self, primitive, tracers, params)
    685 def process_primitive(self, primitive, tracers, params):
--> 686   return primitive.impl(*tracers, **params)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/prng.py:567, in random_seed_impl(seeds, impl)
    565 @random_seed_p.def_impl
    566 def random_seed_impl(seeds, *, impl):
--> 567   base_arr = random_seed_impl_base(seeds, impl=impl)
    568   return PRNGKeyArray(impl, base_arr)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/prng.py:572, in random_seed_impl_base(seeds, impl)
    570 def random_seed_impl_base(seeds, *, impl):
    571   seed = iterated_vmap_unary(seeds.ndim, impl.seed)
--> 572   return seed(seeds)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/prng.py:807, in threefry_seed(seed)
    804   raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")
    805 convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
    806 k1 = convert(
--> 807     lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
    808 with jax.numpy_dtype_promotion('standard'):
    809   # TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
    810   # inputs. We should avoid this.
    811   k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/lax/lax.py:500, in shift_right_logical(x, y)
    498 def shift_right_logical(x: Array, y: Array) -> Array:
    499   r"""Elementwise logical right shift: :math:`x \gg y`."""
--> 500   return shift_right_logical_p.bind(x, y)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/core.py:325, in Primitive.bind(self, *args, **params)
    322 def bind(self, *args, **params):
    323   assert (not config.jax_enable_checks or
    324           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 325   return self.bind_with_trace(find_top_trace(args), args, params)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/core.py:328, in Primitive.bind_with_trace(self, trace, args, params)
    327 def bind_with_trace(self, trace, args, params):
--> 328   out = trace.process_primitive(self, map(trace.full_raise, args), params)
    329   return map(full_lower, out) if self.multiple_results else full_lower(out)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/core.py:686, in EvalTrace.process_primitive(self, primitive, tracers, params)
    685 def process_primitive(self, primitive, tracers, params):
--> 686   return primitive.impl(*tracers, **params)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/dispatch.py:111, in apply_primitive(prim, *args, **params)
    109 def apply_primitive(prim, *args, **params):
    110   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 111   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
    112                                         **params)
    113   return compiled_fun(*args)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/util.py:222, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
    220   return f(*args, **kwargs)
    221 else:
--> 222   return cached(config._trace_context(), *args, **kwargs)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/util.py:215, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
    213 @functools.lru_cache(max_size)
    214 def cached(_, *args, **kwargs):
--> 215   return f(*args, **kwargs)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/dispatch.py:195, in xla_primitive_callable(prim, *arg_specs, **params)
    193   else:
    194     return out,
--> 195 compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
    196                                   prim.name, donated_invars, False, *arg_specs)
    197 if not prim.multiple_results:
    198   return lambda *args, **kw: compiled(*args, **kw)[0]

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/dispatch.py:324, in _xla_callable_uncached(fun, device, backend, name, donated_invars, keep_unused, *arg_specs)
    321   return sharded_lowering(fun, device, backend, name,
    322                           donated_invars, keep_unused, *arg_specs)
    323 else:
--> 324   return lower_xla_callable(fun, device, backend, name, donated_invars, False,
    325                             keep_unused, *arg_specs).compile().unsafe_call

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/dispatch.py:934, in XlaComputation.compile(self)
    931     self._executable = XlaCompiledComputation.from_trivial_jaxpr(
    932         **self.compile_args)
    933   else:
--> 934     self._executable = XlaCompiledComputation.from_xla_computation(
    935         self.name, self._hlo, self._in_type, self._out_type,
    936         **self.compile_args)
    938 return self._executable

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/dispatch.py:1052, in XlaCompiledComputation.from_xla_computation(name, xla_computation, in_type, out_type, nreps, device, backend, tuple_args, in_avals, out_avals, has_unordered_effects, ordered_effects, kept_var_idx, keepalive, host_callbacks)
   1049 options.parameter_is_tupled_arguments = tuple_args
   1050 with log_elapsed_time(f"Finished XLA compilation of {name} "
   1051                       "in {elapsed_time} sec"):
-> 1052   compiled = compile_or_get_cached(backend, xla_computation, options,
   1053                                    host_callbacks)
   1054 buffer_counts = get_buffer_counts(out_avals, ordered_effects,
   1055                                   has_unordered_effects)
   1056 execute = _execute_compiled if nreps == 1 else _execute_replicated

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/dispatch.py:1006, in compile_or_get_cached(backend, computation, compile_options, host_callbacks)
   1004     ir_str = computation
   1005   _dump_ir_to_file(module_name, ir_str)
-> 1006 return backend_compile(backend, computation, compile_options, host_callbacks)

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/profiler.py:313, in annotate_function.<locals>.wrapper(*args, **kwargs)
    310 @wraps(func)
    311 def wrapper(*args, **kwargs):
    312   with TraceAnnotation(name, **decorator_kwargs):
--> 313     return func(*args, **kwargs)
    314   return wrapper

File /srv/conda/envs/notebook/lib/python3.9/site-packages/jax/_src/dispatch.py:950, in backend_compile(backend, built_c, options, host_callbacks)
    945   return backend.compile(built_c, compile_options=options,
    946                          host_callbacks=host_callbacks)
    947 # Some backends don't have `host_callbacks` option yet
    948 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    949 # to take in `host_callbacks`
--> 950 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: UNKNOWN: no kernel image is available for execution on the device

@rabernat
Copy link
Member

Dhruv, it looks like the pangeo/ml-notebook:2022.09.21 predates #378, which might resolve the issue. I'm going to try that on https://staging.leap.2i2c.cloud/ and see if it works.

@rabernat
Copy link
Member

I was unable to overwrite the ML image using the configurator, so I haven't gotten an answer yet.

@ngam
Copy link
Contributor

ngam commented Oct 12, 2022

Thanks for reporting back the errors. These are the relevant bits:

2022-10-12 15:16:32.114395: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc:57] cuLinkAddData fails. This is usually caused by stale driver version.
2022-10-12 15:16:32.114450: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:1325] The CUDA linking API did not work. Please use XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 to bypass it, but expect to get longer compilation time due to the lack of multi-threading.
...
XlaRuntimeError: UNKNOWN: no kernel image is available for execution on the device

And this is definitely related to the ptxas stuff. One thing to try (unlikely to resolve this, but might as well try it) is to get a specific cuda-nvcc, e.g.

conda install 'cuda-nvcc==11.7.99' -c nvidia

dhruvbalwada added a commit to dhruvbalwada/infrastructure that referenced this issue Oct 12, 2022
This change should fix the jax problems that are mentioned here: pangeo-data/pangeo-docker-images#387 (comment)
@ngam
Copy link
Contributor

ngam commented Oct 12, 2022

Okay, I also wonder if this is a bad issue with K80... I hope not, but it could be. The jaxlib versions we have are strictly 11.2+, and we target these compute capabilities:

TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_62,sm_70,sm_72,sm_75,sm_80,sm_86,compute_86

This sadly happens to be skipping sm_37 which is the supposed one for K80s. (This is what we do across the whole ecosystem, e.g. tensorflow too, and I think we got them from upstream.)

I can produce a specific jaxlib build for you later this week to test if nothing else fixes this issue. If that resolves it, then we can add it here https://github.com/conda-forge/jaxlib-feedstock/blob/04564570fb408c83222d67fe948811ac5e3fb9d2/recipe/build.sh#L44

(I happen to have access to only A100 and V100, so my testing could easily miss older GPUs... the clusters I have access to only have A100 and V100)

@rabernat
Copy link
Member

rabernat commented Oct 12, 2022

Alternatively, we could configure out cloud hubs to use a different GPU. What would you recommend as a good tradeoff btw cost and performance for general ML work?

https://cloud.google.com/compute/gpus-pricing

@ngam
Copy link
Contributor

ngam commented Oct 12, 2022

The K80 is definitely good in my previous experience (before I got involved in building jaxlib). The best of all worlds (cost, etc.) seem to be T4 and there are free versions for exploratory work (e.g. planetary computer, the never-available sagemaker free tier, and the ever-present colab all have T4s readily) and T4 supports more interesting modern optimizations (I believe)

@ngam
Copy link
Contributor

ngam commented Oct 12, 2022

Actually, I now realize I have access to some weird "vis" compute node that I never used, that may still have K80s. I will test later today to confirm...

@ngam
Copy link
Contributor

ngam commented Oct 12, 2022

I have tested this snippet with jax/jaxlib from conda-forge:

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)

on various clusters with A100, V100, and K80 and they all work fine. Crucially, these clusters all have system-level thorough installations of all CUDA compilers (by my request from a while ago, especially ptxas). Now I need to figure out a way to test this in isolation with the ptxas from the nvidia channel. If, for some annoying reason, we cannot get this to work correctly, I will work with the team here to rebuild the containers from scratch based on the NGC containers (where they make available all the NVIDIA proprietary stuff). More soon!

@ngam
Copy link
Contributor

ngam commented Oct 12, 2022

Okay, another update. I can reproduce the same exact error as above now inside a container on a cloud service (not a cluster). This is definitely an issue with ptxas. Let me try to get to the bottom of this.

@dhruvbalwada
Copy link
Member Author

Thank you so much ngam!

@ngam
Copy link
Contributor

ngam commented Oct 12, 2022

TEMPORARY SOLUTION: the key issue here is the mismatch between the driver and the software versions from the nvidia channel (something currently beyond my understanding), but here's the solution that got this container to work on a cloud service:

mamba install cuda-nvcc==11.6.* -c nvidia

@dhruvbalwada could you please test with this command? If it works, we will pin cuda-nvcc in these containers until we find a better and more sustainable solution.

@yuvipanda
Copy link
Member

On the cluster, we follow https://cloud.google.com/kubernetes-engine/docs/how-to/gpus#installing_drivers to install the drivers. We use the 'default' version here. The m2lines hub is on k8s 1.21 and hence is running driver version R450 (whatever that means?!). We can upgrade that to 'latest' if needed.

@dhruvbalwada
Copy link
Member Author

dhruvbalwada commented Oct 12, 2022

@ngam - I tried downgrading the cuda-nvcc to 11.6 as you suggested (after some updates to the cloud deployment, it was at 11.8.X). However that did not solve the problem - sadly.

@ngam
Copy link
Contributor

ngam commented Oct 12, 2022

It will likely be difficult to downgrade cuda-nvcc enough for this to work on k80...

A (safe) alternative is to simply follow the instructions in the error message:

set XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1

--

Can you check if the following snippet resolves it for you?

import os
os.environ["XLA_FLAGS"]="--xla_gpu_force_compilation_parallelism=1"

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)

@dhruvbalwada
Copy link
Member Author

dhruvbalwada commented Oct 12, 2022

Maybe helpful, I was reading this jax-ml/jax#5723 . It said have the
nvcc --version (right now on 11.6, after downgrading from 11.8 - default on m2lines)and
nvidia-smi (right now on 11.0).
be on the same version.

@yuvipanda - is it possible to align these version numbers?

@ngam
Copy link
Contributor

ngam commented Oct 12, 2022

Yes, but we don't have cuda-nvcc 11.0 on the nvidia channel :( lowest is 11.3

yuvipanda will need to figure out how to match those numbers to the other crazy numbers (e.g. R450)...

Try import os; os.environ["XLA_FLAGS"]="--xla_gpu_force_compilation_parallelism=1" as a temporary workaround...

@dhruvbalwada
Copy link
Member Author

@ngam setting xla flags as you and the error message suggested did work.

It comes with the downside, as quoted :

2022-10-12 19:37:48.587389: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_asm_compiler.cc:61] cuLinkAddData fails. This is usually caused by stale driver version.
2022-10-12 19:37:48.587444: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:1363] The CUDA linking API did not work. Please use XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 to bypass it, but expect to get longer compilation time due to the lack of multi-threading.

However, since I am new to all this, it is unlikely to a bottle neck to anything I am doing for now.

@dhruvbalwada
Copy link
Member Author

Awesome, yes I see it. Will report back on results soon.

@dhruvbalwada
Copy link
Member Author

@yuvipanda - I tried the T4. It does not resolve the problem, maybe as the version numbers are still distinct (nvcc at 11.8 and nvidia-smi at 11.0).
As discussed above we can bypass this problem by turning off multi-threading inside python.

Maybe the version numbers need to be pinned? @ngam do you have suggestions for what to do with a working T4?

@yuvipanda
Copy link
Member

@dhruvbalwada ok, i'm going to try changing the driver version now, so let's see.

@yuvipanda
Copy link
Member

@dhruvbalwada Try restarting your server, you should see cuda version 11.4 now?

yuvipanda added a commit to yuvipanda/pilot-hubs that referenced this issue Oct 12, 2022
And use it for the m2lines hub and leap hub, based on
conversations in pangeo-data/pangeo-docker-images#387
@yuvipanda
Copy link
Member

yuvipanda commented Oct 12, 2022

Here's the new version:

(notebook) jovyan@jupyter-yuvipanda:~$ nvidia-smi
Wed Oct 12 20:49:39 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   47C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

vs the old version:

(notebook) jovyan@jupyter-yuvipanda:~$ nvidia-smi
Wed Oct 12 20:51:44 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.119.04   Driver Version: 450.119.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   71C    P8    35W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

@dhruvbalwada
Copy link
Member Author

I see cuda version 11.4 and nvcc version 11.8.89. The XLARuntimeError still shows up. :(
@ngam - do you have any suggestions for specific version numbers that might solve??

@ngam
Copy link
Contributor

ngam commented Oct 12, 2022

Try conda install -c nvidia cuda-nvcc==11.4.* or ... ==11.3.*

@dhruvbalwada
Copy link
Member Author

dhruvbalwada commented Oct 12, 2022

So what did work for me now, is this.
I downgraded the nvcc to 11.4. This fixed the problem. Should we pin this for now?

*This is all for T4.

@dhruvbalwada
Copy link
Member Author

Unless it is possible to take CUDA all the way upto 11.8 (without breaking anything else)?

@yuvipanda
Copy link
Member

@dhruvbalwada let me poke around.

@yuvipanda
Copy link
Member

@dhruvbalwada I'll have to upgrade the k8s master, and then upgrade the nodes to make that up I think. Should be doable without interruption, am trying that out now. Will keep you posted here.

@yuvipanda
Copy link
Member

I think it would be very useful to have a table of cuda versions and the pangeo docker images they are compatible with.

@dhruvbalwada
Copy link
Member Author

I don't know how to get that. I also am not sure what other things (apart from JAX) are relying on the cuda versions.

@yuvipanda
Copy link
Member

@dhruvbalwada you're gonna see a couple minutes interruption on m2lines now, sorry about that.

@dhruvbalwada
Copy link
Member Author

No worries, I am not doing anything of significance at the moment. Kill the instance if you have to.

@yuvipanda
Copy link
Member

@dhruvbalwada ok, killed :)

@yuvipanda
Copy link
Member

Now at:

(notebook) jovyan@jupyter-yuvipanda:~$ nvidia-smi
Wed Oct 12 22:38:45 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   51C    P8    11W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

@dhruvbalwada @ngam I think 11.6 is as high as GKE is gonna go for now.

@dhruvbalwada
Copy link
Member Author

dhruvbalwada commented Oct 13, 2022

Thank you Yuvi!

As predicted by ngam:

On T4:

  • Downgrading to nvcc-11.6 (matching to CUDA) makes things work.
    Use: mamba install cuda-nvcc==11.6.* -c nvidia

On K80

  • Downgrading to nvcc-11.6 (matching to CUDA) does not make things work.
  • Bypassing and turning off multi-threading works. Use:
import os
os.environ["XLA_FLAGS"]="--xla_gpu_force_compilation_parallelism=1"

@dhruvbalwada
Copy link
Member Author

@ngam - should we pin nvcc to 11.6 for now (so that the user doesn't have to manually downgrade everytime)?

@ngam
Copy link
Contributor

ngam commented Oct 13, 2022

Sure, here's a PR: #389

@dhruvbalwada
Copy link
Member Author

Thank you. I think we can close this issue when the PR is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants