Skip to content

Commit 2ca6652

Browse files
committed
Fix some issues
1 parent 8e1d6f6 commit 2ca6652

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
accelerate
12
fire
23
interegular
34
regex==2023.8.8

syncode/language_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,11 @@ def get_tokenized_input(self, prompt: Union[str, list], batch_size: int):
176176
raise ValueError("Prompt should be either a string or a list! It is currently of type: "+str(type(prompt)))
177177

178178
input_batch = [prompt_str for _ in range(batch_size)]
179-
inputs = self.tokenizer(input_batch, return_tensors="pt").to(self.model.device)
179+
inputs = self.tokenizer(
180+
input_batch,
181+
return_tensors="pt",
182+
pad_to_multiple_of=8,
183+
).to(self.model.device)
180184

181185
return inputs
182186

tests/test_language_model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,17 @@ def get_vocab(self) -> Dict[str, int]:
5656
return {v: i for i, v in enumerate(self.vocab)}
5757

5858
class TestHuggingFaceModel(unittest.TestCase):
59-
@unittest.skip("Only for local testing")
6059
def test_stop_word(self):
6160
torch.manual_seed(0)
62-
syncode = Syncode(model="microsoft/phi-2", mode='original')
61+
syncode = Syncode(model="microsoft/phi-2", mode='original', device='cpu')
6362
prompt = "Generate a json for the country nigeria.\n```json\n"
6463
stop_words = ["```"]
6564
output = syncode.infer(prompt, stop_words=stop_words)[0]
6665
assert output.endswith('```')
6766

68-
@unittest.skip("Only for local testing")
6967
def test_stop_word2(self):
7068
torch.manual_seed(0)
71-
syncode = Syncode(model="microsoft/phi-2", mode='original')
69+
syncode = Syncode(model="microsoft/phi-2", mode='original', device='cpu')
7270
prompt = "def add(a, b):\n"
7371
stop_words = ["\n\n"]
7472
output = syncode.infer(prompt, stop_words=stop_words)[0]

0 commit comments

Comments
 (0)