Skip to content

jax-metal: 3-D convolution crashes #21554

@jonatanklosko

Description

@jonatanklosko

Description

import jax
import jax.numpy as jnp

def f(left, right):
  return jax.lax.conv(left, right, window_strides=[1, 1, 1], padding="SAME")

left = jnp.full((1, 4, 5, 5, 5), 1, dtype=jnp.float32)
right = jnp.full((4, 4, 3, 3, 3), 1, dtype=jnp.float32)

# Print lowered HLO
print(jax.jit(f).lower(left, right).as_text())
print(jax.jit(f)(left, right))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1x4x5x5x5xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<4x4x3x3x3xf32> {mhlo.layout_mode = "default"}) -> (tensor<1x4x5x5x5xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1, 2]x[o, i, 0, 1, 2]->[b, f, 0, 1, 2], window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1], reverse = [false, false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<1x4x5x5x5xf32>, tensor<4x4x3x3x3xf32>) -> tensor<1x4x5x5x5xf32>
    return %0 : tensor<1x4x5x5x5xf32>
  }
}

Crashes the process with

LLVM ERROR: Failed to infer result type(s).

I assume this means 3-D convolution is not supported yet.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions