diff --git a/README.md b/README.md index 2b16de4..b01753b 100644 --- a/README.md +++ b/README.md @@ -93,29 +93,29 @@ By default, `llmstep` uses a Pythia 2.8b language model fine-tuned on [LeanDojo - [`llmstep` model on Huggingface](https://huggingface.co/wellecks/llmstep-mathlib4-pythia2.8b) -The model is fine-tuned on sequences of the form: -```bash -[GOAL]tactic-state[PROOFSTEP]next-tactic[END] -``` -This format corresponds to the proofstep objective from [Han et al ICLR 2022](https://arxiv.org/abs/2102.06203).\ The [python/train](python/train) directory shows how the model was fine-tuned. +#### Reprover +You can use the non-retrieval version of [Reprover](https://github.com/lean-dojo/ReProver) using: + +``` +python python/server_encdec.py +``` +By default, this runs the `leandojo-lean4-tacgen-byt5-small` model.\ +This model is particularly useful on CPU due to its small parameter count. + #### Fine-tuning your own model The scripts in [python/train](python/train) show how to finetune a model. #### Using a different model -Swap in other language models with the `--hf-model` argument: + +Swap in other decoder-only language models with the `--hf-model` argument: ```bash python server.py --hf-model some/other-model-7B ``` -We recommend using a fine-tuned model, though in principle fine-tuning is not strictly needed. \ -`llmstep` assumes the model uses the proofstep format described above, but this is easy to modify. - +Use `--hf-model` with `python/server_encdec.py` for encoder-decoder models. -#### Speed -Starting the server downloads the default language model, and loads the model. As a result, you will likely experience a delay the first time `llmstep` is run. -Roughly speaking, when `server.py` is run on a typical MacBook Pro, `llmstep` provides suggestions in a few seconds, with a GPU suggestions take ~1 second, and with vLLM suggestions take less than 1 second. -Actual suggestion latency is variable and depends on multiple factors. +The `llmstep_prompt` function in `server.py` determines the expected input and output format. If needed, you can modify this function for your model. ## Additional Notes diff --git a/python/server.py b/python/server.py index 988775c..a25d373 100644 --- a/python/server.py +++ b/python/server.py @@ -12,7 +12,8 @@ def load_hf(hf_model): model = transformers.GPTNeoXForCausalLM.from_pretrained(args.hf_model) tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(args.hf_model) else: - raise NotImplementedError(hf_model) + model = transformers.AutoModelForCausalLM.from_pretrained(args.hf_model) + tokenizer = transformers.AutoTokenizer.from_pretrained(args.hf_model) if torch.cuda.is_available(): model.cuda() @@ -94,11 +95,12 @@ def do_POST(self): self.wfile.write(json.dumps(error_response).encode('utf-8')) -def get_config(args): - # Prompt template for the default model. - def llmstep_prompt(tactic_state, prefix): - return '[GOAL]%s[PROOFSTEP]%s' % (tactic_state, prefix) +# Prompt template for the default model. +def llmstep_prompt(tactic_state, prefix): + return '[GOAL]%s[PROOFSTEP]%s' % (tactic_state, prefix) + +def get_config(args): config = { 'LLMSTEP_MODEL': args.hf_model, 'LLMSTEP_TEMPERATURES': args.temperatures, diff --git a/python/server_encdec.py b/python/server_encdec.py new file mode 100644 index 0000000..891661c --- /dev/null +++ b/python/server_encdec.py @@ -0,0 +1,65 @@ +from server import LLMStepServer, get_argparser, get_config, print_config + +import transformers + + +def load_hf_encdec(model_name): + print("Loading model...") + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name) + print("Done") + return model, tokenizer + + +def hf_encdec_generate( + model, + tokenizer, + prompt, + temperatures, + num_samples, + max_new_tokens=128 +): + input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device) + texts = [] + for temp in temperatures: + out = model.generate( + input_ids, + max_new_tokens=max_new_tokens, + do_sample=temp > 0, + temperature=temp, + pad_token_id=tokenizer.eos_token_id, + num_return_sequences=num_samples if temp > 0 else 1 + ) + texts.extend(tokenizer.batch_decode( + out, skip_special_tokens=True + )) + texts = list(set(texts)) + return texts + + +def reprover_prompt(tactic_state, prefix): + return '%s%s' % (tactic_state, prefix) + + +def get_reprover_config(args): + config = get_config(args) + config['LLMSTEP_PROMPT'] = reprover_prompt + return config + + +if __name__ == '__main__': + parser = get_argparser() + parser.set_defaults(hf_model='kaiyuy/leandojo-lean4-tacgen-byt5-small') + args = parser.parse_args() + + config = get_reprover_config(args) + print_config(config) + + model, tokenizer = load_hf_encdec(args.hf_model) + + httpd = LLMStepServer( + model, tokenizer, hf_encdec_generate, config + ) + + print('Server started') + httpd.serve_forever()