Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cache] Fix the wrong cache when local file input in a pipeline #34743

Merged
merged 14 commits into from
Mar 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import json
import logging
import os
import re
import time
import typing
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple, Union
Expand All @@ -16,6 +18,7 @@
from azure.ai.ml._restclient.v2022_10_01.models import ComponentVersion, ComponentVersionProperties
from azure.ai.ml._schema import PathAwareSchema
from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentSchema
from azure.ai.ml._utils._asset_utils import get_object_hash
from azure.ai.ml._utils.utils import hash_dict, is_data_binding_expression
from azure.ai.ml.constants._common import ARM_ID_PREFIX, ASSET_ARM_ID_REGEX_FORMAT, COMPONENT_TYPE
from azure.ai.ml.constants._component import ComponentSource, NodeType
Expand Down Expand Up @@ -334,6 +337,27 @@ def _get_anonymous_hash(self) -> str:
# command component), so we just use rest object to generate hash for pipeline component,
# which doesn't have reuse issue.
component_interface_dict = self._to_rest_object().properties.component_spec
# Hash local inputs in pipeline component jobs
for job_name, job in self.jobs.items():
if getattr(job, "inputs", None):
for input_name, input_value in job.inputs.items():
try:
if (
isinstance(input_value._data, Input)
and input_value.path
and os.path.exists(input_value.path)
wangchao1230 marked this conversation as resolved.
Show resolved Hide resolved
):
start_time = time.time()
component_interface_dict["jobs"][job_name]["inputs"][input_name][
"content_hash"
] = get_object_hash(input_value.path)
module_logger.debug(
"Takes %s seconds to calculate the content hash of local input %s",
time.time() - start_time,
input_value.path,
)
except ValidationException:
wangchao1230 marked this conversation as resolved.
Show resolved Hide resolved
pass
hash_value: str = hash_dict(
component_interface_dict,
keys_to_omit=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2194,3 +2194,24 @@ def test_pipeline_job_with_data_binding_expression_on_spark_resource(self, mock_
"instance_type": "${{parent.inputs.instance_type}}",
"runtime_version": "3.2.0",
}

def test_local_input_in_pipeline_job(self, client: MLClient, tmp_path: Path):
file_path = tmp_path / "mock_input_file"
file_path.touch(exist_ok=True)
component_path = "./tests/test_configs/components/1in1out.yaml"
component_func = load_component(source=component_path)

@pipeline()
def pipeline_with_local_input():
input_folder = Input(type="uri_folder", path=tmp_path)
component_func(input1=input_folder)

pipeline_obj = pipeline_with_local_input()
pipeline_obj.component.jobs["one_in_one_out"]._component = "mock_component_id"
pipeline_hash_id = pipeline_obj.component._get_anonymous_hash()

with open(file_path, "w") as f:
f.write("mock_file")

new_pipeline_hash_id = pipeline_obj.component._get_anonymous_hash()
assert new_pipeline_hash_id != pipeline_hash_id