Skip to content

Commit

Permalink
Fix test for dev version of statsmodels (mlflow#4570)
Browse files Browse the repository at this point in the history
* Fix test for dev version of statsmodels

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* apply asarray

Signed-off-by: harupy <hkawamura0130@gmail.com>

* remove unnecessary blank line

Signed-off-by: harupy <hkawamura0130@gmail.com>

* fix comment

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove if-else

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove unused import

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
  • Loading branch information
harupy authored Jul 19, 2021
1 parent 2b7a0ea commit 7082a1a
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions tests/statsmodels/model_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from statsmodels.tsa.arima.model import ARIMA
from scipy.linalg import toeplitz


ModelWithResults = namedtuple("ModelWithResults", ["model", "alg", "inference_dataframe"])


Expand Down Expand Up @@ -51,10 +52,18 @@ def failing_logit_model():
return ModelWithResults(model=model, alg=log_reg, inference_dataframe=X)


def get_dataset(name):
dataset_module = getattr(sm.datasets, name)
data = dataset_module.load()
data.exog = np.asarray(data.exog)
data.endog = np.asarray(data.endog)
return data


@pytest.fixture(scope="session")
def gls_model():
# Generalized Least Squares (GLS)
data = sm.datasets.longley.load(as_pandas=False)
data = get_dataset("longley")
data.exog = sm.add_constant(data.exog)
ols_resid = sm.OLS(data.endog, data.exog).fit().resid
res_fit = sm.OLS(ols_resid[1:], ols_resid[:-1]).fit()
Expand Down Expand Up @@ -114,7 +123,7 @@ def rolling_ols_model():
# Rolling Ordinary Least Squares (Rolling OLS)
from statsmodels.regression.rolling import RollingOLS

data = sm.datasets.longley.load(as_pandas=False)
data = get_dataset("longley")
exog = sm.add_constant(data.exog, prepend=False)
rolling_ols = RollingOLS(data.endog, exog)
model = rolling_ols.fit(reset=50)
Expand All @@ -127,7 +136,7 @@ def rolling_wls_model():
# Rolling Weighted Least Squares (Rolling WLS)
from statsmodels.regression.rolling import RollingWLS

data = sm.datasets.longley.load(as_pandas=False)
data = get_dataset("longley")
exog = sm.add_constant(data.exog, prepend=False)
rolling_wls = RollingWLS(data.endog, exog)
model = rolling_wls.fit(reset=50)
Expand Down Expand Up @@ -187,7 +196,7 @@ def gee_model():
@pytest.fixture(scope="session")
def glm_model():
# Generalized Linear Model (GLM)
data = sm.datasets.scotland.load(as_pandas=False)
data = get_dataset("scotland")
data.exog = sm.add_constant(data.exog)
glm = sm.GLM(data.endog, data.exog, family=sm.families.Gamma())
model = glm.fit()
Expand Down

0 comments on commit 7082a1a

Please sign in to comment.