Skip to content

Commit

Permalink
Merge pull request #67 from eren23/gpt4-first-branch
Browse files Browse the repository at this point in the history
gpt4 branch
  • Loading branch information
kaanozbudak authored Mar 29, 2023
2 parents 064d2c8 + 0ca54fd commit fecb45d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
5 changes: 4 additions & 1 deletion knowledgegpt/extractors/base_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

class BaseExtractor:
def __init__(self, dataframe=None, embedding_extractor="hf", model_lang="en", is_turbo=False, index_type="basic",
verbose=False, index_path=None, prompt_template=None):
verbose=False, index_path=None, is_gpt4=False, prompt_template=None):

"""
:param dataframe: if you have own df use it else choose correct extractor
:param embedding_extractor: default hf, openai
Expand All @@ -28,6 +29,7 @@ def __init__(self, dataframe=None, embedding_extractor="hf", model_lang="en", is
self.is_turbo = is_turbo
self.index_type = index_type
self.verbose = verbose
self.is_gpt4 = is_gpt4
self.prompt_template = prompt_template
self.messages = []

Expand Down Expand Up @@ -100,6 +102,7 @@ def extract(self, query, max_tokens, load_index=False) -> tuple[str, str, list]:
embedding_type=self.embedding_extractor,
model_lang=self.model_lang,
is_turbo=self.is_turbo,
is_gpt4=self.is_gpt4,
verbose=self.verbose,
messages=self.messages,
max_tokens=max_tokens,
Expand Down
49 changes: 33 additions & 16 deletions knowledgegpt/utils/utils_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,23 @@
import numpy as np
import tiktoken

COMPLETIONS_API_PARAMS = {
model_types = {
"gpt-3.5-turbo": {
"temperature": 0.0,
"model": "gpt-3.5-turbo",
"max_tokens": 1000,
},
"gpt-4": {
"temperature": 0.0,
"model": "gpt-4",
"max_tokens": 4096,
},
"davinci": {
"temperature": 0.0,
"model": "text-davinci-003",
"max_tokens": 1000,
}

COMPLETIONS_API_PARAMS_TURBO = {
"temperature": 0.0,
"model": "gpt-3.5-turbo",
"max_tokens": 1000,
}


Expand All @@ -25,7 +32,8 @@ def answer_query_with_context(
verbose: bool = False,
embedding_type: str = "hf",
model_lang: str = "en",
is_turbo: str = False,
is_turbo: bool = False,
is_gpt4: bool = False,
messages: list = None,
index_type: str = "basic",
max_tokens=1000,
Expand Down Expand Up @@ -79,24 +87,33 @@ def answer_query_with_context(
print(prompt)


if not is_turbo:
prompt_len = len(encoding.encode(prompt))
COMPLETIONS_API_PARAMS["max_tokens"] = 2000 - prompt_len

if not is_turbo :
prompt_len = len(encoding.encode(prompt))
model_types["davinci"]["max_tokens"] = 2000 - prompt_len
response = openai.Completion.create(
prompt=prompt,
**COMPLETIONS_API_PARAMS
** model_types["davinci"]
)
else:
if is_gpt4:
messages_token_length = encoding.encode(str(messages))
model_types["gpt-4"]["max_tokens"] = 8192 - len(messages_token_length)

messages_token_length = encoding.encode(str(messages))
COMPLETIONS_API_PARAMS_TURBO["max_tokens"] = 4096 - len(messages_token_length)
response = openai.ChatCompletion.create(

response = openai.ChatCompletion.create(
messages=messages,
**model_types["gpt-4"],
)
else:
messages_token_length = encoding.encode(str(messages))
model_types["gpt-3.5-turbo"]["max_tokens"] = 4096 - len(messages_token_length)

messages=messages,
**COMPLETIONS_API_PARAMS_TURBO,
)
response = openai.ChatCompletion.create(

messages=messages,
**model_types["gpt-3.5-turbo"],
)

if is_turbo:
messages.append({"role": "assistant", "content": response["choices"][0]["message"]["content"].strip(" \n")})
Expand Down

0 comments on commit fecb45d

Please sign in to comment.