Skip to content

Commit

Permalink
Implemented Minor Fixed and Enhancements (#28)
Browse files Browse the repository at this point in the history
* Fixed Warnings Raised in Tile.py

* Enhanced Resampling Scheme For r_off and r_on in unpack_parameters

* Added Verbose Argument to patch_model and memtorch.mn Modules
  • Loading branch information
coreylammie authored Feb 2, 2021
1 parent e40c71d commit 6496cb3
Show file tree
Hide file tree
Showing 14 changed files with 118 additions and 46 deletions.
47 changes: 33 additions & 14 deletions memtorch/bh/StochasticParameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import memtorch
import inspect
import math

import copy

def StochasticParameter(distribution=torch.distributions.normal.Normal, min=0, max=float('Inf'), function=True, **kwargs):
"""Method to model a stochatic parameter.
Expand Down Expand Up @@ -52,43 +52,62 @@ def f(return_mean=False):
else:
return f()

def unpack_parameters(local_args, failure_threshold=5):
def unpack_parameters(local_args, r_rel_tol=None, r_abs_tol=None, resample_threshold=5):
"""Method to sample from stochastic sample-value generators
Parameters
----------
local_args : locals()
Local arguments with stochastic sample-value generators from which to sample from.
failure_threshold : int
Failure threshold to raise an Exception if r_off and r_on are indistinguishable.
r_rel_tol : float
Relative threshold tolerance.
r_abs_tol : float
Absolute threshold tolerance.
resample_threshold : int
Number of times to resample r_off and r_on when their proximity is within the threshold tolerance before raising an exception.
Returns
-------
**
locals() with sampled stochastic parameters.
"""
assert r_rel_tol is None or r_abs_tol is None, 'r_rel_tol or r_abs_tol must be None.'
assert type(resample_threshold) == int and resample_threshold >= 0, 'resample_threshold must be of type int and >= 0.'
if 'reference' in local_args:
return_mean = True
else:
return_mean = False

local_args_copy = copy.deepcopy(local_args)
for arg in local_args:
if callable(local_args[arg]) and '__' not in str(arg):
local_args[arg] = local_args[arg](return_mean=return_mean)

args = Dict2Obj(local_args)
if hasattr(args, 'r_off') and hasattr(args, 'r_on'):
assert type(failure_threshold) == int and failure_threshold > 0, 'Invalid failure_threshold value.'
failure_idx = 0
resample_idx = 0
r_off_generator = local_args_copy['r_off']
r_on_generator = local_args_copy['r_on']
while True:
failure_idx += 1
if failure_idx > failure_threshold:
raise Exception('r_off and r_on values are indistinguishable.')

if not math.isclose(args.r_off, args.r_on):
break
if r_abs_tol is None and r_rel_tol is not None:
if not math.isclose(args.r_off, args.r_on, rel_tol=r_rel_tol):
break
elif r_rel_tol is None and r_abs_tol is not None:
if not math.isclose(args.r_off, args.r_on, abs_tol=r_abs_tol):
break
else:
if not math.isclose(args.r_off, args.r_on):
break

if callable(r_off_generator) and callable(r_on_generator):
args.r_off = copy.deepcopy(r_off_generator)(return_mean=return_mean)
args.r_on = copy.deepcopy(r_on_generator)(return_mean=return_mean)
else:
raise Exception('Resample threshold exceeded (deterministic values used).')

resample_idx += 1
if resample_idx > resample_threshold:
raise Exception('Resample threshold exceeded.')

return args

Expand Down
11 changes: 8 additions & 3 deletions memtorch/bh/crossbar/Tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,16 @@ def update_array(self, new_array):
self.array = new_array
else:
new_col_cnt = new_array.shape[1]
if type(new_array) == np.ndarray:
new_array = torch.from_numpy(new_array)
else:
new_array = new_array.clone().detach()

if self.patch_num is None:
new_row_cnt = new_array.shape[0]
self.array[:new_row_cnt, : new_col_cnt] = torch.tensor(new_array)
self.array[:new_row_cnt, : new_col_cnt] = new_array
else:
self.array[:, :new_col_cnt] = torch.tensor(new_array)
self.array[:, :new_col_cnt] = new_array

def gen_tiles(tensor, tile_shape, input=False):
""" Method to generate a set of modular tiles representative of a tensor.
Expand Down Expand Up @@ -111,7 +116,7 @@ def gen_tiles(tensor, tile_shape, input=False):
new_tile_id = len(tiles)-1
tiles_map[tile_row][tile_column] = new_tile_id

tiles = torch.tensor([np.array(tile.array) for tile in tiles])
tiles = torch.tensor([np.array(tile.array.cpu()) for tile in tiles])
return tiles, tiles_map

def tile_matmul(mat_a_tiles, mat_a_tiles_map, mat_a_shape, mat_b_tiles, mat_b_tiles_map, mat_b_shape):
Expand Down
4 changes: 1 addition & 3 deletions memtorch/bh/memristor/Data_Driven.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def __init__(self,
**kwargs):

args = memtorch.bh.unpack_parameters(locals())
super(Data_Driven, self).__init__(args.time_series_resolution, 0, 0)
self.r_off = args.r_off
self.r_on = args.r_on
super(Data_Driven, self).__init__(args.r_off, args.r_on, args.time_series_resolution, 0, 0)
self.A_p = args.A_p
self.A_n = args.A_n
self.t_p = args.t_p
Expand Down
4 changes: 1 addition & 3 deletions memtorch/bh/memristor/LinearIonDrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,9 @@ def __init__(self,
**kwargs):

args = memtorch.bh.unpack_parameters(locals())
super(LinearIonDrift, self).__init__(args.time_series_resolution, args.pos_write_threshold, args.neg_write_threshold)
super(LinearIonDrift, self).__init__(args.r_off, args.r_on, args.time_series_resolution, args.pos_write_threshold, args.neg_write_threshold)
self.u_v = args.u_v
self.d = args.d
self.r_on = args.r_on
self.r_off = args.r_off
self.r_i = args.r_on
self.p = args.p
self.g = 1 / self.r_i
Expand Down
8 changes: 7 additions & 1 deletion memtorch/bh/memristor/Memristor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ class Memristor(ABC):
"""
Parameters
----------
r_off : float
Off (maximum) resistance of the device (ohms).
r_on : float
On (minimum) resistance of the device (ohms).
time_series_resolution : float
Time series resolution (s).
pos_write_threshold : float
Expand All @@ -19,7 +23,9 @@ class Memristor(ABC):
Negative write threshold voltage (V).
"""

def __init__(self, time_series_resolution, pos_write_threshold=0, neg_write_threshold=0):
def __init__(self, r_off, r_on, time_series_resolution, pos_write_threshold=0, neg_write_threshold=0):
self.r_off = r_off
self.r_on = r_on
self.time_series_resolution = time_series_resolution
self.pos_write_threshold = pos_write_threshold
self.neg_write_threshold = neg_write_threshold
Expand Down
4 changes: 1 addition & 3 deletions memtorch/bh/memristor/Stanford_PKU.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def __init__(self,
**kwargs):

args = memtorch.bh.unpack_parameters(locals())
super(Stanford_PKU, self).__init__(args.time_series_resolution, 0, 0)
self.r_off = args.r_off
self.r_on = args.r_on
super(Stanford_PKU, self).__init__(args.r_off, args.r_on, args.time_series_resolution, 0, 0)
self.gap_init = args.gap_init
self.g_0 = args.g_0
self.V_0 = args.V_0
Expand Down
4 changes: 1 addition & 3 deletions memtorch/bh/memristor/VTEAM.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def __init__(self,
**kwargs):

args = memtorch.bh.unpack_parameters(locals())
super(VTEAM, self).__init__(args.time_series_resolution, args.v_off, args.v_on)
self.r_off = args.r_off
self.r_on = args.r_on
super(VTEAM, self).__init__(args.r_off, args.r_on, args.time_series_resolution, args.v_off, args.v_on)
self.d = args.d
self.k_on = args.k_on
self.k_off = args.k_off
Expand Down
8 changes: 6 additions & 2 deletions memtorch/map/Module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.metrics import r2_score


def naive_tune(module, input_shape):
def naive_tune(module, input_shape, verbose=True):
"""Method to determine a linear relationship between a memristive crossbar and the output for a given memristive module.
Parameters
Expand All @@ -17,6 +17,8 @@ def naive_tune(module, input_shape):
Memristive layer to tune.
input_shape : (int, int)
Shape of the randomly generated input used to tune a crossbar.
verbose : bool
Used to determine if verbose output is enabled (True) or disabled (False).
Returns
-------
Expand All @@ -40,5 +42,7 @@ def naive_tune(module, input_shape):
intercept = np.array(reg.intercept_).item()
transform_output = lambda x: x * coef + intercept
module.bias = tmp
print('Tuned %s. Coefficient of determination: %f [%f, %f]' % (module, reg.score(output, legacy_output), coef, intercept))
if verbose:
print('Tuned %s. Coefficient of determination: %f [%f, %f]' % (module, reg.score(output, legacy_output), coef, intercept))

return transform_output
10 changes: 7 additions & 3 deletions memtorch/mn/Conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@ class Conv1d(nn.Conv1d):
Weight representation scheme.
tile_shape : (int, int)
Tile shape to use to store weights. If None, modular tiles are not used.
verbose : bool
Used to determine if verbose output is enabled (True) or disabled (False).
"""

def __init__(self, convolutional_layer, memristor_model, memristor_model_params, mapping_routine=naive_map, transistor=True, programming_routine=None,
programming_routine_params={}, p_l=None, scheme=memtorch.bh.Scheme.DoubleColumn, tile_shape=None, *args, **kwargs):
programming_routine_params={}, p_l=None, scheme=memtorch.bh.Scheme.DoubleColumn, tile_shape=None, verbose=True, *args, **kwargs):
assert isinstance(convolutional_layer, nn.Conv1d), 'convolutional_layer is not an instance of nn.Conv1d.'
self.device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
self.scheme = scheme
self.tile_shape = tile_shape
self.verbose = verbose
self.forward_legacy_enabled = True
super(Conv1d, self).__init__(convolutional_layer.in_channels, convolutional_layer.out_channels, convolutional_layer.kernel_size, **kwargs)
self.padding = convolutional_layer.padding
Expand All @@ -67,7 +70,8 @@ def __init__(self, convolutional_layer, memristor_model, memristor_model_params,
scheme=scheme,
tile_shape=tile_shape)
self.transform_output = lambda x: x
print('Patched %s -> %s' % (convolutional_layer, self))
if verbose:
print('Patched %s -> %s' % (convolutional_layer, self))

def forward(self, input):
"""Method to perform forward propagations.
Expand Down Expand Up @@ -128,7 +132,7 @@ def forward(self, input):

def tune(self, input_batch_size=8, input_shape=32):
"""Tuning method."""
self.transform_output = naive_tune(self, (input_batch_size, self.in_channels, input_shape))
self.transform_output = naive_tune(self, (input_batch_size, self.in_channels, input_shape), self.verbose)

def __str__(self):
return "bh.Conv1d(in_channels=%d, out_channels=%d, kernel_size=%d, stride=%d, padding=%d)" % \
Expand Down
10 changes: 7 additions & 3 deletions memtorch/mn/Conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@ class Conv2d(nn.Conv2d):
Weight representation scheme.
tile_shape : (int, int)
Tile shape to use to store weights. If None, modular tiles are not used.
verbose : bool
Used to determine if verbose output is enabled (True) or disabled (False).
"""

def __init__(self, convolutional_layer, memristor_model, memristor_model_params, mapping_routine=naive_map, transistor=True, programming_routine=None,
programming_routine_params={}, p_l=None, scheme=memtorch.bh.Scheme.DoubleColumn, tile_shape=None, *args, **kwargs):
programming_routine_params={}, p_l=None, scheme=memtorch.bh.Scheme.DoubleColumn, tile_shape=None, verbose=True, *args, **kwargs):
assert isinstance(convolutional_layer, nn.Conv2d), 'convolutional_layer is not an instance of nn.Conv2d.'
self.device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
self.scheme = scheme
self.tile_shape = tile_shape
self.verbose = verbose
self.forward_legacy_enabled = True
super(Conv2d, self).__init__(convolutional_layer.in_channels, convolutional_layer.out_channels, convolutional_layer.kernel_size, **kwargs)
self.padding = convolutional_layer.padding
Expand All @@ -67,7 +70,8 @@ def __init__(self, convolutional_layer, memristor_model, memristor_model_params,
scheme=scheme,
tile_shape=tile_shape)
self.transform_output = lambda x: x
print('Patched %s -> %s' % (convolutional_layer, self))
if verbose:
print('Patched %s -> %s' % (convolutional_layer, self))

def forward(self, input):
"""Method to perform forward propagations.
Expand Down Expand Up @@ -133,7 +137,7 @@ def forward(self, input):

def tune(self, input_batch_size=8, input_shape=32):
"""Tuning method."""
self.transform_output = naive_tune(self, (input_batch_size, self.in_channels, input_shape, input_shape))
self.transform_output = naive_tune(self, (input_batch_size, self.in_channels, input_shape, input_shape), self.verbose)

def __str__(self):
return "bh.Conv2d(in_channels=%d, out_channels=%d, kernel_size=(%d, %d), stride=(%d, %d), padding=(%d, %d))" % \
Expand Down
12 changes: 8 additions & 4 deletions memtorch/mn/Conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@ class Conv3d(nn.Conv3d):
Weight representation scheme.
tile_shape : (int, int)
Tile shape to use to store weights. If None, modular tiles are not used.
verbose : bool
Used to determine if verbose output is enabled (True) or disabled (False).
"""

def __init__(self, convolutional_layer, memristor_model, memristor_model_params, mapping_routine=naive_map, transistor=True, programming_routine=None,
programming_routine_params={}, p_l=None, scheme=memtorch.bh.Scheme.DoubleColumn, tile_shape=None, *args, **kwargs):
programming_routine_params={}, p_l=None, scheme=memtorch.bh.Scheme.DoubleColumn, tile_shape=None, verbose=True, *args, **kwargs):
assert isinstance(convolutional_layer, nn.Conv3d), 'convolutional_layer is not an instance of nn.Conv3d.'
self.device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
self.scheme = scheme
self.tile_shape = tile_shape
self.verbose = verbose
self.forward_legacy_enabled = True
super(Conv3d, self).__init__(convolutional_layer.in_channels, convolutional_layer.out_channels, convolutional_layer.kernel_size, **kwargs)
self.padding = convolutional_layer.padding
Expand All @@ -67,7 +70,8 @@ def __init__(self, convolutional_layer, memristor_model, memristor_model_params,
scheme=scheme,
tile_shape=tile_shape)
self.transform_output = lambda x: x
print('Patched %s -> %s' % (convolutional_layer, self))
if verbose:
print('Patched %s -> %s' % (convolutional_layer, self))

def forward(self, input):
"""Method to perform forward propagations.
Expand Down Expand Up @@ -95,7 +99,7 @@ def forward(self, input):
batch_input = nn.functional.pad(input[batch], pad=(self.padding[2], self.padding[2], self.padding[1], self.padding[1], self.padding[0], self.padding[0]))
else:
batch_input = input[batch]

unfolded_batch_input = batch_input.unfold(1, self.kernel_size[0], self.stride[0]).unfold(2, self.kernel_size[1], self.stride[1]).unfold(3, self.kernel_size[2], self.stride[2]) \
.permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, self.in_channels * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2])
unfolded_batch_input_shape = unfolded_batch_input.shape
Expand Down Expand Up @@ -133,7 +137,7 @@ def forward(self, input):

def tune(self, input_batch_size=4, input_shape=32):
"""Tuning method."""
self.transform_output = naive_tune(self, (input_batch_size, self.in_channels, input_shape, input_shape, input_shape))
self.transform_output = naive_tune(self, (input_batch_size, self.in_channels, input_shape, input_shape, input_shape), self.verbose)

def __str__(self):
return "bh.Conv3d(in_channels=%d, out_channels=%d, kernel_size=(%d, %d, %d), stride=(%d, %d, %d), padding=(%d, %d, %d))" % \
Expand Down
10 changes: 7 additions & 3 deletions memtorch/mn/Linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@ class Linear(nn.Linear):
Weight representation scheme.
tile_shape : (int, int)
Tile shape to use to store weights. If None, modular tiles are not used.
verbose : bool
Used to determine if verbose output is enabled (True) or disabled (False).
"""

def __init__(self, linear_layer, memristor_model, memristor_model_params, mapping_routine=naive_map, transistor=True, programming_routine=None,
programming_routine_params={}, p_l=None, scheme=memtorch.bh.Scheme.DoubleColumn, tile_shape=None, **kwargs):
programming_routine_params={}, p_l=None, scheme=memtorch.bh.Scheme.DoubleColumn, tile_shape=None, verbose=True, *args, **kwargs):
assert isinstance(linear_layer, nn.Linear), 'linear_layer is not an instance of nn.Linear.'
self.device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
self.scheme = scheme
self.tile_shape = tile_shape
self.verbose = verbose
self.forward_legacy_enabled = True
super(Linear, self).__init__(linear_layer.in_features, linear_layer.out_features, **kwargs)
self.weight.data = linear_layer.weight.data
Expand All @@ -67,7 +70,8 @@ def __init__(self, linear_layer, memristor_model, memristor_model_params, mappin
scheme=scheme,
tile_shape=tile_shape)
self.transform_output = lambda x: x
print('Patched %s -> %s' % (linear_layer, self))
if verbose:
print('Patched %s -> %s' % (linear_layer, self))

def forward(self, input):
"""Method to perform forward propagations.
Expand Down Expand Up @@ -123,7 +127,7 @@ def forward(self, input):

def tune(self, input_shape=4098):
"""Tuning method."""
self.transform_output = naive_tune(self, (input_shape, self.in_features))
self.transform_output = naive_tune(self, (input_shape, self.in_features), self.verbose)

def __str__(self):
return "bh.Linear(in_features=%d, out_features=%d, bias=%s)" % (self.in_features, self.out_features, not self.bias is None)
Loading

0 comments on commit 6496cb3

Please sign in to comment.