Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 53 additions & 24 deletions spotfire/sbdf.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -247,29 +247,30 @@ cdef class _ImportContext:
"""Object to store information for each column as it is imported."""
cdef int numpy_type_num
cdef sbdf_c.sbdf_valuetype value_type
cdef np_c.ndarray values_array
cdef np_c.ndarray invalid_array
cdef list values_arrays
cdef list invalid_arrays

def __init__(self, numpy_type_num: int):
def __init__(self, numpy_type_num: int, vt: sbdf_c.sbdf_valuetype):
"""Initialize the import context, including the holding arrays.

:param numpy_type_num: NumPy type number for the value array; see
https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types for more
information
:param vt: SBDF value type
"""
# Store the NumPy type number
self.numpy_type_num = numpy_type_num

# Initialize the SBDF value type
self.value_type = sbdf_c.sbdf_valuetype(0)
self.value_type = vt

# Create a zero-element array for holding values
cdef np_c.npy_intp shape[1]
shape[0] = <np_c.npy_intp>0
self.values_array = np_c.PyArray_SimpleNew(1, shape, self.numpy_type_num)
self.values_arrays = []

# Create a zero-element array for holding invalids
self.invalid_array = np_c.PyArray_SimpleNew(1, shape, np_c.NPY_BOOL)
self.invalid_arrays = []

cdef (int, sbdf_c.sbdf_object*, sbdf_c.sbdf_object*) get_values_and_invalid(self,
sbdf_c.sbdf_columnslice* col_slice):
Expand Down Expand Up @@ -319,7 +320,8 @@ cdef class _ImportContext:
"""
cdef np_c.npy_intp shape[1]
shape[0] = <np_c.npy_intp>count
return np_c.PyArray_SimpleNewFromData(1, shape, self.numpy_type_num, data)
snfd = np_c.PyArray_SimpleNewFromData(1, shape, self.numpy_type_num, data)
return np_c.PyArray_NewCopy(snfd, np_c.NPY_ORDER.NPY_CORDER)

cdef np_c.ndarray new_slice_from_empty(self, int count):
"""Create a NumPy slice ``ndarray`` capable of holding the given amount of data, to be filled in later.
Expand All @@ -341,7 +343,8 @@ cdef class _ImportContext:
cdef np_c.npy_intp shape[1]
shape[0] = <np_c.npy_intp>count
if invalid != NULL:
return np_c.PyArray_SimpleNewFromData(1, shape, np_c.NPY_BOOL, <void*>invalid.data)
snfd = np_c.PyArray_SimpleNewFromData(1, shape, np_c.NPY_BOOL, <void*>invalid.data)
return np_c.PyArray_NewCopy(snfd, np_c.NPY_ORDER.NPY_CORDER)
else:
return np_c.PyArray_ZEROS(1, shape, np_c.NPY_BOOL, 0)

Expand All @@ -351,22 +354,30 @@ cdef class _ImportContext:
:param values_slice: values NumPy slice array to append
:param invalid_slice: invalid NumPy slice array to append
"""
self.values_array = np.append(self.values_array, values_slice)
self.invalid_array = np.append(self.invalid_array, invalid_slice)
self.values_arrays.append(values_slice)
self.invalid_arrays.append(invalid_slice)

cpdef np_c.ndarray get_values_array(self):
"""Get the full table values ``ndarray``.

:return: the full values NumPy array
"""
return self.values_array
# Build concatenated numpy array
if self.values_arrays:
return np.concatenate(self.values_arrays)
else:
return np.array([], dtype=np.dtype(self.get_numpy_dtype()))

cpdef np_c.ndarray get_invalid_array(self):
"""Get the full table invalid ``ndarray``.

:return: the full invalid NumPy array
"""
return self.invalid_array
# Build concatenated numpy array
if self.invalid_arrays:
return np.concatenate(self.invalid_arrays)
else:
return np.array([], dtype=np.bool_)

def get_pandas_dtype_name(self) -> str:
"""Get the correct Pandas dtype for this column.
Expand All @@ -384,6 +395,24 @@ cdef class _ImportContext:
else:
return "object"

def get_numpy_dtype(self):
"""Get the correct NumPy dtype for this ctype.

:return: the NumPy dtype name for this ctype
"""
if self.numpy_type_num == np_c.NPY_INT32:
return "int32"
elif self.numpy_type_num == np_c.NPY_INT64:
return "int64"
elif self.numpy_type_num == np_c.NPY_FLOAT32:
return "float32"
elif self.numpy_type_num == np_c.NPY_FLOAT64:
return "float64"
elif self.numpy_type_num == np_c.NPY_BOOL:
return "bool"
else:
return "object"

def get_spotfire_type_name(self) -> str:
"""Get the correct Spotfire type name for this column.

Expand Down Expand Up @@ -679,40 +708,40 @@ def import_data(sbdf_file):
column_names.append(col_name.decode('utf-8'))

if col_type.id == sbdf_c.SBDF_BOOLTYPEID:
importer_contexts.append(_ImportContext(np_c.NPY_BOOL))
importer_contexts.append(_ImportContext(np_c.NPY_BOOL, col_type))
importer_fns[i] = _import_vts_numpy
elif col_type.id == sbdf_c.SBDF_DOUBLETYPEID:
importer_contexts.append(_ImportContext(np_c.NPY_FLOAT64))
importer_contexts.append(_ImportContext(np_c.NPY_FLOAT64, col_type))
importer_fns[i] = _import_vts_numpy
elif col_type.id == sbdf_c.SBDF_LONGTYPEID:
importer_contexts.append(_ImportContext(np_c.NPY_INT64))
importer_contexts.append(_ImportContext(np_c.NPY_INT64, col_type))
importer_fns[i] = _import_vts_numpy
elif col_type.id == sbdf_c.SBDF_FLOATTYPEID:
importer_contexts.append(_ImportContext(np_c.NPY_FLOAT32))
importer_contexts.append(_ImportContext(np_c.NPY_FLOAT32, col_type))
importer_fns[i] = _import_vts_numpy
elif col_type.id == sbdf_c.SBDF_INTTYPEID:
importer_contexts.append(_ImportContext(np_c.NPY_INT32))
importer_contexts.append(_ImportContext(np_c.NPY_INT32, col_type))
importer_fns[i] = _import_vts_numpy
elif col_type.id == sbdf_c.SBDF_DATETIMETYPEID:
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT))
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT, col_type))
importer_fns[i] = _import_vt_datetime
elif col_type.id == sbdf_c.SBDF_DATETYPEID:
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT))
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT, col_type))
importer_fns[i] = _import_vt_date
elif col_type.id == sbdf_c.SBDF_TIMESPANTYPEID:
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT))
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT, col_type))
importer_fns[i] = _import_vt_timespan
elif col_type.id == sbdf_c.SBDF_TIMETYPEID:
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT))
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT, col_type))
importer_fns[i] = _import_vt_time
elif col_type.id == sbdf_c.SBDF_STRINGTYPEID:
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT))
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT, col_type))
importer_fns[i] = _import_vt_string
elif col_type.id == sbdf_c.SBDF_BINARYTYPEID:
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT))
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT, col_type))
importer_fns[i] = _import_vt_bytes
elif col_type.id == sbdf_c.SBDF_DECIMALTYPEID:
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT))
importer_contexts.append(_ImportContext(np_c.NPY_OBJECT, col_type))
importer_fns[i] = _import_vt_decimal
else:
raise SBDFError(f"column '{col_name}' has unsupported type id {col_type.id}")
Expand Down