Skip to content

Commit

Permalink
I noticed that the hparams in the config file that were being accesse…
Browse files Browse the repository at this point in the history
…d had copy-pasted code that had the wrong llm name. Updated to the right name according to class object
  • Loading branch information
RyanSaxe committed Mar 29, 2024
1 parent 47cf703 commit 75e11cb
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 7 deletions.
2 changes: 1 addition & 1 deletion llms/anthropic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, name):
self.api_key = config['llms']['anthropic']['api_key'].strip()

self.hparams = config['hparams']
self.hparams.update(config['llms']['openai'].get('hparams') or {})
self.hparams.update(config['llms']['anthropic'].get('hparams') or {})

def make_request(self, conversation, add_image=None, logit_bias=None, max_tokens=None):
conversation = [{"role": "user" if i%2 == 0 else "assistant", "content": content} for i,content in enumerate(conversation)]
Expand Down
3 changes: 1 addition & 2 deletions llms/cohere_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ def __init__(self, name):
api_key = config['llms']['cohere']['api_key'].strip()
self.client = cohere.Client(api_key)
self.name = name
config = json.load(open("config.json"))
self.hparams = config['hparams']
self.hparams.update(config['llms']['openai'].get('hparams') or {})
self.hparams.update(config['llms']['cohere'].get('hparams') or {})

def make_request(self, conversation, add_image=None, max_tokens=None):
prior_messages = [{"role": "USER" if i%2 == 0 else "CHATBOT", "message": content} for i,content in enumerate(conversation[:-1])]
Expand Down
3 changes: 1 addition & 2 deletions llms/moonshot_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ def __init__(self, name):
api_key = config['llms']['moonshot']['api_key'].strip()
self.client = OpenAI(api_key=api_key, base_url='https://api.moonshot.cn/v1')
self.name = name
config = json.load(open("config.json"))
self.hparams = config['hparams']
self.hparams.update(config['llms']['openai'].get('hparams') or {})
self.hparams.update(config['llms']['moonshot'].get('hparams') or {})

def make_request(self, conversation, add_image=None, max_tokens=None):
conversation = [{"role": "user" if i%2 == 0 else "assistant", "content": content} for i,content in enumerate(conversation)]
Expand Down
1 change: 0 additions & 1 deletion llms/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def __init__(self, name):
api_key = config['llms']['openai']['api_key'].strip()
self.client = OpenAI(api_key=api_key)
self.name = name
config = json.load(open("config.json"))
self.hparams = config['hparams']
self.hparams.update(config['llms']['openai'].get('hparams') or {})

Expand Down
2 changes: 1 addition & 1 deletion llms/vertexai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, name):
self.name = name
config = json.load(open("config.json"))
self.hparams = config['hparams']
self.hparams.update(config['llms']['mistral'].get('hparams') or {})
self.hparams.update(config['llms']['vertexai'].get('hparams') or {})

self.name = name

Expand Down

0 comments on commit 75e11cb

Please sign in to comment.