Skip to content

Commit 005f71b

Browse files
committed
Add test for dfa_processor/3
1 parent 3afc672 commit 005f71b

File tree

1 file changed

+73
-2
lines changed

1 file changed

+73
-2
lines changed

test/bumblebee/text/generation/logits_processing_test.exs

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,77 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do
55

66
alias Bumblebee.Text.Generation.LogitsProcessing
77

8+
describe "dfa_processor/3" do
9+
test "constrained sampling with DFA" do
10+
# the list of all allowed transitions
11+
transitions = [
12+
# {state, token_id, next_state}
13+
{0, 1, 1},
14+
{1, 2, 2},
15+
{2, 1, 1}
16+
]
17+
18+
initial_state = 0
19+
20+
logits = Nx.tensor([0.0, 1.0, 2.0, 3.0])
21+
22+
start_sequence = [1, 0, 0, 0]
23+
context = context(start_sequence)
24+
25+
# according to our transitions
26+
# the only allowed sequences are:
27+
# (state: 0) -> (token_id: 1) -> (state: 1)
28+
# (state: 1) -> (token_id: 2) -> (state: 2)
29+
# (state: 2) -> (token_id: 1) -> (state: 1)
30+
# (state: 1) -> (token_id: 2) -> (state: 2)
31+
# ...
32+
33+
dfa = %{state_transitions: transitions, initial_state: initial_state}
34+
35+
# (state: 0) -> (token_id: 1) -> (state: 1)
36+
{processed_logits, context} = LogitsProcessing.dfa_processor(logits, context, dfa: dfa)
37+
processed_logits = Nx.devectorize(processed_logits, keep_names: false) |> Nx.squeeze()
38+
39+
# in this transition only token_id 1 was allowed
40+
expected_logits = Nx.tensor([:neg_infinity, 1.0, :neg_infinity, :neg_infinity])
41+
42+
assert_equal(processed_logits, expected_logits)
43+
44+
expected_last_state = Nx.tensor([0]) |> Nx.vectorize(:batch)
45+
assert_equal(context.logits_processor_state.dfa, expected_last_state)
46+
47+
new_sequence = Nx.tensor([1, 1, 0, 0])
48+
context = %{context | length: context.length + 1, sequence: new_sequence}
49+
50+
# (state: 1) -> (token_id: 2) -> (state: 2)
51+
{processed_logits, context} = LogitsProcessing.dfa_processor(logits, context, dfa: dfa)
52+
processed_logits = Nx.devectorize(processed_logits, keep_names: false) |> Nx.squeeze()
53+
54+
# in this transition only token_id 2 was allowed
55+
expected_logits = Nx.tensor([:neg_infinity, :neg_infinity, 2.0, :neg_infinity])
56+
57+
assert_equal(processed_logits, expected_logits)
58+
59+
expected_last_state = Nx.tensor([1]) |> Nx.vectorize(:batch)
60+
assert_equal(context.logits_processor_state.dfa, expected_last_state)
61+
62+
new_sequence = Nx.tensor([1, 1, 2, 0])
63+
context = %{context | length: context.length + 1, sequence: new_sequence}
64+
65+
# (state: 2) -> (token_id: 1) -> (state: 1)
66+
{processed_logits, context} = LogitsProcessing.dfa_processor(logits, context, dfa: dfa)
67+
processed_logits = Nx.devectorize(processed_logits, keep_names: false) |> Nx.squeeze()
68+
69+
# in this transition only token_id 1 was allowed
70+
expected_logits = Nx.tensor([:neg_infinity, 1.0, :neg_infinity, :neg_infinity])
71+
72+
assert_equal(processed_logits, expected_logits)
73+
74+
expected_last_state = Nx.tensor([2]) |> Nx.vectorize(:batch)
75+
assert_equal(context.logits_processor_state.dfa, expected_last_state)
76+
end
77+
end
78+
879
describe "stateful logits processors" do
980
defmodule StatefulLogitsProcessing do
1081
import Nx.Defn
@@ -13,7 +84,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do
1384
initial_suppressed_index = Nx.tensor([opts[:initial_suppressed_index]])
1485

1586
suppressed_index =
16-
context.logits_processor_states[:next_suppressed_index] || initial_suppressed_index
87+
context.logits_processor_state[:next_suppressed_index] || initial_suppressed_index
1788

1889
values =
1990
Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.size(suppressed_index))
@@ -430,7 +501,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do
430501
sequence: Nx.tensor(sequence),
431502
length: Enum.count(sequence, &(&1 != 0)),
432503
input_length: 1,
433-
logits_processor_states: %{}
504+
logits_processor_state: %{}
434505
}
435506
end
436507
end

0 commit comments

Comments
 (0)