Skip to content

remove raising on tie_break :high #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

EMLX is the Nx Backend for the [MLX](https://github.com/ml-explore/mlx) library.

Because of MLX's nature, EMLX with GPU backend is only supported on macOS.
Because of MLX's nature, EMLX with GPU backend is only supported on macOS.

MLX with CPU backend is available on most mainstream platforms, however, the CPU backend may not be as optimized as the GPU backend,
MLX with CPU backend is available on most mainstream platforms, however, the CPU backend may not be as optimized as the GPU backend,
especially for non-macOS OSes, as they're not prioritized for development. Right now, EMLX supports x86_64 and arm64 architectures
on both macOS and Linux.

Expand Down Expand Up @@ -49,6 +49,23 @@ Defaulting to Nx.Defn.Evaluator is the safest option for now.
Nx.Defn.default_options(compiler: EMLX)
```

### Configuration

EMLX supports several configuration options that can be set in your application's config:

#### `:warn_unsupported_option`

Controls whether warnings are logged when unsupported options are used with certain operations.

- **Type**: `boolean`
- **Default**: `true`
- **Description**: When enabled, EMLX will log warnings for operations that receive options not supported by the MLX backend. For example, `Nx.argmax/2` and `Nx.argmin/2` with `tie_break: :high` will log a warning since MLX doesn't support this tie-breaking behavior.

```elixir
# In config/config.exs
config :emlx, :warn_unsupported_option, false
```

### MLX binaries

EMLX relies on the [MLX](https://github.com/ml-explore/mlx) library to function, and currently EMLX will download precompiled builds from [mlx-build](https://github.com/cocoa-xu/mlx-build).
Expand Down
8 changes: 6 additions & 2 deletions lib/emlx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ defmodule EMLX.Backend do
alias Nx.Tensor, as: T
alias EMLX.Backend, as: Backend

require Logger

defstruct [:ref, :shape, :type, :data]

@impl true
Expand Down Expand Up @@ -535,8 +537,10 @@ defmodule EMLX.Backend do
axis = opts[:axis]
keep_axis = opts[:keep_axis] == true

if opts[:tie_break] == :high do
raise "Nx.Backend.#{unquote(op)}/3 with tie_break: :high is not supported in EMLX"
if Application.get_env(:emlx, :warn_unsupported_option, true) and opts[:tie_break] == :high do
Logger.warning(
"Nx.Backend.#{unquote(op)}/3 with tie_break: :high is not supported in EMLX"
)
end

t_mx = from_nx(tensor)
Expand Down
48 changes: 48 additions & 0 deletions test/emlx/config_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
defmodule EMLX.ConfigTest do
use EMLX.Case, async: false

import ExUnit.CaptureLog

setup do
# Store original config value to restore after each test
original_value = Application.get_env(:emlx, :warn_unsupported_option, true)

on_exit(fn ->
Application.put_env(:emlx, :warn_unsupported_option, original_value)
end)

{:ok, original_value: original_value}
end

# Test both argmax and argmin with config enabled/disabled
for op <- [:argmax, :argmin] do
describe "#{op} with warn_unsupported_option" do
test "logs warning when config is enabled (default)" do
Application.put_env(:emlx, :warn_unsupported_option, true)

tensor = Nx.tensor([[1, 3, 2], [6, 4, 5]], backend: EMLX.Backend)

log_output =
capture_log(fn ->
Nx.unquote(op)(tensor, axis: 0, tie_break: :high)
end)

assert log_output =~
"Nx.Backend.#{unquote(op)}/3 with tie_break: :high is not supported in EMLX"
end

test "does not log warning when config is disabled" do
Application.put_env(:emlx, :warn_unsupported_option, false)

tensor = Nx.tensor([[1, 3, 2], [6, 4, 5]], backend: EMLX.Backend)

log_output =
capture_log(fn ->
Nx.unquote(op)(tensor, axis: 0, tie_break: :high)
end)

assert log_output == ""
end
end
end
end