Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ defmodule Bumblebee.Text.Generation do
if config.forced_token_ids do
&forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids)
end,
if config.allowed_token_ids != [] do
&allowed_tokens_processor(&1, &2, allowed_token_ids: config.allowed_token_ids)
end,
if config.dfa do
&dfa_processor(&1, &2, dfa: config.dfa)
end,
if config.temperature && config.temperature != 1.0 do
&temperature_processor(&1, &2, temperature: config.temperature)
end
Expand Down
101 changes: 101 additions & 0 deletions lib/bumblebee/text/generation/logits_processing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,91 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do

import Nx.Defn

deftransform dfa_processor(logits, context, opts \\ []) do
opts = Keyword.validate!(opts, [:dfa])
dfa = opts[:dfa]

num_states =
Enum.dedup_by(dfa.state_transitions, fn {state, _token_id, _next_state} -> state end)
|> length()

state_transition_tensor = Nx.broadcast(0, {num_states, Nx.size(logits)})

state_transitions_tensor =
for {current_state, token_id, next_state} <- dfa.state_transitions,
reduce: state_transition_tensor do
state_transition_tensor ->
Nx.indexed_put(
state_transition_tensor,
Nx.tensor([current_state, token_id]),
next_state
)
end

initial_state = Nx.tensor([0]) |> Nx.vectorize(batch: 1)

current_state =
find_current_state(
initial_state,
state_transitions_tensor,
context.sequence,
context.input_length,
context.length
)

suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits))
logits = Nx.select(state_transitions_tensor[current_state], logits, suppressed_logits)

logits
end

defn find_current_state(
initial_state,
state_transitions_tensor,
sequence,
input_length,
current_length
) do
generated_length = current_length - input_length

last_token_id = sequence[current_length - 1]
token_column = state_transitions_tensor[[.., last_token_id]] |> Nx.squeeze()

# top_k gives two top values + indices of the column
# if the token is unambiguous, there is only one value != 0 in the column (that's top_values[0])
# if top_values[1] != 0, there must be two values != 0 in the column, so it's ambiguous
{top_values, top_indices} = Nx.top_k(token_column, k: 2)

ambiguous_token? = Nx.not_equal(top_values[1], Nx.tensor(0))

state =
cond do
generated_length == 0 ->
initial_state

ambiguous_token? ->
{state, _i, _sequence, _input_length, _generated_length, _states_transitions_tensor} =
while {state = initial_state, i = 0, sequence, input_length, generated_length,
state_transitions_tensor},
Nx.less(i, generated_length) do
chosen_token = sequence[input_length + i]
new_state = state_transitions_tensor[[state, chosen_token]]

{new_state, i + 1, sequence, input_length, generated_length,
state_transitions_tensor}
end

state

true ->
# we know that top_indices[0] is the row index for the only token id != 0
# this is our new state!
top_indices[0]
end

state
end

deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do
opts = Keyword.validate!(opts, [:suppressed_token_ids])

Expand All @@ -11,6 +96,12 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
Nx.indexed_put(logits, indices, values)
end

deftransform allowed_tokens_processor(logits, _context, opts \\ []) do
_opts = Keyword.validate!(opts, [:allowed_token_ids])

allow_token_ids(logits, opts[:allowed_token_ids])
end

defn bos_token_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:bos_token_id])
bos_token_id = opts[:bos_token_id]
Expand Down Expand Up @@ -113,6 +204,16 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
|> Nx.put_slice([token_id], Nx.tensor([0], type: Nx.type(logits)))
end

deftransformp allow_token_ids(logits, allowed_token_ids) do
# Convert allowed_token_ids to a tensor if it's a list
allowed_indices = Nx.tensor(allowed_token_ids)
allowed_logits = Nx.take(logits, allowed_indices)
suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits))

indices = Nx.new_axis(allowed_indices, -1)
Nx.indexed_put(suppressed_logits, indices, allowed_logits)
end

deftransformp ignore_token_id(logits, token_id) do
Nx.put_slice(
logits,
Expand Down
8 changes: 8 additions & 0 deletions lib/bumblebee/text/generation_config.ex
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ defmodule Bumblebee.Text.GenerationConfig do
default: [],
doc: "a list of token ids to suppress during generation"
],
allowed_token_ids: [
default: [],
doc: "a list of token ids to enforce during generation (suppressing the all tokens that are not in the list)"
],
dfa: [
default: nil,
doc: "the definition of a very simple deterministic finite automaton (dfa) for the generation"
],
no_repeat_ngram_length: [
default: nil,
doc: "when set, n-grams of the given length can occur only once in the generated sequence"
Expand Down
127 changes: 127 additions & 0 deletions pair_programming.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
Mix.install([
{:bumblebee, path: "../bumblebee_bitcrowd"},
{:nx, "~> 0.10.0", override: true},
{:emlx, github: "elixir-nx/emlx"},
{:benchee, "~> 1.0"}
])

Nx.global_default_backend({EMLX.Backend, device: :gpu})
repo = {:hf, "HuggingFaceTB/SmolLM2-135M-Instruct"}
{:ok, model_info} = Bumblebee.load_model(repo, backend: {EMLX.Backend, device: :gpu})
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)

sequence_length = 512

prompt = """
Give me an array that contains a mix of numbers and text.
There MUST be at least one number and one text.
Valid examples are:

["hello",89,"hola",6,4,8]
"""

numbers = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]
array_start_token = "["
array_end_token = "]"
array_addition_token = ","
# String Token would require ! (like "everything, just without ....)
# Token 18
string_token = "\""

# ToDo: should be a list -> idx
states = [
:starting,
:in_array,
:in_number,
:in_addition,
:in_string,
:end_of_string,
:ending,
:done
]

state_to_num = fn state -> Enum.find_index(states, & &1 == state) end

# ------------------------------------- above chars ------------------------------ #
# ------------------------------------- below tokens ------------------------------ #

array_start_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_start_token)
array_end_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_end_token)
addition_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_addition_token)
string_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, string_token)
end_of_sequence_token_id = Bumblebee.Tokenizer.special_token_id(tokenizer, :eos)

special_tokens_ids = for token_id <- 0..17, do: token_id
number_tokens_ids = Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1))
vocabulary_token_ids = for token_id <- 0..model_info.spec.vocab_size, do: token_id

string_token_ids = vocabulary_token_ids -- ([string_token_id] ++ special_tokens_ids)

## sequence : 75, 33, 34, ...

# State 0 1
# chosen Token id 75 18
# new state 1 3

## tensor
# State/token ids -> new state
## State / Token ids 0 1 2 ... 18 ... 33 ... 75 76
## starting (0) -1 -1 -1 -1 -1 1
## in_array (1) 4 2 6
## in_number (2)
## in_addition (3)
## in_string (4)
## end_of_string (5)
## ending (6)
## done (7)

## which tokens lead to which state from given state
state_transitions =
[
# starting
{:starting, [array_start_token_id], :in_array},
# in_array
{:in_array, number_tokens_ids, :in_number},
{:in_array, [array_end_token_id], :ending},
{:in_array, [string_token_id], :in_string},
# in_number
{:in_number, number_tokens_ids, :in_number},
{:in_number, [addition_token_id], :in_addition},
{:in_number, [array_end_token_id], :ending},
# in_addition
{:in_addition, number_tokens_ids, :in_number},
{:in_addition, [string_token_id], :in_string},
# in_string
{:in_string, string_token_ids, :in_string},
{:in_string, [string_token_id], :end_of_string},
# end_of_string
{:end_of_string, [addition_token_id], :in_addition},
{:end_of_string, [array_end_token_id], :ending},
# ending
{:ending, [end_of_sequence_token_id], :done}
]
|> Enum.flat_map(fn {current_state, tensor_ids, next_state} ->
for tensor_id <- tensor_ids do
{state_to_num.(current_state), tensor_id, state_to_num.(next_state)}
end
end)

dfa = %{ state_transitions: state_transitions, }

generation_config =
Bumblebee.configure(generation_config,
max_new_tokens: 24,
strategy: %{type: :multinomial_sampling, top_p: 0.6},
dfa: dfa
)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
compile: [batch_size: 1, sequence_length: sequence_length],
stream: false,
defn_options: [compiler: Nx.Defn.Evaluator]
)

%{results: [_result]} = Nx.Serving.run(serving, prompt) |> dbg

18 changes: 18 additions & 0 deletions test/bumblebee/text/generation/logits_processing_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,24 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do

alias Bumblebee.Text.Generation.LogitsProcessing

describe "find_current_state" do
test "finds ambiguous token ids" do
## rows = State / columns = token ids -> value = new state
## 0 1 2 3
## 1 1 0 0
## 2 2 0 0
## token 1 is ambiguous
ambiguous_token_id = 1
state_transitions_tensor = Nx.tensor([[0, 1, 2, 3], [1, 3, 0, 0], [2, 2, 0, 0]])
token_column = state_transitions_tensor[[.., ambiguous_token_id]] |> Nx.squeeze()
{top_values, top_indices} = Nx.top_k(token_column, k: 2)

ambiguous_token? = top_values[1]

assert Nx.not_equal(ambiguous_token?, Nx.tensor(0))
end
end

describe "suppressed_tokens_processor/3" do
test "ignores the given tokens" do
logits = Nx.tensor([1.0, 2.0, 3.0, 4.0])
Expand Down