Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ defmodule Bumblebee.Text.Generation do
if config.forced_token_ids do
&forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids)
end,
if config.dfa do
Bumblebee.configure(Bumblebee.Text.Generation.DFAProcessor, config.dfa)
end,
if config.temperature && config.temperature != 1.0 do
&temperature_processor(&1, &2, temperature: config.temperature)
end
Expand Down
105 changes: 105 additions & 0 deletions lib/bumblebee/text/generation/dfa_processor.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
defmodule Bumblebee.Text.Generation.DFAProcessor do
@moduledoc false

import Nx.Defn

@behaviour Bumblebee.Configurable
@behaviour Bumblebee.LogitsProcessor

options = [
initial_state: [
default: nil,
doc: "the initial state"
],
state_transitions: [
default: nil,
doc: "the definition of a deterministic finite automaton used for constrained generation"
],
vocab_size: [
default: nil,
doc: "the size of the vocabulary"
]
]

defstruct Bumblebee.Shared.option_defaults(options)

@impl Bumblebee.Configurable
def config(logits_processor, opts) do
Bumblebee.Shared.put_config_attrs(logits_processor, opts)
end

@impl Bumblebee.LogitsProcessor
def init(logits_processor, _context) do
dfa = logits_processor

num_states =
dfa.state_transitions
|> Enum.flat_map(fn {state, _token_id, next_state} -> [state, next_state] end)
|> Enum.uniq()
|> Enum.count()

# we add 1 to num_states as we want to have an empty row for state 0
# 0 should represent "no transition" as this is the only false value in nx
empty_state_transitions_tensor = Nx.broadcast(0, {num_states + 1, dfa.vocab_size})

state_transitions_tensor =
for transition <- dfa.state_transitions, reduce: empty_state_transitions_tensor do
transitions_tensor ->
{current_state, token_id, next_state} = transition
index = Nx.tensor([current_state, token_id])

Nx.indexed_put(transitions_tensor, index, next_state)
end

initial_state =
List.wrap(dfa.initial_state)
|> Enum.map(&List.wrap(&1))
|> Nx.tensor()

transition_tensors = state_transitions_tensor

%{
dfa_state: %{
last_state: initial_state,
state_transitions_tensor: transition_tensors
}
}
end

@impl Bumblebee.LogitsProcessor
def process(_logits_processor, state, logits, context) do
dfa_processing(logits, state, context)
end

deftransform dfa_processing(logits, state, context) do
transitions_tensor = state.dfa_state.state_transitions_tensor

last_state = state.dfa_state.last_state |> Nx.vectorize(:batch)
current_state = current_state(context, last_state, transitions_tensor)
logits = logits(logits, transitions_tensor, current_state)

current_state = Nx.devectorize(current_state, keep_names: false)

dfa_state = %{state.dfa_state | last_state: current_state}

state = %{state | dfa_state: dfa_state}

{logits, state}
end

defnp current_state(context, last_state, transitions_tensor) do
if context.length == context.input_length do
last_state
else
last_token_id = context.sequence[context.length - 1]
transitions_tensor[[Nx.squeeze(last_state), last_token_id]]
end
end

defnp logits(logits, transitions_tensor, current_state) do
suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits))
allowed_token_ids = transitions_tensor[Nx.squeeze(current_state)]

Nx.select(allowed_token_ids, logits, suppressed_logits)
end
end
4 changes: 4 additions & 0 deletions lib/bumblebee/text/generation_config.ex
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ defmodule Bumblebee.Text.GenerationConfig do
default: [],
doc: "a list of token ids to suppress during generation"
],
dfa: [
default: nil,
doc: "the definition of a deterministic finite automaton (DFA) for constrained generation"
],
no_repeat_ngram_length: [
default: nil,
doc: "when set, n-grams of the given length can occur only once in the generated sequence"
Expand Down
62 changes: 62 additions & 0 deletions test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,68 @@ defmodule Bumblebee.Text.GenerationTest do
assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]]))
end

test "DFA processor" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})

{:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})

assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec

input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]])
attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]])
seed = Nx.tensor([0])

inputs = %{
"input_ids" => Nx.Batch.concatenate([input_ids, input_ids]),
"attention_mask" => Nx.Batch.concatenate([attention_mask, attention_mask]),
"seed" => Nx.Batch.concatenate([seed, seed])
}

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 4)

generate =
Bumblebee.Text.Generation.build_generate(model, spec, generation_config,
logits_processors: [
Bumblebee.configure(Bumblebee.Text.Generation.DFAProcessor,
initial_state: [1, 2],
state_transitions: [
{1, 2, 2},
{2, 3, 3},
{3, 2, 2}
],
vocab_size: spec.vocab_size
)
]
)

%{token_ids: token_ids} = generate.(params, inputs)

# according to DFA definition
# first batch entry starts in state 1

# first token_id should be 2
assert_equal(token_ids[[0, 0]], 2)

# second token_id should be 3
assert_equal(token_ids[[0, 1]], 3)

# third token_id should be 2
assert_equal(token_ids[[0, 2]], 2)

# second batch entry starts in state 2

# first token_id should be 3
assert_equal(token_ids[[1, 0]], 3)

# second token_id should be 2
assert_equal(token_ids[[1, 1]], 2)

# third token_id should be 3
assert_equal(token_ids[[1, 2]], 3)
end

test "with stateful logits processor with different batch sizes" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})
Expand Down