Skip to content

Commit

Permalink
fix cython test for boolean
Browse files Browse the repository at this point in the history
  • Loading branch information
ipelupessy committed Oct 10, 2018
1 parent d273fe8 commit 5e10a20
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 23 deletions.
24 changes: 6 additions & 18 deletions src/amuse/rfi/tools/create_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def output_function_parameters(self):
if parameter.datatype == 'string':
raise Exception("unknown...")
else:
self.out + 'numpy.ndarray[' + spec.type + ', ndim=1, mode="c"] ' + name
self.out + 'numpy.ndarray[' + spec.type + ', ndim=1, mode="c", cast=True] ' + name
elif self.specification.must_handle_array and parameter.direction == LegacyFunctionSpecification.LENGTH:
self.out + 'int ' + name
else:
Expand Down Expand Up @@ -227,15 +227,15 @@ def output_output_variables(self):
if parameter.datatype == 'string':
raise Exception("unknown...")
else:
self.out.lf() + 'cdef numpy.ndarray[' + spec.type + ', ndim=1, mode="c"] inout_' + parameter.name + ' = ' + parameter.name + '.value'
self.out.lf() + 'cdef numpy.ndarray[' + spec.type + ', ndim=1, mode="c", cast=True] inout_' + parameter.name + ' = ' + parameter.name + '.value'
elif self.specification.must_handle_array and parameter.direction == LegacyFunctionSpecification.OUT:
if parameter.datatype == 'string':
raise Exception("unknown...")

spec = self.dtype_to_spec[parameter.datatype]


self.out.lf() + 'cdef numpy.ndarray[' + spec.type + ', ndim=1, mode="c"] output_' + parameter.name + ' = ' + 'numpy.zeros('+self.length_parameter().name +', dtype = '+self.numpy_dtype(spec)+')'
self.out.lf() + 'cdef numpy.ndarray[' + spec.type + ', ndim=1, mode="c", cast=True] output_' + parameter.name + ' = ' + 'numpy.zeros('+self.length_parameter().name +', dtype = '+self.numpy_dtype(spec)+')'
elif parameter.direction == LegacyFunctionSpecification.OUT:
spec = self.dtype_to_spec[parameter.datatype]

Expand Down Expand Up @@ -839,7 +839,7 @@ def dtype_to_fortran_type(self):
'int64' : 'INTEGER(kind = c_long)',
'float32' : 'REAL(kind=c_float)',
'float64' : 'REAL(kind=c_double)',
'bool' : 'INTEGER(kind = c_int)',
'bool' : 'LOGICAL(kind = c_bool)',
'string' : 'type(C_ptr)'
}

Expand Down Expand Up @@ -1053,7 +1053,7 @@ def output_function_logical_input_copies(self):
if not parameter.datatype == 'bool':
continue
if parameter.direction == LegacyFunctionSpecification.IN or parameter.direction == LegacyFunctionSpecification.INOUT:
self.out.lf() + 'logical_{0} = {0} .EQ. 1'.format(parameter.name)
self.out.lf() + 'logical_{0} = {0}'.format(parameter.name)



Expand All @@ -1064,19 +1064,7 @@ def output_function_logical_output_copies(self):
if not parameter.datatype == 'bool':
continue
if parameter.direction == LegacyFunctionSpecification.OUT or parameter.direction == LegacyFunctionSpecification.INOUT:
if self.specification.must_handle_array:
self.out.lf() + '{0} = 0'.format(parameter.name)
self.out.lf() + 'where(logical_{0}) {0} = 1'.format(parameter.name)
else:
self.out.lf() + 'if (logical_{0}) then'.format(parameter.name)
self.out.lf() + ' {0} = 1'.format(parameter.name)
self.out.lf() + 'else'
self.out.lf() + '{0} = 0'.format(parameter.name)
self.out.lf() + 'end if'




self.out.lf() +' {0} = logical_{0}'.format(parameter.name)



Expand Down
4 changes: 0 additions & 4 deletions test/compile_tests/test_cython_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,6 @@ def tearDown(self):
def test22(self):
self.skip("this test uses mpi internals, skip here")

def test16c(self):
# https://stackoverflow.com/questions/49058191/boolean-numpy-arrays-with-cython
self.skip("boolean must_handle_array argument not yet working for cython")

def skip_if_no_cython(self):

if sys.hexversion > 0x03000000:
Expand Down
2 changes: 1 addition & 1 deletion test/compile_tests/test_fortran_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
logical :: input(n), output(n)
integer :: echo_logical2, n,i
output(i)=.FALSE.
output(1:n)=.FALSE.
do i=1,n
if(input(i)) then
output(i) = .TRUE.
Expand Down

0 comments on commit 5e10a20

Please sign in to comment.