Skip to content

Commit

Permalink
customized function for llm invoking
Browse files Browse the repository at this point in the history
  • Loading branch information
HuXiangkun committed Sep 18, 2024
1 parent c575b53 commit 0f95e9d
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 80 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ You can first setup a [demo website](./demo/) and then use the web UI to try Ref
## 🚀 Quick Start

### Setup Environment
First create a python environment using conda or virtualenv. Clone this repo and change path into the root directory. Then install:
First create a python environment using conda or virtualenv. Then install:
```bash
pip install -e .
pip install refchecker
python -m spacy download en_core_web_sm
```

Install optional dependencies to use open source extractors (Mistral, Mixtral) or enable acceleration for RepCChecker.
```bash
pip install -e .[open-extractor,repcex]
pip install refchecker[open-extractor,repcex]
```

### Code Examples
Expand Down
12 changes: 3 additions & 9 deletions refchecker/checker/checker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def check(
merge_psg: bool = True,
is_joint: bool = False,
joint_check_num: int = 5,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
"""
Expand Down Expand Up @@ -97,9 +95,7 @@ def check(
questions=batch_questions,
is_joint=True,
joint_check_num=joint_check_num,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)
if merge_psg:
Expand Down Expand Up @@ -139,9 +135,7 @@ def check(
responses=[inp[2] for inp in input_flattened],
questions=[inp[3] for inp in input_flattened],
is_joint=False,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func
custom_llm_api_func=custom_llm_api_func,
)

ret = [[x] + y for x, y in zip(ret, input_ids)]
Expand Down
12 changes: 3 additions & 9 deletions refchecker/checker/llm_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def _check(
questions: List[str] = None,
is_joint: bool = False,
joint_check_num: int = 5,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
"""
Expand Down Expand Up @@ -127,9 +125,7 @@ def _check(
model=self.model,
max_new_tokens=joint_check_num * 10 + 100,
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)

Expand Down Expand Up @@ -208,9 +204,7 @@ def _check(
model=self.model,
max_new_tokens=10,
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)

Expand Down
20 changes: 5 additions & 15 deletions refchecker/extractor/extractor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,23 @@ def extract(
batch_responses,
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
if self.claim_format == 'triplet':
result = self.extract_claim_triplets(
batch_responses=batch_responses,
batch_questions=batch_questions,
max_new_tokens=max_new_tokens,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)
elif self.claim_format == 'subsentence':
result = self.extract_subsentence_claims(
batch_responses=batch_responses,
batch_questions=batch_questions,
max_new_tokens=max_new_tokens,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)
return result
Expand All @@ -48,9 +42,7 @@ def extract_claim_triplets(
batch_responses,
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
raise NotImplementedError
Expand All @@ -60,9 +52,7 @@ def extract_subsentence_claims(
batch_responses,
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
raise NotImplementedError
Expand Down
16 changes: 4 additions & 12 deletions refchecker/extractor/llm_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def extract_subsentence_claims(
batch_responses,
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
"""Extract subsentence claims from the response text.
Expand Down Expand Up @@ -75,9 +73,7 @@ def extract_subsentence_claims(
n_choices=1,
max_new_tokens=max_new_tokens,
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)

Expand All @@ -103,9 +99,7 @@ def extract_claim_triplets(
batch_responses,
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
"""Extract KG triplets from the response text.
Expand Down Expand Up @@ -151,9 +145,7 @@ def extract_claim_triplets(
n_choices=1,
max_new_tokens=max_new_tokens,
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)

Expand Down
36 changes: 4 additions & 32 deletions refchecker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ def get_model_batch_response(
n_choices=1,
max_new_tokens=500,
api_base=None,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
"""
Expand All @@ -99,35 +97,9 @@ def get_model_batch_response(
"""
if not prompts or len(prompts) == 0:
raise ValueError("Invalid input.")

if sagemaker_client is not None:
parameters = {
"max_new_tokens": max_new_tokens,
"temperature": temperature
}
if sagemaker_params is not None:
for k, v in sagemaker_params.items():
if k in parameters:
parameters[k] = v
response_list = []
for prompt in prompts:
r = sagemaker_client.invoke_endpoint(
EndpointName=model,
Body=json.dumps(
{
"inputs": prompt,
"parameters": parameters,
}
),
ContentType="application/json",
)
if sagemaker_get_response_func is not None:
response = sagemaker_get_response_func(r)
else:
r = json.loads(r['Body'].read().decode('utf8'))
response = r['outputs'][0]
response_list.append(response)
return response_list

if custom_func is not None:
return custom_func(prompts)
else:
message_list = []
for prompt in prompts:
Expand Down

0 comments on commit 0f95e9d

Please sign in to comment.