Skip to content

Commit

Permalink
Merge pull request #342 from ipelupessy/work_on_samplepoints
Browse files Browse the repository at this point in the history
work on grid for samplePoints with explicit coordinate arrays
  • Loading branch information
ipelupessy authored Nov 22, 2018
2 parents 34f97fc + b45e5fb commit cb8cce4
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 53 deletions.
20 changes: 19 additions & 1 deletion src/amuse/datamodel/grid_attributes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from amuse.units.quantities import zero, as_vector_quantity
from amuse.units.quantities import zero, as_vector_quantity, column_stack

import numpy

Expand Down Expand Up @@ -206,3 +206,21 @@ def get_overlap_with(grid, grid1,eps=None):
# gridminx,gridminy=sys.grid.get_minimum_position()
# gridmaxx,gridmaxy=sys.grid.get_maximum_position()

@grids.BaseGrid.function_for_set
def get_index(grid, pos=None, **kwargs):
raise Exception("not implemented for a {0} grid".format(grid.__class__.__name__))

@grids.RegularBaseGrid.function_for_set
def get_index(grid, pos=None, **kwargs):
pos=grid._get_array_of_positions_from_arguments(pos=pos,**kwargs)
offset = pos - grid.get_minimum_position()
indices = (offset / grid.cellsize())
return numpy.floor(indices).astype(numpy.int)

@grids.BaseGrid.function_for_set
def _get_array_of_positions_from_arguments(grid, **kwargs):
return grids._get_array_of_positions_from_arguments(grid.get_axes_names(), **kwargs)




73 changes: 51 additions & 22 deletions src/amuse/datamodel/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from amuse.units import generic_unit_system
from amuse.units import quantities
from amuse.units.quantities import Quantity
from amuse.units.quantities import VectorQuantity
from amuse.units.quantities import new_quantity
from amuse.units.quantities import zero
from amuse.units.quantities import column_stack
from amuse.support import exceptions
from amuse.datamodel.base import *
from amuse.datamodel.memory_storage import *
Expand Down Expand Up @@ -102,17 +104,22 @@ def empty_copy(self):
result._private.collection_attributes = self._private.collection_attributes._copy_for_collection(result)
return result

def samplePoint(self, position, must_return_values_on_cell_center = False):
if must_return_values_on_cell_center:
return SamplePointOnCellCenter(self, position)
def samplePoint(self, position=None, method="nearest", **kwargs):
if method in ["nearest"]:
return SamplePointOnCellCenter(self, position=position, **kwargs)
elif method in ["interpolation", "linear"]:
return SamplePointWithInterpolation(self, position=position, **kwargs)
else:
return SamplePointWithIntepolation(self, position)
raise Exception("unknown sample method")

def samplePoints(self, positions, must_return_values_on_cell_center = False):
if must_return_values_on_cell_center:
return SamplePointsOnGrid(self, positions, SamplePointOnCellCenter)
def samplePoints(self, positions=None, method="nearest", **kwargs):
if method in ["nearest"]:
return SamplePointsOnGrid(self, positions, SamplePointOnCellCenter, **kwargs)
elif method in ["interpolation", "linear"]:
return SamplePointsOnGrid(self, positions, SamplePointWithInterpolation, **kwargs)
else:
return SamplePointsOnGrid(self, positions, SamplePointWithIntepolation)
raise Exception("unknown sample method")


def __len__(self):
return self.shape[0]
Expand Down Expand Up @@ -245,6 +252,18 @@ def create(cls,*args,**kwargs):
print ("Grid.create deprecated, use new_regular_grid instead")
return new_regular_grid(*args,**kwargs)

def get_axes_names(self):
if "position" in self.GLOBAL_DERIVED_ATTRIBUTES:
result=self.GLOBAL_DERIVED_ATTRIBUTES["position"].attribute_names
elif "position" in self._derived_attributes:
result=self._derived_attributes["position"].attribute_names
else:
try:
result=self._axes_names
except:
raise Exception("do not know how to find axes_names")
return list(result)

class UnstructuredGrid(BaseGrid):
GLOBAL_DERIVED_ATTRIBUTES=CompositeDictionary(BaseGrid.GLOBAL_DERIVED_ATTRIBUTES)
class StructuredBaseGrid(BaseGrid):
Expand Down Expand Up @@ -741,19 +760,17 @@ def transform(self, target, function, source):
self.target.set_values_in_store(self.index, target, converted)

class SamplePointOnCellCenter(object):
def __init__(self, grid, point):
def __init__(self, grid, point=None, **kwargs):
self.grid = grid
self.point = point
self.point = self.grid._get_array_of_positions_from_arguments(pos=point, **kwargs)

@late
def position(self):
return self.cell.position

@late
def index(self):
offset = self.point - self.grid.get_minimum_position()
indices = (offset / self.grid.cellsize())
return numpy.floor(indices).astype(numpy.int)
return self.grid.get_index(self.point)

@late
def isvalid(self):
Expand All @@ -772,7 +789,7 @@ def get_value_of_attribute(self, name_of_the_attribute):
def __getattr__(self, name_of_the_attribute):
return self.get_value_of_attribute(name_of_the_attribute)

class SamplePointWithIntepolation(object):
class SamplePointWithInterpolation(object):
"""
Vxyz =
V000 (1 - x) (1 - y) (1 - z) +
Expand All @@ -785,20 +802,18 @@ class SamplePointWithIntepolation(object):
V111 x y z
"""

def __init__(self, grid, point):
def __init__(self, grid, point=None, **kwargs):
self.grid = grid
self.point = point
self.point = self.grid._get_array_of_positions_from_arguments(pos=point, **kwargs)

@late
def position(self):
return self.point

@late
def index(self):
offset = self.point - self.grid.get_minimum_position()
indices = (offset / self.grid.cellsize())
return numpy.floor(indices)

return self.grid.get_index(self.point)

@late
def index_for_000_cell(self):
offset = self.point - self.grid[0,0,0].position
Expand Down Expand Up @@ -875,8 +890,9 @@ def __getattr__(self, name_of_the_attribute):

class SamplePointsOnGrid(object):

def __init__(self, grid, points, samples_factory = SamplePointWithIntepolation):
def __init__(self, grid, points=None, samples_factory = SamplePointWithInterpolation, **kwargs):
self.grid = grid
points=self.grid._get_array_of_positions_from_arguments(pos=points,**kwargs)
self.samples = [samples_factory(grid, x) for x in points]
self.samples = [x for x in self.samples if x.isvalid ]

Expand Down Expand Up @@ -908,7 +924,7 @@ def __len__(self):

class SamplePointsOnMultipleGrids(object):

def __init__(self, grids, points, samples_factory = SamplePointWithIntepolation, index_factory = None):
def __init__(self, grids, points, samples_factory = SamplePointWithInterpolation, index_factory = None):
self.grids = grids
self.points = points
self.samples_factory = samples_factory
Expand Down Expand Up @@ -1047,3 +1063,16 @@ def grids_for_points(self, points):
index = numpy.floor(index).astype(numpy.int)
index_of_grid = self.grids_on_index[tuple(index)]
return self.grids[index_of_grid]


# convenience function to convert input arguments to positions (or vector of "points")
def _get_array_of_positions_from_arguments(axes_names, **kwargs):
if kwargs.get('pos',None):
return kwargs['pos']
if kwargs.get('position',None):
return kwargs['position']

coordinates=[kwargs[x] for x in axes_names]
if numpy.rank(coordinates[0])==0:
return VectorQuantity.new_from_scalar_quantities(*coordinates)
return column_stack(coordinates)
7 changes: 7 additions & 0 deletions src/amuse/units/quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,13 @@ def concatenate(quantities):
concatenated = numpy.concatenate(numbers)
return VectorQuantity(concatenated, unit)

def column_stack( args ):
args_=[to_quantity(x) for x in args]
units=set([x.unit for x in args_])
if len(units)==1:
return new_quantity(numpy.column_stack([x.number for x in args_]),args_[0].unit)
else:
return numpy.column_stack(args)

def arange(start, stop, step):
if not is_quantity(start):
Expand Down
15 changes: 15 additions & 0 deletions test/core_tests/test_grid_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,18 @@ def xtest4(self):
def xtest5(self):
grid=new_unstructured_grid((10,20),(10,10))
self.assertEqual(grid.cellsize(),[1.,0.5])

def test6(self):
grid=new_regular_grid((10,10),(10,10))
self.assertEquals(grid.get_index((6.2,3.7)), [6,3])
self.assertEquals(grid.get_index(x=[6.2],y=[3.7]), [6,3])
self.assertEquals(grid.get_index(y=[6.2],x=[3.7]), [3,6])
self.assertEquals(grid.get_index(y=6.2,x=3.7)[0], 3)
self.assertEquals(grid.get_index(y=6.2,x=3.7)[1], 6)

def test7(self):
grid=new_regular_grid((10,10),[20,10] | units.m,axes_names="ab")
self.assertEquals(grid.get_index([16.2,3.7] | units.m), [8,3])
self.assertEquals(grid.get_index(a=16.2 | units.m,b=3.7 | units.m), [8,3])
self.assertEquals(grid.get_index(a=[16.2, 4.5] | units.m,b=[3.7,4.2] | units.m), [[8,3],[2,4]])

58 changes: 29 additions & 29 deletions test/core_tests/test_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,52 +896,52 @@ class TestGridSampling(amusetest.TestCase):
def test1(self):
grid = datamodel.new_regular_grid((5,5,5), [10.0, 10.0, 10.0] | units.m)
grid.mass = grid.x.value_in(units.m) | units.kg
sample = grid.samplePoint([3.0,3.0,3.0]| units.m)
sample = grid.samplePoint([3.0,3.0,3.0]| units.m,method="interpolation")
self.assertEquals(sample.index , [1,1,1])
sample = grid.samplePoint([2.5,2.5,2.5]| units.m)
sample = grid.samplePoint([2.5,2.5,2.5]| units.m,method="interpolation")
self.assertEquals(sample.index , [1,1,1])
sample = grid.samplePoint([3.5,3.5,3.5]| units.m)
sample = grid.samplePoint([3.5,3.5,3.5]| units.m,method="interpolation")
self.assertEquals(sample.index , [1,1,1])

for x in range(0,200):
sample = grid.samplePoint([0.0 + (x/100.0),4.0+(x/100.0),6.0+(x/100.0)]| units.m)
sample = grid.samplePoint([0.0 + (x/100.0),4.0+(x/100.0),6.0+(x/100.0)]| units.m,method="interpolation")
self.assertEquals(sample.index , [0,2,3])

for x in range(200,400):
sample = grid.samplePoint([0.0 + (x/100.0),4.0+(x/100.0),6.0+(x/100.0)]| units.m)
sample = grid.samplePoint([0.0 + (x/100.0),4.0+(x/100.0),6.0+(x/100.0)]| units.m,method="interpolation")
self.assertEquals(sample.index , [1,3,4])

def test2(self):
grid = datamodel.new_regular_grid((5,5,5), [10.0, 10.0, 10.0] | units.m)
grid.mass = grid.x.value_in(units.m) | units.kg
sample = grid.samplePoint([3.0,3.0,3.0]| units.m)
sample = grid.samplePoint([3.0,3.0,3.0]| units.m,method="interpolation")
self.assertEquals(sample.index_for_000_cell , [1,1,1])
sample = grid.samplePoint([2.5,2.5,2.5]| units.m)
sample = grid.samplePoint([2.5,2.5,2.5]| units.m,method="interpolation")
self.assertEquals(sample.index_for_000_cell , [0,0,0])
sample = grid.samplePoint([3.5,3.5,3.5]| units.m)
sample = grid.samplePoint([3.5,3.5,3.5]| units.m,method="interpolation")
self.assertEquals(sample.index_for_000_cell , [1,1,1])
sample = grid.samplePoint([4.5,4.5,4.5]| units.m)
sample = grid.samplePoint([4.5,4.5,4.5]| units.m,method="interpolation")
self.assertEquals(sample.index_for_000_cell , [1,1,1])
self.assertEquals(sample.index , [2,2,2])

for x in range(0,100):

sample = grid.samplePoint([0.0 + (x/100.0),4.0+(x/100.0),6.0+(x/100.0)]| units.m)
sample = grid.samplePoint([0.0 + (x/100.0),4.0+(x/100.0),6.0+(x/100.0)]| units.m,method="interpolation")
self.assertEquals(sample.index_for_000_cell , [-1,1,2])
for x in range(100,300):

sample = grid.samplePoint([0.0 + (x/100.0),4.0+(x/100.0),6.0+(x/100.0)]| units.m)
sample = grid.samplePoint([0.0 + (x/100.0),4.0+(x/100.0),6.0+(x/100.0)]| units.m,method="interpolation")
self.assertEquals(sample.index_for_000_cell , [0,2,3])

for x in range(300,400):
sample = grid.samplePoint([0.0 + (x/100.0),4.0+(x/100.0),6.0+(x/100.0)]| units.m)
sample = grid.samplePoint([0.0 + (x/100.0),4.0+(x/100.0),6.0+(x/100.0)]| units.m,method="interpolation")
self.assertEquals(sample.index_for_000_cell , [1,3,4])


def test3(self):
grid = datamodel.new_regular_grid((5,5,5), [10.0, 10.0, 10.0] | units.m)
grid.mass = grid.x.value_in(units.m) | units.kg
sample = grid.samplePoint([3.0,3.0,3.0]| units.m)
sample = grid.samplePoint([3.0,3.0,3.0]| units.m,method="interpolation")
self.assertEquals(sample.index_for_000_cell , [1,1,1])
self.assertEquals(sample.surrounding_cell_indices , [
[1,1,1],
Expand All @@ -957,7 +957,7 @@ def test3(self):
def test4(self):
grid = datamodel.new_regular_grid((5,5,5), [10.0, 10.0, 10.0] | units.m)
grid.mass = grid.x.value_in(units.m) | units.kg
sample = grid.samplePoint([3.0,3.0,3.0]| units.m)
sample = grid.samplePoint([3.0,3.0,3.0]| units.m,method="interpolation")
self.assertEquals(sample.surrounding_cells[0].position , [3.0,3.0,3.0] | units.m )
self.assertEquals(sample.surrounding_cells[1].position , [5.0,3.0,3.0] | units.m )
self.assertEquals(sample.surrounding_cells[-1].position , [5.0,5.0,5.0] | units.m )
Expand All @@ -966,7 +966,7 @@ def test4(self):
def test5(self):
grid = datamodel.new_regular_grid((5,5,5), [10.0, 10.0, 10.0] | units.m)
grid.mass = grid.x.value_in(units.m) | units.kg
sample = grid.samplePoint([3.0,3.0,3.0]| units.m)
sample = grid.samplePoint([3.0,3.0,3.0]| units.m,method="interpolation")
masses = sample.get_values_of_attribute("mass")
self.assertEquals(masses[0] , 3.0 | units.kg )
self.assertEquals(masses[1] , 5.0 | units.kg )
Expand All @@ -992,50 +992,50 @@ def test6(self):
grid = datamodel.new_regular_grid((5,5,5), [10.0, 10.0, 10.0] | units.m)
grid.mass = grid.x.value_in(units.m) | units.kg
for xpos in numpy.arange(3.0,5.0,0.1):
sample = grid.samplePoint([xpos,3.0,3.0]| units.m)
sample = grid.samplePoint([xpos,3.0,3.0]| units.m,method="interpolation")
self.assertAlmostRelativeEquals(sample.mass , (3.0 | units.kg) + ((2.0 * (xpos - 3.0) / 2.0) | units.kg) )

sample = grid.samplePoint([xpos,3.0,3.0]| units.m)
sample = grid.samplePoint([xpos,3.0,3.0]| units.m,method="interpolation")
self.assertAlmostRelativeEquals(sample.mass , (3.0 | units.kg) + ((2.0 * (xpos - 3.0) / 2.0) | units.kg) )

sample = grid.samplePoint([xpos,5.0,3.0]| units.m)
sample = grid.samplePoint([xpos,5.0,3.0]| units.m,method="interpolation")
self.assertAlmostRelativeEquals(sample.mass , (3.0 | units.kg) + ((2.0 * (xpos - 3.0) / 2.0) | units.kg) )

sample = grid.samplePoint([xpos,3.0,5.0]| units.m)
sample = grid.samplePoint([xpos,3.0,5.0]| units.m,method="interpolation")
self.assertAlmostRelativeEquals(sample.mass , (3.0 | units.kg) + ((2.0 * (xpos - 3.0) / 2.0) | units.kg) )


sample = grid.samplePoint([4.0,4.0,4.0]| units.m)
sample = grid.samplePoint([4.0,4.0,4.0]| units.m,method="interpolation")
self.assertAlmostRelativeEquals(sample.mass , (4.0 | units.kg))

def test7(self):
grid = datamodel.new_regular_grid((5,5,5), [10.0, 10.0, 10.0] | units.m)
grid.mass = grid.x.value_in(units.m) | units.kg
sample = grid.samplePoint([3.0,3.0,3.0]| units.m)
sample = grid.samplePoint([3.0,3.0,3.0]| units.m,method="interpolation")
self.assertTrue(sample.isvalid)
sample = grid.samplePoint([11.0,3.0,3.0]| units.m)
sample = grid.samplePoint([11.0,3.0,3.0]| units.m,method="interpolation")
self.assertFalse(sample.isvalid)
sample = grid.samplePoint([3.0,-1.0,3.0]| units.m)
sample = grid.samplePoint([3.0,-1.0,3.0]| units.m,method="interpolation")
self.assertFalse(sample.isvalid)


def test8(self):
grid = datamodel.new_regular_grid((5,5,5), [10.0, 10.0, 10.0] | units.m)
grid.mass = grid.x.value_in(units.m) | units.kg
sample = grid.samplePoint([3.0,3.0,3.0]| units.m, must_return_values_on_cell_center = True)
sample = grid.samplePoint([3.0,3.0,3.0]| units.m, method="nearest")
self.assertEquals(sample.position, [3.0,3.0,3.0]| units.m)
self.assertEquals(sample.mass, 3.0 | units.kg)
sample = grid.samplePoint([3.5,3.0,3.0]| units.m, must_return_values_on_cell_center = True)
sample = grid.samplePoint([3.5,3.0,3.0]| units.m, method="nearest")
self.assertEquals(sample.position, [3.0,3.0,3.0]| units.m)
self.assertEquals(sample.mass, 3.0 | units.kg)

def test9(self):
grid = datamodel.new_regular_grid((5,5,5), [10.0, 10.0, 10.0] | units.m)
grid.mass = grid.x.value_in(units.m) | units.kg
sample = grid.samplePoint([3.0,3.0,3.0]| units.m, must_return_values_on_cell_center = False)
sample = grid.samplePoint([3.0,3.0,3.0]| units.m, method="linear")
self.assertEquals(sample.position, [3.0,3.0,3.0]| units.m)
self.assertEquals(sample.mass, 3.0 | units.kg)
sample = grid.samplePoint([3.5,3.0,3.0]| units.m, must_return_values_on_cell_center = False)
sample = grid.samplePoint([3.5,3.0,3.0]| units.m, method="linear")
self.assertEquals(sample.position, [3.5,3.0,3.0]| units.m)
self.assertEquals(sample.mass, 3.5 | units.kg)

Expand All @@ -1045,7 +1045,7 @@ class TestGridSamplingMultiplePoints(amusetest.TestCase):
def test1(self):
grid = datamodel.new_regular_grid((5,5,5), [10.0, 10.0, 10.0] | units.m)
grid.mass = grid.x.value_in(units.m) | units.kg
samples = grid.samplePoints([[3.0,3.0,3.0], [4.0,3.0,3.0]]| units.m)
samples = grid.samplePoints([[3.0,3.0,3.0], [4.0,3.0,3.0]]| units.m, method="linear")
self.assertEquals(len(samples), 2)
self.assertEquals(samples.position[0] , [3.0,3.0,3.0]| units.m)
self.assertEquals(samples.position[0] , samples[0].position)
Expand All @@ -1055,7 +1055,7 @@ def test1(self):
def test2(self):
grid = datamodel.new_regular_grid((5,5,5), [10.0, 10.0, 10.0] | units.m)
grid.mass = grid.x.value_in(units.m) | units.kg
samples = grid.samplePoints([[3.5,3.0,3.0], [4.5,3.0,3.0]]| units.m)
samples = grid.samplePoints([[3.5,3.0,3.0], [4.5,3.0,3.0]]| units.m, method="linear")
self.assertEquals(len(samples), 2)
self.assertEquals(samples.mass , [3.5, 4.5] | units.kg)

Expand Down
10 changes: 9 additions & 1 deletion test/core_tests/test_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,5 +532,13 @@ def test4(self):

self.assertAlmostRelativeEquals(y, fit_values, 1)


def test5(self):
a=[1,2,3] | units.m
b=[4,5,6] | units.m

ab1=quantities.column_stack((a,b))
ab2=quantities.column_stack((a.number,b.number)) | units.m

self.assertEquals(ab1,ab2)


0 comments on commit cb8cce4

Please sign in to comment.