Skip to content

Commit

Permalink
[Cache] Fix the wrong cache when local file input in a pipeline (#34743)
Browse files Browse the repository at this point in the history
* fix cache bug

* fix test case

* fix code style

* format

* fix code style

* test format

* fix code format

* Add debug log

* Add debug log

* fix code style

* format code

* fix code style

* format code
  • Loading branch information
lalala123123 authored Mar 18, 2024
1 parent 5934b3c commit 2e3eb58
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import json
import logging
import os
import re
import time
import typing
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple, Union
Expand All @@ -16,6 +18,7 @@
from azure.ai.ml._restclient.v2022_10_01.models import ComponentVersion, ComponentVersionProperties
from azure.ai.ml._schema import PathAwareSchema
from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentSchema
from azure.ai.ml._utils._asset_utils import get_object_hash
from azure.ai.ml._utils.utils import hash_dict, is_data_binding_expression
from azure.ai.ml.constants._common import ARM_ID_PREFIX, ASSET_ARM_ID_REGEX_FORMAT, COMPONENT_TYPE
from azure.ai.ml.constants._component import ComponentSource, NodeType
Expand Down Expand Up @@ -334,6 +337,27 @@ def _get_anonymous_hash(self) -> str:
# command component), so we just use rest object to generate hash for pipeline component,
# which doesn't have reuse issue.
component_interface_dict = self._to_rest_object().properties.component_spec
# Hash local inputs in pipeline component jobs
for job_name, job in self.jobs.items():
if getattr(job, "inputs", None):
for input_name, input_value in job.inputs.items():
try:
if (
isinstance(input_value._data, Input)
and input_value.path
and os.path.exists(input_value.path)
):
start_time = time.time()
component_interface_dict["jobs"][job_name]["inputs"][input_name][
"content_hash"
] = get_object_hash(input_value.path)
module_logger.debug(
"Takes %s seconds to calculate the content hash of local input %s",
time.time() - start_time,
input_value.path,
)
except ValidationException:
pass
hash_value: str = hash_dict(
component_interface_dict,
keys_to_omit=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2194,3 +2194,24 @@ def test_pipeline_job_with_data_binding_expression_on_spark_resource(self, mock_
"instance_type": "${{parent.inputs.instance_type}}",
"runtime_version": "3.2.0",
}

def test_local_input_in_pipeline_job(self, client: MLClient, tmp_path: Path):
file_path = tmp_path / "mock_input_file"
file_path.touch(exist_ok=True)
component_path = "./tests/test_configs/components/1in1out.yaml"
component_func = load_component(source=component_path)

@pipeline()
def pipeline_with_local_input():
input_folder = Input(type="uri_folder", path=tmp_path)
component_func(input1=input_folder)

pipeline_obj = pipeline_with_local_input()
pipeline_obj.component.jobs["one_in_one_out"]._component = "mock_component_id"
pipeline_hash_id = pipeline_obj.component._get_anonymous_hash()

with open(file_path, "w") as f:
f.write("mock_file")

new_pipeline_hash_id = pipeline_obj.component._get_anonymous_hash()
assert new_pipeline_hash_id != pipeline_hash_id

0 comments on commit 2e3eb58

Please sign in to comment.