Skip to content

Commit

Permalink
[FEAT] Add spark test (#464)
Browse files Browse the repository at this point in the history
* feat: add spark tests

* fix: pyspark version
  • Loading branch information
AzulGarza authored Apr 20, 2023
1 parent beb2992 commit 3aac9c5
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
if: ${{ matrix.os != 'windows-latest' }}
run: nbdev_test --skip_file_re '(distributed|ets).*.ipynb' --pause 1.0

- name: Run integrattion tests
- name: Run integration tests
if: ${{ matrix.os != 'windows-latest' }}
run: |
pip install ".[dev]" pytest
Expand Down
13 changes: 13 additions & 0 deletions action_files/test_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pyspark.sql import SparkSession

from statsforecast.utils import generate_series
from .utils import pipeline

def test_spark_flow():
n_series = 2
horizon = 7
series = generate_series(n_series).reset_index()
series['unique_id'] = series['unique_id'].astype(str)
spark = SparkSession.builder.getOrCreate()
series = spark.createDataFrame(series).repartition(2, 'unique_id')
pipeline(series, n_series, horizon)
4 changes: 3 additions & 1 deletion dev/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ dependencies:
- numba>=0.55.0
- numpy>=1.21.6
- pandas>=1.3.5
- pyspark>=3.3
- pip
- prophet
- scipy>=1.7.3
- statsmodels>=0.13.2
- tabulate
- plotly
- pip:
- fugue[dask,ray]
- nbdev
- tqdm
- plotly-resampler
- supersmoother
- tqdm
3 changes: 2 additions & 1 deletion settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ status = 2
requirements = matplotlib numba>=0.55.0 numpy>=1.21.6 pandas>=1.3.5 plotly scipy>=1.7.3 statsmodels>=0.13.2 tqdm plotly-resampler fugue>=0.8.1
ray_requirements = fugue[ray]>=0.8.1 protobuf>=3.15.3,<4.0.0
dask_requirements = fugue[dask]>=0.8.1
dev_requirements = nbdev black mypy flake8 ray protobuf>=3.15.3,<4.0.0 matplotlib pmdarima prophet scikit-learn fugue[dask,ray]>=0.8.1 datasetsforecast supersmoother
spark_requirements = fugue[spark]>=0.8.1
dev_requirements = nbdev black mypy flake8 ray protobuf>=3.15.3,<4.0.0 matplotlib pmdarima prophet scikit-learn fugue[dask,ray,spark]>=0.8.1 datasetsforecast supersmoother
nbs_path = nbs
doc_path = _docs
recursive = True
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
dev_requirements = (cfg.get('dev_requirements') or '').split()
dask_requirements = cfg.get('dask_requirements', '').split()
ray_requirements = cfg.get('ray_requirements', '').split()
spark_requirements = cfg.get('spark_requirements', '').split()

setuptools.setup(
name = 'statsforecast',
Expand All @@ -44,7 +45,7 @@
packages = setuptools.find_packages(),
include_package_data = True,
install_requires = requirements,
extras_require={'dev': dev_requirements, 'dask': dask_requirements, 'ray': ray_requirements,},
extras_require={'dev': dev_requirements, 'dask': dask_requirements, 'ray': ray_requirements, 'spark': spark_requirements,},
dependency_links = cfg.get('dep_links','').split(),
python_requires = '>=' + cfg['min_python'],
long_description = open('README.md', encoding='utf8').read(),
Expand Down

0 comments on commit 3aac9c5

Please sign in to comment.