From 3aac9c56ed5b37b2deaf80e66c6077c6724a4449 Mon Sep 17 00:00:00 2001 From: fede Date: Wed, 19 Apr 2023 18:00:35 -0600 Subject: [PATCH] [FEAT] Add spark test (#464) * feat: add spark tests * fix: pyspark version --- .github/workflows/ci.yaml | 2 +- action_files/test_spark.py | 13 +++++++++++++ dev/environment.yml | 4 +++- settings.ini | 3 ++- setup.py | 3 ++- 5 files changed, 21 insertions(+), 4 deletions(-) create mode 100644 action_files/test_spark.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8f5a86a09..442921e7d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -62,7 +62,7 @@ jobs: if: ${{ matrix.os != 'windows-latest' }} run: nbdev_test --skip_file_re '(distributed|ets).*.ipynb' --pause 1.0 - - name: Run integrattion tests + - name: Run integration tests if: ${{ matrix.os != 'windows-latest' }} run: | pip install ".[dev]" pytest diff --git a/action_files/test_spark.py b/action_files/test_spark.py new file mode 100644 index 000000000..6c213898a --- /dev/null +++ b/action_files/test_spark.py @@ -0,0 +1,13 @@ +from pyspark.sql import SparkSession + +from statsforecast.utils import generate_series +from .utils import pipeline + +def test_spark_flow(): + n_series = 2 + horizon = 7 + series = generate_series(n_series).reset_index() + series['unique_id'] = series['unique_id'].astype(str) + spark = SparkSession.builder.getOrCreate() + series = spark.createDataFrame(series).repartition(2, 'unique_id') + pipeline(series, n_series, horizon) diff --git a/dev/environment.yml b/dev/environment.yml index a0884ded1..283864dff 100644 --- a/dev/environment.yml +++ b/dev/environment.yml @@ -8,6 +8,7 @@ dependencies: - numba>=0.55.0 - numpy>=1.21.6 - pandas>=1.3.5 + - pyspark>=3.3 - pip - prophet - scipy>=1.7.3 @@ -15,7 +16,8 @@ dependencies: - tabulate - plotly - pip: + - fugue[dask,ray] - nbdev - - tqdm - plotly-resampler - supersmoother + - tqdm diff --git a/settings.ini b/settings.ini index 7d1482851..fe18d215a 100644 --- a/settings.ini +++ b/settings.ini @@ -18,7 +18,8 @@ status = 2 requirements = matplotlib numba>=0.55.0 numpy>=1.21.6 pandas>=1.3.5 plotly scipy>=1.7.3 statsmodels>=0.13.2 tqdm plotly-resampler fugue>=0.8.1 ray_requirements = fugue[ray]>=0.8.1 protobuf>=3.15.3,<4.0.0 dask_requirements = fugue[dask]>=0.8.1 -dev_requirements = nbdev black mypy flake8 ray protobuf>=3.15.3,<4.0.0 matplotlib pmdarima prophet scikit-learn fugue[dask,ray]>=0.8.1 datasetsforecast supersmoother +spark_requirements = fugue[spark]>=0.8.1 +dev_requirements = nbdev black mypy flake8 ray protobuf>=3.15.3,<4.0.0 matplotlib pmdarima prophet scikit-learn fugue[dask,ray,spark]>=0.8.1 datasetsforecast supersmoother nbs_path = nbs doc_path = _docs recursive = True diff --git a/setup.py b/setup.py index 22806b9d1..c7bbacdca 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ dev_requirements = (cfg.get('dev_requirements') or '').split() dask_requirements = cfg.get('dask_requirements', '').split() ray_requirements = cfg.get('ray_requirements', '').split() +spark_requirements = cfg.get('spark_requirements', '').split() setuptools.setup( name = 'statsforecast', @@ -44,7 +45,7 @@ packages = setuptools.find_packages(), include_package_data = True, install_requires = requirements, - extras_require={'dev': dev_requirements, 'dask': dask_requirements, 'ray': ray_requirements,}, + extras_require={'dev': dev_requirements, 'dask': dask_requirements, 'ray': ray_requirements, 'spark': spark_requirements,}, dependency_links = cfg.get('dep_links','').split(), python_requires = '>=' + cfg['min_python'], long_description = open('README.md', encoding='utf8').read(),