Skip to content

Commit

Permalink
fix typos with get default bucket prefix for sm session (deepjavalibr…
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored and KexinFeng committed Aug 16, 2023
1 parent 246df41 commit 14d5880
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions tests/integration/llm/sagemaker-endpoint-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def get_sagemaker_session(default_bucket=DEFAULT_BUCKET,


def delete_s3_test_artifacts(sagemaker_session):
bucket = sagemaker_session.get_default_bucket()
prefix = sagemaker_session.get_default_bucket_prefix()
bucket = sagemaker_session.default_bucket()
prefix = sagemaker_session.default_bucket_prefix
s3 = boto3.resource("s3")
s3.Bucket(bucket).objects.filter(Prefix=prefix).delete()

Expand All @@ -150,7 +150,8 @@ def get_name_for_resource(name):

def mme_test(name):
config = MME_CONFIGS.get(name)
session = get_sagemaker_session(default_bucket_prefix="mme-tests")
session = get_sagemaker_session(
default_bucket_prefix=get_name_for_resource("mme-tests"))
models = config.get("models")
created_models = []
mme = None
Expand All @@ -168,8 +169,8 @@ def mme_test(name):
created_models.append(model)

mme = MultiDataModel(get_name_for_resource(name),
"s3://" + session.get_default_bucket() + '/' +
session.get_default_bucket_prefix(),
"s3://" + session.default_bucket() + '/' +
session.default_bucket_prefix,
config.get("prefix"),
image_uri=config.get("image_uri"),
role=ROLE,
Expand All @@ -181,7 +182,6 @@ def mme_test(name):
config.get("instance_type", DEFAULT_INSTANCE_TYPE),
serializer=sagemaker.serializers.JSONSerializer(),
deserializer=sagemaker.deserializers.JSONDeserializer())
assert len(created_models) == len(list(mme.list_models()))
for model in list(mme.list_models()):
outputs = predictor.predict(DEFAULT_PAYLOAD, target_model=model)
print(outputs)
Expand All @@ -202,7 +202,8 @@ def mme_test(name):
def no_code_endpoint_test(name):
config = HUGGING_FACE_NO_CODE_CONFIGS.get(name)
data = config.get("payload", DEFAULT_PAYLOAD)
session = get_sagemaker_session(default_bucket_prefix="no-code-tests")
session = get_sagemaker_session(
default_bucket_prefix=get_name_for_resource("no-code-tests"))
model = None
predictor = None
try:
Expand Down Expand Up @@ -236,7 +237,7 @@ def single_model_endpoint_test(name):
config = SINGLE_MODEL_ENDPOINT_CONFIGS.get(name)
data = config.get("payload", DEFAULT_PAYLOAD)
session = get_sagemaker_session(
default_bucket_prefix="single_endpoint-tests")
default_bucket_prefix=get_name_for_resource("single_endpoint-tests"))
model = None
predictor = None
try:
Expand Down

0 comments on commit 14d5880

Please sign in to comment.