@@ -45,12 +45,12 @@ def clone(self) -> "GuidanceLogitsProcessor":
45
45
cloned = copy .copy (self )
46
46
if self .initialized :
47
47
cloned .ll_matcher = llguidance .LLMatcher (
48
- self .ll_tokenizer ,
48
+ self .ll_tokenizer , # type: ignore[assignment]
49
49
self .grammar ,
50
50
log_level = int (os .environ .get ("LLGUIDANCE_LOG_LEVEL" , "1" )),
51
51
)
52
52
self .bitmask = llguidance .torch .allocate_token_bitmask (
53
- 1 , self .ll_tokenizer .vocab_size )
53
+ 1 , self .ll_tokenizer .vocab_size ) # type: ignore[attr-defined]
54
54
return cloned
55
55
56
56
def _initialize (self ):
@@ -72,7 +72,7 @@ def _initialize(self):
72
72
73
73
# create reusable bitmask
74
74
self .bitmask = llguidance .torch .allocate_token_bitmask (
75
- 1 , self .ll_tokenizer .vocab_size )
75
+ 1 , self .ll_tokenizer .vocab_size ) # type: ignore[attr-defined]
76
76
77
77
self .initialized = True
78
78
@@ -86,15 +86,17 @@ def __call__(
86
86
self ._initialize ()
87
87
88
88
if self .new_sampling and len (input_ids ) > 0 :
89
- self .ll_matcher .consume_token (input_ids [- 1 ])
90
- err = self .ll_matcher .get_error ()
89
+ self .ll_matcher .consume_token ( # type: ignore[attr-defined]
90
+ input_ids [- 1 ])
91
+ err = self .ll_matcher .get_error () # type: ignore[attr-defined]
91
92
if err :
92
93
logger .warning ("Error in LLMatcher: %s" , err )
93
94
94
95
llguidance .torch .fill_next_token_bitmask (self .ll_matcher , self .bitmask ,
95
96
0 )
96
97
llguidance .torch .apply_token_bitmask_inplace (
97
- scores , self .bitmask .to (scores .device ))
98
+ scores ,
99
+ self .bitmask .to (scores .device )) # type: ignore[attr-defined]
98
100
99
101
self .new_sampling = True
100
102
0 commit comments