Skip to content

Commit b2e77b2

Browse files
author
Droid
committed
Implemented DBRX truss with model loading and prediction. Added requirements, config, and README files.
1 parent a0f0761 commit b2e77b2

File tree

6 files changed

+122
-0
lines changed

6 files changed

+122
-0
lines changed

dbrx_truss/README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# DBRX Truss
2+
3+
This truss makes the [DBRX](https://huggingface.co/databricks/dbrx-instruct) model available on the Baseten platform for efficient inference. DBRX is an open-source large language model trained by Databricks. It is a 132B parameter model capable of instruction following and general language tasks.
4+
5+
## Setup
6+
7+
This truss requires Python 3.11 and the dependencies listed in `requirements.txt`. It is configured to run on A10G GPUs for optimal performance.
8+
9+
## Usage
10+
11+
Once deployed on Baseten, the truss exposes an endpoint for making prediction requests to the model.
12+
13+
### Request Format
14+
15+
Requests should be made with a JSON payload in the following format:
16+
17+
```json
18+
{
19+
"prompt": "What is machine learning?"
20+
}
21+
```
22+
23+
### Parameters
24+
25+
The following inference parameters can be configured in `config.yaml`:
26+
27+
- `max_new_tokens`: Max number of tokens to generate in the response (default: 100)
28+
- `temperature`: Controls randomness of output (default: 0.7)
29+
- `top_p`: Nucleus sampling probability threshold (default: 0.95)
30+
- `top_k`: Number of highest probability vocabulary tokens to keep (default: 50)
31+
- `repetition_penalty`: Penalty for repeated tokens (default: 1.01)
32+
33+
## Original Model
34+
35+
DBRX was developed and open-sourced by Databricks. For more information, see:
36+
37+
- [DBRX Model Card](https://github.com/databricks/dbrx/blob/master/MODEL_CARD_dbrx_instruct.md)
38+
- [Databricks Blog Post](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm)
39+
- [HuggingFace Model Page](https://huggingface.co/databricks/dbrx-instruct)
40+
41+
## About Baseten
42+
43+
This truss was created by [Baseten](https://www.baseten.co/) to enable easy deployment and serving of the open-source DBRX model at scale. Baseten is a platform for building powerful AI apps.

dbrx_truss/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Empty file

dbrx_truss/config.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
python_version: py311
2+
requirements_file: requirements.txt
3+
4+
resources:
5+
accelerator: A10G
6+
use_gpu: true
7+
8+
model_metadata:
9+
example_model_input: |
10+
{
11+
"prompt": "What is machine learning?"
12+
}
13+
repo_id: databricks/dbrx-instruct

dbrx_truss/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Empty file

dbrx_truss/model/model.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import Dict
2+
3+
import torch
4+
from transformers import AutoModelForCausalLM, AutoTokenizer
5+
6+
7+
class Model:
8+
def __init__(self, data_dir: str, config: Dict, **kwargs):
9+
self.data_dir = data_dir
10+
self.config = config
11+
self.cuda_available = torch.cuda.is_available()
12+
13+
def load(self):
14+
self.tokenizer = AutoTokenizer.from_pretrained(
15+
"databricks/dbrx-instruct", trust_remote_code=True, token=True
16+
)
17+
18+
if self.cuda_available:
19+
self.model = AutoModelForCausalLM.from_pretrained(
20+
"databricks/dbrx-instruct",
21+
trust_remote_code=True,
22+
token=True,
23+
torch_dtype=(
24+
torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
25+
),
26+
device_map="auto",
27+
attn_implementation=(
28+
"flash_attention_2" if "flash_attn" in locals() else "eager"
29+
),
30+
)
31+
else:
32+
self.model = AutoModelForCausalLM.from_pretrained(
33+
"databricks/dbrx-instruct", trust_remote_code=True, token=True
34+
)
35+
36+
def predict(self, request: Dict) -> Dict:
37+
self.load() # Reload model for each request
38+
39+
prompt = request["prompt"]
40+
messages = [{"role": "user", "content": prompt}]
41+
42+
tokenized_input = self.tokenizer.apply_chat_template(
43+
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
44+
)
45+
tokenized_input = tokenized_input.to(self.model.device)
46+
47+
generated = self.model.generate(
48+
input_ids=tokenized_input,
49+
max_new_tokens=self.config.get("max_new_tokens", 100),
50+
temperature=self.config.get("temperature", 0.7),
51+
top_p=self.config.get("top_p", 0.95),
52+
top_k=self.config.get("top_k", 50),
53+
repetition_penalty=self.config.get("repetition_penalty", 1.01),
54+
pad_token_id=self.tokenizer.pad_token_id,
55+
)
56+
57+
decoded_output = self.tokenizer.batch_decode(generated)[0]
58+
response_text = decoded_output.split("<|im_start|> assistant\n")[-1]
59+
60+
return {"result": response_text}

dbrx_truss/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
torch>=2.1.0
2+
transformers>=4.39.0
3+
accelerate==0.28.0
4+
tiktoken==0.4.0

0 commit comments

Comments
 (0)