diff --git a/src/tables_io/ioUtils.py b/src/tables_io/ioUtils.py index 09ad207..c1a0005 100644 --- a/src/tables_io/ioUtils.py +++ b/src/tables_io/ioUtils.py @@ -877,17 +877,20 @@ def writeDataFramesToPq(dataFrames, filepath, **kwargs): Path to output file """ + basepath, ext = os.path.splitext(filepath) + if not ext: # pragma: no cover + ext = "." + FILE_FORMAT_SUFFIX_MAP[PANDAS_PARQUET] for k, v in dataFrames.items(): - _ = v.to_parquet(f"{filepath}{k}.pq", **kwargs) + _ = v.to_parquet(f"{basepath}{k}{ext}", **kwargs) -def readPqToDataFrames(basepath, keys=None, allow_missing_keys=False, columns=None, **kwargs): +def readPqToDataFrames(filepath, keys=None, allow_missing_keys=False, columns=None, **kwargs): """ Reads `pandas.DataFrame` objects from an parquet file. Parameters ---------- - basepath: `str` + filepath: `str` Path to input file keys : `list` @@ -914,6 +917,9 @@ def readPqToDataFrames(basepath, keys=None, allow_missing_keys=False, columns=No if keys is None: # pragma: no cover keys = [""] dataframes = OrderedDict() + basepath, ext = os.path.splitext(filepath) + if not ext: # pragma: no cover + ext = "." + FILE_FORMAT_SUFFIX_MAP[PANDAS_PARQUET] for key in keys: try: column_list = None @@ -923,7 +929,7 @@ def readPqToDataFrames(basepath, keys=None, allow_missing_keys=False, columns=No column_list = columns print("column_list", column_list) - dataframes[key] = readPqToDataFrame(f"{basepath}{key}.pq", columns=column_list, **kwargs) + dataframes[key] = readPqToDataFrame(f"{basepath}{key}{ext}", columns=column_list, **kwargs) except FileNotFoundError as msg: # pragma: no cover if allow_missing_keys: continue @@ -1068,8 +1074,7 @@ def readNative(filepath, fmt=None, keys=None, allow_missing_keys=False, **kwargs if fType == PANDAS_HDF5: return readH5ToDataFrames(filepath, keys=keys) if fType == PANDAS_PARQUET: - basepath = os.path.splitext(filepath)[0] - return readPqToDataFrames(basepath, keys, allow_missing_keys, **kwargs) + return readPqToDataFrames(filepath, keys, allow_missing_keys, **kwargs) raise TypeError(f"Unsupported FileType {fType}") # pragma: no cover @@ -1195,15 +1200,15 @@ def iterator(filepath, tType=None, fmt=None, **kwargs): yield start, stop, convert(data, tType) -def writeNative(odict, basename): +def writeNative(odict, filepath): """Write a file or files with tables Parameters ---------- odict : `OrderedDict`, (`str`, `Tablelike`) The data to write - basename : `str` - Basename for the file to write. The suffix will be applied based on the object type. + filepath : `str` + File name for the file to write. If there's no suffix, it will be applied based on the object type. """ istable = False if istablelike(odict): @@ -1221,7 +1226,8 @@ def writeNative(odict, basename): except KeyError as msg: # pragma: no cover raise KeyError(f"No native format for table type {tType}") from msg fmt = FILE_FORMAT_SUFFIX_MAP[fType] - filepath = basename + "." + fmt + if not os.path.splitext(filepath)[1]: + filepath = filepath + '.' + fmt if istable: odict = OrderedDict([(DEFAULT_TABLE_KEY[fmt], odict)]) @@ -1241,34 +1247,33 @@ def writeNative(odict, basename): writeRecarraysToFits(odict, filepath) return filepath if fType == PANDAS_PARQUET: - writeDataFramesToPq(odict, basename) - return basename + writeDataFramesToPq(odict, filepath) + return filepath raise TypeError(f"Unsupported Native file type {fType}") # pragma: no cover -def write(obj, basename, fmt=None): +def write(obj, filepath, fmt=None): """Write a file or files with tables Parameters ---------- obj : `Tablelike` or `TableDictLike` The data to write - basename : `str` - Basename for the file to write. The suffix will be applied based on the object type. + filepath : `str` + File name for the file to write. If there's no suffix, it will be applied based on the object type. fmt : `str` or `None` The output file format, If `None` this will use `writeNative` """ if fmt is None: - splitpath = os.path.splitext(basename) + splitpath = os.path.splitext(filepath) if not splitpath[1]: - return writeNative(obj, basename) - basename = splitpath[0] + return writeNative(obj, filepath) fmt = splitpath[1][1:] try: fType = FILE_FORMAT_SUFFIXS[fmt] except KeyError as msg: # pragma: no cover - raise KeyError(f"Unknown file format {fmt}, options are {list(FILE_FORMAT_SUFFIXS.keys())}") from msg + raise KeyError(f"Unknown file format {fmt} from {filepath}, options are {list(FILE_FORMAT_SUFFIXS.keys())}") from msg if istablelike(obj): odict = OrderedDict([(DEFAULT_TABLE_KEY[fmt], obj)]) @@ -1286,16 +1291,16 @@ def write(obj, basename, fmt=None): raise KeyError(f"Native file type not known for {fmt}") from msg forcedOdict = convert(odict, nativeTType) - return writeNative(forcedOdict, basename) + return writeNative(forcedOdict, filepath) + if not os.path.splitext(filepath)[1]: + filepath = filepath + '.' + fmt if fType == ASTROPY_FITS: forcedOdict = convert(odict, AP_TABLE) - filepath = f"{basename}.fits" writeApTablesToFits(forcedOdict, filepath) return filepath if fType == PANDAS_HDF5: forcedOdict = convert(odict, PD_DATAFRAME) - filepath = f"{basename}.h5" writeDataFramesToH5(forcedOdict, filepath) return filepath diff --git a/tests/test_io.py b/tests/test_io.py index c968857..adce991 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -102,8 +102,6 @@ def _do_loopback_single(self, tType, basepath, fmt, keys=None, **kwargs): basepath_native = "%s_native" % basepath filepath_native = write(obj_c, basepath_native) - if keys is not None: - filepath_native += ".pq" self._files.append(filepath_native) obj_r_native = read(filepath_native, tType=tType, keys=keys) table_r_native = convert(obj_r_native, types.AP_TABLE) @@ -192,6 +190,8 @@ def testPQLoopback(self): types.PD_DATAFRAME, "test_out", "pq", ["data", "md"], columns={"md": ["a"], "data": ["scalar"]} ) self._do_loopback_single(types.PD_DATAFRAME, "test_out_single", "pq", [""]) + self._do_loopback_single(types.PD_DATAFRAME, "test_out_single_v2", "parquet", [""]) + self._do_iterator("test_out_single.pq", types.PD_DATAFRAME, chunk_size=50) self._do_iterator("test_out_single.pq", types.PD_DATAFRAME, chunk_size=50, columns=["scalar"]) self._do_open("test_out_single.pq")