Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add indexing with numpy index arrays #16

Merged
merged 11 commits into from
Nov 19, 2015
Next Next commit
add channel transform to channel (both grid and particles), passes
attributes thorugh function before setting target attributes
  • Loading branch information
ipelupessy committed Nov 14, 2015
commit 3e1e50df1938a94a5ac3a64bd7334b1c2e4f5f6f
37 changes: 36 additions & 1 deletion src/amuse/datamodel/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,42 @@ def copy_all_attributes(self):

def copy_overlapping_attributes(self):
names_to_copy = self.get_overlapping_attributes()
self.copy_attributes(names_to_copy)
self.copy_attributes(names_to_copy)

def transform_values(self, attributes, f):
values = self.source.get_values_in_store(self.index, attributes)
return f(*values)

def transform(self, target, function, source):
""" Copy and transform values of one attribute from the source set to the target set.

:argument target: name of the attributes in the target set
:argument function: function used for transform, should return tuple
:argument source: name of the attribute in the source set

>>> from amuse.datamodel import Grid
>>> grid1 = Grid(2)
>>> grid2 = Grid(2)
>>> grid1.attribute1 = 1
>>> grid1.attribute2 = 2
>>> channel = grid1.new_channel_to(grid2)
>>> channel.transform(["attribute3","attribute4"], lambda x,y: (y+x,y-x), ["attribute1","attribute2"])
>>> print grid2.attribute3
[3 3]
>>> print grid2.attribute4
[1 1]

"""
if not self.target.can_extend_attributes():
target_attributes = self.target.get_defined_settable_attribute_names()
if not set(target).issubset(set(target_attributes)):
raise Exception("trying to set unsettable attributes {0}".format(
list(set(target)-set(target_attributes))) )
converted=self.transform_values(source, function)
if len(converted) != len(target):
raise Exception("function {0} returns {1} values while target attributes are {2} of length {3}".format(
function.__name__, len(converted), target, len(target)))
self.target.set_values_in_store(self.index, target, converted)

class SamplePointOnCellCenter(object):
def __init__(self, grid, point):
Expand Down
39 changes: 39 additions & 0 deletions src/amuse/datamodel/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3043,6 +3043,45 @@ def copy_attribute(self, name, target_name = None):
data = self.from_particles.get_values_in_store(self.from_indices, [name,])
self.to_particles.set_values_in_store(self.to_indices, [target_name,], data)

def transform_values(self, attributes, f):
values = self.from_particles.get_values_in_store(self.from_indices, attributes)
return f(*values)

def transform(self, target, function, source):
""" Copy and transform values of one attribute from the source set to the target set.

:argument target: name of the attributes in the target set
:argument function: function used for transform, should return tuple
:argument source: name of the attribute in the source set

>>> from amuse.datamodel import Particles
>>> particles1 = Particles(3)
>>> particles2 = particles1.copy()
>>> particles1.attribute1 = 1
>>> particles1.attribute2 = 2
>>> channel = particles1.new_channel_to(particles2)
>>> channel.transform(["attribute3"], lambda x,y: (x+y,), ["attribute1","attribute2"])
>>> print particles2.attribute3
[3 3 3]

"""
self._reindex()

if len(self.keys) == 0:
return

if not self.to_particles.can_extend_attributes():
target_attributes = self.to_particles.get_defined_settable_attribute_names()
if not set(target).issubset(set(target_attributes)):
raise Exception("trying to set unsettable attributes {0}".format(
list(set(target)-set(target_attributes))) )
converted=self.transform_values(source, function)
if len(converted) != len(target):
raise Exception("function {0} returns {1} values while target attributes are {2} of length {3}".format(
function.__name__, len(converted), target, len(target)))
self.to_particles.set_values_in_store(self.to_indices, target, converted)


class Channels(object):
def __init__(self, channels=None):
self._channels = []
Expand Down
16 changes: 16 additions & 0 deletions test/core_tests/test_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,22 @@ def test39(self):
self.assertEqual(sub[-1,-1].x,sub.x[-1,-1])
self.assertEqual(sub[-1,-2].x,sub.x[-1,-2])

def test40(self):
grid1 = datamodel.new_regular_grid((5,4,2), [1.0, 1.0, 1.0] | units.m)
grid2 = datamodel.new_regular_grid((5,4,2), [1.0, 1.0, 1.0] | units.m)
grid1.m1 = 1
grid1.m2 = 2
channel = grid1.new_channel_to(grid2)
channel.transform(["m3"],lambda x,y: (x,),["m1","m2"])
self.assertEquals(grid2.m3, 1)
channel.transform(["m3"],lambda x,y: (y,),["m1","m2"])
self.assertEquals(grid2.m3, 2)
channel.transform(["m3"],lambda x,y: (x+y,),["m1","m2"])
self.assertEquals(grid2.m3, 3)
channel.transform(["m3","m4"],lambda x,y: (x+y,2*x-y),["m1","m2"])
self.assertEquals(grid2.m3, 3)
self.assertEquals(grid2.m4, 0)


class TestGridFactories(amusetest.TestCase):
def test1(self):
Expand Down
14 changes: 14 additions & 0 deletions test/core_tests/test_particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,20 @@ def setup_test_channels(self):

return channel1, channel2, particles1, particles2

def test16(self):

particles1 = datamodel.Particles(keys=[10,11])
particles1.mass = [1,2] | units.kg
particles1.vx = [10,12] | units.m/units.s

particles2 = datamodel.Particles(keys=[11,10])

channel = particles1.new_channel_to(particles2)
channel.transform(["momentum"], lambda x,y: (x*y,),["mass","vx"])

self.assertEquals(particles2.momentum,[2*12,1*10] | units.kg*units.m/units.s)


class TestParticlesSuperset(amusetest.TestCase):

def test1(self):
Expand Down