Skip to content

Commit

Permalink
added tv regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Nov 9, 2017
1 parent 5e3d727 commit ce8dd91
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
**/*Icon*
.idea
data/*
**/*DS_STORE*
69 changes: 69 additions & 0 deletions explore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os, sys, time, subprocess, h5py, argparse, logging, pickle, random
import numpy as np
from os.path import join as oj
import matplotlib.pyplot as plt
import seaborn as sns
from nilearn import datasets
from nilearn import plotting
import cvxpy as cvx
import cvxopt as cvxopt

X = np.loadtxt('data/Caltech_0051461_rois_dosenbach160.1D')[:, :-1]
print(X.shape)


# plot covariance matrix
def plot_cov():
covs = np.cov(X.transpose())
# plt.imshow(covs)
sns.clustermap(covs)
plt.savefig('figs/covs.pdf')


def plot_idxs(idxs_list):
for idx in idxs_list:
plt.plot(X[:, idx])
plt.savefig('figs/time_course.pdf')


def plot_connectome():
g = np.zeros(shape=(160, 160))

dos_coords = datasets.fetch_coords_dosenbach_2010()
dos_coords = dos_coords.rois
dos_coords_table = [[x, y, z] for (x, y, z) in dos_coords] # Reformat the atlas coordinates

f = plt.figure(figsize=(2.3, 3.5)) # 2.2,2.3
plotting.plot_connectome(g, dos_coords_table, display_mode='z',
output_file='figs/connectome.pdf',
annotate=False, figure=f, node_size=18)
# plt.show()


# plot_cov()
# plot_idxs([58, 139])
# plot_connectome()

y = X[:, 58]

# Set regularization parameter.
vlambda = 50
# Solve l1 trend filtering problem.
x = cvx.Variable(y.size)
obj = cvx.Minimize(0.5 * cvx.sum_squares(y - x)
+ vlambda * cvx.tv(x))
# + vlambda * cvx.norm(x, 1))
prob = cvx.Problem(obj)
# ECOS and SCS solvers fail to converge before
# the iteration limit. Use CVXOPT instead.
prob.solve(solver=cvx.CVXOPT, verbose=True)

# print('Solver status: ', prob.status)
# Check for error.
if prob.status != cvx.OPTIMAL:
raise Exception("Solver did not converge!")

plt.plot(y, label='original')
plt.plot(x.value, label='TV regularized')
plt.legend()
plt.show()
Binary file added figs/connectome.pdf
Binary file not shown.
Binary file added figs/covs.pdf
Binary file not shown.
Binary file added figs/time_course.pdf
Binary file not shown.

0 comments on commit ce8dd91

Please sign in to comment.