Skip to content

Commit

Permalink
Fix prompt removal logic and update test case
Browse files Browse the repository at this point in the history
  • Loading branch information
venondev authored and rlouf committed Dec 16, 2023
1 parent 57e5cd3 commit 516709d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
15 changes: 10 additions & 5 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ def __call__(
if isinstance(prompts, str):
prompts = [prompts]

prompt_lengths = [len(prompt) for prompt in prompts]

if rng is None:
rng = torch.Generator(device=self.device)
rng.seed()
Expand All @@ -82,10 +80,17 @@ def __call__(
except StopIteration:
break

sequences = self.tokenizer.decode(last_state.token_ids)
generated = [
sequence[length:] for sequence, length in zip(sequences, prompt_lengths)
# Get the number of tokens in the prompts
prompt_token_ids = init_state[0]
prompt_lengths = [len(prompt_token_ids[i]) for i in range(len(prompts))]

# Remove the prompts from the generated sequences
token_ids = [
cur_token_ids[length:]
for cur_token_ids, length in zip(last_state.token_ids, prompt_lengths)
]

generated = self.tokenizer.decode(token_ids)
formatted = [self.format_sequence(sequence) for sequence in generated]

return formatted if len(formatted) > 1 else formatted[0]
Expand Down
13 changes: 7 additions & 6 deletions tests/generate/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,28 @@
def test_sequence_generator_class():
class MockFSM:
def next_state(self, state, next_token_ids):
return 0
return 4

def allowed_token_ids(self, _):
return []
return [4]

def is_final_state(self, _):
return True

class MockTokenizer:
def encode(self, _):
return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]])
# Input: "test"
return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1, 1]])

def decode(self, _):
return ["testx"]
def decode(self, tokens):
return ["testx"[i] for i in tokens]

class MockModel:
def __init__(self):
self.tokenizer = MockTokenizer()

def __call__(*_):
return torch.tensor([[0, 1, 2, 3]], dtype=torch.float), None
return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None

def sampler(biased_logits, *_):
return torch.argmax(biased_logits, keepdims=True)
Expand Down

0 comments on commit 516709d

Please sign in to comment.