Skip to content

Commit 300cb66

Browse files
committed
add bayesian example
1 parent 31a05e0 commit 300cb66

File tree

1 file changed

+149
-0
lines changed

1 file changed

+149
-0
lines changed

example/03-bayesian.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Fit a regularized Bayesian Linear Regression model to a data set and compare the
2+
# model to Ordinary Least Squares (OLS) and a ridge regression model fit so as to
3+
# minimize Leave-one-out Cross-validation.
4+
#
5+
# 1. Generate the data set.
6+
#
7+
# 2. Fit OLS, Bayesian ridge regression, and ridge regression models.
8+
#
9+
# 3. Compare the prediction errors (measured in error variance) of the different models.
10+
#
11+
# 4. Compare the amount of shrinkage of Bayesian ridge regression and LOOCV ridge regression.
12+
#
13+
# 5. Compare the predicted noise variance and weight variance of OLS and Bayesian ridge
14+
# regression.
15+
16+
####################################################################################################
17+
# Part 1: Generate data set
18+
####################################################################################################
19+
import numpy as np
20+
np.random.seed(0)
21+
22+
def generate_correlation_matrix(p, param):
23+
res = np.zeros(shape=(p, p))
24+
for s in range(p):
25+
for t in range(0, s+1):
26+
corr = param
27+
if s == t:
28+
corr = 1.0
29+
res[s, t] = corr
30+
res[t, s] = corr
31+
return res
32+
33+
def generate_design_matrix(n, K):
34+
mean = np.zeros(K.shape[0])
35+
return np.random.multivariate_normal(mean, K, size=n)
36+
37+
def generate_weights(p):
38+
return np.random.normal(size=p)
39+
40+
def generate_data_set(n, K, signal_noise_ratio):
41+
p = K.shape[0]
42+
X = generate_design_matrix(n, K)
43+
w = generate_weights(p)
44+
signal_var = np.dot(w, np.dot(K, w))
45+
w *= np.sqrt(signal_noise_ratio / signal_var)
46+
y = np.dot(X, w) + np.random.normal(size=n)
47+
return X, y, w
48+
49+
p = 10
50+
n = 20
51+
signal_noise_ratio = 1.0
52+
K = generate_correlation_matrix(p, 0.5)
53+
X, y, w_true = generate_data_set(n, K, signal_noise_ratio)
54+
55+
56+
####################################################################################################
57+
# Part 2: Fit models
58+
####################################################################################################
59+
# OLS
60+
from sklearn.linear_model import LinearRegression
61+
model_ols = LinearRegression(fit_intercept=False)
62+
model_ols.fit(X, y)
63+
64+
# Bayesian Linear Regression
65+
from bbai.glm import BayesianRidgeRegression
66+
model_bay = BayesianRidgeRegression(fit_intercept=False)
67+
model_bay.fit(X, y)
68+
69+
# Ridge Regression (fit to optimize LOOCV error)
70+
from bbai.glm import RidgeRegression
71+
model_rr = RidgeRegression(fit_intercept=False)
72+
model_rr.fit(X, y)
73+
74+
75+
####################################################################################################
76+
# Part 3: Measure and compare the prediction errors for each model
77+
####################################################################################################
78+
def compute_prediction_error_variance(K, w_true, w):
79+
delta = w - w_true
80+
noise_variance = 1.0
81+
return noise_variance + np.dot(delta, np.dot(K, delta))
82+
83+
err_variance_true = compute_prediction_error_variance(K, w_true, w_true)
84+
err_variance_ols = compute_prediction_error_variance(K, w_true, model_ols.coef_)
85+
err_variance_bay = compute_prediction_error_variance(K, w_true, model_bay.weight_mean_vector_)
86+
err_variance_rr = compute_prediction_error_variance(K, w_true, model_rr.coef_)
87+
88+
print("===== prediction error variance")
89+
print("err_variance_true =", err_variance_true)
90+
print("err_variance_ols =", err_variance_ols)
91+
print("err_variance_bay =", err_variance_bay)
92+
print("err_variance_rr =", err_variance_rr)
93+
94+
# Prints:
95+
# err_variance_true = 1.0
96+
# err_variance_ols = 2.2741141849936
97+
# err_variance_bay = 1.2253596030087812
98+
# err_variance_rr = 1.3346410286197081
99+
100+
####################################################################################################
101+
# Part 4: Compare the amount of regularization, measured by shrinkage, between ridge regression
102+
# and the expected weight variance of Bayesian ridge regression
103+
####################################################################################################
104+
def print_shrinkage_comparison_table(w_ols, w_bay, w_rr):
105+
p = len(w_ols)
106+
print("coef\tbay\t\t\trr")
107+
for j in range(p):
108+
shrink_bay = np.abs(w_bay[j] / w_ols[j])
109+
shrink_rr = np.abs(w_rr[j] / w_ols[j])
110+
print(j, "\t", shrink_bay, "\t", shrink_rr)
111+
w_ols_norm = np.linalg.norm(w_ols)
112+
w_bay_norm = np.linalg.norm(w_bay)
113+
w_rr_norm = np.linalg.norm(w_rr)
114+
shrink_bay = w_bay_norm / w_ols_norm
115+
shrink_rr = w_rr_norm / w_ols_norm
116+
print("total\t", shrink_bay, "\t", shrink_rr)
117+
118+
####################################################################################################
119+
# Part 5: Compare OLS expected noise and weight variances to the expected noise and weight
120+
# variances from Bayesian ridge regression
121+
####################################################################################################
122+
err_ols = y - np.dot(X, model_ols.coef_)
123+
s_ols = np.dot(err_ols, err_ols) / (n - p)
124+
print("===== noise variance comparison")
125+
print("noise_variance_ols =", s_ols)
126+
print("noise_variance_bay =", model_bay.noise_variance_mean_)
127+
128+
# Prints
129+
# noise_variance_ols = 0.4113088518293715
130+
# noise_variance_bay = 0.646895178795739
131+
132+
print("===== weight variance comparison")
133+
w_covariance_ols = s_ols * np.linalg.inv(np.dot(X.T, X))
134+
print("coef\tOLS\t\t\tbay")
135+
for j in range(p):
136+
print(j, "\t", w_covariance_ols[j, j], "\t", model_bay.weight_covariance_matrix_[j, j])
137+
138+
# Prints
139+
# coef OLS bay
140+
# 0 0.08077309122198699 0.07028402634979392
141+
# 1 0.11934260399465739 0.10372106176189841
142+
# 2 0.07565674651068977 0.09961127194993058
143+
# 3 0.05984623288735467 0.07126271235803656
144+
# 4 0.06790137334901475 0.07153243474011801
145+
# 5 0.04090207206067381 0.06168627188610626
146+
# 6 0.06252983263310996 0.06966266955483112
147+
# 7 0.2007543244631264 0.15580956937954868
148+
# 8 0.06512850980887401 0.06067162578733245
149+
# 9 0.03501167166038174 0.03794523021332901

0 commit comments

Comments
 (0)