Skip to content

Commit

Permalink
add checks and update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Jan 29, 2024
1 parent 7b4cf0a commit c5a38ee
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/Examples/01.AdaSTEM_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@
" ), # hurdel model for zero-inflated problem (e.g., count)\n",
" save_gridding_plot = True,\n",
" ensemble_fold=10, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n",
" min_ensemble_required=7, # Only points covered by > 7 stixels will be predicted\n",
" min_ensemble_required=7, # Only points covered by > 7 ensembles will be predicted\n",
" grid_len_upper_threshold=25, # force splitting if the grid length exceeds 25\n",
" grid_len_lower_threshold=5, # stop splitting if the grid length fall short 5 \n",
" temporal_start=1, # The next 4 params define the temporal sliding window\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ model = AdaSTEMRegressor(
), # hurdel model for zero-inflated problem (e.g., count)
save_gridding_plot = True,
ensemble_fold=10, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo
min_ensemble_required=7, # Only points covered by > 7 stixels will be predicted
min_ensemble_required=7, # Only points covered by > 7 ensembles will be predicted
grid_len_upper_threshold=25, # force splitting if the grid length exceeds 25
grid_len_lower_threshold=5, # stop splitting if the grid length fall short 5
temporal_start=1, # The next 4 params define the temporal sliding window
Expand Down
8 changes: 8 additions & 0 deletions stemflow/model/AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def __init__(
self.Spatio2 = Spatio2

# 3. Gridding params
if min_ensemble_required > ensemble_fold:
raise ValueError("Not satisfied: min_ensemble_required <= ensemble_fold")

self.ensemble_fold = ensemble_fold
self.min_ensemble_required = min_ensemble_required
self.grid_len_upper_threshold = (
Expand Down Expand Up @@ -607,6 +610,11 @@ def SAC_predict(
)

# pred = pred.reset_index(drop=False)
if len(pred) == 0:
raise ValueError(
"All samples are not predictable based on current settings!\nTry adjusting the 'points_lower_threshold', increase the grid size, or increase sample size!"
)

pred = pred.droplevel(1, axis=0).reset_index(drop=False)
pred = pred.pivot_table(index="index", columns="ensemble_index", values="pred")
return pred
Expand Down

0 comments on commit c5a38ee

Please sign in to comment.