Skip to content

Commit

Permalink
Add chain of thought example
Browse files Browse the repository at this point in the history
  • Loading branch information
alonsosilvaallende authored and rlouf committed Aug 8, 2024
1 parent 951b020 commit e3247fe
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 0 deletions.
123 changes: 123 additions & 0 deletions docs/cookbook/chain_of_thought.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Chain of thought


Chain of thought is a prompting technique introduced in the paper [``Chain-of-Thought Prompting Elicits Reasoning in Large Language Models''](https://arxiv.org/abs/2201.11903) where throught prompting the authors generate a series of intermediate reasoning steps which improves the ability of LLMs to perform complex reasoning.

In this guide, we use [outlines](https://outlines-dev.github.io/outlines/) to apply chain of thought through structured output.

We use [llama.cpp](https://github.com/ggerganov/llama.cpp) using the [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) library. Outlines supports llama-cpp-python, but we need to install it ourselves:

```shell
pip install llama-cpp-python
```

We pull a quantized GGUF model, in this guide we pull [Hermes-2-Pro-Llama-3-8B](https://huggingface.co/NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF) by [NousResearch](https://nousresearch.com/) from [HuggingFace](https://huggingface.co/):

```shell
wget https://hf.co/NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF/resolve/main/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf
```

We initialize the model:

```python
from llama_cpp import Llama
from outlines import generate, models

llm = Llama(
"/path/to/model/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(
"NousResearch/Hermes-2-Pro-Llama-3-8B"
),
n_gpu_layers=-1,
flash_attn=True,
n_ctx=8192,
verbose=False
)
model = models.LlamaCpp(llm)
```

## Chain of thought

We first define our Pydantic class for a reasoning step:

```python
from pydantic import BaseModel, Field

class Reasoning_Step(BaseModel):
reasoning_step: str = Field(..., description="Reasoning step")
```

We then define the Pydantic class for reasoning which will consist on a list of reasoning steps and a conclusion, and we get its JSON schema:

```python
from typing import List

from typing import List

class Reasoning(BaseModel):
reasoning: List[Reasoning_Step] = Field(..., description="List of reasoning steps")
conclusion: str = Field(..., description="Conclusion")

json_schema = Reasoning.model_json_schema()
```

We could generate using the json schema but for a change we will use the regex:

```python
from outlines.integrations.utils import convert_json_schema_to_str
from outlines.fsm.json_schema import build_regex_from_schema

json_schema = Reasoning.model_json_schema()
schema_str = convert_json_schema_to_str(json_schema=json_schema)
regex_str = build_regex_from_schema(schema_str)
```

We then need to adapt our prompt to the [Hermes prompt format for JSON schema](https://github.com/NousResearch/Hermes-Function-Calling?tab=readme-ov-file#prompt-format-for-json-mode--structured-outputs):

```python
def generate_hermes_prompt(user_prompt):
return (
"<|im_start|>system\n"
"You are a world class AI model who answers questions in JSON "
f"Here's the json schema you must adhere to:\n<schema>\n{schema}\n</schema><|im_end|>\n"
"<|im_start|>user\n"
+ user_prompt
+ "<|im_end|>"
+ "\n<|im_start|>assistant\n"
"<schema>"
)
```

For a given user prompt, for example:

```python
user_prompt = "9.11 and 9.9 -- which is bigger?"
```

We can use `generate.regex` by passing the Pydantic class we previously defined, and call the generator with the Hermes prompt:

```python
generator = generate.regex(model, regex_str)
prompt = generate_hermes_prompt(user_prompt)
response = generator(prompt, max_tokens=1024, temperature=0, seed=42)
```

We obtain the reasoning steps as well as the conclusion

```python
import json

json_response = json.loads(response)

print(json_response["reasoning"])
print(json_response["conclusion"])
# [{'reasoning_step': 'Both 9.11 and 9.9 are decimal numbers.'},
# {'reasoning_step': 'When comparing decimal numbers, we look at the numbers after the decimal point.'},
# {'reasoning_step': 'In this case, 9.11 has the number 1 after the decimal point, while 9.9 has the number 9.'},
# {'reasoning_step': 'Since 1 is greater than 9, 9.11 is greater than 9.9.'}]
# '9.11 is bigger.'
```

We notice that the 4th reasoning step is wrong ``Since 1 is greater than 9, 9.11 is greater than 9.9.'', so we could probably give the model some examples for this particular task.

This example was originally contributed by [Alonso Silva](https://github.com/alonsosilvaallende).
1 change: 1 addition & 0 deletions docs/cookbook/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
- [SimToM](simtom.md): Improve LLMs' Theory of Mind capabilities with perspective-taking prompting and JSON-structured generation.
- [Q&A with Citations](qa-with-citations.md): Answer questions and provide citations using JSON-structured generation.
- [Knowledge Graph Generation](knowledge_graph_extraction.md): Generate a Knowledge Graph from unstructured text using JSON-structured generation.
- [Chain Of Thought (CoT)](chain_of_thought.md): Generate a series of intermediate reasoning steps using regex-structured generation.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ nav:
- Perspective-taking prompting: cookbook/simtom.md
- Question-answering with citations: cookbook/qa-with-citations.md
- Knowledge Graph Extraction: cookbook/knowledge_graph_extraction.md
- Chain of Thought (CoT): cookbook/chain_of_thought.md
- Run on the cloud:
- BentoML: cookbook/deploy-using-bentoml.md
- Cerebrium: cookbook/deploy-using-cerebrium.md
Expand Down

0 comments on commit e3247fe

Please sign in to comment.