Skip to content

Commit

Permalink
add ST_Kfold class to stemflow.model_selection.py; This is to match t…
Browse files Browse the repository at this point in the history
…he KFold class in sklearn.model_selection, and it only generates indixes instead of data splits. As part of #47
  • Loading branch information
chenyangkang committed Sep 20, 2024
1 parent d6d0cb9 commit a0ef330
Showing 1 changed file with 113 additions and 3 deletions.
116 changes: 113 additions & 3 deletions stemflow/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def ST_train_test_split(
Returns:
X_train, X_test, y_train, y_test
"""
# random seed
rng = check_random_state(random_state)
Expand Down Expand Up @@ -112,7 +110,7 @@ def ST_CV(
random_state: Union[np.random.RandomState, None, int] = None,
CV: int = 3,
) -> Generator[Tuple[DataFrame, DataFrame, ndarray, ndarray], None, None]:
"""Spatial Temporal train-test split
"""A function to generate spatiotemporal train-test-split data. To only generate indexes, see class `ST_Kfold`.
Args:
X:
Expand Down Expand Up @@ -195,3 +193,115 @@ def ST_CV(
y_test = np.array(y).flatten()[test_indexes].reshape(-1, 1)

yield X_train, X_test, y_train, y_test


class ST_Kfold():
def __init__(
self,
Spatio1: str = "longitude",
Spatio2: str = "latitude",
Temporal1: str = "DOY",
Spatio_blocks_count: int = 10,
Temporal_blocks_count: int = 10,
random_state: Union[np.random.RandomState, None, int] = None,
n_splits: int = 3,
) -> None:
"""Spatial Temporal Kfold generator class. While the ST_CV functions yield the data directly (X_train, X_test, y_train, y_test),
this ST_Kfold class generate only indices, which match the Kfold class in sklearn.model_selection.
Args:
Spatio1:
column name of spatial indicator 1
Spatio2:
column name of spatial indicator 2
Temporal1:
column name of temporal indicator 1
Spatio_blocks_count:
How many block to split for spatio indicators
Temporal_blocks_count:
How many block to split for temporal indicators
random_state:
random state for choosing testing blocks
n_splits:
fold cross validation
Returns:
train_indexes, test_indexes
Example:
```
from sklearn.model_selection import KFold
ST_KFold_generator = ST_Kfold(n_splits=5,
Spatio1 = "longitude",
Spatio2 = "latitude",
Temporal1 = "DOY",
Spatio_blocks_count = 10,
Temporal_blocks_count = 10,
random_state = 42).split(X)
for train_indexes, test_indexes in ST_KFold_generator:
X_train = X.iloc[train_indexes,:]
X_test = X.iloc[test_indexes,:]
...
```
"""
self.rng = check_random_state(random_state)
self.Spatio1 = Spatio1
self.Spatio2 = Spatio2
self.Temporal1 = Temporal1
self.Spatio_blocks_count = Spatio_blocks_count
self.Temporal_blocks_count = Temporal_blocks_count
self.n_splits = n_splits

Check warning on line 256 in stemflow/model_selection.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model_selection.py#L250-L256

Added lines #L250 - L256 were not covered by tests

if not (isinstance(n_splits, int) and n_splits > 0):
raise ValueError("CV should be a positive integer")

Check warning on line 259 in stemflow/model_selection.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model_selection.py#L258-L259

Added lines #L258 - L259 were not covered by tests

def split(self, X: DataFrame) -> Generator[Tuple[ndarray, ndarray], None, None]:
"""split
Args:
X:
Training variables
Yields:
Generator[Tuple[ndarray, ndarray], None, None]: train_index, test_index
"""

# validate
if not isinstance(X, DataFrame):
type_x = str(type(X))
raise TypeError(f"X input should be pandas.core.frame.DataFrame, Got {type_x}")

Check warning on line 274 in stemflow/model_selection.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model_selection.py#L272-L274

Added lines #L272 - L274 were not covered by tests

# indexing
Sindex1 = np.linspace(X[self.Spatio1].min(), X[self.Spatio1].max(), self.Spatio_blocks_count)
Sindex2 = np.linspace(X[self.Spatio2].min(), X[self.Spatio2].max(), self.Spatio_blocks_count)
Tindex1 = np.linspace(X[self.Temporal1].min(), X[self.Temporal1].max(), self.Temporal_blocks_count)

Check warning on line 279 in stemflow/model_selection.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model_selection.py#L277-L279

Added lines #L277 - L279 were not covered by tests

indexes = [

Check warning on line 281 in stemflow/model_selection.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model_selection.py#L281

Added line #L281 was not covered by tests
str(a) + "_" + str(b) + "_" + str(c)
for a, b, c in zip(
np.digitize(X[self.Spatio1], Sindex1), np.digitize(X[self.Spatio2], Sindex2), np.digitize(X[self.Temporal1], Tindex1)
)
]

unique_indexes = list(np.unique(indexes))
self.rng.shuffle(unique_indexes)
test_size = int(len(unique_indexes) * (1 / self.n_splits))

Check warning on line 290 in stemflow/model_selection.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model_selection.py#L288-L290

Added lines #L288 - L290 were not covered by tests

tmp_table = pd.DataFrame({"index": range(len(indexes)), "cell": indexes})

Check warning on line 292 in stemflow/model_selection.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model_selection.py#L292

Added line #L292 was not covered by tests

for cv_count in range(self.n_splits):

Check warning on line 294 in stemflow/model_selection.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model_selection.py#L294

Added line #L294 was not covered by tests
# get test set record indexes
test_indexes = []
start = cv_count * test_size
end = np.min([(cv_count + 1) * test_size, len(unique_indexes) + 1])
test_cell = unique_indexes[start:end]

Check warning on line 299 in stemflow/model_selection.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model_selection.py#L296-L299

Added lines #L296 - L299 were not covered by tests

tmp_this_CV_table = tmp_table[tmp_table["cell"].isin(test_cell)]
test_indexes = tmp_this_CV_table["index"].values

Check warning on line 302 in stemflow/model_selection.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model_selection.py#L301-L302

Added lines #L301 - L302 were not covered by tests

# get train set record indexes
train_indexes = list(set(range(len(indexes))) - set(test_indexes))

Check warning on line 305 in stemflow/model_selection.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model_selection.py#L305

Added line #L305 was not covered by tests

yield train_indexes, test_indexes

Check warning on line 307 in stemflow/model_selection.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model_selection.py#L307

Added line #L307 was not covered by tests

0 comments on commit a0ef330

Please sign in to comment.