@@ -5,6 +5,77 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do
55
66 alias Bumblebee.Text.Generation.LogitsProcessing
77
8+ describe "dfa_processor/3" do
9+ test "constrained sampling with DFA" do
10+ # the list of all allowed transitions
11+ transitions = [
12+ # {state, token_id, next_state}
13+ { 0 , 1 , 1 } ,
14+ { 1 , 2 , 2 } ,
15+ { 2 , 1 , 1 }
16+ ]
17+
18+ initial_state = 0
19+
20+ logits = Nx . tensor ( [ 0.0 , 1.0 , 2.0 , 3.0 ] )
21+
22+ start_sequence = [ 1 , 0 , 0 , 0 ]
23+ context = context ( start_sequence )
24+
25+ # according to our transitions
26+ # the only allowed sequences are:
27+ # (state: 0) -> (token_id: 1) -> (state: 1)
28+ # (state: 1) -> (token_id: 2) -> (state: 2)
29+ # (state: 2) -> (token_id: 1) -> (state: 1)
30+ # (state: 1) -> (token_id: 2) -> (state: 2)
31+ # ...
32+
33+ dfa = % { state_transitions: transitions , initial_state: initial_state }
34+
35+ # (state: 0) -> (token_id: 1) -> (state: 1)
36+ { processed_logits , context } = LogitsProcessing . dfa_processor ( logits , context , dfa: dfa )
37+ processed_logits = Nx . devectorize ( processed_logits , keep_names: false ) |> Nx . squeeze ( )
38+
39+ # in this transition only token_id 1 was allowed
40+ expected_logits = Nx . tensor ( [ :neg_infinity , 1.0 , :neg_infinity , :neg_infinity ] )
41+
42+ assert_equal ( processed_logits , expected_logits )
43+
44+ expected_last_state = Nx . tensor ( [ 0 ] ) |> Nx . vectorize ( :batch )
45+ assert_equal ( context . logits_processor_state . dfa , expected_last_state )
46+
47+ new_sequence = Nx . tensor ( [ 1 , 1 , 0 , 0 ] )
48+ context = % { context | length: context . length + 1 , sequence: new_sequence }
49+
50+ # (state: 1) -> (token_id: 2) -> (state: 2)
51+ { processed_logits , context } = LogitsProcessing . dfa_processor ( logits , context , dfa: dfa )
52+ processed_logits = Nx . devectorize ( processed_logits , keep_names: false ) |> Nx . squeeze ( )
53+
54+ # in this transition only token_id 2 was allowed
55+ expected_logits = Nx . tensor ( [ :neg_infinity , :neg_infinity , 2.0 , :neg_infinity ] )
56+
57+ assert_equal ( processed_logits , expected_logits )
58+
59+ expected_last_state = Nx . tensor ( [ 1 ] ) |> Nx . vectorize ( :batch )
60+ assert_equal ( context . logits_processor_state . dfa , expected_last_state )
61+
62+ new_sequence = Nx . tensor ( [ 1 , 1 , 2 , 0 ] )
63+ context = % { context | length: context . length + 1 , sequence: new_sequence }
64+
65+ # (state: 2) -> (token_id: 1) -> (state: 1)
66+ { processed_logits , context } = LogitsProcessing . dfa_processor ( logits , context , dfa: dfa )
67+ processed_logits = Nx . devectorize ( processed_logits , keep_names: false ) |> Nx . squeeze ( )
68+
69+ # in this transition only token_id 1 was allowed
70+ expected_logits = Nx . tensor ( [ :neg_infinity , 1.0 , :neg_infinity , :neg_infinity ] )
71+
72+ assert_equal ( processed_logits , expected_logits )
73+
74+ expected_last_state = Nx . tensor ( [ 2 ] ) |> Nx . vectorize ( :batch )
75+ assert_equal ( context . logits_processor_state . dfa , expected_last_state )
76+ end
77+ end
78+
879 describe "stateful logits processors" do
980 defmodule StatefulLogitsProcessing do
1081 import Nx.Defn
@@ -13,7 +84,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do
1384 initial_suppressed_index = Nx . tensor ( [ opts [ :initial_suppressed_index ] ] )
1485
1586 suppressed_index =
16- context . logits_processor_states [ :next_suppressed_index ] || initial_suppressed_index
87+ context . logits_processor_state [ :next_suppressed_index ] || initial_suppressed_index
1788
1889 values =
1990 Nx . broadcast ( Nx.Constants . neg_infinity ( Nx . type ( logits ) ) , Nx . size ( suppressed_index ) )
@@ -430,7 +501,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do
430501 sequence: Nx . tensor ( sequence ) ,
431502 length: Enum . count ( sequence , & ( & 1 != 0 ) ) ,
432503 input_length: 1 ,
433- logits_processor_states : % { }
504+ logits_processor_state : % { }
434505 }
435506 end
436507end
0 commit comments