Skip to content

Commit

Permalink
add trust_remote_code in tokenizer to fix baichuan issue (#2725)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored Nov 7, 2024
1 parent b1f9044 commit 97f7a22
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
10 changes: 8 additions & 2 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub enum Tokenizer {
Python {
tokenizer_name: String,
revision: Option<String>,
trust_remote_code: bool,
},
Rust(tokenizers::Tokenizer),
}
Expand All @@ -38,15 +39,20 @@ impl<'a> PyTokenizer<'a> {
py: Python<'a>,
tokenizer_name: String,
revision: Option<String>,
trust_remote_code: bool,
) -> PyResult<PyTokenizer<'a>> {
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name,);
let kwargs = if let Some(rev) = &revision {
[("revision", rev.to_string())].into_py_dict_bound(py)
[
("revision", rev.to_string().into_py(py)),
("trust_remote_code", trust_remote_code.into_py(py)),
]
.into_py_dict_bound(py)
} else {
pyo3::types::PyDict::new_bound(py)
[("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py)
};
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
tracing::info!("Loaded a python tokenizer");
Expand Down
1 change: 1 addition & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1829,6 +1829,7 @@ pub async fn run(
Tokenizer::Python {
tokenizer_name: tokenizer_name.clone(),
revision: revision.clone(),
trust_remote_code,
}
}
};
Expand Down
4 changes: 3 additions & 1 deletion router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,11 @@ fn tokenizer_worker(
Tokenizer::Python {
tokenizer_name,
revision,
trust_remote_code,
} => {
pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> {
let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision)?;
let tokenizer =
PyTokenizer::from_py(py, tokenizer_name, revision, trust_remote_code)?;
// Loop over requests
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
Expand Down

0 comments on commit 97f7a22

Please sign in to comment.