Skip to content

Commit e3690c3

Browse files
committed
Prevent automatic type conversion of complex data types (closes #135)
1 parent f92c107 commit e3690c3

File tree

2 files changed

+43
-12
lines changed

2 files changed

+43
-12
lines changed

spectral/io/spyfile.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,13 @@ def load(self, **kwargs):
196196
for k in list(kwargs.keys()):
197197
if k not in ('dtype', 'scale'):
198198
raise ValueError('Invalid keyword %s.' % str(k))
199-
dtype = kwargs.get('dtype', ImageArray.format)
199+
ctypes = [np.dtype(f'complex{b}').name for b in (64, 128, 256)]
200+
if 'dtype' in kwargs:
201+
dtype = kwargs['dtype']
202+
elif np.dtype(self.dtype).name in ctypes:
203+
dtype = self.dtype
204+
else:
205+
dtype = ImageArray.format
200206
data = array.array(typecode('b'))
201207
self.fid.seek(self.offset)
202208
data.fromfile(self.fid, self.nrows * self.ncols *
@@ -210,7 +216,8 @@ def load(self, **kwargs):
210216
npArray = npArray.transpose([1, 2, 0])
211217
else:
212218
npArray.shape = (self.nrows, self.ncols, self.nbands)
213-
npArray = npArray.astype(dtype)
219+
if np.dtype(dtype).name != npArray.dtype.name:
220+
npArray = npArray.astype(dtype)
214221
if self.scale_factor != 1 and kwargs.get('scale', True):
215222
npArray = npArray / float(self.scale_factor)
216223
imarray = ImageArray(npArray, self)

spectral/tests/spyfile.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
from spectral.tests.spytest import SpyTest
2626

2727

28-
def assert_almost_equal(a, b, **kwargs):
29-
if not np.allclose(a, b, **kwargs):
30-
raise Exception('NOPE')
28+
assert_almost_equal = np.testing.assert_allclose
29+
30+
def assert_allclose (a, b, **kwargs):
31+
np.testing.assert_allclose(np.array(a), np.array(b), **kwargs)
3132

3233
class SpyFileTest(SpyTest):
3334
'''Tests that SpyFile methods read data correctly from files.'''
@@ -144,14 +145,14 @@ def test_load(self):
144145
data = self.image.load()
145146
spyf = self.image
146147

147-
load_assert = np.allclose
148+
load_assert = assert_allclose
148149
load_assert(data[i, j, k], self.value)
149150
first_band = spyf[:, :, 0]
150151
load_assert(data[:, :, 0], first_band)
151152
# This is checking if different ImageArray and SpyFile indexing
152153
# results are the same shape, so we can't just reuse the already
153154
# loaded first band.
154-
load_assert(data[:, 0, 0], spyf[:, 0, 0])
155+
load_assert(data[:, 0, 0].squeeze(), spyf[:, 0, 0].squeeze())
155156
load_assert(data[0, 0, 0], spyf[0, 0, 0])
156157
load_assert(data[0, 0], spyf[0, 0])
157158
load_assert(data[-1, -1, -1], spyf[-1, -1, -1])
@@ -267,7 +268,7 @@ def run(self):
267268
os.mkdir(testdir)
268269
image = spy.open_image(self.filename)
269270
basename = os.path.join(testdir,
270-
os.path.splitext(self.filename)[0])
271+
os.path.splitext(os.path.split(self.filename)[-1])[0])
271272
interleaves = ('bil', 'bip', 'bsq')
272273
ends = ('big', 'little')
273274
cases = itertools.product(interleaves, self.dtypes, ends)
@@ -298,17 +299,40 @@ def run(self):
298299
test = SpyFileTest(testimg, self.datum, self.value)
299300
test.run()
300301

302+
def create_complex_test_files(dtypes):
303+
'''Create test files with complex data'''
304+
if not os.path.isdir(testdir):
305+
os.mkdir(testdir)
306+
tests = []
307+
shape = (100, 200, 64)
308+
datum = (33, 44, 25)
309+
for t in dtypes:
310+
X = np.array(np.random.rand(*shape) + 1j * np.random.rand(*shape),
311+
dtype=t)
312+
fname = os.path.join(testdir, f'test_{t}.hdr')
313+
spy.envi.save_image(fname, X)
314+
tests.append((fname, datum, X[datum]))
315+
return tests
301316

302317
def run():
303318
tests = [('92AV3C.lan', (99, 99, 99), 2057.0)]
304-
# tests = [('92AV3C.lan', (99, 99, 99), 2057.0),
305-
# ('f970619t01p02_r02_sc04.a.rfl', (99, 99, 99), 0.2311),
306-
# ('cup95eff.int.hdr', (99, 99, 33), 0.1842)]
307319
for (fname, datum, value) in tests:
308320
try:
309321
check = find_file_path(fname)
310322
suite = SpyFileTestSuite(fname, datum, value,
311-
dtypes=('i2', 'i4', 'f4', 'f8'))
323+
dtypes=('i2', 'i4', 'f4', 'f8', 'c8', 'c16'))
324+
suite.run()
325+
except FileNotFoundError:
326+
print('File "%s" not found. Skipping.' % fname)
327+
328+
# Run tests for complex data types
329+
dtypes = ['complex64', 'complex128']
330+
tests = create_complex_test_files(dtypes)
331+
for (dtype, (fname, datum, value)) in zip(dtypes, tests):
332+
try:
333+
check = find_file_path(fname)
334+
suite = SpyFileTestSuite(fname, datum, value,
335+
dtypes=(dtype,))
312336
suite.run()
313337
except FileNotFoundError:
314338
print('File "%s" not found. Skipping.' % fname)

0 commit comments

Comments
 (0)