Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

continue setting up ci #3

Merged
merged 12 commits into from
Jul 9, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
device=cpu in resnet tests
  • Loading branch information
jeromedockes committed Jul 9, 2024
commit 4a4e1fa92bc73c71cfee68254558b1ac588a3cdb
18 changes: 9 additions & 9 deletions tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def test_numerical_data(n_classes, resnet_or_mlp):

# Train the classifier
if resnet_or_mlp == "resnet":
clf = Resnet_RTDL_D_Classifier()
clf = Resnet_RTDL_D_Classifier(device="cpu")
elif resnet_or_mlp == "mlp":
clf = MLP_RTDL_D_Classifier()
clf = MLP_RTDL_D_Classifier(device="cpu")
clf.fit(X_train, y_train, cat_features=[False] * 20) # Assuming no categorical features

# Predict and evaluate
Expand All @@ -51,9 +51,9 @@ def test_categorical_data(n_classes, resnet_or_mlp):

# Train the classifier with categorical feature
if resnet_or_mlp == "resnet":
clf = Resnet_RTDL_D_Classifier()
clf = Resnet_RTDL_D_Classifier(device="cpu")
elif resnet_or_mlp == "mlp":
clf = MLP_RTDL_D_Classifier()
clf = MLP_RTDL_D_Classifier(device="cpu")
clf.fit(X_train, y_train, cat_features=[False] * 20 + [True])

# Predict and evaluate
Expand All @@ -80,9 +80,9 @@ def test_regressor_numerical_categorical(tranformed_target, resnet_or_mlp):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=43)

if resnet_or_mlp == "resnet":
regressor = Resnet_RTDL_D_Regressor(transformed_target=tranformed_target, random_state=41)
regressor = Resnet_RTDL_D_Regressor(transformed_target=tranformed_target, random_state=41, device="cpu")
elif resnet_or_mlp == "mlp":
regressor = MLP_RTDL_D_Regressor(transformed_target=tranformed_target, random_state=41)
regressor = MLP_RTDL_D_Regressor(transformed_target=tranformed_target, random_state=41, device="cpu")
regressor.fit(X_train, y_train, cat_features=cat_features)
predictions = regressor.predict(X_test)

Expand All @@ -98,9 +98,9 @@ def test_regressor_numerical_categorical(tranformed_target, resnet_or_mlp):

def create_model(regression, resnet_or_mlp, **kwargs):
if resnet_or_mlp == "resnet":
model = Resnet_RTDL_D_Regressor(**kwargs) if regression else Resnet_RTDL_D_Classifier(**kwargs)
model = Resnet_RTDL_D_Regressor(device="cpu", **kwargs) if regression else Resnet_RTDL_D_Classifier(device="cpu", **kwargs)
elif resnet_or_mlp == "mlp":
model = MLP_RTDL_D_Regressor(**kwargs) if regression else MLP_RTDL_D_Classifier(**kwargs)
model = MLP_RTDL_D_Regressor(device="cpu", **kwargs) if regression else MLP_RTDL_D_Classifier(device="cpu", **kwargs)
return model


Expand Down Expand Up @@ -205,4 +205,4 @@ def test_high_cardinality(seed, resnet_or_mlp):
# assert np.isfinite(history[:, 'valid_loss']).any()
# predictions = model.predict(X)
# assert not np.allclose(predictions, np.mean(y[100:])), "Predictions should not be the mean of the training set"
# assert model.alg_interface_.sub_split_interfaces[0].model.predict_mean == False
# assert model.alg_interface_.sub_split_interfaces[0].model.predict_mean == False
Loading