Skip to content

Commit

Permalink
Merge pull request #1729 from PrincetonUniversity/devel
Browse files Browse the repository at this point in the history
Devel
  • Loading branch information
dillontsmith authored Jul 27, 2020
2 parents 001a7a6 + f094293 commit 52996d6
Show file tree
Hide file tree
Showing 16 changed files with 329 additions and 191 deletions.
19 changes: 13 additions & 6 deletions psyneulink/core/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@
RESET_STATEFUL_FUNCTION_WHEN, VALUE, VARIABLE
from psyneulink.core.globals.log import LogCondition
from psyneulink.core.scheduling.time import Time, TimeScale
from psyneulink.core.globals.sampleiterator import SampleIterator
from psyneulink.core.globals.parameters import \
Defaults, Parameter, ParameterAlias, ParameterError, ParametersBase, copy_parameter_value
from psyneulink.core.globals.preferences.basepreferenceset import BasePreferenceSet, VERBOSE_PREF
Expand Down Expand Up @@ -1095,16 +1096,16 @@ def __init__(self,
k: v for k, v in parameter_values.items() if k in self.parameters.names() and getattr(self.parameters, k).function_parameter
}

v = call_with_pruned_args(
var = call_with_pruned_args(
self._handle_default_variable,
default_variable=default_variable,
size=size,
**parameter_values
)
if v is None:
if var is None:
default_variable = self.defaults.variable
else:
default_variable = v
default_variable = var
self.defaults.variable = default_variable
self.parameters.variable._user_specified = True

Expand Down Expand Up @@ -1269,7 +1270,8 @@ def _state_values(p):
val = p.get(context)
if isinstance(val, Component):
return val._get_state_values(context)
return val
return [val for i in range(p.history_min_length + 1)]

return tuple(map(_state_values, self._get_compilation_state()))

def _get_state_initializer(self, context):
Expand All @@ -1296,7 +1298,7 @@ def _get_compilation_params(self):
"input_port_variables", "results", "simulation_results",
"monitor_for_control", "feature_values", "simulation_ids",
"input_labels_dict", "output_labels_dict",
"modulated_mechanisms", "search_space", "grid",
"modulated_mechanisms", "grid",
"activation_derivative_fct", "input_specification",
# Shape mismatch
"costs", "auto", "hetero",
Expand All @@ -1313,7 +1315,7 @@ def _is_compilation_param(p):
#FIXME: this should use defaults
val = p.get()
# Check if the value type is valid for compilation
return not isinstance(val, (str, dict, ComponentsMeta,
return not isinstance(val, (str, ComponentsMeta,
ContentAddressableList, type(max),
type(_is_compilation_param),
type(self._get_compilation_params)))
Expand Down Expand Up @@ -1368,6 +1370,11 @@ def _get_param_initializer(self, context):
def _convert(x):
if isinstance(x, Enum):
return x.value
elif isinstance(x, SampleIterator):
if isinstance(x.generator, list):
return (float(v) for v in x.generator)
else:
return (float(x.start), float(x.step), int(x.num))
try:
return (_convert(i) for i in x)
except TypeError:
Expand Down
74 changes: 27 additions & 47 deletions psyneulink/core/components/functions/optimizationfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,10 +1352,10 @@ def reset_grid(self):
def _gen_llvm_function(self, *, ctx:pnlvm.LLVMBuilderContext, tags:frozenset):
if "select_min" in tags:
return self._gen_llvm_select_min_function(ctx=ctx, tags=tags)
if self._is_composition_optimize():
# self.objective_function may be bound method of
# an OptimizationControlMechanism
ocm = self.objective_function.__self__
ocm = self._get_optimized_composition()
if ocm is not None:
# self.objective_function may be a bound method of
# OptimizationControlMechanism
extra_args = [ctx.get_param_struct_type(ocm.agent_rep).as_pointer(),
ctx.get_state_struct_type(ocm.agent_rep).as_pointer(),
ctx.get_data_struct_type(ocm.agent_rep).as_pointer()]
Expand All @@ -1382,77 +1382,57 @@ def _get_input_struct_type(self, ctx):

return ctx.convert_python_struct_to_llvm_ir(variable)

def _is_composition_optimize(self):
# self.objective_function may be bound method of
# an OptimizationControlMechanism
return hasattr(self.objective_function, '__self__')
def _get_optimized_composition(self):
# self.objective_function may be a bound method of
# OptimizationControlMechanism
return getattr(self.objective_function, '__self__', None)

def _get_param_ids(self):
ids = super()._get_param_ids() + ["search_space"]
if self._is_composition_optimize():
ids = super()._get_param_ids()
if self._get_optimized_composition() is not None:
ids.append("objective_function")

return ids

def _get_search_dim_type(self, ctx, d):
if isinstance(d.generator, list):
# Make sure we only generate float values
return ctx.convert_python_struct_to_llvm_ir([float(x) for x in d.generator])
if isinstance(d, SampleIterator):
return pnlvm.ir.LiteralStructType((ctx.float_ty, ctx.float_ty, ctx.int32_ty))
assert False, "Unsupported dimension type: {}".format(d)

def _get_param_struct_type(self, ctx):
param_struct = ctx.get_param_struct_type(super())
search_space = (self._get_search_dim_type(ctx, d) for d in self.search_space)
search_space_struct = pnlvm.ir.LiteralStructType(search_space)

if self._is_composition_optimize():
ocm = self._get_optimized_composition()
if ocm is not None:
# self.objective_function is a bound method of
# an OptimizationControlMechanism
ocm = self.objective_function.__self__
# OptimizationControlMechanism
obj_func_params = ocm._get_evaluate_param_struct_type(ctx)
return pnlvm.ir.LiteralStructType([*param_struct,
search_space_struct,
obj_func_params])

return pnlvm.ir.LiteralStructType([*param_struct, search_space_struct])

def _get_search_dim_init(self, context, d):
if isinstance(d.generator, list):
return tuple(d.generator)
if isinstance(d, SampleIterator):
return (d.start, d.step, d.num)

assert False, "Unsupported dimension type: {}".format(d)
return param_struct

def _get_param_initializer(self, context):
param_initializer = super()._get_param_initializer(context)
grid_init = (self._get_search_dim_init(context, d) for d in self.search_space)

if self._is_composition_optimize():
ocm = self._get_optimized_composition()
if ocm is not None:
# self.objective_function is a bound method of
# an OptimizationControlMechanism
ocm = self.objective_function.__self__
# OptimizationControlMechanism
obj_func_param_init = ocm._get_evaluate_param_initializer(context)
return (*param_initializer, tuple(grid_init), obj_func_param_init)
return (*param_initializer, obj_func_param_init)

return (*param_initializer, tuple(grid_init))
return param_initializer

def _get_state_ids(self):
ids = super()._get_state_ids()
if self._is_composition_optimize():
if self._get_optimized_composition() is not None:
ids.append("objective_function")

return ids

def _get_state_struct_type(self, ctx):
state_struct = ctx.get_state_struct_type(super())

if self._is_composition_optimize():
ocm = self._get_optimized_composition()
if ocm is not None:
# self.objective_function is a bound method of
# an OptimizationControlMechanism
ocm = self.objective_function.__self__
# OptimizationControlMechanism
obj_func_state = ocm._get_evaluate_state_struct_type(ctx)
state_struct = pnlvm.ir.LiteralStructType([*state_struct,
obj_func_state])
Expand All @@ -1462,10 +1442,10 @@ def _get_state_struct_type(self, ctx):
def _get_state_initializer(self, context):
state_initializer = super()._get_state_initializer(context)

if self._is_composition_optimize():
ocm = self._get_optimized_composition()
if ocm is not None:
# self.objective_function is a bound method of
# an OptimizationControlMechanism
ocm = self.objective_function.__self__
# OptimizationControlMechanism
obj_func_state_init = ocm._get_evaluate_state_initializer(context)
state_initializer = (*state_initializer, obj_func_state_init)

Expand Down Expand Up @@ -1792,7 +1772,7 @@ def _function(self,
format(repr(DIRECTION), self.name, direction)


ocm = self.objective_function.__self__ if self._is_composition_optimize() else None
ocm = self._get_optimized_composition()
if ocm is not None and \
ocm.parameters.comp_execution_mode._get(context).startswith("PTX"):
opt_sample, opt_value, all_samples, all_values = self._run_cuda_grid(ocm, variable, context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,6 @@ def __init__(self,
self.parameters.val_size._set(len(self.previous_value[VALS][0]), Context())

self.has_initializers = True
self.stateful_attributes = ["random_state", "previous_value"]

def _get_state_ids(self):
return super()._get_state_ids() + ["ring_memory"]
Expand Down Expand Up @@ -772,14 +771,9 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
var_val_ptr = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(1)])

# Zero output
builder.store(arg_out.type.pointee(None), arg_out)
out_key_ptr = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(0)])
out_val_ptr = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(1)])
with pnlvm.helpers.array_ptr_loop(builder, out_key_ptr, "zero_key") as (b, i):
out_ptr = b.gep(out_key_ptr, [ctx.int32_ty(0), i])
b.store(out_ptr.type.pointee(0), out_ptr)
with pnlvm.helpers.array_ptr_loop(builder, out_val_ptr, "zero_val") as (b, i):
out_ptr = b.gep(out_val_ptr, [ctx.int32_ty(0), i])
b.store(out_ptr.type.pointee(0), out_ptr)

# Check retrieval probability
retr_ptr = builder.alloca(pnlvm.ir.IntType(1))
Expand Down Expand Up @@ -813,7 +807,7 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
builder.gep(distance_arg_in, [ctx.int32_ty(0),
ctx.int32_ty(0)]))
selection_arg_in = builder.alloca(pnlvm.ir.ArrayType(distance_f.args[3].type.pointee, max_entries))
with pnlvm.helpers.for_loop_zero_inc(builder, entries, "distance_loop") as (b,idx):
with pnlvm.helpers.for_loop_zero_inc(builder, entries, "distance_loop") as (b, idx):
compare_ptr = b.gep(keys_ptr, [ctx.int32_ty(0), idx])
b.store(b.load(compare_ptr),
b.gep(distance_arg_in, [ctx.int32_ty(0), ctx.int32_ty(1)]))
Expand All @@ -833,7 +827,7 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
builder.store(ctx.int32_ty(0), selected_idx_ptr)
with pnlvm.helpers.for_loop_zero_inc(builder, entries, "distance_loop") as (b,idx):
selection_val = b.load(b.gep(selection_arg_out, [ctx.int32_ty(0), idx]))
non_zero = b.fcmp_ordered('!=', selection_val, ctx.float_ty(0))
non_zero = b.fcmp_ordered('!=', selection_val, selection_val.type(0))
with b.if_then(non_zero):
b.store(idx, selected_idx_ptr)

Expand All @@ -842,8 +836,8 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
selected_idx]))
selected_val = builder.load(builder.gep(vals_ptr, [ctx.int32_ty(0),
selected_idx]))
builder.store(selected_key, builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(0)]))
builder.store(selected_val, builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(1)]))
builder.store(selected_key, out_key_ptr)
builder.store(selected_val, out_val_ptr)

# Check storage probability
store_ptr = builder.alloca(pnlvm.ir.IntType(1))
Expand Down
75 changes: 60 additions & 15 deletions psyneulink/core/components/mechanisms/mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ class `UserList <https://docs.python.org/3.6/library/collections.html?highlight=
PARAMETER_PORT, PARAMETER_PORT_PARAMS, PARAMETER_PORTS, PROJECTIONS, REFERENCE_VALUE, RESULT, \
TARGET_LABELS_DICT, VALUE, VARIABLE, WEIGHT
from psyneulink.core.globals.parameters import Parameter
from psyneulink.core.scheduling.condition import Condition
from psyneulink.core.scheduling.condition import Condition, TimeScale
from psyneulink.core.globals.preferences.preferenceset import PreferenceLevel
from psyneulink.core.globals.registry import register_category, remove_instance_from_registry
from psyneulink.core.globals.utilities import \
Expand Down Expand Up @@ -2947,41 +2947,67 @@ def _fill_input(b, s_input, i):

def _gen_llvm_invoke_function(self, ctx, builder, function, params, state, variable, *, tags:frozenset):
fun = ctx.import_llvm_function(function, tags=tags)
fun_in, builder = self._gen_llvm_function_input_parse(builder, ctx, fun, variable)
fun_out = builder.alloca(fun.args[3].type.pointee)

builder.call(fun, [params, state, fun_in, fun_out])
builder.call(fun, [params, state, variable, fun_out])

return fun_out, builder

def _gen_llvm_is_finished_cond(self, ctx, builder, params, state, current):
def _gen_llvm_is_finished_cond(self, ctx, builder, params, state):
return pnlvm.ir.IntType(1)(1)

def _gen_llvm_function_internal(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):

ip_output, builder = self._gen_llvm_input_ports(ctx, builder,
params, state, arg_in)
def _gen_llvm_mechanism_functions(self, ctx, builder, params, state, arg_in,
ip_output, *, tags:frozenset):

# Default mechanism runs only the main function
f_params_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, "function")
f_params, builder = self._gen_llvm_param_ports_for_obj(
self.function, f_params_ptr, ctx, builder, params, state, arg_in)

f_state = pnlvm.helpers.get_state_ptr(builder, self, state, "function")
value, builder = self._gen_llvm_invoke_function(ctx, builder, self.function, f_params, f_state, ip_output, tags=tags)

return self._gen_llvm_invoke_function(ctx, builder, self.function,
f_params, f_state, ip_output,
tags=tags)

def _gen_llvm_function_internal(self, ctx, builder, params, state, arg_in,
arg_out, *, tags:frozenset):

ip_output, builder = self._gen_llvm_input_ports(ctx, builder,
params, state, arg_in)

value, builder = self._gen_llvm_mechanism_functions(ctx, builder, params,
state, arg_in,
ip_output,
tags=tags)

# Update execution counter
exec_count_ptr = pnlvm.helpers.get_state_ptr(builder, self, state, "execution_count")
exec_count = builder.load(exec_count_ptr)
exec_count = builder.fadd(exec_count, exec_count.type(1))
builder.store(exec_count, exec_count_ptr)

# Update internal clock (i.e. num_executions parameter)
num_executions_ptr = pnlvm.helpers.get_state_ptr(builder, self, state, "num_executions")
for scale in [TimeScale.TIME_STEP, TimeScale.PASS, TimeScale.TRIAL, TimeScale.RUN]:
num_exec_time_ptr = builder.gep(num_executions_ptr, [ctx.int32_ty(0), ctx.int32_ty(scale.value)])
new_val = builder.load(num_exec_time_ptr)
new_val = builder.add(new_val, ctx.int32_ty(1))
builder.store(new_val, num_exec_time_ptr)

builder = self._gen_llvm_output_ports(ctx, builder, value, params, state, arg_in, arg_out)
is_finished_cond = self._gen_llvm_is_finished_cond(ctx, builder, params,
state, value)
return builder, is_finished_cond

def _gen_llvm_function_input_parse(self, builder, ctx, func, func_in):
return func_in, builder
val_ptr = pnlvm.helpers.get_state_ptr(builder, self, state, "value")
if val_ptr.type.pointee == value.type.pointee:
pnlvm.helpers.push_state_val(builder, self, state, "value", value)
else:
# FIXME: Does this need some sort of parsing?
warnings.warn("Shape mismatch: function result does not match mechanism value: {}".format(value.type.pointee, val_ptr.type.pointee))

# is_finished should be checked after output ports ran
is_finished_f = ctx.import_llvm_function(self, tags=tags.union({"is_finished"}))
is_finished_cond = builder.call(is_finished_f, [params, state, arg_in,
arg_out])
return builder, is_finished_cond

def _gen_llvm_function_reset(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
assert "reset" in tags
Expand All @@ -2995,6 +3021,25 @@ def _gen_llvm_function_reset(self, ctx, builder, params, state, arg_in, arg_out,

return builder

def _gen_llvm_function(self, *, extra_args=[], ctx:pnlvm.LLVMBuilderContext, tags:frozenset):
if "is_finished" not in tags:
return super()._gen_llvm_function(extra_args=extra_args, ctx=ctx,
tags=tags)

# Keep all 4 standard arguments to ease invocation
args = [ctx.get_param_struct_type(self).as_pointer(),
ctx.get_state_struct_type(self).as_pointer(),
ctx.get_input_struct_type(self).as_pointer(),
ctx.get_output_struct_type(self).as_pointer()]

builder = ctx.create_llvm_function(args, self,
return_type=pnlvm.ir.IntType(1),
tags=tags)
params, state = builder.function.args[:2]
finished = self._gen_llvm_is_finished_cond(ctx, builder, params, state)
builder.ret(finished)
return builder.function

def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
assert "reset" not in tags

Expand Down
Loading

0 comments on commit 52996d6

Please sign in to comment.