Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,36 @@ defmodule Bumblebee do
end
end

@doc """
Initializes state for a new logits processor.

Returns `state`, which is an opaque `Nx.Container`, and it is then
passed to and returned from `process/4`.
"""
@doc type: :logits_processor
@spec logits_processor_init(
Bumblebee.LogitsProcessor.t(),
context :: term()
) :: Bumblebee.LogitsProcessor.state()
def logits_processor_init(%module{} = logits_processor, context) do
module.init(logits_processor, context)
end

@doc """
Processes logits, applying specific rules. Receives context, state and
logits, and returns updated logits and state.
"""
@doc type: :logits_processor
@spec logits_processor_process(
Bumblebee.LogitsProcessor.t(),
Bumblebee.LogitsProcessor.state(),
logits :: Nx.Tensor.t(),
context :: term()
) :: {Bumblebee.LogitsProcessor.state(), logits :: Nx.Tensor.t()}
def logits_processor_process(%module{} = logits_processor, state, logits, context) do
module.process(logits_processor, state, logits, context)
end

@doc """
Initializes state for a new scheduler loop.

Expand Down
38 changes: 38 additions & 0 deletions lib/bumblebee/logits_processor.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
defmodule Bumblebee.LogitsProcessor do
@moduledoc """
An interface for configuring and using logits processors.

Logits processors are used during autoregressive generation to modify
predicted scores at each generation step. This allows for applying
certain rules to the model output to control which tokens are picked
at each generation step, and which are not.

Every module implementing this behaviour is expected to also define
a configuration struct.
"""

@type t :: Bumblebee.Configurable.t()

@type state :: Nx.Container.t()

@doc """
Initializes state for a new logits processor.

Returns `state`, which is an opaque `Nx.Container`, and it is then
passed to and returned from `process/2`.

Oftentimes logits processors are stateless, in which case this
function can return an empty container, such as `{}`.
"""
@callback init(t(), any()) :: state()

@doc """
Processes logits, applying specific rules.
"""
@callback process(
t(),
state(),
logits :: Nx.Tensor.t(),
context :: term()
) :: {logits :: Nx.Tensor.t(), state :: map()}
end
Loading