Skip to content

Commit

Permalink
correct prediction of latent variances for 'fitc' and 'full_scale_tap…
Browse files Browse the repository at this point in the history
…ering' for Gaussian likelihood
  • Loading branch information
fabsig committed Apr 2, 2024
1 parent 30f2f3b commit 7b2efbb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
6 changes: 6 additions & 0 deletions R-package/tests/testthat/test_GPModel_gaussian_process.R
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,8 @@ if(Sys.getenv("GPBOOST_ALL_TESTS") == "GPBOOST_ALL_TESTS"){
y = y, X = X, params = params), file='NUL')
pred_var_no_approx <- predict(gp_model_no_approx, gp_coords_pred = coord_test_v1, cov_pars = cov_pars_pred,
X_pred = X_test_v1, predict_var = TRUE)
pred_var_lat_no_approx <- predict(gp_model_no_approx, gp_coords_pred = coord_test_v1, cov_pars = cov_pars_pred,
X_pred = X_test_v1, predict_var = TRUE, predict_response = FALSE)
pred_cov_no_approx <- predict(gp_model_no_approx, gp_coords_pred = coord_test_v1, cov_pars = cov_pars_pred,
X_pred = X_test_v1, predict_cov = TRUE)
X0 <- matrix(0, nrow=nrow(X), ncol=ncol(X))
Expand All @@ -1176,6 +1178,10 @@ if(Sys.getenv("GPBOOST_ALL_TESTS") == "GPBOOST_ALL_TESTS"){
X_pred = X_test_v1, predict_var = TRUE)
expect_lt(sum(abs(pred$mu - pred_var_no_approx$mu)),TOLERANCE_STRICT)
expect_lt(sum(abs(as.vector(pred$var) - as.vector(pred_var_no_approx$var))),TOLERANCE_STRICT)
pred <- predict(gp_model, gp_coords_pred = coord_test_v1, cov_pars = cov_pars_pred,
X_pred = X_test_v1, predict_var = TRUE, predict_response = FALSE)
expect_lt(sum(abs(pred$mu - pred_var_lat_no_approx$mu)),TOLERANCE_STRICT)
expect_lt(sum(abs(as.vector(pred$var) - as.vector(pred_var_lat_no_approx$var))),TOLERANCE_STRICT)
pred <- predict(gp_model, gp_coords_pred = coord_test_v1, cov_pars = cov_pars_pred,
X_pred = X_test_v1, predict_cov = TRUE)
expect_lt(sum(abs(pred$mu - pred_cov_no_approx$mu)),TOLERANCE_STRICT)
Expand Down
8 changes: 5 additions & 3 deletions include/GPBoost/re_model_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -2821,7 +2821,7 @@ namespace GPBoost {
else {// not gp_approx_ == "vecchia"
if (gp_approx_ == "fitc" || gp_approx_ == "full_scale_tapering") {
CalcPredPPFSA(cluster_i, num_data_per_cluster_pred, num_data_per_cluster_, gp_coords_mat_pred, predict_cov_mat,
predict_var_or_response, mean_pred_id, cov_mat_pred_id, var_pred_id, nsim_var_pred_, cg_delta_conv_pred_);
predict_var_or_response, predict_response, mean_pred_id, cov_mat_pred_id, var_pred_id, nsim_var_pred_, cg_delta_conv_pred_);
}
else {
CalcPred(cluster_i, num_data_pred, num_data_per_cluster_pred, data_indices_per_cluster_pred,
Expand Down Expand Up @@ -7813,6 +7813,7 @@ namespace GPBoost {
* \param gp_coords_mat_pred Coordinates for prediction locations
* \param calc_pred_cov If true, the covariance matrix is also calculated
* \param calc_pred_var If true, predictive variances are also calculated
* \param predict_response If true, the response variable (label) is predicted, otherwise the latent random effects
* \param[out] pred_mean Predictive mean (only for Gaussian likelihoods)
* \param[out] pred_cov Predictive covariance matrix (only for Gaussian likelihoods)
* \param[out] pred_var Predictive variances (only for Gaussian likelihoods)
Expand All @@ -7825,6 +7826,7 @@ namespace GPBoost {
const den_mat_t& gp_coords_mat_pred,
bool calc_pred_cov,
bool calc_pred_var,
bool predict_response,
vec_t& pred_mean,
T_mat& pred_cov,
vec_t& pred_var,
Expand Down Expand Up @@ -7923,7 +7925,7 @@ namespace GPBoost {
if (calc_pred_cov || calc_pred_var) {
// Add unconditional variances and covarainces
if (calc_pred_var) {
if (gauss_likelihood_) {
if (gauss_likelihood_ && predict_response) {
pred_var = vec_t::Ones(num_data_pred_cli);
}
else {
Expand All @@ -7939,7 +7941,7 @@ namespace GPBoost {
"Therefore, if this number is large we recommend only computing the predictive variances ");
}
pred_cov = T_mat(num_data_pred_cli, num_data_pred_cli);
if (gauss_likelihood_) {
if (gauss_likelihood_ && predict_response) {
pred_cov.setIdentity();
}
else {
Expand Down

0 comments on commit 7b2efbb

Please sign in to comment.