From 5e10a2044ec0331bd83cc7b1c57d2f28be870e32 Mon Sep 17 00:00:00 2001 From: ipelupessy Date: Wed, 10 Oct 2018 17:35:53 +0200 Subject: [PATCH] fix cython test for boolean --- src/amuse/rfi/tools/create_cython.py | 24 +++++-------------- .../test_cython_implementation.py | 4 ---- .../test_fortran_implementation.py | 2 +- 3 files changed, 7 insertions(+), 23 deletions(-) diff --git a/src/amuse/rfi/tools/create_cython.py b/src/amuse/rfi/tools/create_cython.py index 67ff7c813..50b91706f 100644 --- a/src/amuse/rfi/tools/create_cython.py +++ b/src/amuse/rfi/tools/create_cython.py @@ -129,7 +129,7 @@ def output_function_parameters(self): if parameter.datatype == 'string': raise Exception("unknown...") else: - self.out + 'numpy.ndarray[' + spec.type + ', ndim=1, mode="c"] ' + name + self.out + 'numpy.ndarray[' + spec.type + ', ndim=1, mode="c", cast=True] ' + name elif self.specification.must_handle_array and parameter.direction == LegacyFunctionSpecification.LENGTH: self.out + 'int ' + name else: @@ -227,7 +227,7 @@ def output_output_variables(self): if parameter.datatype == 'string': raise Exception("unknown...") else: - self.out.lf() + 'cdef numpy.ndarray[' + spec.type + ', ndim=1, mode="c"] inout_' + parameter.name + ' = ' + parameter.name + '.value' + self.out.lf() + 'cdef numpy.ndarray[' + spec.type + ', ndim=1, mode="c", cast=True] inout_' + parameter.name + ' = ' + parameter.name + '.value' elif self.specification.must_handle_array and parameter.direction == LegacyFunctionSpecification.OUT: if parameter.datatype == 'string': raise Exception("unknown...") @@ -235,7 +235,7 @@ def output_output_variables(self): spec = self.dtype_to_spec[parameter.datatype] - self.out.lf() + 'cdef numpy.ndarray[' + spec.type + ', ndim=1, mode="c"] output_' + parameter.name + ' = ' + 'numpy.zeros('+self.length_parameter().name +', dtype = '+self.numpy_dtype(spec)+')' + self.out.lf() + 'cdef numpy.ndarray[' + spec.type + ', ndim=1, mode="c", cast=True] output_' + parameter.name + ' = ' + 'numpy.zeros('+self.length_parameter().name +', dtype = '+self.numpy_dtype(spec)+')' elif parameter.direction == LegacyFunctionSpecification.OUT: spec = self.dtype_to_spec[parameter.datatype] @@ -839,7 +839,7 @@ def dtype_to_fortran_type(self): 'int64' : 'INTEGER(kind = c_long)', 'float32' : 'REAL(kind=c_float)', 'float64' : 'REAL(kind=c_double)', - 'bool' : 'INTEGER(kind = c_int)', + 'bool' : 'LOGICAL(kind = c_bool)', 'string' : 'type(C_ptr)' } @@ -1053,7 +1053,7 @@ def output_function_logical_input_copies(self): if not parameter.datatype == 'bool': continue if parameter.direction == LegacyFunctionSpecification.IN or parameter.direction == LegacyFunctionSpecification.INOUT: - self.out.lf() + 'logical_{0} = {0} .EQ. 1'.format(parameter.name) + self.out.lf() + 'logical_{0} = {0}'.format(parameter.name) @@ -1064,19 +1064,7 @@ def output_function_logical_output_copies(self): if not parameter.datatype == 'bool': continue if parameter.direction == LegacyFunctionSpecification.OUT or parameter.direction == LegacyFunctionSpecification.INOUT: - if self.specification.must_handle_array: - self.out.lf() + '{0} = 0'.format(parameter.name) - self.out.lf() + 'where(logical_{0}) {0} = 1'.format(parameter.name) - else: - self.out.lf() + 'if (logical_{0}) then'.format(parameter.name) - self.out.lf() + ' {0} = 1'.format(parameter.name) - self.out.lf() + 'else' - self.out.lf() + '{0} = 0'.format(parameter.name) - self.out.lf() + 'end if' - - - - + self.out.lf() +' {0} = logical_{0}'.format(parameter.name) diff --git a/test/compile_tests/test_cython_implementation.py b/test/compile_tests/test_cython_implementation.py index 6ecc6def2..306294878 100644 --- a/test/compile_tests/test_cython_implementation.py +++ b/test/compile_tests/test_cython_implementation.py @@ -263,10 +263,6 @@ def tearDown(self): def test22(self): self.skip("this test uses mpi internals, skip here") - def test16c(self): - # https://stackoverflow.com/questions/49058191/boolean-numpy-arrays-with-cython - self.skip("boolean must_handle_array argument not yet working for cython") - def skip_if_no_cython(self): if sys.hexversion > 0x03000000: diff --git a/test/compile_tests/test_fortran_implementation.py b/test/compile_tests/test_fortran_implementation.py index caa58f96c..47aeb809d 100644 --- a/test/compile_tests/test_fortran_implementation.py +++ b/test/compile_tests/test_fortran_implementation.py @@ -161,7 +161,7 @@ logical :: input(n), output(n) integer :: echo_logical2, n,i - output(i)=.FALSE. + output(1:n)=.FALSE. do i=1,n if(input(i)) then output(i) = .TRUE.