Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 13 additions & 25 deletions TICC_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,21 @@


class TICC:
window_size = 10
number_of_clusters = 5
lambda_parameter = 11e-2
switch_penalty = 400
maxIters = 1000
threshold = 2e-5
write_out_file = False
prefix_string = ''
num_proc = 1
compute_BIC = False
num_blocks = None
cluster_reassignment = 20 # number of points to reassign to a 0 cluster

def __init__(self, window_size=10, number_of_clusters=5, lambda_parameter=11e-2,
beta=400, maxIters=1000, threshold=2e-5, write_out_file=False,
prefix_string="", num_proc=1, compute_BIC=False):
prefix_string="", num_proc=1, compute_BIC=False, cluster_reassignment=20):
"""
Parameters:
- window_size: size of the sliding window
- number_of_clusters: number of clusters
- lambda_parameter: sparsity parameter
- switch_penalty: temporal consistency parameter
- maxIters: number of iterations
- threshold: convergence threshold
- write_out_file: (bool) if true, prefix_string is output file dir
- prefix_string: output directory if necessary
- cluster_reassignment: number of points to reassign to a 0 cluster
"""
Parameters:
- window_size: size of the sliding window
- number_of_clusters: number of clusters
- lambda_parameter: sparsity parameter
- switch_penalty: temporal consistency parameter
- maxIters: number of iterations
- threshold: convergence threshold
- write_out_file: (bool) if true, prefix_string is output file dir
- prefix_string: output directory if necessary
"""
self.window_size = window_size
self.number_of_clusters = number_of_clusters
self.lambda_parameter = lambda_parameter
Expand All @@ -51,7 +39,7 @@ def __init__(self, window_size=10, number_of_clusters=5, lambda_parameter=11e-2,
self.prefix_string = prefix_string
self.num_proc = num_proc
self.compute_BIC = compute_BIC

self.cluster_reassignment = cluster_reassignment
self.num_blocks = self.window_size + 1
pd.set_option('display.max_columns', 500)
np.set_printoptions(formatter={'float': lambda x: "{0:0.4f}".format(x)})
Expand Down
13 changes: 8 additions & 5 deletions UnitTest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
import TICC_solver as TICC
from TICC_solver import TICC
import numpy as np
import sys

Expand All @@ -10,7 +10,9 @@ class TestStringMethods(unittest.TestCase):

def test_example(self):
fname = "example_data.txt"
(cluster_assignment, cluster_MRFs) = TICC.solve(window_size = 1,number_of_clusters = 8, lambda_parameter = 11e-2, beta = 600, maxIters = 100, threshold = 2e-5, write_out_file = False, input_file = fname, prefix_string = "output_folder/", num_proc=1)
ticc = TICC(window_size = 1,number_of_clusters = 8, lambda_parameter = 11e-2, beta = 600, maxIters = 100,
threshold = 2e-5, write_out_file = False, prefix_string = "output_folder/", num_proc=1)
(cluster_assignment, cluster_MRFs) = ticc.fit(input_file=fname)
assign = np.loadtxt("UnitTest_Data/Results.txt")
val = abs(assign - cluster_assignment)
self.assertEqual(sum(val), 0)
Expand All @@ -21,13 +23,14 @@ def test_example(self):
np.testing.assert_array_almost_equal(mrf, cluster_MRFs[i], decimal=3)
except AssertionError:
#Test failed
self.assertTrue(1==0)
self.assertTrue(1==0)


def test_multiExample(self):
fname = "example_data.txt"
(cluster_assignment, cluster_MRFs) = TICC.solve(window_size = 5,number_of_clusters = 5, lambda_parameter = 11e-2, beta = 600, maxIters = 100, threshold = 2e-5, write_out_file = False, input_file = fname, prefix_string = "output_folder/", num_proc=1)

ticc = TICC(window_size = 5,number_of_clusters = 5, lambda_parameter = 11e-2, beta = 600, maxIters = 100,
threshold = 2e-5, write_out_file = False, prefix_string = "output_folder/", num_proc=1)
(cluster_assignment, cluster_MRFs) = ticc.fit(input_file=fname)
assign = np.loadtxt("UnitTest_Data/multiResults.txt")
val = abs(assign - cluster_assignment)
self.assertEqual(sum(val), 0)
Expand Down