Skip to content

Commit

Permalink
llvm: Cleanup (#1728)
Browse files Browse the repository at this point in the history
Refactor to make better use of shared routines
  • Loading branch information
jvesely authored Jul 27, 2020
2 parents c7afc35 + 93bda72 commit f094293
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 63 deletions.
10 changes: 8 additions & 2 deletions psyneulink/core/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@
RESET_STATEFUL_FUNCTION_WHEN, VALUE, VARIABLE
from psyneulink.core.globals.log import LogCondition
from psyneulink.core.scheduling.time import Time, TimeScale
from psyneulink.core.globals.sampleiterator import SampleIterator
from psyneulink.core.globals.parameters import \
Defaults, Parameter, ParameterAlias, ParameterError, ParametersBase, copy_parameter_value
from psyneulink.core.globals.preferences.basepreferenceset import BasePreferenceSet, VERBOSE_PREF
Expand Down Expand Up @@ -1297,7 +1298,7 @@ def _get_compilation_params(self):
"input_port_variables", "results", "simulation_results",
"monitor_for_control", "feature_values", "simulation_ids",
"input_labels_dict", "output_labels_dict",
"modulated_mechanisms", "search_space", "grid",
"modulated_mechanisms", "grid",
"activation_derivative_fct", "input_specification",
# Shape mismatch
"costs", "auto", "hetero",
Expand All @@ -1314,7 +1315,7 @@ def _is_compilation_param(p):
#FIXME: this should use defaults
val = p.get()
# Check if the value type is valid for compilation
return not isinstance(val, (str, dict, ComponentsMeta,
return not isinstance(val, (str, ComponentsMeta,
ContentAddressableList, type(max),
type(_is_compilation_param),
type(self._get_compilation_params)))
Expand Down Expand Up @@ -1369,6 +1370,11 @@ def _get_param_initializer(self, context):
def _convert(x):
if isinstance(x, Enum):
return x.value
elif isinstance(x, SampleIterator):
if isinstance(x.generator, list):
return (float(v) for v in x.generator)
else:
return (float(x.start), float(x.step), int(x.num))
try:
return (_convert(i) for i in x)
except TypeError:
Expand Down
74 changes: 27 additions & 47 deletions psyneulink/core/components/functions/optimizationfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,10 +1352,10 @@ def reset_grid(self):
def _gen_llvm_function(self, *, ctx:pnlvm.LLVMBuilderContext, tags:frozenset):
if "select_min" in tags:
return self._gen_llvm_select_min_function(ctx=ctx, tags=tags)
if self._is_composition_optimize():
# self.objective_function may be bound method of
# an OptimizationControlMechanism
ocm = self.objective_function.__self__
ocm = self._get_optimized_composition()
if ocm is not None:
# self.objective_function may be a bound method of
# OptimizationControlMechanism
extra_args = [ctx.get_param_struct_type(ocm.agent_rep).as_pointer(),
ctx.get_state_struct_type(ocm.agent_rep).as_pointer(),
ctx.get_data_struct_type(ocm.agent_rep).as_pointer()]
Expand All @@ -1382,77 +1382,57 @@ def _get_input_struct_type(self, ctx):

return ctx.convert_python_struct_to_llvm_ir(variable)

def _is_composition_optimize(self):
# self.objective_function may be bound method of
# an OptimizationControlMechanism
return hasattr(self.objective_function, '__self__')
def _get_optimized_composition(self):
# self.objective_function may be a bound method of
# OptimizationControlMechanism
return getattr(self.objective_function, '__self__', None)

def _get_param_ids(self):
ids = super()._get_param_ids() + ["search_space"]
if self._is_composition_optimize():
ids = super()._get_param_ids()
if self._get_optimized_composition() is not None:
ids.append("objective_function")

return ids

def _get_search_dim_type(self, ctx, d):
if isinstance(d.generator, list):
# Make sure we only generate float values
return ctx.convert_python_struct_to_llvm_ir([float(x) for x in d.generator])
if isinstance(d, SampleIterator):
return pnlvm.ir.LiteralStructType((ctx.float_ty, ctx.float_ty, ctx.int32_ty))
assert False, "Unsupported dimension type: {}".format(d)

def _get_param_struct_type(self, ctx):
param_struct = ctx.get_param_struct_type(super())
search_space = (self._get_search_dim_type(ctx, d) for d in self.search_space)
search_space_struct = pnlvm.ir.LiteralStructType(search_space)

if self._is_composition_optimize():
ocm = self._get_optimized_composition()
if ocm is not None:
# self.objective_function is a bound method of
# an OptimizationControlMechanism
ocm = self.objective_function.__self__
# OptimizationControlMechanism
obj_func_params = ocm._get_evaluate_param_struct_type(ctx)
return pnlvm.ir.LiteralStructType([*param_struct,
search_space_struct,
obj_func_params])

return pnlvm.ir.LiteralStructType([*param_struct, search_space_struct])

def _get_search_dim_init(self, context, d):
if isinstance(d.generator, list):
return tuple(d.generator)
if isinstance(d, SampleIterator):
return (d.start, d.step, d.num)

assert False, "Unsupported dimension type: {}".format(d)
return param_struct

def _get_param_initializer(self, context):
param_initializer = super()._get_param_initializer(context)
grid_init = (self._get_search_dim_init(context, d) for d in self.search_space)

if self._is_composition_optimize():
ocm = self._get_optimized_composition()
if ocm is not None:
# self.objective_function is a bound method of
# an OptimizationControlMechanism
ocm = self.objective_function.__self__
# OptimizationControlMechanism
obj_func_param_init = ocm._get_evaluate_param_initializer(context)
return (*param_initializer, tuple(grid_init), obj_func_param_init)
return (*param_initializer, obj_func_param_init)

return (*param_initializer, tuple(grid_init))
return param_initializer

def _get_state_ids(self):
ids = super()._get_state_ids()
if self._is_composition_optimize():
if self._get_optimized_composition() is not None:
ids.append("objective_function")

return ids

def _get_state_struct_type(self, ctx):
state_struct = ctx.get_state_struct_type(super())

if self._is_composition_optimize():
ocm = self._get_optimized_composition()
if ocm is not None:
# self.objective_function is a bound method of
# an OptimizationControlMechanism
ocm = self.objective_function.__self__
# OptimizationControlMechanism
obj_func_state = ocm._get_evaluate_state_struct_type(ctx)
state_struct = pnlvm.ir.LiteralStructType([*state_struct,
obj_func_state])
Expand All @@ -1462,10 +1442,10 @@ def _get_state_struct_type(self, ctx):
def _get_state_initializer(self, context):
state_initializer = super()._get_state_initializer(context)

if self._is_composition_optimize():
ocm = self._get_optimized_composition()
if ocm is not None:
# self.objective_function is a bound method of
# an OptimizationControlMechanism
ocm = self.objective_function.__self__
# OptimizationControlMechanism
obj_func_state_init = ocm._get_evaluate_state_initializer(context)
state_initializer = (*state_initializer, obj_func_state_init)

Expand Down Expand Up @@ -1792,7 +1772,7 @@ def _function(self,
format(repr(DIRECTION), self.name, direction)


ocm = self.objective_function.__self__ if self._is_composition_optimize() else None
ocm = self._get_optimized_composition()
if ocm is not None and \
ocm.parameters.comp_execution_mode._get(context).startswith("PTX"):
opt_sample, opt_value, all_samples, all_values = self._run_cuda_grid(ocm, variable, context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,6 @@ def __init__(self,
self.parameters.val_size._set(len(self.previous_value[VALS][0]), Context())

self.has_initializers = True
self.stateful_attributes = ["random_state", "previous_value"]

def _get_state_ids(self):
return super()._get_state_ids() + ["ring_memory"]
Expand Down Expand Up @@ -772,14 +771,9 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
var_val_ptr = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(1)])

# Zero output
builder.store(arg_out.type.pointee(None), arg_out)
out_key_ptr = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(0)])
out_val_ptr = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(1)])
with pnlvm.helpers.array_ptr_loop(builder, out_key_ptr, "zero_key") as (b, i):
out_ptr = b.gep(out_key_ptr, [ctx.int32_ty(0), i])
b.store(out_ptr.type.pointee(0), out_ptr)
with pnlvm.helpers.array_ptr_loop(builder, out_val_ptr, "zero_val") as (b, i):
out_ptr = b.gep(out_val_ptr, [ctx.int32_ty(0), i])
b.store(out_ptr.type.pointee(0), out_ptr)

# Check retrieval probability
retr_ptr = builder.alloca(pnlvm.ir.IntType(1))
Expand Down Expand Up @@ -813,7 +807,7 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
builder.gep(distance_arg_in, [ctx.int32_ty(0),
ctx.int32_ty(0)]))
selection_arg_in = builder.alloca(pnlvm.ir.ArrayType(distance_f.args[3].type.pointee, max_entries))
with pnlvm.helpers.for_loop_zero_inc(builder, entries, "distance_loop") as (b,idx):
with pnlvm.helpers.for_loop_zero_inc(builder, entries, "distance_loop") as (b, idx):
compare_ptr = b.gep(keys_ptr, [ctx.int32_ty(0), idx])
b.store(b.load(compare_ptr),
b.gep(distance_arg_in, [ctx.int32_ty(0), ctx.int32_ty(1)]))
Expand All @@ -833,7 +827,7 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
builder.store(ctx.int32_ty(0), selected_idx_ptr)
with pnlvm.helpers.for_loop_zero_inc(builder, entries, "distance_loop") as (b,idx):
selection_val = b.load(b.gep(selection_arg_out, [ctx.int32_ty(0), idx]))
non_zero = b.fcmp_ordered('!=', selection_val, ctx.float_ty(0))
non_zero = b.fcmp_ordered('!=', selection_val, selection_val.type(0))
with b.if_then(non_zero):
b.store(idx, selected_idx_ptr)

Expand All @@ -842,8 +836,8 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
selected_idx]))
selected_val = builder.load(builder.gep(vals_ptr, [ctx.int32_ty(0),
selected_idx]))
builder.store(selected_key, builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(0)]))
builder.store(selected_val, builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(1)]))
builder.store(selected_key, out_key_ptr)
builder.store(selected_val, out_val_ptr)

# Check storage probability
store_ptr = builder.alloca(pnlvm.ir.IntType(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1025,8 +1025,7 @@ def _get_evaluate_param_struct_type(self, ctx):

def _get_evaluate_param_initializer(self, context):
num_estimates = self.parameters.num_estimates.get(context) or 0
# FIXME: The intensity cost function is not setup with the right execution id
intensity_cost = tuple(op.intensity_cost_function._get_param_initializer(None) for op in self.output_ports)
intensity_cost = tuple(op.intensity_cost_function._get_param_initializer(context) for op in self.output_ports)
return (intensity_cost, num_estimates)

def _get_evaluate_state_struct_type(self, ctx):
Expand Down
6 changes: 6 additions & 0 deletions psyneulink/core/llvm/builder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Set
import weakref
from psyneulink.core.scheduling.time import Time
from psyneulink.core.globals.sampleiterator import SampleIterator
from psyneulink.core import llvm as pnlvm
from . import codegen
from .debug import debug_env
Expand Down Expand Up @@ -305,6 +306,11 @@ def convert_python_struct_to_llvm_ir(self, t):
return pnlvm.builtins.get_mersenne_twister_state_struct(self)
elif isinstance(t, Time):
return ir.ArrayType(self.int32_ty, len(Time._time_scale_attr_map))
elif isinstance(t, SampleIterator):
if isinstance(t.generator, list):
return ir.ArrayType(self.float_ty, len(t.generator))
# Generic iterator is {start, increment, count}
return ir.LiteralStructType((self.float_ty, self.float_ty, self.int32_ty))
assert False, "Don't know how to convert {}".format(type(t))


Expand Down
2 changes: 1 addition & 1 deletion psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def execute(self, variable):

class CompExecution(CUDAExecution):

def __init__(self, composition, execution_ids=[None], additional_tags=frozenset()):
def __init__(self, composition, execution_ids=[None], *, additional_tags=frozenset()):
super().__init__(buffers=['state_struct', 'param_struct', 'data_struct', 'conditions'])
self._composition = composition
self._execution_contexts = [
Expand Down

0 comments on commit f094293

Please sign in to comment.