Skip to content

Commit

Permalink
Chainlit bonus material fixes (rasbt#361)
Browse files Browse the repository at this point in the history
* fix cmd

* moved idx to device

* improved code with clone().detach()

* fixed path

* fix: added extra line for pep8

* updated .gitginore

* Update ch05/06_user_interface/app_orig.py

* Update ch05/06_user_interface/app_own.py

* Apply suggestions from code review

---------

Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
  • Loading branch information
d-kleine and rasbt authored Sep 18, 2024
1 parent ea9b4e8 commit eefe4bf
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ ch07/04_preference-tuning-with-dpo/loss-plot.pdf
# Other
ch05/06_user_interface/chainlit.md
ch05/06_user_interface/.chainlit
ch05/06_user_interface/.files

# Temporary OS-related files
.DS_Store
Expand Down
2 changes: 1 addition & 1 deletion ch05/06_user_interface/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ To implement this user interface, we use the open-source [Chainlit Python packag

First, we install the `chainlit` package via

```python
```bash
pip install chainlit
```

Expand Down
8 changes: 4 additions & 4 deletions ch05/06_user_interface/app_orig.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
token_ids_to_text,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_model_and_tokenizer():
"""
Expand Down Expand Up @@ -44,8 +46,6 @@ def get_model_and_tokenizer():

BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")

gpt = GPTModel(BASE_CONFIG)
Expand All @@ -67,9 +67,9 @@ async def main(message: chainlit.Message):
"""
The main Chainlit function.
"""
token_ids = generate(
token_ids = generate( # function uses `with torch.no_grad()` internally already
model=model,
idx=text_to_token_ids(message.content, tokenizer), # The user text is provided via as `message.content`
idx=text_to_token_ids(message.content, tokenizer).to(device), # The user text is provided via as `message.content`
max_new_tokens=50,
context_size=model_config["context_length"],
top_k=1,
Expand Down
10 changes: 5 additions & 5 deletions ch05/06_user_interface/app_own.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
token_ids_to_text,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_model_and_tokenizer():
"""
Expand All @@ -34,16 +36,14 @@ def get_model_and_tokenizer():
"qkv_bias": False # Query-key-value bias
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = tiktoken.get_encoding("gpt2")

model_path = Path("..") / "01_main-chapter-code" / "model.pth"
if not model_path.exists():
print(f"Could not find the {model_path} file. Please run the chapter 5 code (ch05.ipynb) to generate the model.pth file.")
sys.exit()

checkpoint = torch.load("model.pth", weights_only=True)
checkpoint = torch.load(model_path, weights_only=True)
model = GPTModel(GPT_CONFIG_124M)
model.load_state_dict(checkpoint)
model.to(device)
Expand All @@ -60,9 +60,9 @@ async def main(message: chainlit.Message):
"""
The main Chainlit function.
"""
token_ids = generate(
token_ids = generate( # function uses `with torch.no_grad()` internally already
model=model,
idx=text_to_token_ids(message.content, tokenizer), # The user text is provided via as `message.content`
idx=text_to_token_ids(message.content, tokenizer).to(device), # The user text is provided via as `message.content`
max_new_tokens=50,
context_size=model_config["context_length"],
top_k=1,
Expand Down

0 comments on commit eefe4bf

Please sign in to comment.