diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml index d56e01c6b..8a36e5db6 100644 --- a/.github/workflows/e2e-tests.yml +++ b/.github/workflows/e2e-tests.yml @@ -12,8 +12,6 @@ on: jobs: e2e-tests: - env: - UV_HTTP_TIMEOUT: 1000 defaults: run: shell: bash @@ -33,7 +31,7 @@ jobs: restore-keys: ${{inputs.plugin}} - name: Install uv run: | - python -m pip install "uv==0.1.13" + python -m pip install "uv==0.2.21" - name: Install dependencies run: | cd ${{ inputs.plugin }} diff --git a/.github/workflows/kedro-datasets.yml b/.github/workflows/kedro-datasets.yml index e28cacc57..991a12731 100644 --- a/.github/workflows/kedro-datasets.yml +++ b/.github/workflows/kedro-datasets.yml @@ -38,8 +38,6 @@ jobs: check-docs: runs-on: ubuntu-latest - env: - UV_HTTP_TIMEOUT: 1000 steps: - name: Checkout code uses: actions/checkout@v4 @@ -55,7 +53,7 @@ jobs: restore-keys: kedro-datasets - name: Install uv run: | - python -m pip install "uv==0.1.13" + python -m pip install "uv==0.2.21" - name: Install dependencies run: | cd kedro-datasets diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 400fe994c..1594b493d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -12,8 +12,6 @@ on: jobs: lint: - env: - UV_HTTP_TIMEOUT: 1000 defaults: run: shell: bash @@ -33,7 +31,7 @@ jobs: restore-keys: ${{inputs.plugin}} - name: Install uv run: | - python -m pip install "uv==0.1.13" + python -m pip install "uv==0.2.21" - name: Install dependencies run: | cd ${{ inputs.plugin }} diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 27e437315..088a55d11 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -14,8 +14,6 @@ jobs: unit-tests: runs-on: ${{ inputs.os }} - env: - UV_HTTP_TIMEOUT: 1000 defaults: run: shell: bash @@ -45,7 +43,7 @@ jobs: uses: microsoft/setup-msbuild@v2 - name: Install uv run: | - python -m pip install "uv==0.1.13" + python -m pip install "uv==0.2.21" - name: Install dependencies run: | cd ${{ inputs.plugin }} diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 9bd3cb0ff..a4931c1fa 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,5 +1,11 @@ # Upcoming Release ## Major features and improvements +## Bug fixes and other changes +## Breaking Changes +## Community contributions + +# Release 4.0.0 +## Major features and improvements * Added the following new **experimental** datasets: @@ -9,21 +15,34 @@ | `langchain.ChatCohereDataset` | A dataset for loading a ChatCohere langchain model. | `kedro_datasets_experimental.langchain` | | `langchain.OpenAIEmbeddingsDataset` | A dataset for loading a OpenAIEmbeddings langchain model. | `kedro_datasets_experimental.langchain` | | `langchain.ChatOpenAIDataset` | A dataset for loading a ChatOpenAI langchain model. | `kedro_datasets_experimental.langchain` | +| `rioxarray.GeoTIFFDataset` | A dataset for loading and saving geotiff raster data | `kedro_datasets_experimental.rioxarray` | | `netcdf.NetCDFDataset` | A dataset for loading and saving "*.nc" files. | `kedro_datasets_experimental.netcdf` | -* `netcdf.NetCDFDataset` moved from `kedro_datasets` to `kedro_datasets_experimental`. * Added the following new core datasets: + | Type | Description | Location | |-------------------------------------|-----------------------------------------------------------|-----------------------------------------| | `dask.CSVDataset` | A dataset for loading a CSV files using `dask` | `kedro_datasets.dask` | * Extended preview feature to `yaml.YAMLDataset`. +## Bug fixes and other changes +* Added `metadata` parameter for a few datasets + +## Breaking Changes +* `netcdf.NetCDFDataset` moved from `kedro_datasets` to `kedro_datasets_experimental`. + ## Community contributions Many thanks to the following Kedroids for contributing PRs to this release: +* [Ian Whalen](https://github.com/ianwhale) +* [Charles Guan](https://github.com/charlesbmi) +* [Thomas Gölles](https://github.com/tgoelles) * [Lukas Innig](https://github.com/derluke) * [Michael Sexton](https://github.com/michaelsexton) +* [michal-mmm](https://github.com/michal-mmm) + + # Release 3.0.1 diff --git a/kedro-datasets/docs/source/api/kedro_datasets.rst b/kedro-datasets/docs/source/api/kedro_datasets.rst index 0109ebefc..4a7868d38 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets.rst @@ -11,56 +11,56 @@ kedro_datasets :toctree: :template: autosummary/class.rst - kedro_datasets.api.APIDataset - kedro_datasets.biosequence.BioSequenceDataset - kedro_datasets.dask.CSVDataset - kedro_datasets.dask.ParquetDataset - kedro_datasets.databricks.ManagedTableDataset - kedro_datasets.email.EmailMessageDataset - kedro_datasets.geopandas.GeoJSONDataset - kedro_datasets.holoviews.HoloviewsWriter - kedro_datasets.huggingface.HFDataset - kedro_datasets.huggingface.HFTransformerPipelineDataset - kedro_datasets.ibis.TableDataset - kedro_datasets.json.JSONDataset - kedro_datasets.matlab.MatlabDataset - kedro_datasets.matplotlib.MatplotlibWriter - kedro_datasets.networkx.GMLDataset - kedro_datasets.networkx.GraphMLDataset - kedro_datasets.networkx.JSONDataset - kedro_datasets.pandas.CSVDataset - kedro_datasets.pandas.DeltaTableDataset - kedro_datasets.pandas.ExcelDataset - kedro_datasets.pandas.FeatherDataset - kedro_datasets.pandas.GBQQueryDataset - kedro_datasets.pandas.GBQTableDataset - kedro_datasets.pandas.GenericDataset - kedro_datasets.pandas.HDFDataset - kedro_datasets.pandas.JSONDataset - kedro_datasets.pandas.ParquetDataset - kedro_datasets.pandas.SQLQueryDataset - kedro_datasets.pandas.SQLTableDataset - kedro_datasets.pandas.XMLDataset - kedro_datasets.partitions.IncrementalDataset - kedro_datasets.partitions.PartitionedDataset - kedro_datasets.pickle.PickleDataset - kedro_datasets.pillow.ImageDataset - kedro_datasets.plotly.JSONDataset - kedro_datasets.plotly.PlotlyDataset - kedro_datasets.polars.CSVDataset - kedro_datasets.polars.EagerPolarsDataset - kedro_datasets.polars.LazyPolarsDataset - kedro_datasets.redis.PickleDataset - kedro_datasets.snowflake.SnowparkTableDataset - kedro_datasets.spark.DeltaTableDataset - kedro_datasets.spark.SparkDataset - kedro_datasets.spark.SparkHiveDataset - kedro_datasets.spark.SparkJDBCDataset - kedro_datasets.spark.SparkStreamingDataset - kedro_datasets.svmlight.SVMLightDataset - kedro_datasets.tensorflow.TensorFlowModelDataset - kedro_datasets.text.TextDataset - kedro_datasets.tracking.JSONDataset - kedro_datasets.tracking.MetricsDataset - kedro_datasets.video.VideoDataset - kedro_datasets.yaml.YAMLDataset + api.APIDataset + biosequence.BioSequenceDataset + dask.CSVDataset + dask.ParquetDataset + databricks.ManagedTableDataset + email.EmailMessageDataset + geopandas.GeoJSONDataset + holoviews.HoloviewsWriter + huggingface.HFDataset + huggingface.HFTransformerPipelineDataset + ibis.TableDataset + json.JSONDataset + matlab.MatlabDataset + matplotlib.MatplotlibWriter + networkx.GMLDataset + networkx.GraphMLDataset + networkx.JSONDataset + pandas.CSVDataset + pandas.DeltaTableDataset + pandas.ExcelDataset + pandas.FeatherDataset + pandas.GBQQueryDataset + pandas.GBQTableDataset + pandas.GenericDataset + pandas.HDFDataset + pandas.JSONDataset + pandas.ParquetDataset + pandas.SQLQueryDataset + pandas.SQLTableDataset + pandas.XMLDataset + partitions.IncrementalDataset + partitions.PartitionedDataset + pickle.PickleDataset + pillow.ImageDataset + plotly.JSONDataset + plotly.PlotlyDataset + polars.CSVDataset + polars.EagerPolarsDataset + polars.LazyPolarsDataset + redis.PickleDataset + snowflake.SnowparkTableDataset + spark.DeltaTableDataset + spark.SparkDataset + spark.SparkHiveDataset + spark.SparkJDBCDataset + spark.SparkStreamingDataset + svmlight.SVMLightDataset + tensorflow.TensorFlowModelDataset + text.TextDataset + tracking.JSONDataset + tracking.MetricsDataset + video.VideoDataset + yaml.YAMLDataset diff --git a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst index fbae09589..c6e443564 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst @@ -11,8 +11,9 @@ kedro_datasets_experimental :toctree: :template: autosummary/class.rst - kedro_datasets_experimental.langchain.ChatAnthropicDataset - kedro_datasets_experimental.langchain.ChatCohereDataset - kedro_datasets_experimental.langchain.ChatOpenAIDataset - kedro_datasets_experimental.langchain.OpenAIEmbeddingsDataset - kedro_datasets_experimental.netcdf.NetCDFDataset + langchain.ChatAnthropicDataset + langchain.ChatCohereDataset + langchain.ChatOpenAIDataset + langchain.OpenAIEmbeddingsDataset + netcdf.NetCDFDataset + rioxarray.GeoTIFFDataset diff --git a/kedro-datasets/kedro_datasets/__init__.py b/kedro-datasets/kedro_datasets/__init__.py index dc901c852..def06a600 100644 --- a/kedro-datasets/kedro_datasets/__init__.py +++ b/kedro-datasets/kedro_datasets/__init__.py @@ -1,7 +1,7 @@ """``kedro_datasets`` is where you can find all of Kedro's data connectors.""" __all__ = ["KedroDeprecationWarning"] -__version__ = "3.0.1" +__version__ = "4.0.0" import sys import warnings diff --git a/kedro-datasets/kedro_datasets_experimental/rioxarray/__init__.py b/kedro-datasets/kedro_datasets_experimental/rioxarray/__init__.py new file mode 100644 index 000000000..b1f52ce01 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/rioxarray/__init__.py @@ -0,0 +1,13 @@ +"""``AbstractDataset`` implementation to load/save data from/to a geospatial raster files.""" +from __future__ import annotations + +from typing import Any + +import lazy_loader as lazy + +# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 +GeoTIFFDataset: Any + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, submod_attrs={"geotiff_dataset": ["GeoTIFFDataset"]} +) diff --git a/kedro-datasets/kedro_datasets_experimental/rioxarray/geotiff_dataset.py b/kedro-datasets/kedro_datasets_experimental/rioxarray/geotiff_dataset.py new file mode 100644 index 000000000..b69dea574 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/rioxarray/geotiff_dataset.py @@ -0,0 +1,209 @@ +"""GeoTIFFDataset loads geospatial raster data and saves it to a local geoiff file. The +underlying functionality is supported by rioxarray and xarray. A read rasterdata file +returns a xarray.DataArray object. +""" +import logging +from copy import deepcopy +from pathlib import PurePosixPath +from typing import Any + +import fsspec +import rasterio +import rioxarray as rxr +import xarray +from kedro.io import AbstractVersionedDataset, DatasetError +from kedro.io.core import Version, get_filepath_str, get_protocol_and_path +from rasterio.crs import CRS +from rasterio.transform import from_bounds + +logger = logging.getLogger(__name__) + +SUPPORTED_DIMS = [("band", "x", "y"), ("x", "y")] +DEFAULT_NO_DATA_VALUE = -9999 +SUPPORTED_FILE_FORMATS = [".tif", ".tiff"] + + +class GeoTIFFDataset(AbstractVersionedDataset[xarray.DataArray, xarray.DataArray]): + """``GeoTIFFDataset`` loads and saves rasterdata files and reads them as xarray + DataArrays. The underlying functionality is supported by rioxarray, rasterio and xarray. + + Reading and writing of single and multiband GeoTIFFs data is supported. There are sanity checks to ensure that a coordinate reference system (CRS) is present. + Supported dimensions are ("band", "x", "y") and ("x", "y") and xarray.DataArray with other dimension can not be saved to a GeoTIFF file. + Have a look at netcdf if this is what you need. + + + .. code-block:: yaml + + sentinal_data: + type: rioxarray.GeoTIFFDataset + filepath: sentinal_data.tif + + Example usage for the + `Python API `_: + + .. code-block:: pycon + + >>> from kedro_datasets.rioxarray import GeoTIFFDataset + >>> import xarray as xr + >>> import numpy as np + >>> + >>> data = xr.DataArray( + ... np.random.randn(2, 3, 2), + ... dims=("band", "y", "x"), + ... coords={"band": [1, 2], "y": [0.5, 1.5, 2.5], "x": [0.5, 1.5]} + ... ) + >>> data_crs = data.rio.write_crs("epsg:4326") + >>> data_spatial_dims = data_crs.rio.set_spatial_dims("x", "y") + >>> dataset = GeoTIFFDataset(filepath="test.tif") + >>> dataset.save(data_spatial_dims) + >>> reloaded = dataset.load() + >>> xr.testing.assert_allclose(data_spatial_dims, reloaded, rtol=1e-5) + + """ + + DEFAULT_LOAD_ARGS: dict[str, Any] = {} + DEFAULT_SAVE_ARGS: dict[str, Any] = {} + + def __init__( # noqa: PLR0913 + self, + *, + filepath: str, + load_args: dict[str, Any] | None = None, + save_args: dict[str, Any] | None = None, + version: Version | None = None, + metadata: dict[str, Any] | None = None, + ): + """Creates a new instance of ``GeoTIFFDataset`` pointing to a concrete + geospatial raster data file. + + + Args: + filepath: Filepath in POSIX format to a rasterdata file. + The prefix should be any protocol supported by ``fsspec``. + load_args: rioxarray options for loading rasterdata files. + Here you can find all available arguments: + https://corteva.github.io/rioxarray/html/rioxarray.html#rioxarray-open-rasterio + All defaults are preserved. + save_args: options for rioxarray for data without the band dimension and rasterio otherwhise. + version: If specified, should be an instance of + ``kedro.io.core.Version``. If its ``load`` attribute is + None, the latest version will be loaded. If its ``save`` + attribute is None, save version will be autogenerated. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + """ + protocol, path = get_protocol_and_path(filepath, version) + self._protocol = protocol + self._fs = fsspec.filesystem(self._protocol) + self.metadata = metadata + + super().__init__( + filepath=PurePosixPath(path), + version=version, + exists_function=self._fs.exists, + glob_function=self._fs.glob, + ) + + # Handle default load and save arguments + self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + def _describe(self) -> dict[str, Any]: + return { + "filepath": self._filepath, + "protocol": self._protocol, + "load_args": self._load_args, + "save_args": self._save_args, + "version": self._version, + } + + def _load(self) -> xarray.DataArray: + load_path = self._get_load_path().as_posix() + with rasterio.open(load_path) as data: + tags = data.tags() + data = rxr.open_rasterio(load_path, **self._load_args) + data.attrs.update(tags) + self._sanity_check(data) + logger.info(f"found coordinate rerence system {data.rio.crs}") + return data + + def _save(self, data: xarray.DataArray) -> None: + self._sanity_check(data) + save_path = get_filepath_str(self._get_save_path(), self._protocol) + if not save_path.endswith(tuple(SUPPORTED_FILE_FORMATS)): + raise ValueError( + f"Unsupported file format. Supported formats are: {SUPPORTED_FILE_FORMATS}" + ) + if "band" in data.dims: + self._save_multiband(data, save_path) + else: + data.rio.to_raster(save_path, **self._save_args) + self._fs.invalidate_cache(save_path) + + def _exists(self) -> bool: + try: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + except DatasetError: + return False + + return self._fs.exists(load_path) + + def _release(self) -> None: + super()._release() + self._invalidate_cache() + + def _invalidate_cache(self) -> None: + """Invalidate underlying filesystem caches.""" + filepath = get_filepath_str(self._filepath, self._protocol) + self._fs.invalidate_cache(filepath) + + def _save_multiband(self, data: xarray.DataArray, save_path: str): + """Saving multiband raster data to a geotiff file.""" + bands_data = [data.sel(band=band) for band in data.band.values] + transform = from_bounds( + west=data.x.min(), + south=data.y.min(), + east=data.x.max(), + north=data.y.max(), + width=data[0].shape[1], + height=data[0].shape[0], + ) + + nodata_value = ( + data.rio.nodata if data.rio.nodata is not None else DEFAULT_NO_DATA_VALUE + ) + crs = data.rio.crs + + meta = { + "driver": "GTiff", + "height": bands_data[0].shape[0], + "width": bands_data[0].shape[1], + "count": len(bands_data), + "dtype": str(bands_data[0].dtype), + "crs": crs, + "transform": transform, + "nodata": nodata_value, + } + with rasterio.open(save_path, "w", **meta) as dst: + for idx, band in enumerate(bands_data, start=1): + dst.write(band.data, idx, **self._save_args) + + def _sanity_check(self, data: xarray.DataArray) -> None: + """Perform sanity checks on the data to ensure it meets the requirements.""" + if not isinstance(data, xarray.DataArray): + raise NotImplementedError( + "Currently only supporting xarray.DataArray while saving raster data." + ) + + if not isinstance(data.rio.crs, CRS): + raise ValueError("Dataset lacks a coordinate reference system.") + + if all(set(data.dims) != set(dims) for dims in SUPPORTED_DIMS): + raise ValueError( + f"Data has unsupported dimensions: {data.dims}. Supported dimensions are: {SUPPORTED_DIMS}" + ) diff --git a/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/__init__.py b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/cog.tif b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/cog.tif new file mode 100644 index 000000000..e2bc24a1c Binary files /dev/null and b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/cog.tif differ diff --git a/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_geotiff_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_geotiff_dataset.py new file mode 100644 index 000000000..7f217eee6 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_geotiff_dataset.py @@ -0,0 +1,181 @@ +from pathlib import Path + +import numpy as np +import pytest +import rasterio +import xarray as xr +from kedro.io import DatasetError +from rasterio.crs import CRS + +from kedro_datasets_experimental.rioxarray.geotiff_dataset import GeoTIFFDataset + + +@pytest.fixture +def cog_file_path() -> str: + cog_file_path = Path(__file__).parent / "cog.tif" + return cog_file_path.as_posix() + +@pytest.fixture +def multi1_file_path() -> str: + path = Path(__file__).parent / "test_multi1.tif" + return path.as_posix() + +@pytest.fixture +def multi2_file_path() -> str: + path = Path(__file__).parent / "test_multi2.tif" + return path.as_posix() + +@pytest.fixture +def synthetic_xarray(): + """Create a synthetic xarray.DataArray with CRS information.""" + data = xr.DataArray( + np.random.rand(100, 100), + dims=("y", "x"), + coords={"x": np.linspace(0, 100, 100), "y": np.linspace(0, 100, 100)} + ) + data.rio.write_crs("epsg:4326", inplace=True) + return data + +@pytest.fixture +def synthetic_xarray_multiband(): + """Create a synthetic xarray.DataArray with CRS information.""" + data = xr.DataArray( + np.random.rand(10, 100, 100), + dims=("band", "y", "x"), + coords={"x": np.linspace(0, 100, 100), "y": np.linspace(0, 100, 100)} + ) + data.rio.write_crs("epsg:4326", inplace=True) + return data + +@pytest.fixture +def synthetic_xarray_many_vars_no_band(): + """Create a synthetic xarray.DataArray with CRS information.""" + data = xr.DataArray( + np.random.rand(2,3,4, 100, 100), + dims=("var1","var2","var3","y", "x"), + coords={"x": np.linspace(0, 100, 100), "y": np.linspace(0, 100, 100)} + ) + data.rio.write_crs("epsg:4326", inplace=True) + return data + + +@pytest.fixture +def cog_geotiff_dataset(cog_file_path, save_args) -> GeoTIFFDataset: + return GeoTIFFDataset(filepath=cog_file_path, save_args=save_args) + + +def test_load_cog_geotiff(cog_geotiff_dataset): + """Test loading cloud optimised geotiff reloading the data set.""" + loaded_xr = cog_geotiff_dataset.load() + assert isinstance(loaded_xr.rio.crs, CRS) + assert isinstance(loaded_xr, xr.DataArray) + assert loaded_xr.shape == (1, 500, 500) + assert loaded_xr.dims == ("band", "y", "x") + +def test_load_save_cog(tmp_path,cog_file_path): + """Test loading a multiband raster file.""" + dataset = GeoTIFFDataset(filepath=cog_file_path) + loaded_xr = dataset.load() + band1_data = loaded_xr.sel(band=1) + target_file = tmp_path / "tmp22.tif" + dataset_to = GeoTIFFDataset(filepath=str(target_file)) + dataset_to.save(loaded_xr) + reloaded_xr = dataset_to.load() + assert target_file.exists() + assert isinstance(loaded_xr.rio.crs, CRS) + assert isinstance(loaded_xr, xr.DataArray) + assert len(loaded_xr.band) == 1 + assert loaded_xr.dims == ("band", "y", "x") + assert loaded_xr.shape == (1, 500, 500) + assert np.isclose(band1_data.values.std(), 4688.72624578268) + assert (loaded_xr.values == reloaded_xr.values).all() + + + +def test_load_save_multi1(tmp_path,multi1_file_path): + """Test loading a multiband raster file.""" + dataset = GeoTIFFDataset(filepath=multi1_file_path) + dataset_to = GeoTIFFDataset(filepath=str(tmp_path / "tmp.tif")) + loaded_xr = dataset.load() + band1_data = loaded_xr.sel(band=1) + assert isinstance(loaded_xr.rio.crs, CRS) + assert isinstance(loaded_xr, xr.DataArray) + BAND_COUNT = 2 + assert len(loaded_xr.band) == BAND_COUNT + assert loaded_xr.shape == (BAND_COUNT, 5, 5) + assert loaded_xr.dims == ("band", "y", "x") + assert np.isclose(band1_data.values.std(), 0.015918046) + dataset_to.save(loaded_xr) + reloaded_xr = dataset_to.load() + assert (loaded_xr.values == reloaded_xr.values).all() + +def test_load_geotiff_with_tags(tmp_path, synthetic_xarray): + filepath = tmp_path / "test_with_tags.tif" + tags = {"TAG_KEY": "TAG_VALUE", "ANOTHER_TAG": "ANOTHER_VALUE"} + with rasterio.open( + filepath, "w", driver="GTiff", height=100, width=100, count=1, dtype=str(synthetic_xarray.dtype), + crs="EPSG:4326" + ) as dst: + dst.write(synthetic_xarray.values, 1) + dst.update_tags(**tags) + + dataset = GeoTIFFDataset(filepath=str(filepath)) + loaded_xr = dataset.load() + + assert loaded_xr.attrs["TAG_KEY"] == "TAG_VALUE" + assert loaded_xr.attrs["ANOTHER_TAG"] == "ANOTHER_VALUE" + + assert isinstance(loaded_xr, xr.DataArray) + assert isinstance(loaded_xr.rio.crs, CRS) + assert loaded_xr.shape == (1, 100, 100) + +def test_load_no_crs(multi2_file_path): + """Test loading a multiband raster file.""" + dataset = GeoTIFFDataset(filepath=multi2_file_path) + with pytest.raises(DatasetError): + dataset.load() + +def test_load_not_tif(): + """Test loading a multiband raster file.""" + dataset = GeoTIFFDataset(filepath="whatever.nc") + with pytest.raises(DatasetError): + dataset.load() + + +def test_exists(tmp_path, synthetic_xarray): + """Test `exists` method invocation for both existing and + nonexistent data set.""" + dataset = GeoTIFFDataset(filepath=str(tmp_path / "tmp.tif")) + assert not dataset.exists() + dataset.save(synthetic_xarray) + assert dataset.exists() + +@pytest.mark.parametrize("xarray_fixture", [ + "synthetic_xarray_multiband", + "synthetic_xarray", +]) +def test_save_and_load_geotiff(tmp_path, request, xarray_fixture): + """Test saving and reloading the data set.""" + xarray_data = request.getfixturevalue(xarray_fixture) + dataset = GeoTIFFDataset(filepath=str(tmp_path / "tmp.tif")) + dataset.save(xarray_data) + assert dataset.exists() + reloaded_xr = dataset.load() + assert isinstance(reloaded_xr, xr.DataArray) + assert isinstance(reloaded_xr.rio.crs, CRS) + assert reloaded_xr.dims == ("band", "y", "x") + assert (xarray_data.values == reloaded_xr.values).all() + +def test_save_and_load_geotiff_no_band(tmp_path, synthetic_xarray_many_vars_no_band): + """this test should fail because the data array has no band dimension""" + dataset = GeoTIFFDataset(filepath=str(tmp_path / "tmp.tif")) + with pytest.raises(DatasetError): + dataset.save(synthetic_xarray_many_vars_no_band) + +def test_load_missing_file(tmp_path): + """Check the error when trying to load missing file.""" + dataset = GeoTIFFDataset(filepath=str(tmp_path / "tmp.tif")) + assert not dataset._exists(), "File unexpectedly exists" + pattern = r"Failed while loading data from data set GeoTIFFDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + dataset.load() diff --git a/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_multi1.tif b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_multi1.tif new file mode 100644 index 000000000..bfc0c6a2c Binary files /dev/null and b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_multi1.tif differ diff --git a/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_multi2.tif b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_multi2.tif new file mode 100644 index 000000000..6dfbedb6a Binary files /dev/null and b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_multi2.tif differ diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index d9eede037..f3012ea0d 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -167,9 +167,13 @@ langchain-openaiembeddingsdataset = ["langchain-openai~=0.1.7"] langchain-chatanthropicdataset = ["langchain-anthropic~=0.1.13", "langchain-community~=0.2.0"] langchain-chatcoheredataset = ["langchain-cohere~=0.1.5", "langchain-community~=0.2.0"] langchain = ["kedro-datasets[langchain-chatopenaidataset,langchain-openaiembeddingsdataset,langchain-chatanthropicdataset,langchain-chatcoheredataset ]"] + netcdf-netcdfdataset = ["h5netcdf>=1.2.0","netcdf4>=1.6.4","xarray>=2023.1.0"] netcdf = ["kedro-datasets[netcdf-netcdfdataset]"] +rioxarray-geotiffdataset = ["rioxarray>=0.15.0"] +rioxarray = ["kedro-datasets[rioxarray-geotiffdataset]"] + # Docs requirements docs = [ "kedro-sphinx-theme==2024.4.0", @@ -179,6 +183,7 @@ docs = [ # Test requirements test = [ + "accelerate<0.32", # Temporary pin "adlfs~=2023.1", "bandit>=1.6.2, <2.0", "behave==1.2.6", @@ -270,6 +275,7 @@ experimental = [ "h5netcdf>=1.2.0", "netcdf4>=1.6.4", "xarray>=2023.1.0", + "rioxarray", ] # All requirements diff --git a/kedro-telemetry/README.md b/kedro-telemetry/README.md index 47584a3e4..c5ca2f816 100644 --- a/kedro-telemetry/README.md +++ b/kedro-telemetry/README.md @@ -45,6 +45,7 @@ To withdraw consent, you can change the `consent` variable to `false` in `.telem ```yaml consent: false ``` +You can also set `DO_NOT_TRACK` or `KEDRO_DISABLE_TELEMETRY` environment variable to `True`. Or you can uninstall the plugin: diff --git a/kedro-telemetry/RELEASE.md b/kedro-telemetry/RELEASE.md index f50ecaf7c..f92ddf548 100644 --- a/kedro-telemetry/RELEASE.md +++ b/kedro-telemetry/RELEASE.md @@ -1,4 +1,5 @@ # Upcoming release +* Added `DO_NOT_TRACK` and `KEDRO_DISABLE_TELEMETRY` environment variables to skip telemetry. # Release 0.5.0 * Updated the plugin to generate a unique project UUID for kedro project and store it in `pyproject.toml`. diff --git a/kedro-telemetry/kedro_telemetry/__init__.py b/kedro-telemetry/kedro_telemetry/__init__.py index 5c139beba..c8c39dd9b 100644 --- a/kedro-telemetry/kedro_telemetry/__init__.py +++ b/kedro-telemetry/kedro_telemetry/__init__.py @@ -1,3 +1,7 @@ """Kedro plugin for collecting Kedro usage data.""" __version__ = "0.5.0" + +import logging + +logging.getLogger(__name__).setLevel(logging.INFO) diff --git a/kedro-telemetry/kedro_telemetry/plugin.py b/kedro-telemetry/kedro_telemetry/plugin.py index f0b2485ee..6afb3ef95 100644 --- a/kedro-telemetry/kedro_telemetry/plugin.py +++ b/kedro-telemetry/kedro_telemetry/plugin.py @@ -8,12 +8,10 @@ import os import sys import uuid -from copy import deepcopy from datetime import datetime from pathlib import Path from typing import Any -import click import requests import toml import yaml @@ -43,6 +41,10 @@ "TRAVIS", # https://docs.travis-ci.com/user/environment-variables/#default-environment-variables "BUILDKITE", # https://buildkite.com/docs/pipelines/environment-variables } +_SKIP_TELEMETRY_ENV_VAR_KEYS = ( + "DO_NOT_TRACK", + "KEDRO_DISABLE_TELEMETRY", +) TIMESTAMP_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" CONFIG_FILENAME = "telemetry.toml" PYPROJECT_CONFIG_NAME = "pyproject.toml" @@ -150,101 +152,119 @@ def _generate_new_uuid(full_path: str) -> str: return "" -class KedroTelemetryCLIHooks: +class KedroTelemetryHook: """Hook to send CLI command data to Heap""" + def __init__(self): + self._consent = None + self._sent = False + self._event_properties = None + self._project_path = None + self._user_uuid = None + @cli_hook_impl def before_command_run( self, project_metadata: ProjectMetadata, command_args: list[str] ): """Hook implementation to send command run data to Heap""" - try: - if not project_metadata: # in package mode - return - - consent = _check_for_telemetry_consent(project_metadata.project_path) - if not consent: - logger.debug( - "Kedro-Telemetry is installed, but you have opted out of " - "sharing usage analytics so none will be collected.", - ) - return - # get KedroCLI and its structure from actual project root - cli = KedroCLI(project_path=project_metadata.project_path) - cli_struct = _get_cli_structure(cli_obj=cli, get_help=False) - masked_command_args = _mask_kedro_cli( - cli_struct=cli_struct, command_args=command_args - ) - main_command = masked_command_args[0] if masked_command_args else "kedro" + if not project_metadata: # in package mode + return - logger.debug("You have opted into product usage analytics.") - user_uuid = _get_or_create_uuid() - project_properties = _get_project_properties( - user_uuid, project_metadata.project_path / PYPROJECT_CONFIG_NAME - ) - cli_properties = _format_user_cli_data( - project_properties, masked_command_args - ) + self._consent = _check_for_telemetry_consent(project_metadata.project_path) + if not self._consent: + self._opt_out_notification() + return - _send_heap_event( - event_name=f"Command run: {main_command}", - identity=user_uuid, - properties=cli_properties, - ) + # get KedroCLI and its structure from actual project root + cli = KedroCLI(project_path=project_metadata.project_path) + cli_struct = _get_cli_structure(cli_obj=cli, get_help=False) + masked_command_args = _mask_kedro_cli( + cli_struct=cli_struct, command_args=command_args + ) - # send generic event too, so it's easier in data processing - generic_properties = deepcopy(cli_properties) - generic_properties["main_command"] = main_command - _send_heap_event( - event_name="CLI command", - identity=user_uuid, - properties=generic_properties, - ) - except Exception as exc: - logger.warning( - "Something went wrong in hook implementation to send command run data to Heap. " - "Exception: %s", - exc, - ) + self._user_uuid = _get_or_create_uuid() + event_properties = _get_project_properties( + self._user_uuid, project_metadata.project_path / PYPROJECT_CONFIG_NAME + ) + event_properties["command"] = ( + f"kedro {' '.join(masked_command_args)}" if masked_command_args else "kedro" + ) + event_properties["main_command"] = ( + masked_command_args[0] if masked_command_args else "kedro" + ) -class KedroTelemetryProjectHooks: - """Hook to send project statistics data to Heap""" + self._event_properties = event_properties + + @cli_hook_impl + def after_command_run(self): + if self._consent and not self._sent: + self._send_telemetry_heap_event("CLI command") @hook_impl def after_context_created(self, context): """Hook implementation to send project statistics data to Heap""" - self.consent = _check_for_telemetry_consent(context.project_path) - self.project_path = context.project_path + + if self._consent is None: + self._consent = _check_for_telemetry_consent(context.project_path) + if not self._consent: + self._opt_out_notification() + self._project_path = context.project_path @hook_impl def after_catalog_created(self, catalog): - if not self.consent: - logger.debug( - "Kedro-Telemetry is installed, but you have opted out of " - "sharing usage analytics so none will be collected.", - ) + if self._consent is False: return - logger.debug("You have opted into product usage analytics.") - default_pipeline = pipelines.get("__default__") # __default__ - user_uuid = _get_or_create_uuid() - project_properties = _get_project_properties( - user_uuid, self.project_path / PYPROJECT_CONFIG_NAME + if not self._user_uuid: + self._user_uuid = _get_or_create_uuid() + + if not self._event_properties: + self._event_properties = _get_project_properties( + self._user_uuid, self._project_path / PYPROJECT_CONFIG_NAME + ) + + project_properties = _format_project_statistics_data( + catalog, default_pipeline, pipelines ) + self._event_properties.update(project_properties) - project_statistics_properties = _format_project_statistics_data( - project_properties, catalog, default_pipeline, pipelines + self._send_telemetry_heap_event("Kedro Project Statistics") + + def _opt_out_notification(self): + logger.info( + "Kedro-Telemetry is installed, but you have opted out of " + "sharing usage analytics so none will be collected.", ) - _send_heap_event( - event_name="Kedro Project Statistics", - identity=user_uuid, - properties=project_statistics_properties, + + def _send_telemetry_heap_event(self, event_name: str): + """Hook implementation to send command run data to Heap""" + + logger.info( + "Kedro is sending anonymous usage data with the sole purpose of improving the product. " + "No personal data or IP addresses are stored on our side. " + "If you want to opt out, set the `KEDRO_DISABLE_TELEMETRY` or `DO_NOT_TRACK` environment variables, " + "or create a `.telemetry` file in the current working directory with the contents `consent: false`. " + "Read more at https://docs.kedro.org/en/stable/configuration/telemetry.html" ) + try: + _send_heap_event( + event_name=event_name, + identity=self._user_uuid, + properties=self._event_properties, + ) + self._sent = True + except Exception as exc: + logger.warning( + "Something went wrong in hook implementation to send command run data to Heap. " + "Exception: %s", + exc, + ) + def _is_known_ci_env(known_ci_env_var_keys: set[str]): # Most CI tools will set the CI environment variable to true @@ -276,33 +296,20 @@ def _get_project_properties(user_uuid: str, pyproject_path: Path) -> dict: return properties -def _format_user_cli_data( - properties: dict, - command_args: list[str], -): - """Add format CLI command data to send to Heap.""" - cli_properties = properties.copy() - cli_properties["command"] = ( - f"kedro {' '.join(command_args)}" if command_args else "kedro" - ) - return cli_properties - - def _format_project_statistics_data( - properties: dict, catalog: DataCatalog, default_pipeline: Pipeline, project_pipelines: dict, ): """Add project statistics to send to Heap.""" - project_statistics_properties = properties.copy() + project_statistics_properties = {} project_statistics_properties["number_of_datasets"] = sum( 1 for c in catalog.list() if not c.startswith("parameters") and not c.startswith("params:") ) project_statistics_properties["number_of_nodes"] = ( - len(default_pipeline.nodes) if default_pipeline else None + len(default_pipeline.nodes) if default_pipeline else None # type: ignore ) project_statistics_properties["number_of_pipelines"] = len(project_pipelines.keys()) return project_statistics_properties @@ -347,14 +354,21 @@ def _send_heap_event( def _check_for_telemetry_consent(project_path: Path) -> bool: + """ + Use telemetry consent from ".telemetry" file if it exists and has a valid format. + Telemetry is considered as opt-in otherwise. + """ telemetry_file_path = project_path / ".telemetry" - if not telemetry_file_path.exists(): - return _confirm_consent(telemetry_file_path) - with open(telemetry_file_path, encoding="utf-8") as telemetry_file: - telemetry = yaml.safe_load(telemetry_file) - if _is_valid_syntax(telemetry): - return telemetry["consent"] - return _confirm_consent(telemetry_file_path) + + for env_var in _SKIP_TELEMETRY_ENV_VAR_KEYS: + if os.environ.get(env_var): + return False + if telemetry_file_path.exists(): + with open(telemetry_file_path, encoding="utf-8") as telemetry_file: + telemetry = yaml.safe_load(telemetry_file) + if _is_valid_syntax(telemetry): + return telemetry["consent"] + return True def _is_valid_syntax(telemetry: Any) -> bool: @@ -363,35 +377,4 @@ def _is_valid_syntax(telemetry: Any) -> bool: ) -def _confirm_consent(telemetry_file_path: Path) -> bool: - try: - with telemetry_file_path.open("w") as telemetry_file: - confirm_msg = ( - "As an open-source project, we collect usage analytics. \n" - "We cannot see nor store information contained in " - "a Kedro project. \nYou can find out more by reading our " - "privacy notice: \n" - "https://github.com/kedro-org/kedro-plugins/tree/main/kedro-telemetry#" - "privacy-notice \n" - "Do you opt into usage analytics? " - ) - if click.confirm(confirm_msg): - yaml.dump({"consent": True}, telemetry_file) - click.secho("You have opted into product usage analytics.", fg="green") - return True - click.secho( - "You have opted out of product usage analytics, so none will be collected.", - fg="green", - ) - yaml.dump({"consent": False}, telemetry_file) - return False - except Exception as exc: - logger.warning( - "Failed to confirm consent. No data was sent to Heap. Exception: %s", - exc, - ) - return False - - -cli_hooks = KedroTelemetryCLIHooks() -project_hooks = KedroTelemetryProjectHooks() +telemetry_hook = KedroTelemetryHook() diff --git a/kedro-telemetry/pyproject.toml b/kedro-telemetry/pyproject.toml index bda085e85..32fb3d0b8 100644 --- a/kedro-telemetry/pyproject.toml +++ b/kedro-telemetry/pyproject.toml @@ -43,10 +43,10 @@ test = [ ] [project.entry-points."kedro.cli_hooks"] -kedro-telemetry = "kedro_telemetry.plugin:cli_hooks" +kedro-telemetry = "kedro_telemetry.plugin:telemetry_hook" [project.entry-points."kedro.hooks"] -kedro-telemetry = "kedro_telemetry.plugin:project_hooks" +kedro-telemetry = "kedro_telemetry.plugin:telemetry_hook" [tool.setuptools] include-package-data = true diff --git a/kedro-telemetry/tests/test_masking.py b/kedro-telemetry/tests/test_masking.py index edd7efe93..778e85a54 100644 --- a/kedro-telemetry/tests/test_masking.py +++ b/kedro-telemetry/tests/test_masking.py @@ -20,6 +20,7 @@ PACKAGE_NAME = "cli_tools_dummy_package" DEFAULT_KEDRO_COMMANDS = [ "catalog", + "info", "ipython", "jupyter", "micropkg", diff --git a/kedro-telemetry/tests/test_plugin.py b/kedro-telemetry/tests/test_plugin.py index 7348d63f4..9c2ba65fd 100644 --- a/kedro-telemetry/tests/test_plugin.py +++ b/kedro-telemetry/tests/test_plugin.py @@ -1,3 +1,4 @@ +import logging import sys from pathlib import Path @@ -13,11 +14,10 @@ from kedro_telemetry import __version__ as TELEMETRY_VERSION from kedro_telemetry.plugin import ( + _SKIP_TELEMETRY_ENV_VAR_KEYS, KNOWN_CI_ENV_VAR_KEYS, - KedroTelemetryCLIHooks, - KedroTelemetryProjectHooks, + KedroTelemetryHook, _check_for_telemetry_consent, - _confirm_consent, _is_known_ci_env, ) @@ -121,8 +121,8 @@ def fake_sub_pipeline(): return mock_sub_pipeline -class TestKedroTelemetryCLIHooks: - def test_before_command_run(self, mocker, fake_metadata): +class TestKedroTelemetryHook: + def test_before_command_run(self, mocker, fake_metadata, caplog): mocker.patch( "kedro_telemetry.plugin._check_for_telemetry_consent", return_value=True ) @@ -140,9 +140,12 @@ def test_before_command_run(self, mocker, fake_metadata): ) mocked_heap_call = mocker.patch("kedro_telemetry.plugin._send_heap_event") - telemetry_hook = KedroTelemetryCLIHooks() - command_args = ["--version"] - telemetry_hook.before_command_run(fake_metadata, command_args) + + with caplog.at_level(logging.INFO): + telemetry_hook = KedroTelemetryHook() + command_args = ["--version"] + telemetry_hook.before_command_run(fake_metadata, command_args) + telemetry_hook.after_command_run() expected_properties = { "username": "user_uuid", "project_id": "digested", @@ -150,8 +153,8 @@ def test_before_command_run(self, mocker, fake_metadata): "telemetry_version": TELEMETRY_VERSION, "python_version": sys.version, "os": sys.platform, - "command": "kedro --version", "is_ci_env": True, + "command": "kedro --version", } generic_properties = { **expected_properties, @@ -159,11 +162,6 @@ def test_before_command_run(self, mocker, fake_metadata): } expected_calls = [ - mocker.call( - event_name="Command run: --version", - identity="user_uuid", - properties=expected_properties, - ), mocker.call( event_name="CLI command", identity="user_uuid", @@ -171,6 +169,20 @@ def test_before_command_run(self, mocker, fake_metadata): ), ] assert mocked_heap_call.call_args_list == expected_calls + assert any( + "Kedro is sending anonymous usage data with the sole purpose of improving the product. " + "No personal data or IP addresses are stored on our side. " + "If you want to opt out, set the `KEDRO_DISABLE_TELEMETRY` or `DO_NOT_TRACK` environment variables, " + "or create a `.telemetry` file in the current working directory with the contents `consent: false`. " + "Read more at https://docs.kedro.org/en/stable/configuration/telemetry.html" + in record.message + for record in caplog.records + ) + assert not any( + "Kedro-Telemetry is installed, but you have opted out of " + "sharing usage analytics so none will be collected." in record.message + for record in caplog.records + ) def test_before_command_run_with_tools(self, mocker, fake_metadata): mocker.patch( @@ -192,9 +204,10 @@ def test_before_command_run_with_tools(self, mocker, fake_metadata): mocked_heap_call = mocker.patch("kedro_telemetry.plugin._send_heap_event") mocker.patch("builtins.open", mocker.mock_open(read_data=MOCK_PYPROJECT_TOOLS)) mocker.patch("pathlib.Path.exists", return_value=True) - telemetry_hook = KedroTelemetryCLIHooks() + telemetry_hook = KedroTelemetryHook() command_args = ["--version"] telemetry_hook.before_command_run(fake_metadata, command_args) + telemetry_hook.after_command_run() expected_properties = { "username": "user_uuid", "project_id": "digested", @@ -213,11 +226,6 @@ def test_before_command_run_with_tools(self, mocker, fake_metadata): } expected_calls = [ - mocker.call( - event_name="Command run: --version", - identity="user_uuid", - properties=expected_properties, - ), mocker.call( event_name="CLI command", identity="user_uuid", @@ -244,9 +252,10 @@ def test_before_command_run_empty_args(self, mocker, fake_metadata): ) mocked_heap_call = mocker.patch("kedro_telemetry.plugin._send_heap_event") - telemetry_hook = KedroTelemetryCLIHooks() + telemetry_hook = KedroTelemetryHook() command_args = [] telemetry_hook.before_command_run(fake_metadata, command_args) + telemetry_hook.after_command_run() expected_properties = { "username": "user_uuid", "project_id": "digested", @@ -263,11 +272,6 @@ def test_before_command_run_empty_args(self, mocker, fake_metadata): } expected_calls = [ - mocker.call( - event_name="Command run: kedro", - identity="user_uuid", - properties=expected_properties, - ), mocker.call( event_name="CLI command", identity="user_uuid", @@ -277,29 +281,45 @@ def test_before_command_run_empty_args(self, mocker, fake_metadata): assert mocked_heap_call.call_args_list == expected_calls - def test_before_command_run_no_consent_given(self, mocker, fake_metadata): + def test_before_command_run_no_consent_given(self, mocker, fake_metadata, caplog): mocker.patch( "kedro_telemetry.plugin._check_for_telemetry_consent", return_value=False ) mocked_heap_call = mocker.patch("kedro_telemetry.plugin._send_heap_event") - telemetry_hook = KedroTelemetryCLIHooks() - command_args = ["--version"] - telemetry_hook.before_command_run(fake_metadata, command_args) + with caplog.at_level(logging.INFO): + telemetry_hook = KedroTelemetryHook() + command_args = ["--version"] + telemetry_hook.before_command_run(fake_metadata, command_args) mocked_heap_call.assert_not_called() + assert not any( + "Kedro is sending anonymous usage data with the sole purpose of improving the product. " + "No personal data or IP addresses are stored on our side. " + "If you want to opt out, set the `KEDRO_DISABLE_TELEMETRY` or `DO_NOT_TRACK` environment variables, " + "or create a `.telemetry` file in the current working directory with the contents `consent: false`. " + "Read more at https://docs.kedro.org/en/latest/configuration/telemetry.html" + in record.message + for record in caplog.records + ) + assert any( + "Kedro-Telemetry is installed, but you have opted out of " + "sharing usage analytics so none will be collected." in record.message + for record in caplog.records + ) def test_before_command_run_connection_error(self, mocker, fake_metadata, caplog): mocker.patch( "kedro_telemetry.plugin._check_for_telemetry_consent", return_value=True ) - telemetry_hook = KedroTelemetryCLIHooks() + telemetry_hook = KedroTelemetryHook() command_args = ["--version"] mocked_post_request = mocker.patch( "requests.post", side_effect=requests.exceptions.ConnectionError() ) telemetry_hook.before_command_run(fake_metadata, command_args) + telemetry_hook.after_command_run() msg = "Failed to send data to Heap. Exception of type 'ConnectionError' was raised." assert msg in caplog.messages[-1] mocked_post_request.assert_called() @@ -315,9 +335,10 @@ def test_before_command_run_anonymous(self, mocker, fake_metadata): mocker.patch("builtins.open", side_effect=OSError) mocked_heap_call = mocker.patch("kedro_telemetry.plugin._send_heap_event") - telemetry_hook = KedroTelemetryCLIHooks() + telemetry_hook = KedroTelemetryHook() command_args = ["--version"] telemetry_hook.before_command_run(fake_metadata, command_args) + telemetry_hook.after_command_run() expected_properties = { "username": "", "command": "kedro --version", @@ -334,11 +355,6 @@ def test_before_command_run_anonymous(self, mocker, fake_metadata): } expected_calls = [ - mocker.call( - event_name="Command run: --version", - identity="", - properties=expected_properties, - ), mocker.call( event_name="CLI command", identity="", @@ -354,10 +370,11 @@ def test_before_command_run_heap_call_error(self, mocker, fake_metadata, caplog) mocked_heap_call = mocker.patch( "kedro_telemetry.plugin._send_heap_event", side_effect=Exception ) - telemetry_hook = KedroTelemetryCLIHooks() + telemetry_hook = KedroTelemetryHook() command_args = ["--version"] telemetry_hook.before_command_run(fake_metadata, command_args) + telemetry_hook.after_command_run() msg = ( "Something went wrong in hook implementation to send command run data to" " Heap. Exception:" @@ -371,8 +388,6 @@ def test_check_for_telemetry_consent_given(self, mocker, fake_metadata): with open(telemetry_file_path, "w", encoding="utf-8") as telemetry_file: yaml.dump({"consent": True}, telemetry_file) - mock_create_file = mocker.patch("kedro_telemetry.plugin._confirm_consent") - mock_create_file.assert_not_called() assert _check_for_telemetry_consent(fake_metadata.project_path) def test_check_for_telemetry_consent_not_given(self, mocker, fake_metadata): @@ -381,29 +396,28 @@ def test_check_for_telemetry_consent_not_given(self, mocker, fake_metadata): with open(telemetry_file_path, "w", encoding="utf-8") as telemetry_file: yaml.dump({"consent": False}, telemetry_file) - mock_create_file = mocker.patch("kedro_telemetry.plugin._confirm_consent") - mock_create_file.assert_not_called() assert not _check_for_telemetry_consent(fake_metadata.project_path) - def test_check_for_telemetry_consent_empty_file(self, mocker, fake_metadata): + @mark.parametrize("env_var", _SKIP_TELEMETRY_ENV_VAR_KEYS) + def test_check_for_telemetry_consent_skip_telemetry_with_env_var( + self, monkeypatch, fake_metadata, env_var + ): + monkeypatch.setenv(env_var, "True") Path(fake_metadata.project_path, "conf").mkdir(parents=True) telemetry_file_path = fake_metadata.project_path / ".telemetry" - mock_create_file = mocker.patch( - "kedro_telemetry.plugin._confirm_consent", return_value=True - ) + with open(telemetry_file_path, "w", encoding="utf-8") as telemetry_file: + yaml.dump({"consent": True}, telemetry_file) - assert _check_for_telemetry_consent(fake_metadata.project_path) - mock_create_file.assert_called_once_with(telemetry_file_path) + assert not _check_for_telemetry_consent(fake_metadata.project_path) - def test_check_for_telemetry_no_consent_empty_file(self, mocker, fake_metadata): + def test_check_for_telemetry_consent_empty_file(self, mocker, fake_metadata): Path(fake_metadata.project_path, "conf").mkdir(parents=True) telemetry_file_path = fake_metadata.project_path / ".telemetry" - mock_create_file = mocker.patch( - "kedro_telemetry.plugin._confirm_consent", return_value=False - ) - assert not _check_for_telemetry_consent(fake_metadata.project_path) - mock_create_file.assert_called_once_with(telemetry_file_path) + with open(telemetry_file_path, "w", encoding="utf-8") as telemetry_file: + yaml.dump({}, telemetry_file) + + assert _check_for_telemetry_consent(fake_metadata.project_path) def test_check_for_telemetry_consent_file_no_consent_field( self, mocker, fake_metadata @@ -413,37 +427,14 @@ def test_check_for_telemetry_consent_file_no_consent_field( with open(telemetry_file_path, "w", encoding="utf8") as telemetry_file: yaml.dump({"nonsense": "bla"}, telemetry_file) - mock_create_file = mocker.patch( - "kedro_telemetry.plugin._confirm_consent", return_value=True - ) - assert _check_for_telemetry_consent(fake_metadata.project_path) - mock_create_file.assert_called_once_with(telemetry_file_path) def test_check_for_telemetry_consent_file_invalid_yaml(self, mocker, fake_metadata): Path(fake_metadata.project_path, "conf").mkdir(parents=True) telemetry_file_path = fake_metadata.project_path / ".telemetry" telemetry_file_path.write_text("invalid_ yaml") - mock_create_file = mocker.patch( - "kedro_telemetry.plugin._confirm_consent", return_value=True - ) - assert _check_for_telemetry_consent(fake_metadata.project_path) - mock_create_file.assert_called_once_with(telemetry_file_path) - - def test_confirm_consent_yaml_dump_error(self, mocker, fake_metadata, caplog): - Path(fake_metadata.project_path, "conf").mkdir(parents=True) - telemetry_file_path = fake_metadata.project_path / ".telemetry" - mocker.patch("yaml.dump", side_efyfect=Exception) - - assert not _confirm_consent(telemetry_file_path) - - msg = ( - "Failed to confirm consent. No data was sent to Heap. Exception: " - "pytest: reading from stdin while output is captured! Consider using `-s`." - ) - assert msg in caplog.messages[-1] @mark.parametrize( "env_vars,result", @@ -470,8 +461,6 @@ def test_check_is_known_ci_env(self, monkeypatch, env_vars, result): known_ci_vars.discard("GITHUB_ACTION") assert _is_known_ci_env(known_ci_vars) == result - -class TestKedroTelemetryProjectHooks: def test_after_context_created_without_kedro_run( # noqa: PLR0913 self, mocker, @@ -504,7 +493,7 @@ def test_after_context_created_without_kedro_run( # noqa: PLR0913 mocker.patch("kedro_telemetry.plugin.toml.dump") # Without CLI invoked - i.e. `session.run` in Jupyter/IPython - telemetry_hook = KedroTelemetryProjectHooks() + telemetry_hook = KedroTelemetryHook() telemetry_hook.after_context_created(fake_context) telemetry_hook.after_catalog_created(fake_catalog) @@ -562,12 +551,12 @@ def test_after_context_created_with_kedro_run( # noqa: PLR0913 mocker.patch("kedro_telemetry.plugin.toml.load") mocker.patch("kedro_telemetry.plugin.toml.dump") # CLI run first - telemetry_cli_hook = KedroTelemetryCLIHooks() + telemetry_cli_hook = KedroTelemetryHook() command_args = ["--version"] telemetry_cli_hook.before_command_run(fake_metadata, command_args) # Follow by project run - telemetry_hook = KedroTelemetryProjectHooks() + telemetry_hook = KedroTelemetryHook() telemetry_hook.after_context_created(fake_context) telemetry_hook.after_catalog_created(fake_catalog) @@ -593,8 +582,7 @@ def test_after_context_created_with_kedro_run( # noqa: PLR0913 properties=expected_properties, ) - # CLI hook makes the first 2 calls, the 3rd one is the Project hook - assert mocked_heap_call.call_args_list[2] == expected_call + assert mocked_heap_call.call_args_list[0] == expected_call def test_after_context_created_with_kedro_run_and_tools( # noqa: PLR0913 self, @@ -627,12 +615,12 @@ def test_after_context_created_with_kedro_run_and_tools( # noqa: PLR0913 mocker.patch("pathlib.Path.exists", return_value=True) # CLI run first - telemetry_cli_hook = KedroTelemetryCLIHooks() + telemetry_cli_hook = KedroTelemetryHook() command_args = ["--version"] telemetry_cli_hook.before_command_run(fake_metadata, command_args) # Follow by project run - telemetry_hook = KedroTelemetryProjectHooks() + telemetry_hook = KedroTelemetryHook() telemetry_hook.after_context_created(fake_context) telemetry_hook.after_catalog_created(fake_catalog) @@ -659,8 +647,8 @@ def test_after_context_created_with_kedro_run_and_tools( # noqa: PLR0913 identity="user_uuid", properties=expected_properties, ) - # CLI hook makes the first 2 calls, the 3rd one is the Project hook - assert mocked_heap_call.call_args_list[2] == expected_call + + assert mocked_heap_call.call_args_list[0] == expected_call def test_after_context_created_no_consent_given(self, mocker): fake_context = mocker.Mock() @@ -669,7 +657,7 @@ def test_after_context_created_no_consent_given(self, mocker): ) mocked_heap_call = mocker.patch("kedro_telemetry.plugin._send_heap_event") - telemetry_hook = KedroTelemetryProjectHooks() + telemetry_hook = KedroTelemetryHook() telemetry_hook.after_context_created(fake_context) mocked_heap_call.assert_not_called()