Skip to content

Commit

Permalink
Add replicate support, Fixes h2oai#603
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Aug 3, 2023
1 parent c64434c commit 73f82ae
Show file tree
Hide file tree
Showing 15 changed files with 336 additions and 134 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Query and summarize your documents or just chat with local private GPT LLMs usin
- **Variety** of models supported (LLaMa2, Falcon, Vicuna, WizardLM including AutoGPTQ, 4-bit/8-bit, LORA)
- **GPU** support from HF and LLaMa.cpp GGML models, and **CPU** support using HF, LLaMa.cpp, and GPT4ALL models
- **Linux, Docker, MAC, and Windows** support
- **Inference Servers** support (HF TGI server, vLLM, Gradio, ExLLaMa, OpenAI)
- **Inference Servers** support (HF TGI server, vLLM, Gradio, ExLLaMa, Replicate, OpenAI)
- **OpenAI-compliant Python client API** for client-server control
- **Evaluate** performance using reward models
- **Quality** maintained with over 250 unit and integration tests taking over 4 GPU-hours
Expand Down
6 changes: 6 additions & 0 deletions client/h2ogpt_client/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def create(
system_pre_context: str = "",
langchain_mode: LangChainMode = LangChainMode.DISABLED,
add_chat_history_to_context: bool = True,
system_prompt: str = '',
) -> "TextCompletion":
"""
Creates a new text completion.
Expand All @@ -95,6 +96,7 @@ def create(
:param system_pre_context: directly pre-appended without prompt processing
:param langchain_mode: LangChain mode
:param add_chat_history_to_context: Whether to add chat history to context
:param system_prompt: Universal system prompt to override prompt_type's system prompt
"""
params = _utils.to_h2ogpt_params(locals().copy())
params["instruction"] = "" # empty when chat_mode is False
Expand All @@ -106,6 +108,7 @@ def create(
params["instruction_nochat"] = None # future prompt
params["langchain_mode"] = langchain_mode.value # convert to serializable type
params["add_chat_history_to_context"] = True
params['system_prompt'] = ''
params["langchain_action"] = LangChainAction.QUERY.value
params["langchain_agents"] = []
params["top_k_docs"] = 4 # langchain: number of document chunks
Expand Down Expand Up @@ -179,6 +182,7 @@ def create(
system_pre_context: str = "",
langchain_mode: LangChainMode = LangChainMode.DISABLED,
add_chat_history_to_context: bool = True,
system_prompt: str = '',
) -> "ChatCompletion":
"""
Creates a new chat completion.
Expand All @@ -204,6 +208,7 @@ def create(
:param system_pre_context: directly pre-appended without prompt processing
:param langchain_mode: LangChain mode
:param add_chat_history_to_context: Whether to add chat history to context
:param system_prompt: Universal system prompt to override prompt_type's system prompt
"""
params = _utils.to_h2ogpt_params(locals().copy())
params["instruction"] = None # future prompts
Expand All @@ -215,6 +220,7 @@ def create(
params["instruction_nochat"] = "" # empty when chat_mode is True
params["langchain_mode"] = langchain_mode.value # convert to serializable type
params["add_chat_history_to_context"] = True
params["system_prompt"] = ''
params["langchain_action"] = LangChainAction.QUERY.value
params["langchain_agents"] = []
params["top_k_docs"] = 4 # langchain: number of document chunks
Expand Down
1 change: 1 addition & 0 deletions client/h2ogpt_client/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
iinput_nochat="input_context_for_instruction",
langchain_mode="langchain_mode",
add_chat_history_to_context="add_chat_history_to_context",
system_prompt="system_prompt",
langchain_action="langchain_action",
langchain_agents="langchain_agents",
top_k_docs="langchain_top_k_docs",
Expand Down
21 changes: 21 additions & 0 deletions docs/README_InferenceServers.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,27 @@ Note: `vllm_chat` ChatCompletion is not supported by vLLM project.

Note vLLM has bug in stopping sequence that is does not return the last token, unlike OpenAI, so a hack is in place for `prompt_type=human_bot`, and other prompts may need similar hacks. See `fix_text()` in `src/prompter.py`.

## Replicate Inference Server-Client

If you have a Replicate key and set an ENV `REPLICATE_API_TOKEN`, then you can access Replicate models via gradio by running:
```bash
pip install replicate
export REPLICATE_API_TOKEN=<key>
python generate.py --inference_server="replicate:<replicate model string>" --base_model="<HF model name>"
```
where `<key>` should be replaced by your Replicate key, `<replicate model string>` should be replaced by the model name, e.g. `model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5`. Here we used an example for [LLaMa-V2](https://replicate.com/a16z-infra/llama13b-v2-chat), and `<HF model name>` should be replaced by equivalent HuggingFace Model Name (if this is not known or cannot match, then choose whichever HF model has most similar tokenizer.). The `prompt_type` in h2oGPT is unused except for system prompting if chosen.

For example, for LLaMa-2 7B:
```bash
python generate.py --inference_server="replicate:lucataco/llama-2-7b-chat:6ab580ab4eef2c2b440f2441ec0fc0ace5470edaf2cbea50b8550aec0b3fbd38" --base_model="TheBloke/Llama-2-7b-Chat-GPTQ"
```

Replicate is **not** recommended for private document question-answer, but sufficient when full privacy is not required. Only chunks of documents will be sent to the LLM for each LLM response.

Issues:
* `requests.exceptions.JSONDecodeError: Expecting value: line 1 column 1 (char 0)`
* Sometimes Replicate sends back bad json, seems randomly occurs.

## h2oGPT start-up vs. in-app selection

When using `generate.py`, specifying the `--base_model` or `--inference_server` on the CLI is not required. One can also add any model and server URL (with optional port) in the **Model** tab at the bottom:
Expand Down
2 changes: 1 addition & 1 deletion reqs_optional/requirements_optional_langchain.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# optional for chat with PDF
langchain==0.0.235
langchain==0.0.250
pypdf==3.12.2
# avoid textract, requires old six
#textract==1.6.5
Expand Down
2 changes: 1 addition & 1 deletion src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def run_cli( # for local function:
trust_remote_code=None, offload_folder=None, rope_scaling=None, max_seq_len=None, compile_model=None,
# for some evaluate args
stream_output=None, async_output=None, num_async=None,
prompt_type=None, prompt_dict=None,
prompt_type=None, prompt_dict=None, system_prompt=None,
temperature=None, top_p=None, top_k=None, num_beams=None,
max_new_tokens=None, min_new_tokens=None, early_stopping=None, max_time=None, repetition_penalty=None,
num_return_sequences=None, do_sample=None, chat=None,
Expand Down
9 changes: 7 additions & 2 deletions src/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def get_args(prompt, prompt_type=None, chat=False, stream_output=False,
iinput_nochat='', # only for chat=False
langchain_mode=langchain_mode,
add_chat_history_to_context=add_chat_history_to_context,
system_prompt='',
langchain_action=langchain_action,
langchain_agents=langchain_agents,
top_k_docs=top_k_docs,
Expand Down Expand Up @@ -248,8 +249,12 @@ def test_client_chat_stream(prompt_type='human_bot'):
langchain_agents=[])


def run_client_chat(prompt, stream_output, max_new_tokens,
langchain_mode, langchain_action, langchain_agents,
def run_client_chat(prompt='',
stream_output=None,
max_new_tokens=128,
langchain_mode='Disabled',
langchain_action=LangChainAction.QUERY.value,
langchain_agents=[],
prompt_type=None, prompt_dict=None):
client = get_client(serialize=False)

Expand Down
4 changes: 1 addition & 3 deletions src/eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
import os
import traceback
import numpy as np
Expand All @@ -9,13 +8,12 @@
from evaluate_params import eval_func_param_names, eval_extra_columns
from gen import get_context, get_score_model, get_model, evaluate, check_locals
from prompter import Prompter
from src.enums import LangChainMode
from utils import clear_torch_cache, NullContext, get_kwargs, makedirs


def run_eval( # for local function:
base_model=None, lora_weights=None, inference_server=None,
prompt_type=None, prompt_dict=None,
prompt_type=None, prompt_dict=None, system_prompt=None,
debug=None, chat=False, chat_context=None,
stream_output=None, async_output=None, num_async=None,
eval_filename=None, eval_prompts_only_num=None, eval_prompts_only_seed=None, eval_as_output=None,
Expand Down
2 changes: 1 addition & 1 deletion src/evaluate_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
input_args_list = ['model_state', 'my_db_state', 'selection_docs_state']


no_default_param_names = [
'instruction',
'iinput',
Expand Down Expand Up @@ -34,6 +33,7 @@
'iinput_nochat',
'langchain_mode',
'add_chat_history_to_context',
'system_prompt',
'langchain_action',
'langchain_agents',
'top_k_docs',
Expand Down
44 changes: 35 additions & 9 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def main(
inference_server: str = "",
prompt_type: Union[int, str] = None,
prompt_dict: typing.Dict = None,
system_prompt: str = '',

model_lock: typing.List[typing.Dict[str, str]] = None,
model_lock_columns: int = None,
Expand Down Expand Up @@ -232,8 +233,12 @@ def main(
e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003
Or Address can be "vllm:IP:port" or "vllm:IP:port" for OpenAI-compliant vLLM endpoint
Note: vllm_chat not supported by vLLM project.
--inference_server=replicate:<model name string> will use a Replicate server, requiring a Replicate key.
e.g. <model name string> looks like "a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5"
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
:param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
:param system_prompt: Universal system prompt to use if model supports, like LLaMa2, regardless of prompt_type definition.
Useful for langchain case to control behavior, or OpenAI and Replicate.
:param model_lock: Lock models to specific combinations, for ease of use and extending to many models
Only used if gradio = True
List of dicts, each dict has base_model, tokenizer_base_model, lora_weights, inference_server, prompt_type, and prompt_dict
Expand Down Expand Up @@ -383,7 +388,7 @@ def main(
:param use_llm_if_no_docs: Whether to use LLM even if no documents, when langchain_mode=UserData or MyData or custom
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
:param use_system_prompt: Whether to use system prompt (e.g. llama2 safe system prompt)
:param use_system_prompt: Whether to use system prompt (e.g. llama2 safe system prompt and OpenAI).
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
:param use_openai_embedding: Whether to use OpenAI embeddings for vector db
:param use_openai_model: Whether to use OpenAI model for use with vector db
Expand Down Expand Up @@ -696,7 +701,7 @@ def main(
get_generate_params(model_lower,
chat,
stream_output, show_examples,
prompt_type, prompt_dict,
prompt_type, prompt_dict, system_prompt,
temperature, top_p, top_k, num_beams,
max_new_tokens, min_new_tokens, early_stopping, max_time,
repetition_penalty, num_return_sequences,
Expand Down Expand Up @@ -1208,12 +1213,28 @@ def get_model(
# Don't return None, None for model, tokenizer so triggers
return client, tokenizer, 'http'
if isinstance(inference_server, str) and (
inference_server.startswith('openai') or inference_server.startswith('vllm')):
inference_server.startswith('openai') or
inference_server.startswith('vllm') or
inference_server.startswith('replicate')):
if inference_server.startswith('openai'):
assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY"
# Don't return None, None for model, tokenizer so triggers
# include small token cushion
tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 50)
if inference_server.startswith('replicate'):
assert len(inference_server.split(':')) >= 3, "Expected replicate:model string, got %s" % inference_server
assert os.getenv('REPLICATE_API_TOKEN'), "Set environment for REPLICATE_API_TOKEN"
assert max_seq_len is not None, "Please pass --max_seq_len=<max_seq_len> for replicate models."
try:
import replicate as replicate_python
except ImportError:
raise ImportError(
"Could not import replicate python package. "
"Please install it with `pip install replicate`."
)
# Don't return None, None for model, tokenizer so triggers
# include small token cushion
tokenizer = FakeTokenizer(model_max_length=max_seq_len - 50)
return inference_server, tokenizer, inference_server
assert not inference_server, "Malformed inference_server=%s" % inference_server
if base_model in non_hf_types:
Expand Down Expand Up @@ -1543,6 +1564,7 @@ def evaluate(
iinput_nochat,
langchain_mode,
add_chat_history_to_context,
system_prompt,
langchain_action,
langchain_agents,
top_k_docs,
Expand Down Expand Up @@ -1756,7 +1778,7 @@ def evaluate(
else:
db = None
t_generate = time.time()
langchain_only_model = base_model in non_hf_types or load_exllama
langchain_only_model = base_model in non_hf_types or load_exllama or inference_server.startswith('replicate')
do_langchain_path = langchain_mode not in [False, 'Disabled', 'LLM'] or \
langchain_only_model or \
force_langchain_evaluate
Expand Down Expand Up @@ -1794,6 +1816,7 @@ def evaluate(
answer_with_sources=answer_with_sources,
append_sources_to_answer=append_sources_to_answer,
add_chat_history_to_context=add_chat_history_to_context,
system_prompt=system_prompt,
use_openai_embedding=use_openai_embedding,
use_openai_model=use_openai_model,
hf_embedding_model=hf_embedding_model,
Expand Down Expand Up @@ -1870,8 +1893,9 @@ def evaluate(
# don't clear torch cache here, delays multi-generation, and bot(), all_bot(), and evaluate_nochat() do it
return

if inference_server.startswith('vllm') or inference_server.startswith('openai') or inference_server.startswith(
'http'):
if inference_server.startswith('vllm') or \
inference_server.startswith('openai') or \
inference_server.startswith('http'):
if inference_server.startswith('vllm') or inference_server.startswith('openai'):
where_from = "openai_client"
openai, inf_type = set_openai(inference_server)
Expand Down Expand Up @@ -1916,10 +1940,11 @@ def evaluate(
elif inf_type == 'vllm_chat' or inference_server == 'openai_chat':
if inf_type == 'vllm_chat':
raise NotImplementedError('%s not supported by vLLM' % inf_type)
openai_system_prompt = system_prompt or "You are a helpful assistant."
responses = openai.ChatCompletion.create(
model=base_model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "system", "content": openai_system_prompt},
{'role': 'user',
'content': prompt,
}
Expand Down Expand Up @@ -2494,7 +2519,7 @@ def generate_with_exceptions(func, *args, prompt='', inputs_decoded='', raise_ge
def get_generate_params(model_lower,
chat,
stream_output, show_examples,
prompt_type, prompt_dict,
prompt_type, prompt_dict, system_prompt,
temperature, top_p, top_k, num_beams,
max_new_tokens, min_new_tokens, early_stopping, max_time,
repetition_penalty, num_return_sequences,
Expand Down Expand Up @@ -2663,7 +2688,8 @@ def mean(a):""", ''] + params_list,

# move to correct position
for example in examples:
example += [chat, '', '', LangChainMode.DISABLED.value, True, LangChainAction.QUERY.value, [],
example += [chat, '', '', LangChainMode.DISABLED.value, True, system_prompt,
LangChainAction.QUERY.value, [],
top_k_docs, chunk, chunk_size, DocumentSubset.Relevant.name, [], '', '',
]
# adjust examples if non-chat mode
Expand Down
Loading

0 comments on commit 73f82ae

Please sign in to comment.