Skip to content

Commit

Permalink
Reprover
Browse files Browse the repository at this point in the history
  • Loading branch information
wellecks committed Oct 19, 2023
1 parent 9e4e857 commit ca2aa04
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 18 deletions.
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,29 +93,29 @@ By default, `llmstep` uses a Pythia 2.8b language model fine-tuned on [LeanDojo
- [`llmstep` model on Huggingface](https://huggingface.co/wellecks/llmstep-mathlib4-pythia2.8b)


The model is fine-tuned on sequences of the form:
```bash
[GOAL]tactic-state[PROOFSTEP]next-tactic[END]
```
This format corresponds to the proofstep objective from [Han et al ICLR 2022](https://arxiv.org/abs/2102.06203).\
The [python/train](python/train) directory shows how the model was fine-tuned.

#### Reprover
You can use the non-retrieval version of [Reprover](https://github.com/lean-dojo/ReProver) using:

```
python python/server_encdec.py
```
By default, this runs the `leandojo-lean4-tacgen-byt5-small` model.\
This model is particularly useful on CPU due to its small parameter count.

#### Fine-tuning your own model
The scripts in [python/train](python/train) show how to finetune a model.

#### Using a different model
Swap in other language models with the `--hf-model` argument:

Swap in other decoder-only language models with the `--hf-model` argument:
```bash
python server.py --hf-model some/other-model-7B
```
We recommend using a fine-tuned model, though in principle fine-tuning is not strictly needed. \
`llmstep` assumes the model uses the proofstep format described above, but this is easy to modify.

Use `--hf-model` with `python/server_encdec.py` for encoder-decoder models.

#### Speed
Starting the server downloads the default language model, and loads the model. As a result, you will likely experience a delay the first time `llmstep` is run.
Roughly speaking, when `server.py` is run on a typical MacBook Pro, `llmstep` provides suggestions in a few seconds, with a GPU suggestions take ~1 second, and with vLLM suggestions take less than 1 second.
Actual suggestion latency is variable and depends on multiple factors.
The `llmstep_prompt` function in `server.py` determines the expected input and output format. If needed, you can modify this function for your model.


## Additional Notes
Expand Down
12 changes: 7 additions & 5 deletions python/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def load_hf(hf_model):
model = transformers.GPTNeoXForCausalLM.from_pretrained(args.hf_model)
tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(args.hf_model)
else:
raise NotImplementedError(hf_model)
model = transformers.AutoModelForCausalLM.from_pretrained(args.hf_model)
tokenizer = transformers.AutoTokenizer.from_pretrained(args.hf_model)

if torch.cuda.is_available():
model.cuda()
Expand Down Expand Up @@ -94,11 +95,12 @@ def do_POST(self):
self.wfile.write(json.dumps(error_response).encode('utf-8'))


def get_config(args):
# Prompt template for the default model.
def llmstep_prompt(tactic_state, prefix):
return '[GOAL]%s[PROOFSTEP]%s' % (tactic_state, prefix)
# Prompt template for the default model.
def llmstep_prompt(tactic_state, prefix):
return '[GOAL]%s[PROOFSTEP]%s' % (tactic_state, prefix)


def get_config(args):
config = {
'LLMSTEP_MODEL': args.hf_model,
'LLMSTEP_TEMPERATURES': args.temperatures,
Expand Down
65 changes: 65 additions & 0 deletions python/server_encdec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from server import LLMStepServer, get_argparser, get_config, print_config

import transformers


def load_hf_encdec(model_name):
print("Loading model...")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name)
print("Done")
return model, tokenizer


def hf_encdec_generate(
model,
tokenizer,
prompt,
temperatures,
num_samples,
max_new_tokens=128
):
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
texts = []
for temp in temperatures:
out = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=temp > 0,
temperature=temp,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=num_samples if temp > 0 else 1
)
texts.extend(tokenizer.batch_decode(
out, skip_special_tokens=True
))
texts = list(set(texts))
return texts


def reprover_prompt(tactic_state, prefix):
return '%s%s' % (tactic_state, prefix)


def get_reprover_config(args):
config = get_config(args)
config['LLMSTEP_PROMPT'] = reprover_prompt
return config


if __name__ == '__main__':
parser = get_argparser()
parser.set_defaults(hf_model='kaiyuy/leandojo-lean4-tacgen-byt5-small')
args = parser.parse_args()

config = get_reprover_config(args)
print_config(config)

model, tokenizer = load_hf_encdec(args.hf_model)

httpd = LLMStepServer(
model, tokenizer, hf_encdec_generate, config
)

print('Server started')
httpd.serve_forever()

0 comments on commit ca2aa04

Please sign in to comment.