From c5a38eea9f04bbfda62b47668d12cdf915551166 Mon Sep 17 00:00:00 2001 From: chenyangkang Date: Mon, 29 Jan 2024 11:31:46 +0800 Subject: [PATCH] add checks and update docs --- docs/Examples/01.AdaSTEM_demo.ipynb | 2 +- docs/index.md | 2 +- stemflow/model/AdaSTEM.py | 8 ++++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/Examples/01.AdaSTEM_demo.ipynb b/docs/Examples/01.AdaSTEM_demo.ipynb index c514993..a003cbd 100644 --- a/docs/Examples/01.AdaSTEM_demo.ipynb +++ b/docs/Examples/01.AdaSTEM_demo.ipynb @@ -743,7 +743,7 @@ " ), # hurdel model for zero-inflated problem (e.g., count)\n", " save_gridding_plot = True,\n", " ensemble_fold=10, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n", - " min_ensemble_required=7, # Only points covered by > 7 stixels will be predicted\n", + " min_ensemble_required=7, # Only points covered by > 7 ensembles will be predicted\n", " grid_len_upper_threshold=25, # force splitting if the grid length exceeds 25\n", " grid_len_lower_threshold=5, # stop splitting if the grid length fall short 5 \n", " temporal_start=1, # The next 4 params define the temporal sliding window\n", diff --git a/docs/index.md b/docs/index.md index 833f120..8df4268 100644 --- a/docs/index.md +++ b/docs/index.md @@ -135,7 +135,7 @@ model = AdaSTEMRegressor( ), # hurdel model for zero-inflated problem (e.g., count) save_gridding_plot = True, ensemble_fold=10, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo - min_ensemble_required=7, # Only points covered by > 7 stixels will be predicted + min_ensemble_required=7, # Only points covered by > 7 ensembles will be predicted grid_len_upper_threshold=25, # force splitting if the grid length exceeds 25 grid_len_lower_threshold=5, # stop splitting if the grid length fall short 5 temporal_start=1, # The next 4 params define the temporal sliding window diff --git a/stemflow/model/AdaSTEM.py b/stemflow/model/AdaSTEM.py index b7d5ffb..d097d29 100644 --- a/stemflow/model/AdaSTEM.py +++ b/stemflow/model/AdaSTEM.py @@ -205,6 +205,9 @@ def __init__( self.Spatio2 = Spatio2 # 3. Gridding params + if min_ensemble_required > ensemble_fold: + raise ValueError("Not satisfied: min_ensemble_required <= ensemble_fold") + self.ensemble_fold = ensemble_fold self.min_ensemble_required = min_ensemble_required self.grid_len_upper_threshold = ( @@ -607,6 +610,11 @@ def SAC_predict( ) # pred = pred.reset_index(drop=False) + if len(pred) == 0: + raise ValueError( + "All samples are not predictable based on current settings!\nTry adjusting the 'points_lower_threshold', increase the grid size, or increase sample size!" + ) + pred = pred.droplevel(1, axis=0).reset_index(drop=False) pred = pred.pivot_table(index="index", columns="ensemble_index", values="pred") return pred