Skip to content

Commit

Permalink
additional params for sagemaker
Browse files Browse the repository at this point in the history
  • Loading branch information
HuXiangkun committed Sep 5, 2024
1 parent d5cd604 commit ba57354
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "refchecker"
version = "0.2.5"
version = "0.2.6"
description = "RefChecker provides automatic checking pipeline for detecting fine-grained hallucinations generated by Large Language Models."
authors = [
"Xiangkun Hu <xiangkhu@amazon.com>",
Expand Down
5 changes: 4 additions & 1 deletion refchecker/checker/checker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def check(
is_joint: bool = False,
joint_check_num: int = 5,
sagemaker_client=None,
sagemaker_params=None,
**kwargs
):
"""
Expand Down Expand Up @@ -96,6 +97,7 @@ def check(
is_joint=True,
joint_check_num=joint_check_num,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
**kwargs
)
if merge_psg:
Expand Down Expand Up @@ -135,7 +137,8 @@ def check(
responses=[inp[2] for inp in input_flattened],
questions=[inp[3] for inp in input_flattened],
is_joint=False,
sagemaker_client=sagemaker_client
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params
)

ret = [[x] + y for x, y in zip(ret, input_ids)]
Expand Down
3 changes: 3 additions & 0 deletions refchecker/checker/llm_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _check(
is_joint: bool = False,
joint_check_num: int = 5,
sagemaker_client=None,
sagemaker_params=None,
**kwargs
):
"""
Expand Down Expand Up @@ -126,6 +127,7 @@ def _check(
max_new_tokens=joint_check_num * 10 + 100,
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
**kwargs
)

Expand Down Expand Up @@ -205,6 +207,7 @@ def _check(
max_new_tokens=10,
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
**kwargs
)

Expand Down
5 changes: 5 additions & 0 deletions refchecker/extractor/extractor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def extract(
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
**kwargs
):
if self.claim_format == 'triplet':
Expand All @@ -25,6 +26,7 @@ def extract(
batch_questions=batch_questions,
max_new_tokens=max_new_tokens,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
**kwargs
)
elif self.claim_format == 'subsentence':
Expand All @@ -33,6 +35,7 @@ def extract(
batch_questions=batch_questions,
max_new_tokens=max_new_tokens,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
**kwargs
)
return result
Expand All @@ -43,6 +46,7 @@ def extract_claim_triplets(
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
**kwargs
):
raise NotImplementedError
Expand All @@ -53,6 +57,7 @@ def extract_subsentence_claims(
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
**kwargs
):
raise NotImplementedError
Expand Down
3 changes: 3 additions & 0 deletions refchecker/extractor/llm_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def extract_subsentence_claims(
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
**kwargs
):
"""Extract subsentence claims from the response text.
Expand Down Expand Up @@ -74,6 +75,7 @@ def extract_subsentence_claims(
max_new_tokens=max_new_tokens,
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
**kwargs
)

Expand Down Expand Up @@ -146,6 +148,7 @@ def extract_claim_triplets(
max_new_tokens=max_new_tokens,
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
**kwargs
)

Expand Down
3 changes: 3 additions & 0 deletions refchecker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_model_batch_response(
max_new_tokens=500,
api_base=None,
sagemaker_client=None,
sagemaker_params=None,
**kwargs
):
"""
Expand Down Expand Up @@ -108,6 +109,8 @@ def get_model_batch_response(
"logits_processor" : None,
# "remove_invalid_values" : True
}
if sagemaker_params is not None:
parameters.update(sagemaker_params)
response_list = []
for prompt in prompts:
r = sagemaker_client.invoke_endpoint(
Expand Down

0 comments on commit ba57354

Please sign in to comment.