Skip to content

add llama2 model options #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion lib/assessment/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
VALID_LABELS = ["Extensive Evidence", "Convincing Evidence", "Limited Evidence", "No Evidence"]
# do not include gpt-4, so that we always know what version of the model we are using.
SUPPORTED_MODELS = ['gpt-4-0314', 'gpt-4-32k-0314', 'gpt-4-0613', 'gpt-4-32k-0613', 'gpt-4-1106-preview']
SUPPORTED_MODELS = [
'meta.llama2-13b-chat-v1',
'meta.llama2-70b-chat-v1',
'gpt-4-0314',
'gpt-4-32k-0314',
'gpt-4-0613',
'gpt-4-32k-0613',
'gpt-4-1106-preview'
]
DEFAULT_MODEL = 'gpt-4-0613'
LESSONS = {
# "U3-2022-L10" : "1ROCbvHb3yWGVoQqzKAjwdaF0dSRPUjy_",
Expand Down
71 changes: 71 additions & 0 deletions lib/assessment/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import requests
import logging
import boto3

from typing import List, Dict, Any
from lib.assessment.config import VALID_LABELS
Expand Down Expand Up @@ -48,6 +49,45 @@ def statically_label_student_work(self, rubric, student_code, student_id, exampl
return None

def ai_label_student_work(self, prompt, rubric, student_code, student_id, examples=[], num_responses=0, temperature=0.0, llm_model=""):
if llm_model.startswith("gpt"):
return self.openai_label_student_work(prompt, rubric, student_code, student_id, examples=examples, num_responses=num_responses, temperature=temperature, llm_model=llm_model)
elif llm_model.startswith("meta"):
return self.meta_label_student_work(prompt, rubric, student_code, student_id, examples=examples, num_responses=num_responses, temperature=temperature, llm_model=llm_model)
else:
raise Exception("Unknown model: {}".format(llm_model))

def meta_label_student_work(self, prompt, rubric, student_code, student_id, examples=[], num_responses=0, temperature=0.0, llm_model=""):
bedrock = boto3.client(service_name='bedrock-runtime')

meta_prompt = self.compute_meta_prompt(prompt, rubric, student_code, examples=examples)
body = json.dumps({
"prompt": meta_prompt,
"max_gen_len": 1024,
"temperature": temperature,
})
accept = 'application/json'
content_type = 'application/json'
response = bedrock.invoke_model(body=body, modelId=llm_model, accept=accept, contentType=content_type)

response_body = json.loads(response.get('body').read())
#logging.info(f"raw AI response:\n{response_body}")
generation = response_body.get('generation')

data = self.get_json_data_if_valid(generation, rubric, student_id)
#logging.info(f"AI response json:\n{json.dumps(data, indent=2)}")

return {
'metadata': {
'agent': 'meta',
'request': body,
},
'data': data,
}

def compute_meta_prompt(self, prompt, rubric, student_code, examples=[]):
return f"[INST]{prompt}[/INST]\n\nRubric:\n{rubric}\n\nStudent Code:\n{student_code}\n\nEvaluation (JSON):\n"

def openai_label_student_work(self, prompt, rubric, student_code, student_id, examples=[], num_responses=0, temperature=0.0, llm_model=""):
# Determine the OpenAI URL and headers
api_url = 'https://api.openai.com/v1/chat/completions'
headers = {
Expand Down Expand Up @@ -178,6 +218,7 @@ def tsv_data_from_choices(self, info, rubric, student_id):
tsv_data = self.get_consensus_response(tsv_data_choices, student_id)
return tsv_data

# TODO: rename to compute_openai_messages
def compute_messages(self, prompt, rubric, student_code, examples=[]):
messages = [
{'role': 'system', 'content': f"{prompt}\n\nRubric:\n{rubric}"}
Expand All @@ -188,6 +229,36 @@ def compute_messages(self, prompt, rubric, student_code, examples=[]):
messages.append({'role': 'user', 'content': student_code})
return messages

def get_json_data_if_valid(self, response_text, rubric, student_id):
# ensure that the first non-whitespace character is '['
if not response_text or response_text.strip()[0] != '[':
logging.error(f"{student_id} Response does not start with '[': {response_text}")
return None

# capture all data from the first '[' to the first ']', inclusive
match = re.match(r'(\[[^\]]+\])', response_text)
if not match:
logging.error(f"{student_id} Invalid response: no valid JSON data: {response_text}")
return None
json_text = match.group(1)

# parse the JSON data
try:
data = json.loads(json_text)
except json.JSONDecodeError as e:
logging.error(f"{student_id} JSON decoding error: {e}\n{json_text}")
return None

# rename Grade to Label
for row in data:
if "Grade" in row.keys():
row['Label'] = row['Grade']
del row['Grade']

# TODO: sanitize and validate json

return data

def get_tsv_data_if_valid(self, response_text, rubric, student_id, choice_index=None, reraise=False):
choice_text = f"Choice {choice_index}: " if choice_index is not None else ''
if not response_text:
Expand Down