Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ defmodule Bumblebee.Text.Generation do
if config.forced_token_ids do
&forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids)
end,
if config.allowed_token_ids != [] do
&allowed_tokens_processor(&1, &2, allowed_token_ids: config.allowed_token_ids)
end,
if config.dfa do
&dfa_processor(&1, &2, dfa: config.dfa)
end,
if config.temperature && config.temperature != 1.0 do
&temperature_processor(&1, &2, temperature: config.temperature)
end
Expand Down
143 changes: 143 additions & 0 deletions lib/bumblebee/text/generation/logits_processing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,133 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do

import Nx.Defn

deftransform dfa_processor(logits, context, opts \\ []) do
opts = Keyword.validate!(opts, [:dfa])
dfa = opts[:dfa]
dfa_mode = dfa[:mode]

last_state =
Enum.map(dfa.state_transitions, fn {state, _token_id, next_state} ->
max(state, next_state)
end)
|> Enum.max()

num_states = last_state + 1

state_transition_tensor = Nx.broadcast(0, {num_states, Nx.size(logits)})

state_transitions_tensor =
for {current_state, token_id, next_state} <- dfa.state_transitions,
reduce: state_transition_tensor do
state_transition_tensor ->
Nx.indexed_put(
state_transition_tensor,
Nx.tensor([current_state, token_id]),
next_state
)
end

initial_state = Nx.tensor([dfa.initial_state]) |> Nx.vectorize(:batch)

case dfa_mode do
:stateful ->
current_state =
if context.length == context.input_length do
initial_state
else
last_state = context.logits_processor_state.dfa

current_state_from_last_state(
state_transitions_tensor,
context.sequence,
context.length,
last_state
)
end

logits = suppress_logits(logits, state_transitions_tensor, current_state)

context = put_in(context, [:logits_processor_state, :dfa], current_state)

{logits, context}

:stateless ->
current_state =
if context.length == context.input_length do
initial_state
else
find_current_state(
initial_state,
state_transitions_tensor,
context.sequence,
context.input_length,
context.length
)
end

suppress_logits(logits, state_transitions_tensor, current_state)
end
end

defnp suppress_logits(logits, state_transitions_tensor, state) do
suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits))
Nx.select(state_transitions_tensor[state], logits, suppressed_logits)
end

defnp current_state_from_last_state(
state_transitions_tensor,
sequence,
current_length,
last_state
) do
last_token_id = sequence[current_length - 1]
state_transitions_tensor[[last_state, last_token_id]] |> Nx.squeeze()
end

defn find_current_state(
initial_state,
state_transitions_tensor,
sequence,
input_length,
current_length
) do
generated_length = current_length - input_length

last_token_id = sequence[current_length - 1]
token_column = state_transitions_tensor[[.., last_token_id]] |> Nx.squeeze()

# top_k gives two top values + indices of the column
# if the token is unambiguous, there is only one value != 0 in the column (that's top_values[0])
# if top_values[1] != 0, there must be two values != 0 in the column, so it's ambiguous
{top_values, _top_indices} = Nx.top_k(token_column, k: 2)

ambiguous_token? = top_values[[1]]

state =
cond do
ambiguous_token? ->
{state, _i, _sequence, _input_length, _generated_length, _states_transitions_tensor} =
while {state = initial_state, i = 0, sequence, input_length, generated_length,
state_transitions_tensor},
Nx.less(i, generated_length) do
chosen_token = sequence[input_length + i]
new_state = state_transitions_tensor[[state, chosen_token]]

{new_state, i + 1, sequence, input_length, generated_length,
state_transitions_tensor}
end

state

true ->
# we know that top_values[0] is the state we moved to
# as it's the only state transition with new state != 0 for the token_id
top_values[[0]]
end

state
end

deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do
opts = Keyword.validate!(opts, [:suppressed_token_ids])

Expand All @@ -11,6 +138,12 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
Nx.indexed_put(logits, indices, values)
end

deftransform allowed_tokens_processor(logits, _context, opts \\ []) do
_opts = Keyword.validate!(opts, [:allowed_token_ids])

allow_token_ids(logits, opts[:allowed_token_ids])
end

defn bos_token_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:bos_token_id])
bos_token_id = opts[:bos_token_id]
Expand Down Expand Up @@ -113,6 +246,16 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
|> Nx.put_slice([token_id], Nx.tensor([0], type: Nx.type(logits)))
end

deftransformp allow_token_ids(logits, allowed_token_ids) do
# Convert allowed_token_ids to a tensor if it's a list
allowed_indices = Nx.tensor(allowed_token_ids)
allowed_logits = Nx.take(logits, allowed_indices)
suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits))

indices = Nx.new_axis(allowed_indices, -1)
Nx.indexed_put(suppressed_logits, indices, allowed_logits)
end

deftransformp ignore_token_id(logits, token_id) do
Nx.put_slice(
logits,
Expand Down
9 changes: 9 additions & 0 deletions lib/bumblebee/text/generation_config.ex
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ defmodule Bumblebee.Text.GenerationConfig do
default: [],
doc: "a list of token ids to suppress during generation"
],
allowed_token_ids: [
default: [],
doc:
"a list of token ids to enforce during generation (suppressing the all tokens that are not in the list)"
],
dfa: [
default: nil,
doc: "the definition of a deterministic finite automaton (dfa) for the generation"
],
no_repeat_ngram_length: [
default: nil,
doc: "when set, n-grams of the given length can occur only once in the generated sequence"
Expand Down
205 changes: 205 additions & 0 deletions pair_programming.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
Mix.install([
{:bumblebee, path: "../bumblebee_bitcrowd"},
{:nx, "~> 0.10.0", override: true},
{:exla, "~> 0.10.0"},
{:emlx, github: "elixir-nx/emlx"},
{:benchee, "~> 1.0"}
])

# backend = EMLX.Backend
# compiler = Nx.Defn.Evaluator
backend = EXLA.Backend
compiler = EXLA

Nx.global_default_backend(backend)

repo = {:hf, "HuggingFaceTB/SmolLM2-135M-Instruct"}

sequence_length = 512

prompt = """
Give me an array that contains a mix of numbers and text.
There MUST be at least one number and one text.
Valid examples are:

["hello",89,"hola",6,4,8]
"""

# this DFA definition is "array of integers" generatd by outlines-core
#
# let schema = r#"{
# "type": "array",
# "items": {
# "type": "integer"
# }
# }"#;

initial_state = 64

state_transitions =
[
{96, 33, 128},
{96, 40, 128},
{96, 36, 128},
{96, 32, 112},
{96, 39, 128},
{96, 35, 128},
{96, 38, 128},
{96, 34, 128},
{96, 41, 128},
{96, 37, 128},
{144, 2, 144},
{176, 77, 144},
{224, 33, 240},
{224, 40, 240},
{224, 36, 240},
{224, 32, 112},
{224, 39, 240},
{224, 35, 240},
{224, 38, 240},
{224, 34, 240},
{224, 41, 240},
{224, 37, 240},
{128, 33, 128},
{128, 77, 144},
{128, 36, 128},
{128, 28, 192},
{128, 39, 128},
{128, 10790, 224},
{128, 34, 128},
{128, 37, 128},
{128, 40, 128},
{128, 32, 128},
{128, 216, 176},
{128, 35, 128},
{128, 38, 128},
{128, 41, 128},
{128, 6329, 144},
{80, 33, 128},
{80, 77, 144},
{80, 29, 96},
{80, 36, 128},
{80, 41, 128},
{80, 32, 112},
{80, 39, 128},
{80, 216, 176},
{80, 35, 128},
{80, 40, 128},
{80, 38, 128},
{80, 34, 128},
{80, 6329, 144},
{80, 37, 128},
{112, 216, 176},
{112, 10790, 224},
{112, 77, 144},
{112, 6329, 144},
{112, 28, 192},
{64, 9197, 96},
{64, 75, 160},
{208, 33, 240},
{208, 29, 224},
{208, 36, 240},
{208, 40, 240},
{208, 32, 112},
{208, 39, 240},
{208, 35, 240},
{208, 38, 240},
{208, 34, 240},
{208, 41, 240},
{208, 37, 240},
{160, 33, 128},
{160, 77, 144},
{160, 36, 128},
{160, 39, 128},
{160, 256, 176},
{160, 731, 96},
{160, 34, 128},
{160, 37, 128},
{160, 29, 96},
{160, 40, 128},
{160, 32, 112},
{160, 216, 80},
{160, 35, 128},
{160, 38, 128},
{160, 6329, 144},
{160, 41, 128},
{240, 33, 240},
{240, 77, 144},
{240, 36, 240},
{240, 28, 192},
{240, 39, 240},
{240, 10790, 224},
{240, 34, 240},
{240, 37, 240},
{240, 40, 240},
{240, 32, 240},
{240, 216, 176},
{240, 35, 240},
{240, 38, 240},
{240, 41, 240},
{240, 6329, 144},
{192, 33, 240},
{192, 29, 224},
{192, 36, 240},
{192, 40, 240},
{192, 32, 112},
{192, 39, 240},
{192, 216, 208},
{192, 35, 240},
{192, 731, 224},
{192, 38, 240},
{192, 34, 240},
{192, 41, 240},
{192, 37, 240}
]

unique_states =
Enum.flat_map(state_transitions, fn {state, _token_id, next_state} -> [state, next_state] end)
|> Enum.uniq()
|> Enum.sort()

states_map = for {state, i} <- Enum.with_index(unique_states), into: %{}, do: {state, i}

compact_states =
Enum.map(state_transitions, fn {state, token_id, next_state} ->
{states_map[state], token_id, states_map[next_state]}
end)

state_transitions = compact_states
initial_state = states_map[initial_state]

dfa = %{state_transitions: state_transitions, mode: :stateful, initial_state: initial_state}

build_serving = fn backend, compiler, max_new_tokens, dfa ->
Nx.global_default_backend(backend)

{:ok, model_info} = Bumblebee.load_model(repo, backend: backend)

{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)

generation_config =
Bumblebee.configure(generation_config,
max_new_tokens: max_new_tokens,
# min_length: sequence_length + max_new_tokens,
strategy: %{type: :multinomial_sampling, top_p: 0.9},
# strategy: %{type: :greedy_search},
dfa: dfa
)

Bumblebee.Text.generation(model_info, tokenizer, generation_config,
compile: [batch_size: 1, sequence_length: sequence_length],
stream: false,
defn_options: [compiler: compiler]
)
end

max_new_tokens = 32
# dfa = nil
dfa = %{dfa | mode: :stateful}

serving = build_serving.(backend, compiler, max_new_tokens, dfa)

for _i <- 1..3 do
%{results: [_result]} = Nx.Serving.run(serving, prompt) |> dbg
end