From dc5006bab7bd48a0ecc128540c5631ee8e8a69c6 Mon Sep 17 00:00:00 2001 From: Julio Campagnolo Date: Fri, 13 Oct 2023 12:31:47 -0300 Subject: [PATCH 1/7] ci: do not publish to tests --- .github/workflows/build-n-publish.yml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/workflows/build-n-publish.yml b/.github/workflows/build-n-publish.yml index 2236eeec..3a2569be 100644 --- a/.github/workflows/build-n-publish.yml +++ b/.github/workflows/build-n-publish.yml @@ -32,12 +32,6 @@ jobs: --sdist --wheel --outdir dist/ - - name: Publish package to PyPi testing - if: github.ref == 'refs/heads/main' - uses: pypa/gh-action-pypi-publish@release/v1 - with: - password: ${{ secrets.TEST_PYPI_API_TOKEN }} - repository_url: https://test.pypi.org/legacy/ - name: Publish package to PyPi if: github.ref == 'refs/heads/main' uses: pypa/gh-action-pypi-publish@release/v1 From 71abf017f49ad37221ba62fd74fc679114adb5db Mon Sep 17 00:00:00 2001 From: Julio Campagnolo Date: Fri, 13 Oct 2023 12:49:02 -0300 Subject: [PATCH 2/7] start moving db to dbastable --- astropop/_db.py | 1075 --------------------- astropop/file_collection.py | 12 +- tests/test_db.py | 1804 ----------------------------------- 3 files changed, 6 insertions(+), 2885 deletions(-) delete mode 100644 astropop/_db.py delete mode 100644 tests/test_db.py diff --git a/astropop/_db.py b/astropop/_db.py deleted file mode 100644 index a6784bc7..00000000 --- a/astropop/_db.py +++ /dev/null @@ -1,1075 +0,0 @@ -# Licensed under a 3-clause BSD style license - see LICENSE.rst -"""Manage SQL databases in a simplier way.""" - -import sqlite3 as sql -import numpy as np -from astropy.table import Table - -from .logger import logger -from .py_utils import broadcast - - -__all__ = ['SQLDatabase', 'SQLTable', 'SQLRow', 'SQLColumn', 'SQLColumnMap'] - - -_ID_KEY = '__id__' - - -class _SQLViewerBase: - """Memview for SQL data. Not allowed to copy.""" - - def __copy__(self): - raise NotImplementedError('Cannot copy SQL viewing classes.') - - def __deepcopy__(self, memo): - raise NotImplementedError('Cannot copy SQL viewing classes.') - - -class _SQLRowIndexer(_SQLViewerBase): - """A class for indexing SQL rows. Safer method while removing rows. - - Index is obtained by indexof(self). - - Parameters - ---------- - row_list: list - The list of `_SQLRowIndexer` to get the index of. - """ - - def __init__(self, row_list): - self._row_list = row_list - - @property - def index(self): - """Get the index of the row.""" - return self._row_list.index(self) - - -class SQLColumnMap(): - """Map keywords to SQL columns.""" - - def __init__(self, db, map_table, map_key, map_column): - self.db = db - self.map = db[map_table] - self.key = map_key - self.col = map_column - - self._clear_cache() - - def add_column(self, name): - """Add a new column to the table.""" - name = name.lower() - - if name in self.keywords: - raise ValueError(f'{name} already exists') - - i = len(self.keywords)+1 - col = f'col_{i}' - while col in self.keywords: - i += 1 - col = f'col_{i}' - - self.map.add_rows({self.key: name, self.col: col}) - self._clear_cache() - return col - - def get_column_name(self, item, add_columns=False): - """Get the column name for a given keyword.""" - if not np.isscalar(item): - return [self.get_column_name(i) for i in item] - - item = item.lower() - if item not in self.keywords: - if add_columns: - return self.add_column(item) - raise KeyError(f'{item}') - - return self.columns[np.where(self.keywords == item)][0] - - def get_keyword(self, item): - """Get the keyword for a given column.""" - if not np.isscalar(item): - return [self.get_keyword(i) for i in item] - - item = item.lower() - if item not in self.columns: - raise KeyError(f'{item}') - - return self.keywords[np.where(self.columns == item)][0] - - def _clear_cache(self): - self._columns = None - self._keywords = None - - @property - def columns(self): - """Get the column names for the table.""" - if self._columns is None: - self._columns = np.array(self.map.select(columns=[self.col])) - return self._columns - - @property - def keywords(self): - """Get the keywords of the columns for the table.""" - if self._keywords is None: - self._keywords = np.array(self.map.select(columns=[self.key])) - return self._keywords - - def map_row(self, data, add_columns=False): - """Map a row to the columns.""" - if isinstance(data, dict): - d = {} - for k, v in data.items(): - if k in self.keywords or add_columns: - d[self.get_column_name(k, add_columns=add_columns)] = v - data = d - elif not isinstance(data, list): - raise ValueError('Only dict and list are supported') - return data - - def parse_where(self, where): - """Parse a where clause using column mappring.""" - if isinstance(where, dict): - return {self.get_column_name(k): v for k, v in where.items()} - raise TypeError('Only dict is supported') - - -class SQLTable(_SQLViewerBase): - """Handle an SQL table operations interfacing with the DB.""" - - def __init__(self, db, name, colmap=None): - """Initialize the table. - - Parameters - ---------- - db : SQLDatabase - The parent database object. - name : str - The name of the table in the database. - """ - self._db = db - self._name = name - self._colmap = colmap - - @property - def name(self): - """Get the name of the table.""" - return self._name - - @property - def db(self): - """Get the database name.""" - return self._db._db - - @property - def column_names(self): - """Get the column names of the current table.""" - names = self._db.column_names(self._name) - if self._colmap is not None: - return self._colmap.get_keyword(names) - return names - - @property - def values(self): - """Get the values of the current table.""" - return self.select() - - def select(self, **kwargs): - """Select rows from the table.""" - where = kwargs.pop('where', None) - order = kwargs.pop('order', None) - if self._colmap is not None: - if where is not None: - where = self._colmap.parse_where(where) - if order is not None: - order = self._colmap.get_column_name(order) - - return self._db.select(self._name, where=where, order=order, **kwargs) - - def as_table(self): - """Return the current table as an `~astropy.table.Table` object.""" - if len(self) == 0: - return Table(names=self.column_names) - return Table(rows=self.values, - names=self.column_names) - - def add_column(self, name, data=None): - """Add a column to the table.""" - if self._colmap is not None: - name = self._colmap.add_column(name) - self._db.add_column(self._name, name, data=data) - - def add_rows(self, data, add_columns=False): - """Add a row to the table.""" - # If keymappging is used, only dict and list - if self._colmap is not None: - data = self._colmap.map_row(data, add_columns=add_columns) - self._db.add_rows(self._name, data, add_columns=add_columns) - - def get_column(self, column): - """Get a given column from the table.""" - if self._colmap is not None: - column = self._colmap.get_column_name(column) - return self._db.get_column(self._name, column) - - def get_row(self, row): - """Get a given row from the table.""" - return self._db.get_row(self._name, row, column_map=self._colmap) - - def set_column(self, column, data): - """Set a given column in the table.""" - if self._colmap is not None: - column = self._colmap.get_column_name(column) - self._db.set_column(self._name, column, data) - - def set_row(self, row, data): - """Set a given row in the table.""" - if self._colmap is not None: - data = self._colmap.map_row(data) - self._db.set_row(self._name, row, data) - - def delete_column(self, column): - """Delete a given column from the table.""" - if self._colmap is not None: - column = self._colmap.get_column_name(column) - self._db.delete_column(self._name, column) - - def delete_row(self, row): - """Delete all rows from the table.""" - self._db.delete_row(self._name, row) - - def index_of(self, where): - """Get the index of the rows that match the given condition.""" - if self._colmap is not None: - where = self._colmap.parse_where(where) - return self._db.index_of(self._name, where) - - def _resolve_tuple(self, key): - """Resolve how tuples keys are handled.""" - col, row = key - _tuple_err = """Tuple items must be in the format table[col, row] or - table[row, col]. - """ - - if not isinstance(col, str): - # Try inverting - col, row = row, col - - if not isinstance(col, str): - raise KeyError(_tuple_err) - - if not isinstance(row, (int, slice, list, np.ndarray)): - raise KeyError(_tuple_err) - - return col, row - - def __getitem__(self, key): - """Get a row or a column from the table.""" - if isinstance(key, (int, np.int_)): - return self.get_row(key) - if isinstance(key, (str, np.str_)): - return self.get_column(key) - if isinstance(key, tuple): - if len(key) not in (1, 2): - raise KeyError(f'{key}') - if len(key) == 1: - return self[key[0]] - col, row = self._resolve_tuple(key) - return self[col][row] - raise KeyError(f'{key}') - - def __setitem__(self, key, value): - """Set a row or a column in the table.""" - if isinstance(key, int): - self.set_row(key, value) - elif isinstance(key, str): - self.set_column(key, value) - elif isinstance(key, tuple): - if len(key) not in (1, 2): - raise KeyError(f'{key}') - if len(key) == 1: - self[key[0]] = value - else: - col, row = self._resolve_tuple(key) - self[col][row] = value - else: - raise KeyError(f'{key}') - - def __len__(self): - """Get the number of rows in the table.""" - return self._db.count(self._name) - - def __contains__(self, item): - """Check if a given column is in the table.""" - return item in self.column_names - - def __iter__(self): - """Iterate over the rows of the table.""" - for i in self.select(): - yield i - - def __repr__(self): - """Get a string representation of the table.""" - s = f"{self.__class__.__name__} '{self.name}'" - s += f" in database '{self.db}':" - s += f"({len(self.column_names)} columns x {len(self)} rows)\n" - s += '\n'.join(self.as_table().__repr__().split('\n')[1:]) - return s - - -class SQLColumn(_SQLViewerBase): - """Handle an SQL column operations interfacing with the DB.""" - - def __init__(self, db, table, name): - """Initialize the column. - - Parameters - ---------- - db : SQLDatabase - The parent database object. - table : str - The name of the table in the database. - name : str - The column name in the table. - """ - self._db = db - self._table = table - self._name = name - - @property - def name(self): - """Get the name of the column.""" - return self._name - - @property - def values(self): - """Get the values of the current column.""" - vals = self._db.select(self._table, columns=[self._name]) - return [i[0] for i in vals] - - @property - def table(self): - """Get the table name.""" - return self._table - - def __getitem__(self, key): - """Get a row from the column.""" - if isinstance(key, (int, np.int_, slice)): - return self.values[key] - if isinstance(key, (list, np.ndarray)): - v = self.values - return [v[i] for i in key] - raise IndexError(f'{key}') - - def __setitem__(self, key, value): - """Set a row in the column.""" - if isinstance(key, (int, np.int_)): - self._db.set_item(self._table, self._name, key, value) - elif isinstance(key, (slice, list, np.ndarray)): - v = np.array(self.values) - v[key] = value - self._db.set_column(self._table, self._name, v) - else: - raise IndexError(f'{key}') - - def __len__(self): - """Get the number of rows in the column.""" - return len(self.values) - - def __iter__(self): - """Iterate over the column.""" - for i in self.values: - yield i - - def __contains__(self, item): - """Check if the column contains a given value.""" - return item in self.values - - def __repr__(self): - """Get a string representation of the column.""" - s = f"{self.__class__.__name__} {self._name} in table '{self._table}'" - s += f" ({len(self)} rows)" - return s - - -class SQLRow(_SQLViewerBase): - """Handle and SQL table row interfacing with the DB.""" - - def __init__(self, db, table, row_indexer, colmap=None): - """Initialize the row. - - Parameters - ---------- - db : SQLDatabase - The parent database object. - table : str - The name of the table in the database. - row_indexer : `~astropop._db._SQLRowIndexer` - The row index in the table. - """ - self._db = db - self._table = table - self._row_indexer = row_indexer - self._colmap = colmap - - @property - def column_names(self): - """Get the column names of the current table.""" - names = self._db.column_names(self._table) - if self._colmap is not None: - names = self._colmap.get_keyword(names) - return names - - @property - def table(self): - """Get the table name.""" - return self._table - - @property - def values(self): - """Get the values of the current row.""" - return self._db.select(self._table, where={_ID_KEY: self.index+1})[0] - - @property - def index(self): - """Get the index of the current row.""" - return self._row_indexer.index - - @property - def keys(self): - """Get the keys of the current row.""" - return self.column_names - - @property - def items(self): - """Get the items of the current row.""" - return zip(self.column_names, self.values) - - def as_dict(self): - """Get the row as a dict.""" - return dict(self.items) - - def __getitem__(self, key): - """Get a column from the row.""" - if isinstance(key, (str, np.str_)): - column = key - if self._colmap is not None: - column = self._colmap.get_column_name(key) - try: - return self._db.get_item(self._table, column, self.index) - except ValueError: - raise KeyError(f'{key}') - if isinstance(key, (int, np.int_)): - return self.values[key] - raise KeyError(f'{key}') - - def __setitem__(self, key, value): - """Set a column in the row.""" - if not isinstance(key, (str, np.str_)): - raise KeyError(f'{key}') - - column = key = key.lower() - if self._colmap is not None: - column = self._colmap.get_column_name(key) - if key not in self.column_names: - raise KeyError(f'{key}') - self._db.set_item(self._table, column, self.index, value) - - def __iter__(self): - """Iterate over the row.""" - for i in self.values: - yield i - - def __contains__(self, item): - """Check if the row contains a given value.""" - return item in self.values - - def __repr__(self): - """Get a string representation of the row.""" - s = f"{self.__class__.__name__} {self.index} in table '{self._table}' " - s += self.as_dict().__repr__() - return s - - -def _sanitize_colnames(data): - """Sanitize the colnames to avoid invalid characteres like '-'.""" - def _sanitize(key): - if len([ch for ch in key if not ch.isalnum() and ch != '_']) != 0: - raise ValueError(f'Invalid column name: {key}.') - return key.lower() - - if isinstance(data, dict): - d = data - colnames = _sanitize_colnames(list(data.keys())) - return dict(zip(colnames, d.values())) - if isinstance(data, str): - return _sanitize(data) - if not isinstance(data, (list, tuple, np.ndarray)): - raise TypeError(f'{type(data)} is not supported.') - - return [_sanitize(i) for i in data] - - -def _sanitize_value(data): - """Sanitize the value to avoid sql errors.""" - if data is None or isinstance(data, bytes): - return data - if isinstance(data, (str, np.str_)): - return f"{data}" - if np.isscalar(data) and np.isreal(data): - if isinstance(data, (int, np.integer)): - return int(data) - elif isinstance(data, (float, np.floating)): - return float(data) - if isinstance(data, (bool, np.bool_)): - return bool(data) - raise TypeError(f'{type(data)} is not supported.') - - -def _fix_row_index(row, length): - """Fix the row number to be a valid index.""" - if row < 0: - row += length - if row >= length or row < 0: - raise IndexError('Row index out of range.') - return row - - -def _dict2row(cols, **row): - values = [None]*len(cols) - for i, c in enumerate(cols): - if c in row.keys(): - values[i] = row[c] - else: - values[i] = None - return values - - -def _parse_where(where): - args = None - if where is None: - _where = None - elif isinstance(where, dict): - where = _sanitize_colnames(where) - for i, (k, v) in enumerate(where.items()): - v = _sanitize_value(v) - if i == 0: - _where = f"{k}=?" - args = [v] - else: - _where += f" AND {k}=?" - args.append(v) - elif isinstance(where, str): - _where = where - elif isinstance(where, (list, tuple)): - for w in where: - if not isinstance(w, str): - raise TypeError('if where is a list, it must be a list ' - f'of strings. Not {type(w)}.') - _where = ' AND '.join(where) - else: - raise TypeError('where must be a string, list of strings or' - ' dict.') - return _where, args - - -class SQLDatabase: - """Database creation and manipulation with SQL. - - Notes - ----- - - '__id__' is only for internal indexing. It is ignored on returns. - """ - - def __init__(self, db=':memory:', autocommit=True): - """Initialize the database. - - Parameters - ---------- - db : str - The name of the database file. If ':memory:' is given, the - database will be created in memory. - autocommit : bool (optional) - Whether to commit changes to the database after each operation. - Defaults to True. - """ - self._db = db - self._con = sql.connect(self._db) - self._cur = self._con.cursor() - self.autocommit = autocommit - - self._row_indexes = {} - self._build_row_indexes() - - def execute(self, command, arguments=None): - """Execute a SQL command in the database.""" - logger.debug('executing sql command: "%s"', - str.replace(command, '\n', ' ')) - try: - if arguments is None: - self._cur.execute(command) - else: - self._cur.execute(command, arguments) - res = self._cur.fetchall() - except sql.Error as e: - self._con.rollback() - raise e - - if self.autocommit: - self.commit() - return res - - def executemany(self, command, arguments): - """Execute a SQL command in the database.""" - logger.debug('executing sql command: "%s"', - str.replace(command, '\n', ' ')) - - try: - self._cur.executemany(command, arguments) - res = self._cur.fetchall() - except sql.Error as e: - self._con.rollback() - raise e - - if self.autocommit: - self.commit() - return res - - def commit(self): - """Commit the current transaction.""" - self._con.commit() - - def count(self, table, where=None): - """Get the number of rows in the table.""" - self._check_table(table) - comm = "SELECT COUNT(*) FROM " - comm += f"{table} " - where, args = _parse_where(where) - if where is not None: - comm += f"WHERE {where}" - comm += ";" - return self.execute(comm, args)[0][0] - - def select(self, table, columns=None, where=None, order=None, limit=None, - offset=None): - """Select rows from a table. - - Parameters - ---------- - columns : list (optional) - List of columns to select. If None, select all columns. - where : dict (optional) - Dictionary of conditions to select rows. Keys are column names, - values are values to compare. All rows equal to the values will - be selected. If None, all rows are selected. - order : str (optional) - Column name to order by. - limit : int (optional) - Number of rows to select. - """ - self._check_table(table) - if columns is None: - columns = self[table].column_names - elif isinstance(columns, str): - columns = [columns] - # only use sanitized column names - columns = ', '.join(_sanitize_colnames(columns)) - - comm = f"SELECT {columns} " - comm += f"FROM {table} " - args = [] - - where, args_w = _parse_where(where) - if where is not None: - comm += f"WHERE {where} " - if args_w is not None: - args += args_w - - if order is not None: - order = _sanitize_colnames(order) - comm += f"ORDER BY {order} ASC " - - if limit is not None: - comm += "LIMIT ? " - if not isinstance(limit, (int, np.integer)): - raise TypeError('limit must be an integer.') - args.append(int(limit)) - if offset is not None: - if limit is None: - raise ValueError('offset cannot be used without limit.') - if not isinstance(offset, (int, np.integer)): - raise TypeError('offset must be an integer.') - comm += "OFFSET ? " - args.append(int(offset)) - - comm = comm + ';' - - if args == []: - args = None - res = self.execute(comm, args) - return res - - def copy(self, indexes=None): - """Get a copy of the database.""" - return self.__copy__(indexes=indexes) - - def column_names(self, table): - """Get the column names of the table.""" - self._check_table(table) - comm = "SELECT * FROM " - comm += f"{table} LIMIT 1;" - self.execute(comm) - return [i[0].lower() for i in self._cur.description - if i[0].lower() != _ID_KEY.lower()] - - @property - def db(self): - """Get the database name.""" - return str(self._db) - - @property - def table_names(self): - """Get the table names in the database.""" - comm = "SELECT name FROM sqlite_master WHERE type='table';" - return [i[0] for i in self.execute(comm) if i[0] != 'sqlite_sequence'] - - def _check_table(self, table): - """Check if the table exists in the database.""" - if table not in self.table_names: - raise KeyError(f'Table "{table}" does not exist.') - - def _add_missing_columns(self, table, columns): - """Add missing columns to the table.""" - existing = set(self.column_names(table)) - for col in [i for i in columns if i not in existing]: - self.add_column(table, col) - - def _add_data_dict(self, table, data, add_columns=False, - skip_sanitize=False): - """Add data sotred in a dict to the table.""" - data = _sanitize_colnames(data) - if add_columns: - self._add_missing_columns(table, data.keys()) - - dict_row_list = _dict2row(cols=self.column_names(table), **data) - try: - rows = np.broadcast(*dict_row_list) - except ValueError: - rows = broadcast(*dict_row_list) - rows = list(zip(*rows.iters)) - self._add_data_list(table, rows, skip_sanitize=skip_sanitize) - - def _add_data_list(self, table, data, skip_sanitize=False): - """Add data stored in a list to the table.""" - if np.ndim(data) not in (1, 2): - raise ValueError('data must be a 1D or 2D array.') - - if np.ndim(data) == 1: - data = np.reshape(data, (1, len(data))) - - if np.shape(data)[1] != len(self.column_names(table)): - raise ValueError('data must have the same number of columns as ' - 'the table.') - - if not skip_sanitize: - data = [tuple(map(_sanitize_value, d)) for d in data] - comm = f"INSERT INTO {table} VALUES " - comm += f"(NULL, {', '.join(['?']*len(data[0]))})" - comm += ';' - self.executemany(comm, data) - - # Update the row indexes - rl = self._row_indexes[table] - rl.extend([_SQLRowIndexer(rl) for i in range(len(data))]) - - def _get_indexes(self, table): - """Get the indexes of the table.""" - comm = f"SELECT {_ID_KEY} FROM {table};" - return [i[0] for i in self.execute(comm)] - - def _update_indexes(self, table): - """Update the indexes of the table.""" - rows = list(range(1, self.count(table) + 1)) - origin = self._get_indexes(table) - comm = f"UPDATE {table} SET {_ID_KEY} = ? WHERE {_ID_KEY} = ?;" - self.executemany(comm, zip(rows, origin)) - - def _build_row_indexes(self): - """Build the row indexes.""" - for table in self.table_names: - size = self.count(table) - # Create the list that must be passed to _SQLRowIndexer - rl = [None]*size - self._row_indexes[table] = rl - for i in range(size): - self._row_indexes[table][i] = _SQLRowIndexer(rl) - - def add_table(self, table, columns=None, data=None): - """Create a table in database.""" - logger.debug('Initializing "%s" table.', table) - if table in self.table_names: - raise ValueError('table {table} already exists.') - - comm = f"CREATE TABLE '{table}'" - comm += f" (\n{_ID_KEY} INTEGER PRIMARY KEY AUTOINCREMENT" - - if columns is not None and data is not None: - raise ValueError('cannot specify both columns and data.') - if columns is not None: - comm += ",\n" - for i, name in enumerate(columns): - comm += f"\t'{name}'" - if i != len(columns) - 1: - comm += ",\n" - comm += "\n);" - - self.execute(comm) - - # Add the row indexer list - self._row_indexes[table] = [] - - if data is not None: - self.add_rows(table, data, add_columns=True) - - def add_column(self, table, column, data=None): - """Add a column to a table.""" - self._check_table(table) - - column = column.lower() - if data is not None and len(data) != len(self[table]) and \ - len(self[table]) != 0: - raise ValueError("data must have the same length as the table.") - - if column in (_ID_KEY, 'table', 'default'): - raise ValueError(f"{column} is a protected name.") - - col = _sanitize_colnames([column])[0] - comm = f"ALTER TABLE {table} ADD COLUMN '{col}' ;" - logger.debug('adding column "%s" to table "%s"', col, table) - self.execute(comm) - - # adding the data to the table - if data is not None: - self.set_column(table, column, data) - - def delete_column(self, table, column): - """Delete a column from a table.""" - self._check_table(table) - - if column in (_ID_KEY, 'table', 'default'): - raise ValueError(f"{column} is a protected name.") - if column not in self.column_names(table): - raise KeyError(f'Column "{column}" does not exist.') - - comm = f"ALTER TABLE {table} DROP COLUMN '{column}' ;" - logger.debug('deleting column "%s" from table "%s"', column, table) - self.execute(comm) - - def add_rows(self, table, data, add_columns=False, skip_sanitize=False): - """Add a dict row to a table. - - Parameters - ---------- - data : dict, list or `~numpy.ndarray` - Data to add to the table. If dict, keys are column names, - if list, the order of the values is the same as the order of - the column names. If `~numpy.ndarray`, dtype names are interpreted - as column names. - add_columns : bool (optional) - If True, add missing columns to the table. - """ - self._check_table(table) - if isinstance(data, (list, tuple)): - return self._add_data_list(table, data, - skip_sanitize=skip_sanitize) - if isinstance(data, dict): - return self._add_data_dict(table, data, add_columns=add_columns, - skip_sanitize=skip_sanitize) - if isinstance(data, np.ndarray): - names = data.dtype.names - if names is not None: - data = {n: data[n] for n in names} - return self._add_data_dict(table, data, - add_columns=add_columns, - skip_sanitize=skip_sanitize) - return self._add_data_list(table, data, - skip_sanitize=skip_sanitize) - if isinstance(data, Table): - data = {c: list(data[c]) for c in data.colnames} - return self._add_data_dict(table, data, add_columns=add_columns, - skip_sanitize=skip_sanitize) - - raise TypeError('data must be a dict, list, or numpy array. ' - f'Not {type(data)}.') - - def delete_row(self, table, index): - """Delete a row from the table.""" - self._check_table(table) - row = _fix_row_index(index, len(self[table])) - comm = f"DELETE FROM {table} WHERE {_ID_KEY}={row+1};" - self.execute(comm) - self._row_indexes[table].pop(row) - self._update_indexes(table) - - def drop_table(self, table): - """Drop a table from the database.""" - self._check_table(table) - comm = f"DROP TABLE {table};" - self.execute(comm) - del self._row_indexes[table] - - def get_table(self, table, column_map=None): - """Get a table from the database.""" - self._check_table(table) - return SQLTable(self, table, colmap=column_map) - - def get_row(self, table, index, column_map=None): - """Get a row from the table.""" - self._check_table(table) - index = _fix_row_index(index, len(self[table])) - row = self._row_indexes[table][index] - return SQLRow(self, table, row, colmap=column_map) - - def get_column(self, table, column): - """Get a column from the table.""" - column = column.lower() - if column not in self.column_names(table): - raise KeyError(f"column {column} does not exist.") - return SQLColumn(self, table, column) - - def get_item(self, table, column, row): - """Get an item from the table.""" - self._check_table(table) - row = _fix_row_index(row, len(self[table])) - column = _sanitize_colnames([column])[0] - return self.get_column(table, column)[row] - - def set_item(self, table, column, row, value): - """Set a value in a cell.""" - row = _fix_row_index(row, self.count(table)) - column = _sanitize_colnames([column])[0] - value = _sanitize_value(value) - self.execute(f"UPDATE {table} SET {column}=? " - f"WHERE {_ID_KEY}=?;", (value, row+1)) - - def set_row(self, table, row, data): - """Set a row in the table.""" - row = _fix_row_index(row, self.count(table)) - colnames = self.column_names(table) - - if isinstance(data, dict): - data = _dict2row(colnames, **data) - elif isinstance(data, (list, tuple, np.ndarray)): - if len(data) != len(colnames): - raise ValueError('data must have the same length as the ' - 'table.') - else: - raise TypeError('data must be a dict, list, or numpy array. ' - f'Not {type(data)}.') - - comm = f"UPDATE {table} SET " - comm += f"{', '.join(f'{i}=?' for i in colnames)} " - comm += f" WHERE {_ID_KEY}=?;" - self.execute(comm, tuple(list(map(_sanitize_value, data)) + [row+1])) - - def set_column(self, table, column, data): - """Set a column in the table.""" - tablen = self.count(table) - if column not in self.column_names(table): - raise KeyError(f"column {column} does not exist.") - if len(data) != tablen and tablen != 0: - raise ValueError("data must have the same length as the table.") - - if tablen == 0: - for i in range(len(data)): - self.add_rows(table, {}) - - col = _sanitize_colnames([column])[0] - comm = f"UPDATE {table} SET " - comm += f"{col}=? " - comm += f" WHERE {_ID_KEY}=?;" - args = list(zip([_sanitize_value(d) for d in data], - range(1, self.count(table)+1))) - self.executemany(comm, args) - - def index_of(self, table, where): - """Get the index(es) where a given condition is satisfied.""" - indx = self.select(table, _ID_KEY, where=where) - if len(indx) == 1: - return indx[0][0]-1 - return [i[0]-1 for i in indx] - - def __len__(self): - """Get the number of rows in the current table.""" - return len(self.table_names) - - def __del__(self): - """Delete the class, closing the db connection.""" - # ensure connection is closed. - self._con.close() - - def __setitem__(self, item, value): - """Set a row in the table.""" - if not isinstance(item, tuple): - raise KeyError('item must be a in the formats ' - 'db[table, row], db[table, column] or ' - 'db[table, column, row].') - if not isinstance(item[0], str): - raise KeyError('first item must be the table name.') - self.get_table(item[0])[item[1:]] = value - - def __getitem__(self, item): - """Get a items from the table.""" - if isinstance(item, (str, np.str_)): - return self.get_table(item) - if isinstance(item, tuple): - if not isinstance(item[0], str): - raise ValueError('first item must be the table name.') - return self.get_table(item[0])[item[1:]] - raise ValueError('items must be a string for table names, ' - 'or a tuple in the formats. ' - 'db[table, row], db[table, column] or ' - 'db[table, column, row].') - - def __repr__(self): - """Get a string representation of the table.""" - s = f"{self.__class__.__name__} '{self.db}' at {hex(id(self))}:" - if len(self) == 0: - s += '\n\tEmpty database.' - for i in self.table_names: - s += f"\n\t{i}: {len(self.column_names(i))} columns" - s += f" {len(self[i])} rows" - return s - - def __copy__(self, indexes=None): - """Copy the database. - - Parameters - ---------- - indexes : dict, optional - A dictionary of table names and their indexes to copy. - - Returns - ------- - db : SQLDatabase - A copy of the database. - """ - def _get_data(table, indx=None): - if indx is None: - return self.select(table) - if len(indx) == 0: - return None - indx = np.array(np.array(indx, dtype=int)+1, dtype=str) - where = f"{_ID_KEY} in ({','.join(indx)})" - return self.select(table, where=where) - - # when copying, always copy to memory - db = SQLDatabase(':memory:') - if indexes is None: - indexes = {} - for i in self.table_names: - db.add_table(i, columns=self.column_names(i)) - rows = _get_data(i, indexes.get(i, None)) - if rows is not None: - db.add_rows(i, rows, skip_sanitize=True) - return db diff --git a/astropop/file_collection.py b/astropop/file_collection.py index 33efc20c..c109cc1e 100644 --- a/astropop/file_collection.py +++ b/astropop/file_collection.py @@ -5,11 +5,12 @@ import glob from pathlib import Path import numpy as np +import sqlite3 as sql from astropy.io import fits from astropy.table import Column +from dbastable import SQLDatabase, _ID_KEY, SQLTable -from ._db import SQLDatabase, _ID_KEY, sql, SQLTable, SQLColumnMap from .fits_utils import _fits_extensions, \ _fits_extensions_with_compress from .framedata import check_framedata @@ -100,7 +101,9 @@ def __init__(self, location=None, files=None, ext=0, for i in kwargs.keys(): raise ValueError('Unknown parameter: {}'.format(i)) - self._db = SQLDatabase(database) + # As the headers may contain not allowed keywords, let's enable Base32 + # column names + self._db = SQLDatabase(database, allow_b32_colnames=True) if database == ':memory:': self._db_dir = None else: @@ -152,10 +155,7 @@ def _read_db(self, files, location, compression, update=False): self._ext = self._db[_metadata, 'ext'][0] self._location = self._db[_metadata, 'location'][0] self._compression = self._db[_metadata, 'compression'][0] - - cmap = SQLColumnMap(self._db, _keycolstable, - _keywords_col, _columns_col) - self._table = SQLTable(self._db, _headers, colmap=cmap) + self._table = SQLTable(self._db, _headers) if update or not initialized: self.update(files, location, compression) diff --git a/tests/test_db.py b/tests/test_db.py deleted file mode 100644 index 0e73a1b1..00000000 --- a/tests/test_db.py +++ /dev/null @@ -1,1804 +0,0 @@ -# Licensed under a 3-clause BSD style license - see LICENSE.rst -# flake8: noqa: F403, F405 - -import pytest -from astropop._db import SQLColumn, SQLRow, SQLTable, _ID_KEY, \ - SQLDatabase, _sanitize_colnames, \ - SQLColumnMap -import numpy as np -from astropy.table import Table -from astropop.testing import * -import sqlite3 -import copy - - -def test_sanitize_string(): - for i in ['test-2', 'test!2', 'test@2', 'test#2', 'test$2', - 'test&2', 'test*2', 'test(2)', 'test)2', 'test[2]', 'test]2', - 'test{2}', 'test}2', 'test|2', 'test\\2', 'test^2', 'test~2' - 'test"2', 'test\'2', 'test`2', 'test<2', 'test>2', 'test=2', - 'test,2', 'test;2', 'test:2', 'test?2', 'test/2']: - with pytest.raises(ValueError): - _sanitize_colnames(i) - - for i in ['test', 'test_1', 'test_1_2', 'test_1_2', 'Test', 'Test_1']: - assert_equal(_sanitize_colnames(i), i.lower()) - - -class Test_SQLColumnMap: - def cmap(self): - db = SQLDatabase() - db.add_table('key_columns', data={'keywords': ['key1', 'key 2', - 'key-3', 'key_4'], - 'columns': ['col1', 'col2', 'col3', - 'col4']}) - cmap = SQLColumnMap(db, 'key_columns', 'keywords', 'columns') - return cmap - - def test_columnmap_get_column_name(self): - cmap = self.cmap() - assert_equal(cmap.get_column_name('key1'), 'col1') - assert_equal(cmap.get_column_name('key 2'), 'col2') - assert_equal(cmap.get_column_name('key-3'), 'col3') - assert_equal(cmap.get_column_name('key_4'), 'col4') - - def test_columnmap_get_column_name_list(self): - cmap = self.cmap() - assert_equal(cmap.get_column_name(['key1', 'key 2']), ['col1', 'col2']) - - def test_columnmap_get_column_name_not_found(self): - cmap = self.cmap() - with pytest.raises(KeyError): - cmap.get_column_name('key5') - - def test_columnmap_get_keyword(self): - cmap = self.cmap() - assert_equal(cmap.get_keyword('col1'), 'key1') - assert_equal(cmap.get_keyword('col2'), 'key 2') - assert_equal(cmap.get_keyword('col3'), 'key-3') - assert_equal(cmap.get_keyword('col4'), 'key_4') - - def test_columnmap_get_keyword_list(self): - cmap = self.cmap() - assert_equal(cmap.get_keyword(['col1', 'col2']), ['key1', 'key 2']) - - def test_columnmap_get_keyword_not_found(self): - db = SQLDatabase() - db.add_table('key_columns', data={'keywords': ['key1', 'key 2', - 'key-3', 'key_4'], - 'columns': ['col1', 'col2', 'col3', - 'col4']}) - cmap = SQLColumnMap(db, 'key_columns', 'keywords', 'columns') - with pytest.raises(KeyError): - cmap.get_keyword('col5') - - def test_columnmap_map_row(self): - cmap = self.cmap() - row = {'key1': 1, 'key 2': 2, 'key-3': 3} - assert_equal(cmap.map_row(row), {'col1': 1, 'col2': 2, 'col3': 3}) - - def test_columnmap_map_row_add_column(self): - cmap = self.cmap() - row = {'key1': 1, 'key 2': 2, 'key 5': 3} - assert_equal(cmap.map_row(row, add_columns=True), - {'col1': 1, 'col2': 2, 'col_5': 3}) - - def test_columnmap_map_row_list(self): - cmap = self.cmap() - row = [1, 2, 3, 4] - assert_equal(cmap.map_row(row), [1, 2, 3, 4]) - - def test_columnmap_add_column(self): - cmap = self.cmap() - cmap.add_column('key 55') - assert_equal(cmap.get_column_name('key 55'), 'col_5') - - def test_columnmap_add_column_existing(self): - cmap = self.cmap() - with pytest.raises(ValueError): - cmap.add_column('key1') - - def test_columnmap_parse_where(self): - cmap = self.cmap() - with pytest.raises(TypeError): - cmap.parse_where('key1 = 1') - with pytest.raises(TypeError): - cmap.parse_where(['key1 = 1', 'key2 = 2']) - cmap.parse_where({'key1': 1, 'key 2': 2}) - - -class Test_SQLDatabase_Creation_Modify: - def test_sql_db_creation(self): - db = SQLDatabase(':memory:') - assert_equal(db.table_names, []) - assert_equal(len(db), 0) - - db.add_table('test') - assert_equal(db.table_names, ['test']) - assert_equal(len(db), 1) - - def test_sql_add_column_name_and_data(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - assert_equal(db.get_column('test', 'a').values, np.arange(10, 20)) - assert_equal(db.get_column('test', 'b').values, np.arange(20, 30)) - assert_equal(db.column_names('test'), ['a', 'b']) - assert_equal(len(db), 1) - assert_equal(db.table_names, ['test']) - - def test_sql_add_column_only_name(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a') - db.add_column('test', 'b') - - assert_equal(db.get_column('test', 'a').values, []) - assert_equal(db.get_column('test', 'b').values, []) - assert_equal(db.column_names('test'), ['a', 'b']) - assert_equal(len(db), 1) - assert_equal(db.table_names, ['test']) - - def test_sql_add_table_from_data_table(self): - db = SQLDatabase(':memory:') - d = Table(names=['a', 'b'], data=[np.arange(10, 20), np.arange(20, 30)]) - db.add_table('test', data=d) - - assert_equal(db.get_column('test', 'a').values, np.arange(10, 20)) - assert_equal(db.get_column('test', 'b').values, np.arange(20, 30)) - assert_equal(db.column_names('test'), ['a', 'b']) - assert_equal(len(db), 1) - assert_equal(db.table_names, ['test']) - - def test_sql_add_table_from_data_ndarray(self): - dtype = [('a', 'i4'), ('b', 'f8')] - data = np.array([(1, 2.0), (3, 4.0), (5, 6.0), (7, 8.0)], dtype=dtype) - db = SQLDatabase(':memory:') - db.add_table('test', data=data) - - assert_equal(db.get_column('test', 'a').values, [1, 3, 5, 7]) - assert_equal(db.get_column('test', 'b').values, [2.0, 4.0, 6.0, 8.0]) - assert_equal(db.column_names('test'), ['a', 'b']) - assert_equal(len(db), 1) - assert_equal(db.table_names, ['test']) - - def test_sql_add_table_from_data_dict(self): - d = {'a': np.arange(10, 20), 'b': np.arange(20, 30)} - db = SQLDatabase(':memory:') - db.add_table('test', data=d) - - assert_equal(db.get_column('test', 'a').values, np.arange(10, 20)) - assert_equal(db.get_column('test', 'b').values, np.arange(20, 30)) - assert_equal(db.column_names('test'), ['a', 'b']) - assert_equal(len(db), 1) - assert_equal(db.table_names, ['test']) - - def test_sql_add_table_from_data_ndarray_untyped(self): - # Untyped ndarray should fail in get column names - data = np.array([(1, 2.0), (3, 4.0), (5, 6.0), (7, 8.0)]) - db = SQLDatabase(':memory:') - with pytest.raises(ValueError): - db.add_table('test', data=data) - - @pytest.mark.parametrize('data, error', [([1, 2, 3], ValueError), - (1, TypeError), - (1.0, TypeError), - ('test', TypeError)]) - def test_sql_add_table_from_data_invalid(self, data, error): - db = SQLDatabase(':memory:') - with assert_raises(error): - db.add_table('test', data=data) - - def test_sql_add_table_columns(self): - db = SQLDatabase(':memory:') - db.add_table('test', columns=['a', 'b']) - - assert_equal(db.get_column('test', 'a').values, []) - assert_equal(db.get_column('test', 'b').values, []) - assert_equal(db.column_names('test'), ['a', 'b']) - assert_equal(len(db), 1) - assert_equal(db.table_names, ['test']) - - def test_sql_add_table_columns_data(self): - db = SQLDatabase(':memory:') - with pytest.raises(ValueError): - db.add_table('test', columns=['a', 'b'], data=[1, 2, 3]) - - def test_sql_add_row(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a') - db.add_column('test', 'b') - db.add_rows('test', dict(a=1, b=2)) - db.add_rows('test', dict(a=[3, 5], b=[4, 6])) - - assert_equal(db.get_column('test', 'a').values, [1, 3, 5]) - assert_equal(db.get_column('test', 'b').values, [2, 4, 6]) - assert_equal(len(db), 1) - assert_equal(db.table_names, ['test']) - - def test_sql_add_row_types(self): - db = SQLDatabase(':memory:') - db.add_table('test') - for k in ['a', 'b', 'c', 'd', 'e']: - db.add_column('test', k) - - db.add_rows('test', dict(a=1, b='a', c=True, d=b'a', e=3.14)) - db.add_rows('test', dict(a=2, b='b', c=False, d=b'b', e=2.71)) - - assert_equal(db.get_column('test', 'a').values, [1, 2]) - assert_equal(db.get_column('test', 'b').values, ['a', 'b']) - assert_equal(db.get_column('test', 'c').values, [1, 0]) - assert_equal(db.get_column('test', 'd').values, [b'a', b'b']) - assert_almost_equal(db.get_column('test', 'e').values, [3.14, 2.71]) - assert_equal(len(db), 1) - assert_equal(db.table_names, ['test']) - - def test_sql_add_row_invalid(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a') - db.add_column('test', 'b') - with assert_raises(ValueError): - db.add_rows('test', [1, 2, 3]) - - def test_sql_add_row_add_columns(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a') - db.add_column('test', 'b') - db.add_rows('test', dict(a=1, b=2)) - db.add_rows('test', dict(a=3, c=4), add_columns=False) - db.add_rows('test', dict(a=5, d=6), add_columns=True) - - assert_equal(db.get_column('test', 'a').values, [1, 3, 5]) - assert_equal(db.get_column('test', 'b').values, [2, None, None]) - assert_equal(len(db), 1) - assert_equal(db.table_names, ['test']) - assert_equal(db.column_names('test'), ['a', 'b', 'd']) - - def test_sql_add_row_superpass_64_limit(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_rows('test', {f'col{i}': np.arange(10) for i in range(128)}, - add_columns=True) - assert_equal(db.column_names('test'), [f'col{i}' for i in range(128)]) - - for i, v in enumerate(db.select('test')): - assert_equal(v, [i]*128) - - def test_sqltable_add_column(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db['test'].add_column('a') - db['test'].add_column('b') - db['test'].add_column('c', data=[1, 2, 3]) - - assert_equal(db.get_column('test', 'a').values, [None, None, None]) - assert_equal(db.get_column('test', 'b').values, [None, None, None]) - assert_equal(db.get_column('test', 'c').values, [1, 2, 3]) - assert_equal(len(db), 1) - assert_equal(db.table_names, ['test']) - assert_equal(db.column_names('test'), ['a', 'b', 'c']) - assert_equal(len(db['test']), 3) - - def test_sqltable_add_row_add_columns(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a') - db.add_column('test', 'b') - db['test'].add_rows(dict(a=1, b=2)) - db['test'].add_rows(dict(a=3, c=4), add_columns=False) - db['test'].add_rows(dict(a=5, d=6), add_columns=True) - - assert_equal(db.get_column('test', 'a').values, [1, 3, 5]) - assert_equal(db.get_column('test', 'b').values, [2, None, None]) - assert_equal(len(db), 1) - assert_equal(db.table_names, ['test']) - assert_equal(db.column_names('test'), ['a', 'b', 'd']) - - def test_sql_set_column(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', [1, 3, 5]) - db.add_column('test', 'b', [2, 4, 6]) - - db.set_column('test', 'a', [10, 20, 30]) - db.set_column('test', 'b', [20, 40, 60]) - - assert_equal(db.get_column('test', 'a').values, [10, 20, 30]) - assert_equal(db.get_column('test', 'b').values, [20, 40, 60]) - - with pytest.raises(KeyError): - db.set_column('test', 'c', [10, 20, 30]) - with pytest.raises(ValueError): - db.set_column('test', 'a', [10, 20, 30, 40]) - - def test_sql_set_row(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', [1, 3, 5]) - db.add_column('test', 'b', [2, 4, 6]) - - db.set_row('test', 0, dict(a=10, b=20)) - db.set_row('test', 1, [20, 40]) - db.set_row('test', 2, np.array([30, 60])) - - assert_equal(db.get_column('test', 'a').values, [10, 20, 30]) - assert_equal(db.get_column('test', 'b').values, [20, 40, 60]) - - with pytest.raises(IndexError): - db.set_row('test', 3, dict(a=10, b=20)) - with pytest.raises(IndexError): - db.set_row('test', -4, dict(a=10, b=20)) - - def test_sql_set_item(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', [1, 3, 5]) - db.add_column('test', 'b', [2, 4, 6]) - - db.set_item('test', 'a', 0, 10) - db.set_item('test', 'b', 1, 'a') - assert_equal(db.get_column('test', 'a').values, [10, 3, 5]) - assert_equal(db.get_column('test', 'b').values, [2, 'a', 6]) - - def test_sql_setitem_tuple_only(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - with pytest.raises(KeyError): - db[1] = 0 - with pytest.raises(KeyError): - db['notable'] = 0 - with pytest.raises(KeyError): - db[['test', 0]] = 0 - with pytest.raises(KeyError): - db[1, 0] = 0 - - def test_sql_setitem(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - db['test', 'a'] = np.arange(50, 60) - db['test', 0] = {'a': 1, 'b': 2} - db['test', 'b', 5] = -999 - - expect = np.transpose([np.arange(50, 60), np.arange(20, 30)]) - expect[0] = [1, 2] - expect[5, 1] = -999 - - assert_equal(db['test'].values, expect) - - def test_sql_droptable(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', [1, 3, 5]) - db.add_column('test', 'b', [2, 4, 6]) - - db.drop_table('test') - assert_equal(db.table_names, []) - with pytest.raises(KeyError): - db['test'] - - def test_sql_copy(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', [1, 3, 5]) - db.add_column('test', 'b', [2, 4, 6]) - - db2 = db.copy() - assert_equal(db2.table_names, ['test']) - assert_equal(db2.column_names('test'), ['a', 'b']) - assert_equal(db2.get_column('test', 'a').values, [1, 3, 5]) - assert_equal(db2.get_column('test', 'b').values, [2, 4, 6]) - - def test_sql_copy_indexes(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', np.arange(1, 101, 2)) - db.add_column('test', 'b', np.arange(2, 102, 2)) - - db2 = db.copy(indexes={'test': [30, 24, 32, 11]}) - assert_equal(db2.table_names, ['test']) - assert_equal(db2.column_names('test'), ['a', 'b']) - assert_equal(db2.get_column('test', 'a').values, [23, 49, 61, 65]) - assert_equal(db2.get_column('test', 'b').values, [24, 50, 62, 66]) - - def test_sql_copy_more_than_1000(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', np.arange(1, 5001, 1)) - db.add_column('test', 'b', np.arange(2, 5002, 1)) - - db2 = db.copy() - assert_equal(db2.table_names, ['test']) - assert_equal(db2.column_names('test'), ['a', 'b']) - assert_equal(db2.get_column('test', 'a').values, - np.arange(1, 5001, 1)) - assert_equal(db2.get_column('test', 'b').values, - np.arange(2, 5002, 1)) - - def test_sql_copy_more_than_1000_indexes(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', np.arange(1, 5001, 1)) - db.add_column('test', 'b', np.arange(2, 5002, 1)) - - db2 = db.copy(indexes={'test': np.arange(1000, 2500, 1)}) - assert_equal(db2.table_names, ['test']) - assert_equal(db2.column_names('test'), ['a', 'b']) - assert_equal(db2.get_column('test', 'a').values, - np.arange(1001, 2501, 1)) - assert_equal(db2.get_column('test', 'b').values, - np.arange(1002, 2502, 1)) - - def test_sql_delete_row(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', [1, 3, 5]) - db.add_column('test', 'b', [2, 4, 6]) - - db.delete_row('test', 1) - assert_equal(db.get_column('test', 'a').values, [1, 5]) - assert_equal(db.get_column('test', 'b').values, [2, 6]) - - with pytest.raises(IndexError): - db.delete_row('test', 2) - with pytest.raises(IndexError): - db.delete_row('test', -4) - - def test_sql_delete_column(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', [1, 3, 5]) - db.add_column('test', 'b', [2, 4, 6]) - - db.delete_column('test', 'b') - assert_equal(db.column_names('test'), ['a']) - assert_equal(db.get_column('test', 'a').values, [1, 3, 5]) - - with pytest.raises(KeyError, match='does not exist'): - db.delete_column('test', 'b') - with pytest.raises(ValueError, match='protected name'): - db.delete_column('test', 'table') - - -class Test_SQLDatabase_Access: - def test_sql_get_table(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - assert_equal(db.get_table('test').values, list(zip(np.arange(10, 20), - np.arange(20, 30)))) - assert_is_instance(db.get_table('test'), SQLTable) - - with pytest.raises(KeyError): - db.get_table('not_a_table') - - def test_sql_get_table_empty(self): - db = SQLDatabase(':memory:') - db.add_table('test') - - assert_equal(len(db.get_table('test')), 0) - assert_is_instance(db.get_table('test'), SQLTable) - - def test_sql_get_column(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - assert_equal(db.get_column('test', 'a').values, np.arange(10, 20)) - assert_equal(db.get_column('test', 'b').values, np.arange(20, 30)) - assert_is_instance(db.get_column('test', 'a'), SQLColumn) - assert_is_instance(db.get_column('test', 'b'), SQLColumn) - - # same access from table - assert_equal(db.get_table('test').get_column('a').values, np.arange(10, 20)) - assert_equal(db.get_table('test').get_column('b').values, np.arange(20, 30)) - assert_is_instance(db.get_table('test').get_column('a'), SQLColumn) - assert_is_instance(db.get_table('test').get_column('b'), SQLColumn) - - with pytest.raises(KeyError): - db.get_column('test', 'c') - with pytest.raises(KeyError): - db.get_table('test').get_column('c') - - def test_sql_get_row(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - assert_equal(db.get_row('test', 4).values, (14, 24)) - assert_is_instance(db.get_row('test', 4), SQLRow) - - assert_equal(db.get_row('test', -1).values, (19, 29)) - assert_is_instance(db.get_row('test', -1), SQLRow) - - # same access from table - assert_equal(db.get_table('test').get_row(4).values, [14, 24]) - assert_is_instance(db.get_table('test').get_row(4), SQLRow) - - with pytest.raises(IndexError): - db.get_row('test', 11) - with pytest.raises(IndexError): - db.get_row('test', -11) - with pytest.raises(IndexError): - db.get_table('test').get_row(11) - with pytest.raises(IndexError): - db.get_table('test').get_row(-11) - - def test_sql_getitem(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - assert_equal(db['test']['a'].values, np.arange(10, 20)) - assert_equal(db['test']['b'].values, np.arange(20, 30)) - assert_is_instance(db['test']['a'], SQLColumn) - assert_is_instance(db['test']['b'], SQLColumn) - - assert_equal(db['test'][4].values, (14, 24)) - assert_is_instance(db['test'][4], SQLRow) - assert_equal(db['test'][-1].values, (19, 29)) - assert_is_instance(db['test'][-1], SQLRow) - - with pytest.raises(KeyError): - db['test']['c'] - with pytest.raises(KeyError): - db['not_a_table']['a'] - - with pytest.raises(IndexError): - db['test'][11] - with pytest.raises(IndexError): - db['test'][-11] - - def test_sql_getitem_tuple(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - assert_equal(db['test', 'a'].values, np.arange(10, 20)) - assert_equal(db['test', 'b'].values, np.arange(20, 30)) - assert_is_instance(db['test', 'a'], SQLColumn) - assert_is_instance(db['test', 'b'], SQLColumn) - - assert_equal(db['test', 4].values, (14, 24)) - assert_is_instance(db['test', 4], SQLRow) - assert_equal(db['test', -1].values, (19, 29)) - assert_is_instance(db['test', -1], SQLRow) - - assert_equal(db['test', 'a', 4], 14) - assert_equal(db['test', 'b', 4], 24) - assert_equal(db['test', 'a', -1], 19) - assert_equal(db['test', 'b', -1], 29) - assert_equal(db['test', 4, 'a'], 14) - assert_equal(db['test', 4, 'b'], 24) - assert_equal(db['test', -1, 'a'], 19) - assert_equal(db['test', -1, 'b'], 29) - - def test_sql_getitem_table_force(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - with pytest.raises(ValueError): - db[1] - with pytest.raises(ValueError): - db[1, 2] - with pytest.raises(ValueError): - db[1, 2, 'test'] - with pytest.raises(ValueError): - db[[1, 2], 'test'] - - -class Test_SQLDatabase_PropsComms: - def test_sql_select_where(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - a = db.select('test', columns='a', where={'a': 15}) - assert_equal(a, 15) - - a = db.select('test', columns=['a', 'b'], where={'b': 22}) - assert_equal(a, [(12, 22)]) - - a = db.select('test', columns=['a', 'b'], where=None) - assert_equal(a, list(zip(np.arange(10, 20), np.arange(20, 30)))) - - a = db.select('test', columns=['a', 'b'], where=['a > 12', 'b < 26']) - assert_equal(a, [(13, 23), (14, 24), (15, 25)]) - - def test_sql_select_limit_offset(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - a = db.select('test', columns='a', limit=1) - assert_equal(a, 10) - - a = db.select('test', columns='a', limit=3, offset=2) - assert_equal(a, [[12], [13], [14]]) - - def test_sql_select_invalid(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - with pytest.raises(sqlite3.OperationalError, - match='no such column: c'): - db.select('test', columns=['c']) - - with pytest.raises(ValueError, - match='offset cannot be used without limit.'): - db.select('test', columns='a', offset=1) - - with pytest.raises(TypeError, match='where must be'): - db.select('test', columns='a', where=1) - - with pytest.raises(TypeError, match='if where is a list'): - db.select('test', columns='a', where=[1, 2, 3]) - - with pytest.raises(TypeError): - db.select('test', limit=3.14) - - with pytest.raises(TypeError): - db.select('test', order=5) - - def test_sql_select_order(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)[::-1]) - - a = db.select('test', order='b') - assert_equal(a, list(zip(np.arange(10, 20), - np.arange(20, 30)[::-1]))[::-1]) - - a = db.select('test', order='b', limit=2) - assert_equal(a, [(19, 20), (18, 21)]) - - a = db.select('test', order='b', limit=2, offset=2) - assert_equal(a, [(17, 22), (16, 23)]) - - a = db.select('test', order='b', where='a < 15') - assert_equal(a, [(14, 25), (13, 26), (12, 27), (11, 28), (10, 29)]) - - a = db.select('test', order='b', where='a < 15', limit=3) - assert_equal(a, [(14, 25), (13, 26), (12, 27)]) - - a = db.select('test', order='b', where='a < 15', limit=3, offset=2) - assert_equal(a, [(12, 27), (11, 28), (10, 29)]) - - def test_sql_count(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - assert_equal(db.count('test'), 10) - assert_equal(db.count('test', where={'a': 15}), 1) - assert_equal(db.count('test', where={'a': 15, 'b': 22}), 0) - assert_equal(db.count('test', where='a > 15'), 4) - assert_equal(db.count('test', where=['a > 15', 'b < 27']), 1) - - def test_sql_prop_db(self, tmp_path): - db = SQLDatabase(':memory:') - assert_equal(db.db, ':memory:') - - db = SQLDatabase(str(tmp_path / 'test.db')) - assert_equal(db.db, str(tmp_path / 'test.db')) - - def test_sql_prop_table_names(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_table('test2') - assert_equal(db.table_names, ['test', 'test2']) - - def test_sql_prop_column_names(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - assert_equal(db.column_names('test'), ['a', 'b']) - - def test_sql_repr(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - db.add_table('test2') - db.add_column('test2', 'a', data=np.arange(10, 20)) - db.add_column('test2', 'b', data=np.arange(20, 30)) - - expect = f"SQLDatabase ':memory:' at {hex(id(db))}:\n" - expect += "\ttest: 2 columns 10 rows\n" - expect += "\ttest2: 2 columns 10 rows" - assert_equal(repr(db), expect) - - db = SQLDatabase(':memory:') - expect = f"SQLDatabase ':memory:' at {hex(id(db))}:\n" - expect += "\tEmpty database." - assert_equal(repr(db), expect) - - def test_sql_len(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_table('test2') - db.add_table('test3') - - assert_equal(len(db), 3) - - def test_sql_index_of(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - assert_equal(db.index_of('test', {'a': 15}), 5) - assert_equal(db.index_of('test', 'b >= 27'), [7, 8, 9]) - assert_equal(db.index_of('test', {'a': 1, 'b': 2}), []) - - -class Test_SQLRow: - @property - def db(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - return db - - def test_row_copy_error(self): - db = self.db - row = db['test'][1] - with pytest.raises(NotImplementedError, match='Cannot copy'): - copy.copy(row) - with pytest.raises(NotImplementedError, match='Cannot copy'): - copy.deepcopy(row) - - def test_row_basic_properties(self): - db = self.db - row = db['test'][0] - assert_is_instance(row, SQLRow) - assert_equal(row.table, 'test') - assert_equal(row.index, 0) - assert_equal(row.column_names, ['a', 'b']) - assert_equal(row.keys, ['a', 'b']) - assert_equal(row.values, [10, 20]) - assert_equal(row.as_dict(), {'a': 10, 'b': 20}) - - def test_row_iter(self): - db = self.db - row = db['test'][0] - assert_is_instance(row, SQLRow) - - v = 10 - for i in row: - assert_equal(i, v) - v += 10 - - def test_row_getitem(self): - db = self.db - row = db['test'][0] - assert_is_instance(row, SQLRow) - - assert_equal(row['a'], 10) - assert_equal(row['b'], 20) - - with pytest.raises(KeyError): - row['c'] - - assert_equal(row[0], 10) - assert_equal(row[1], 20) - assert_equal(row[-1], 20) - assert_equal(row[-2], 10) - - with pytest.raises(IndexError): - row[2] - with pytest.raises(IndexError): - row[-3] - - def test_row_setitem(self): - db = self.db - row = db['test'][0] - assert_is_instance(row, SQLRow) - - row['a'] = 1 - row['b'] = 1 - assert_equal(db['test']['a'], [1, 11, 12, 13, 14, - 15, 16, 17, 18, 19]) - assert_equal(db['test']['b'], [1, 21, 22, 23, 24, - 25, 26, 27, 28, 29]) - - with pytest.raises(KeyError): - row['c'] = 1 - with pytest.raises(KeyError): - row[2] = 1 - with pytest.raises(KeyError): - row[-3] = 1 - - def test_row_contains(self): - db = self.db - row = db['test'][0] - assert_is_instance(row, SQLRow) - - assert_true(10 in row) - assert_true(20 in row) - assert_false('c' in row) - assert_false('a' in row) - assert_false('b' in row) - - def test_row_repr(self): - db = self.db - row = db['test'][0] - assert_is_instance(row, SQLRow) - assert_equal(repr(row), "SQLRow 0 in table 'test' {'a': 10, 'b': 20}") - - -class Test_SQLTable: - @property - def db(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - return db - - def test_table_copy_error(self): - db = self.db - table = db['test'] - with pytest.raises(NotImplementedError, match='Cannot copy'): - copy.copy(table) - with pytest.raises(NotImplementedError, match='Cannot copy'): - copy.deepcopy(table) - - def test_table_basic_properties(self): - db = self.db - table = db['test'] - assert_equal(table.name, 'test') - assert_equal(table.db, db.db) - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, list(zip(np.arange(10, 20), - np.arange(20, 30)))) - - def test_table_select(self): - db = self.db - table = db['test'] - - a = table.select() - assert_equal(a, list(zip(np.arange(10, 20), - np.arange(20, 30)))) - - a = table.select(order='a') - assert_equal(a, list(zip(np.arange(10, 20), - np.arange(20, 30)))) - - a = table.select(order='a', limit=2) - assert_equal(a, [(10, 20), (11, 21)]) - - a = table.select(order='a', limit=2, offset=2) - assert_equal(a, [(12, 22), (13, 23)]) - - a = table.select(order='a', where='a < 15') - assert_equal(a, [(10, 20), (11, 21), (12, 22), (13, 23), (14, 24)]) - - a = table.select(order='a', where='a < 15', limit=3) - assert_equal(a, [(10, 20), (11, 21), (12, 22)]) - - a = table.select(order='a', where='a < 15', limit=3, offset=2) - assert_equal(a, [(12, 22), (13, 23), (14, 24)]) - - a = table.select(columns=['a'], where='a < 15') - assert_equal(a, [(10,), (11,), (12,), (13,), (14,)]) - - def test_table_as_table(self): - db = self.db - table = db['test'] - - a = table.as_table() - assert_is_instance(a, Table) - assert_equal(a.colnames, ['a', 'b']) - assert_equal(a, Table(names=['a', 'b'], data=[np.arange(10, 20), - np.arange(20, 30)])) - - def test_table_as_table_empty(self): - db = SQLDatabase(':memory:') - db.add_table('test') - table = db['test'] - - a = table.as_table() - assert_is_instance(a, Table) - assert_equal(a.colnames, []) - assert_equal(a, Table()) - - def test_table_len(self): - db = self.db - table = db['test'] - assert_equal(len(table), 10) - - def test_table_iter(self): - db = self.db - table = db['test'] - - v = 10 - for i in table: - assert_equal(i, (v, v + 10)) - v += 1 - - def test_table_contains(self): - db = self.db - table = db['test'] - - assert_false(10 in table) - assert_false(20 in table) - assert_false('c' in table) - assert_true('a' in table) - assert_true('b' in table) - - def test_table_repr(self): - db = self.db - table = db['test'] - i = hex(id(table)) - - expect = "SQLTable 'test' in database ':memory:':" - expect += f"(2 columns x 10 rows)\n" - expect += '\n'.join(table.as_table().__repr__().split('\n')[1:]) - assert_is_instance(table, SQLTable) - assert_equal(repr(table), expect) - - def test_table_add_column(self): - db = self.db - table = db['test'] - - table.add_column('c', data=np.arange(10, 20)) - assert_equal(table.column_names, ['a', 'b', 'c']) - assert_equal(table.values, list(zip(np.arange(10, 20), - np.arange(20, 30), - np.arange(10, 20)))) - - table.add_column('d', data=np.arange(20, 30)) - assert_equal(table.column_names, ['a', 'b', 'c', 'd']) - assert_equal(table.values, list(zip(np.arange(10, 20), - np.arange(20, 30), - np.arange(10, 20), - np.arange(20, 30)))) - - def test_table_get_column(self): - db = self.db - table = db['test'] - - a = table.get_column('a') - assert_is_instance(a, SQLColumn) - assert_equal(a.values, np.arange(10, 20)) - - a = table.get_column('b') - assert_is_instance(a, SQLColumn) - assert_equal(a.values, np.arange(20, 30)) - - def test_table_set_column(self): - db = self.db - table = db['test'] - - table.set_column('a', np.arange(5, 15)) - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, list(zip(np.arange(5, 15), - np.arange(20, 30)))) - - def test_table_set_column_invalid(self): - db = self.db - table = db['test'] - - with assert_raises(ValueError): - table.set_column('a', np.arange(5, 16)) - - with assert_raises(KeyError): - table.set_column('c', np.arange(5, 15)) - - def test_table_add_row(self): - db = self.db - table = db['test'] - - table.add_rows({'a': -1, 'b': -1}) - assert_equal(table.column_names, ['a', 'b']) - assert_equal(len(table), 11) - assert_equal(table[-1].values, (-1, -1)) - - table.add_rows({'a': -2, 'c': -2}, add_columns=True) - assert_equal(table.column_names, ['a', 'b', 'c']) - assert_equal(len(table), 12) - assert_equal(table[-1].values, (-2, None, -2)) - - table.add_rows({'a': -3, 'd': -3}, add_columns=False) - assert_equal(table.column_names, ['a', 'b', 'c']) - assert_equal(len(table), 13) - assert_equal(table[-1].values, (-3, None, None)) - - # defult add_columns must be false - table.add_rows({'a': -4, 'b': -4, 'c': -4, 'd': -4}) - assert_equal(table.column_names, ['a', 'b', 'c']) - assert_equal(len(table), 14) - assert_equal(table[-1].values, (-4, -4, -4)) - - def test_table_add_row_invalid(self): - db = self.db - table = db['test'] - - with assert_raises(ValueError): - table.add_rows([1, 2, 3, 4]) - - with assert_raises(TypeError): - table.add_rows(2) - - def test_table_get_row(self): - db = self.db - table = db['test'] - - a = table.get_row(0) - assert_is_instance(a, SQLRow) - assert_equal(a.values, (10, 20)) - - a = table.get_row(1) - assert_is_instance(a, SQLRow) - assert_equal(a.values, (11, 21)) - - def test_table_set_row(self): - db = self.db - table = db['test'] - - table.set_row(0, {'a': 5, 'b': 15}) - expect = np.transpose([np.arange(10, 20), np.arange(20, 30)]) - expect[0] = [5, 15] - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, expect) - - expect[-1] = [-1, -1] - table.set_row(-1, {'a': -1, 'b': -1}) - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, expect) - - expect[-1] = [5, 5] - table.set_row(-1, [5, 5]) - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, expect) - - def test_table_set_row_invalid(self): - db = self.db - table = db['test'] - - with pytest.raises(IndexError): - table.set_row(10, {'a': -1, 'b': -1}) - with pytest.raises(IndexError): - table.set_row(-11, {'a': -1, 'b': -1}) - - with pytest.raises(TypeError): - table.set_row(0, 'a') - - def test_table_getitem_int(self): - db = self.db - table = db['test'] - assert_is_instance(table, SQLTable) - - assert_equal(table[0].values, (10, 20)) - assert_equal(table[-1].values, (19, 29)) - - with pytest.raises(IndexError): - table[10] - with pytest.raises(IndexError): - table[-11] - - def test_table_getitem_str(self): - db = self.db - table = db['test'] - assert_is_instance(table, SQLTable) - - assert_equal(table['a'].values, np.arange(10, 20)) - assert_equal(table['b'].values, np.arange(20, 30)) - - with pytest.raises(KeyError): - table['c'] - - def test_table_getitem_tuple(self): - db = self.db - table = db['test'] - assert_is_instance(table, SQLTable) - - assert_equal(table[('a',)].values, np.arange(10, 20)) - assert_is_instance(table[('a',)], SQLColumn) - assert_equal(table[(1,)].values, (11, 21)) - assert_is_instance(table[(1,)], SQLRow) - - with pytest.raises(KeyError): - table[('c')] - with pytest.raises(IndexError): - table[(11,)] - - def test_table_getitem_tuple_rowcol(self): - db = self.db - table = db['test'] - assert_is_instance(table, SQLTable) - - assert_equal(table['a', 0], 10) - assert_equal(table['a', 1], 11) - assert_equal(table['b', 0], 20) - assert_equal(table['b', 1], 21) - - assert_equal(table[0, 'a'], 10) - assert_equal(table[1, 'a'], 11) - assert_equal(table[0, 'b'], 20) - assert_equal(table[1, 'b'], 21) - - assert_equal(table['a', [0, 1, 2]], [10, 11, 12]) - assert_equal(table['b', [0, 1, 2]], [20, 21, 22]) - assert_equal(table[[0, 1, 2], 'b'], [20, 21, 22]) - assert_equal(table[[0, 1, 2], 'a'], [10, 11, 12]) - - assert_equal(table['a', 2:5], [12, 13, 14]) - assert_equal(table['b', 2:5], [22, 23, 24]) - assert_equal(table[2:5, 'b'], [22, 23, 24]) - assert_equal(table[2:5, 'a'], [12, 13, 14]) - - with pytest.raises(KeyError): - table['c', 0] - with pytest.raises(IndexError): - table['a', 11] - - with pytest.raises(KeyError): - table[0, 0] - with pytest.raises(KeyError): - table['b', 'a'] - with pytest.raises(KeyError): - table[0, 1, 2] - with pytest.raises(KeyError): - table[0, 'a', 'b'] - - def test_table_setitem_int(self): - db = self.db - table = db['test'] - assert_is_instance(table, SQLTable) - - table[0] = {'a': 5, 'b': 15} - expect = np.transpose([np.arange(10, 20), np.arange(20, 30)]) - expect[0] = [5, 15] - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, expect) - - table[-1] = {'a': -1, 'b': -1} - expect[-1] = [-1, -1] - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, expect) - - with pytest.raises(IndexError): - table[10] = {'a': -1, 'b': -1} - with pytest.raises(IndexError): - table[-11] = {'a': -1, 'b': -1} - - def test_table_setitem_str(self): - db = self.db - table = db['test'] - assert_is_instance(table, SQLTable) - - table['a'] = np.arange(40, 50) - expect = np.transpose([np.arange(40, 50), np.arange(20, 30)]) - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, expect) - - table['b'] = np.arange(10, 20) - expect = np.transpose([np.arange(40, 50), np.arange(10, 20)]) - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, expect) - - with pytest.raises(KeyError): - table['c'] = np.arange(10, 20) - - def test_table_setitem_tuple(self): - db = self.db - table = db['test'] - assert_is_instance(table, SQLTable) - - table[('a',)] = np.arange(40, 50) - expect = np.transpose([np.arange(40, 50), np.arange(20, 30)]) - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, expect) - - table[(1,)] = {'a': -1, 'b': -1} - expect[1] = [-1, -1] - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, expect) - - with pytest.raises(KeyError): - table[('c',)] = np.arange(10, 20) - with pytest.raises(IndexError): - table[(11,)] = np.arange(10, 20) - - def test_table_setitem_tuple_multiple(self): - db = self.db - table = db['test'] - assert_is_instance(table, SQLTable) - expect = np.transpose([np.arange(10, 20), np.arange(20, 30)]) - - table[('a', 1)] = 57 - expect[1, 0] = 57 - table['b', -1] = 32 - expect[-1, 1] = 32 - table[0, 'a'] = -1 - expect[0, 0] = -1 - table[5, 'b'] = 99 - expect[5, 1] = 99 - table['a', 3:6] = -999 - expect[3:6, 0] = -999 - table['b', [2, 7]] = -888 - expect[[2, 7], 1] = -888 - assert_equal(table.values, expect) - - with pytest.raises(KeyError): - table[('c',)] = np.arange(10, 20) - with pytest.raises(IndexError): - table[(11,)] = np.arange(10, 20) - with pytest.raises(KeyError): - table['a', 'c'] = None - with pytest.raises(KeyError): - table[2:5] = 2 - with pytest.raises(KeyError): - table[1, 2, 3] = 3 - - def test_table_indexof(self): - db = self.db - table = db['test'] - assert_equal(table.index_of({'a': 15}), 5) - assert_equal(table.index_of({'a': 50}), []) - assert_equal(table.index_of('a < 13'), [0, 1, 2]) - - def test_table_delete_row(self): - db = self.db - table = db['test'] - assert_is_instance(table, SQLTable) - - table.delete_row(0) - expect = np.transpose([np.arange(11, 20), np.arange(21, 30)]) - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, expect) - - table.delete_row(-1) - expect = np.transpose([np.arange(11, 19), np.arange(21, 29)]) - assert_equal(table.column_names, ['a', 'b']) - assert_equal(table.values, expect) - - with pytest.raises(IndexError): - table.delete_row(10) - with pytest.raises(IndexError): - table.delete_row(-11) - - def test_table_delete_rows_indexer_robustness(self): - db = self.db - table = db['test'] - - # Test that the row index is updated correctly after deleting rows - row = table[5] - table.delete_row(4) - assert_equal(row.index, 4) - assert_equal(table[4].values, row.values) - - def test_table_delete_column(self): - db = self.db - table = db['test'] - assert_is_instance(table, SQLTable) - - table.delete_column('a') - expect = np.transpose([np.arange(20, 30)]) - assert_equal(table.column_names, ['b']) - assert_equal(table.values, expect) - - with pytest.raises(KeyError): - table.delete_column('a') - with pytest.raises(KeyError): - table.delete_column('c') - - -class Test_SQLColumn: - @property - def db(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - return db - - def test_column_copy_error(self): - db = self.db - col = db['test']['a'] - with pytest.raises(NotImplementedError, match='Cannot copy'): - copy.copy(col) - with pytest.raises(NotImplementedError, match='Cannot copy'): - copy.deepcopy(col) - - def test_column_basic_properties(self): - db = self.db - table = db['test'] - column = table['a'] - - assert_equal(column.name, 'a') - assert_equal(column.table, 'test') - assert_equal(column.values, np.arange(10, 20)) - - def test_column_len(self): - db = self.db - table = db['test'] - column = table['a'] - assert_equal(len(column), 10) - - def test_column_repr(self): - db = self.db - table = db['test'] - column = table['a'] - assert_equal(repr(column), "SQLColumn a in table 'test' (10 rows)") - - def test_column_contains(self): - db = self.db - table = db['test'] - column = table['a'] - assert_true(15 in column) - assert_false(25 in column) - - def test_column_iter(self): - db = self.db - table = db['test'] - column = table['a'] - - v = 10 - for i in column: - assert_equal(i, v) - v += 1 - - def test_column_getitem_int(self): - db = self.db - table = db['test'] - column = table['a'] - - assert_equal(column[0], 10) - assert_equal(column[-1], 19) - - with pytest.raises(IndexError): - column[10] - with pytest.raises(IndexError): - column[-11] - - def test_column_getitem_list(self): - db = self.db - table = db['test'] - column = table['a'] - - assert_equal(column[[0, 1]], [10, 11]) - assert_equal(column[[-2, -1]], [18, 19]) - - with pytest.raises(IndexError): - column[[10, 11]] - with pytest.raises(IndexError): - column[[-11, -12]] - - def test_column_getitem_slice(self): - db = self.db - table = db['test'] - column = table['a'] - - assert_equal(column[:2], [10, 11]) - assert_equal(column[-2:], [18, 19]) - assert_equal(column[2:5], [12, 13, 14]) - assert_equal(column[::-1], [19, 18, 17, 16, 15, 14, 13, 12, 11, 10]) - - def test_column_getitem_tuple(self): - db = self.db - table = db['test'] - column = table['a'] - with pytest.raises(IndexError): - column[('a',)] - with pytest.raises(IndexError): - column[(1,)] - with pytest.raises(IndexError): - column[1, 2] - - def test_column_setitem_int(self): - db = self.db - table = db['test'] - column = table['a'] - - column[0] = 5 - assert_equal(db.get_row('test', 0).values, [5, 20]) - - column[-1] = -1 - assert_equal(db.get_row('test', -1).values, [-1, 29]) - - def test_column_setitem_list_slice(self): - db = self.db - table = db['test'] - column = table['a'] - - column[:] = -1 - assert_equal(db.get_column('test', 'a').values, [-1]*10) - column[[2, 4]] = 2 - assert_equal(db.get_column('test', 'a').values, [-1, -1, 2, -1, 2, - -1, -1, -1, -1, -1]) - - def test_column_setitem_invalid(self): - db = self.db - table = db['test'] - column = table['a'] - - with pytest.raises(IndexError): - column[10] = 10 - with pytest.raises(IndexError): - column[-11] = 10 - with pytest.raises(IndexError): - column[2, 4] = [10, 11] - - -class Test_SQLTableMapping: - @property - def table(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - db.add_table('mapping') - db.add_column('mapping', 'keywords', ['key a', 'key-b']) - db.add_column('mapping', 'columns', ['a', 'b']) - - map = SQLColumnMap(db, 'mapping', 'keywords', 'columns') - return SQLTable(db, 'test', colmap=map) - - def test_table_select(self): - table = self.table - - a = table.select() - assert_equal(a, list(zip(np.arange(10, 20), - np.arange(20, 30)))) - - a = table.select(order='key a') - assert_equal(a, list(zip(np.arange(10, 20), - np.arange(20, 30)))) - - a = table.select(order='key-b', limit=2) - assert_equal(a, [(10, 20), (11, 21)]) - - a = table.select(order='key a', limit=2, offset=2) - assert_equal(a, [(12, 22), (13, 23)]) - - a = table.select(order='key-b', where={'key a': 15}) - assert_equal(a, [(15, 25)]) - - def test_table_column_names(self): - table = self.table - assert_equal(table.column_names, ['key a', 'key-b']) - - def test_table_getitem_str(self): - table = self.table - - assert_equal(table['key a'].values, np.arange(10, 20)) - assert_equal(table['key-b'].values, np.arange(20, 30)) - - with pytest.raises(KeyError): - table['c'] - - def test_table_getitem_tuple(self): - table = self.table - assert_equal(table[('key a',)].values, np.arange(10, 20)) - assert_is_instance(table[('key a',)], SQLColumn) - assert_equal(table[(1,)].values, (11, 21)) - assert_is_instance(table[(1,)], SQLRow) - - with pytest.raises(KeyError): - table[('c')] - with pytest.raises(IndexError): - table[(11,)] - - def test_table_getitem_tuple_rowcol(self): - table = self.table - assert_equal(table['key a', 0], 10) - assert_equal(table['key a', 1], 11) - assert_equal(table['key-b', 0], 20) - assert_equal(table['key-b', 1], 21) - - assert_equal(table[0, 'key a'], 10) - assert_equal(table[1, 'key a'], 11) - assert_equal(table[0, 'key-b'], 20) - assert_equal(table[1, 'key-b'], 21) - - assert_equal(table['key a', [0, 1, 2]], [10, 11, 12]) - assert_equal(table['key-b', [0, 1, 2]], [20, 21, 22]) - assert_equal(table[[0, 1, 2], 'key-b'], [20, 21, 22]) - assert_equal(table[[0, 1, 2], 'key a'], [10, 11, 12]) - - assert_equal(table['key a', 2:5], [12, 13, 14]) - assert_equal(table['key-b', 2:5], [22, 23, 24]) - assert_equal(table[2:5, 'key-b'], [22, 23, 24]) - assert_equal(table[2:5, 'key a'], [12, 13, 14]) - - with pytest.raises(KeyError): - table['c', 0] - with pytest.raises(IndexError): - table['key a', 11] - - with pytest.raises(KeyError): - table[0, 0] - with pytest.raises(KeyError): - table['key-b', 'key a'] - with pytest.raises(KeyError): - table[0, 1, 2] - with pytest.raises(KeyError): - table[0, 'key a', 'key-b'] - - def test_table_set_row(self): - table = self.table - - table.set_row(0, {'key a': 5, 'key-b': 15}) - expect = np.transpose([np.arange(10, 20), np.arange(20, 30)]) - expect[0] = [5, 15] - assert_equal(table.column_names, ['key a', 'key-b']) - assert_equal(table.values, expect) - - expect[-1] = [-1, -1] - table.set_row(-1, {'key a': -1, 'key-b': -1}) - assert_equal(table.column_names, ['key a', 'key-b']) - assert_equal(table.values, expect) - - expect[-1] = [5, 5] - table.set_row(-1, [5, 5]) - assert_equal(table.column_names, ['key a', 'key-b']) - assert_equal(table.values, expect) - - def test_table_add_row(self): - table = self.table - - table.add_rows({'key a': -1, 'key-b': -1}) - assert_equal(table.column_names, ['key a', 'key-b']) - assert_equal(len(table), 11) - assert_equal(table[-1].values, (-1, -1)) - - table.add_rows({'key a': -2, 'key!c': -2}, add_columns=True) - assert_equal(table.column_names, ['key a', 'key-b', 'key!c']) - assert_equal(len(table), 12) - assert_equal(table[-1].values, (-2, None, -2)) - - table.add_rows({'key a': -3, 'key_d': -3}, add_columns=False) - assert_equal(table.column_names, ['key a', 'key-b', 'key!c']) - assert_equal(len(table), 13) - assert_equal(table[-1].values, (-3, None, None)) - - # defult add_columns must be false - table.add_rows({'key a': -4, 'key-b': -4, 'key!c': -4, 'key_d': -4}) - assert_equal(table.column_names, ['key a', 'key-b', 'key!c']) - assert_equal(len(table), 14) - assert_equal(table[-1].values, (-4, -4, -4)) - - def test_table_get_column(self): - table = self.table - - a = table.get_column('key a') - assert_is_instance(a, SQLColumn) - assert_equal(a.values, np.arange(10, 20)) - assert_equal(a.name, 'a') - - a = table.get_column('key-b') - assert_is_instance(a, SQLColumn) - assert_equal(a.values, np.arange(20, 30)) - assert_equal(a.name, 'b') - - def test_table_set_column(self): - table = self.table - - table.set_column('key a', np.arange(5, 15)) - assert_equal(table.column_names, ['key a', 'key-b']) - assert_equal(table.values, list(zip(np.arange(5, 15), - np.arange(20, 30)))) - - def test_table_set_column_invalid(self): - table = self.table - - with assert_raises(ValueError): - table.set_column('key a', np.arange(5, 16)) - - with assert_raises(KeyError): - table.set_column('key!c', np.arange(5, 15)) - - def test_table_add_column(self): - table = self.table - - table.add_column('key!c', data=np.arange(10, 20)) - assert_equal(table.column_names, ['key a', 'key-b', 'key!c']) - assert_equal(table.values, list(zip(np.arange(10, 20), - np.arange(20, 30), - np.arange(10, 20)))) - - table.add_column('key_d', data=np.arange(20, 30)) - assert_equal(table.column_names, ['key a', 'key-b', 'key!c', 'key_d']) - assert_equal(table.values, list(zip(np.arange(10, 20), - np.arange(20, 30), - np.arange(10, 20), - np.arange(20, 30)))) - - def test_table_contains(self): - table = self.table - - assert_false(10 in table) - assert_false(20 in table) - assert_false('key!c' in table) - assert_true('key a' in table) - assert_true('key-b' in table) - - def test_table_as_table_empty(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_table('mapping') - db.add_column('mapping', 'keywords') - db.add_column('mapping', 'columns') - table = SQLTable(db, 'test', - SQLColumnMap(db, 'mapping', 'keywords', 'columns')) - - a = table.as_table() - assert_is_instance(a, Table) - assert_equal(a.colnames, []) - assert_equal(a, Table()) - - def test_table_as_table(self): - table = self.table - - a = table.as_table() - assert_is_instance(a, Table) - assert_equal(a.colnames, ['key a', 'key-b']) - assert_equal(a, Table(names=['key a', 'key-b'], - data=[np.arange(10, 20), np.arange(20, 30)])) - - def test_table_setitem_int(self): - table = self.table - - table[0] = {'key a': 5, 'key-b': 15} - expect = np.transpose([np.arange(10, 20), np.arange(20, 30)]) - expect[0] = [5, 15] - assert_equal(table.column_names, ['key a', 'key-b']) - assert_equal(table.values, expect) - - table[-1] = {'key a': -1, 'key-b': -1} - expect[-1] = [-1, -1] - assert_equal(table.column_names, ['key a', 'key-b']) - assert_equal(table.values, expect) - - with pytest.raises(IndexError): - table[10] = {'key a': -1, 'key-b': -1} - with pytest.raises(IndexError): - table[-11] = {'key a': -1, 'key-b': -1} - - def test_table_setitem_str(self): - table = self.table - - table['key a'] = np.arange(40, 50) - expect = np.transpose([np.arange(40, 50), np.arange(20, 30)]) - assert_equal(table.column_names, ['key a', 'key-b']) - assert_equal(table.values, expect) - - table['key-b'] = np.arange(10, 20) - expect = np.transpose([np.arange(40, 50), np.arange(10, 20)]) - assert_equal(table.column_names, ['key a', 'key-b']) - assert_equal(table.values, expect) - - with pytest.raises(KeyError): - table['c'] = np.arange(10, 20) - - def test_table_setitem_tuple(self): - table = self.table - - table[('key a',)] = np.arange(40, 50) - expect = np.transpose([np.arange(40, 50), np.arange(20, 30)]) - assert_equal(table.column_names, ['key a', 'key-b']) - assert_equal(table.values, expect) - - table[(1,)] = {'key a': -1, 'key-b': -1} - expect[1] = [-1, -1] - assert_equal(table.column_names, ['key a', 'key-b']) - assert_equal(table.values, expect) - - def test_table_setitem_tuple_multiple(self): - table = self.table - expect = np.transpose([np.arange(10, 20), np.arange(20, 30)]) - - table[('key a', 1)] = 57 - expect[1, 0] = 57 - table['key-b', -1] = 32 - expect[-1, 1] = 32 - table[0, 'key a'] = -1 - expect[0, 0] = -1 - table[5, 'key-b'] = 99 - expect[5, 1] = 99 - table['key a', 3:6] = -999 - expect[3:6, 0] = -999 - table['key-b', [2, 7]] = -888 - expect[[2, 7], 1] = -888 - assert_equal(table.values, expect) - - def test_table_indexof(self): - table = self.table - assert_equal(table.index_of({'key a': 15}), 5) - assert_equal(table.index_of({'key a': 50}), []) - with pytest.raises(TypeError): - assert_equal(table.index_of('"key a" < 13'), [0, 1, 2]) - - -class Test_SQLRowMapping: - @property - def table(self): - db = SQLDatabase(':memory:') - db.add_table('test') - db.add_column('test', 'a', data=np.arange(10, 20)) - db.add_column('test', 'b', data=np.arange(20, 30)) - - db.add_table('mapping') - db.add_column('mapping', 'keywords', ['key a', 'key-b']) - db.add_column('mapping', 'columns', ['a', 'b']) - - map = SQLColumnMap(db, 'mapping', 'keywords', 'columns') - return SQLTable(db, 'test', colmap=map) - - def test_row_column_names(self): - table = self.table - r = table.get_row(5) - assert_equal(r.column_names, ['key a', 'key-b']) - assert_equal(r.keys, ['key a', 'key-b']) - r = table[5] - assert_equal(r.column_names, ['key a', 'key-b']) - assert_equal(r.keys, ['key a', 'key-b']) - - def test_row_invalid_index(self): - table = self.table - - with pytest.raises(IndexError): - table[10] - with pytest.raises(IndexError): - table[-11] - with pytest.raises(IndexError): - table.get_row(10) - with pytest.raises(IndexError): - table.get_row(-11) - - def test_row_getitem(self): - table = self.table - r = table.get_row(5) - assert_equal(r['key a'], 15) - assert_equal(r['key-b'], 25) - - with pytest.raises(KeyError): - r['c'] - - def test_row_setitem(self): - table = self.table - r = table.get_row(5) - r['key a'] = -1 - r['key-b'] = -2 - assert_equal(r.values, [-1, -2]) - assert_equal(table.values[5], [-1, -2]) - - def test_row_as_dict(self): - table = self.table - r = table.get_row(5) - assert_equal(r.as_dict(), {'key a': 15, 'key-b': 25}) - - def test_row_items(self): - table = self.table - r = table.get_row(5) - assert_equal(list(r.items), [('key a', 15), ('key-b', 25)]) From a37ea4c270e1e10d781593479299586d9ad19330 Mon Sep 17 00:00:00 2001 From: Julio Campagnolo Date: Fri, 13 Oct 2023 12:53:04 -0300 Subject: [PATCH 3/7] added dbastable as dependency --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e8059432..b1b617cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,8 @@ dependencies = [ "astroscrappy", "astroalign", "tqdm", - "nest-asyncio" + "nest-asyncio", + "dbastable" ] [project.optional-dependencies] From 893fb02183de66a4d1e9ec837667123892cf432c Mon Sep 17 00:00:00 2001 From: Julio Campagnolo Date: Sat, 14 Oct 2023 23:47:14 -0300 Subject: [PATCH 4/7] file_collection: do not use colmap anymore --- astropop/file_collection.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/astropop/file_collection.py b/astropop/file_collection.py index c109cc1e..65786d8d 100644 --- a/astropop/file_collection.py +++ b/astropop/file_collection.py @@ -54,9 +54,6 @@ def list_fits_files(location, fits_extensions=None, _headers = 'headers' _metadata = 'astropop_metadata' _files_col = '__file' -_keycolstable = 'astropop_keyword2column' -_keywords_col = 'keyword' -_columns_col = 'column' class FitsFileGroup(): @@ -143,9 +140,6 @@ def _read_db(self, files, location, compression, update=False): 'EXT': self._ext}, add_columns=True) self._db.add_column(_metadata, 'FITS_EXT', self._extensions) - self._db.add_table(_keycolstable) - self._db.add_column(_keycolstable, _keywords_col) - self._db.add_column(_keycolstable, _columns_col) self._include = self._db[_metadata, 'glob_include'][0] self._exclude = self._db[_metadata, 'glob_exclude'][0] From 299736c2d6a9047bc0e445b9b9143651c89d8a89 Mon Sep 17 00:00:00 2001 From: Julio Campagnolo Date: Thu, 19 Oct 2023 10:04:17 -0300 Subject: [PATCH 5/7] file_collection: add keys property to fitsfilegroup --- astropop/file_collection.py | 5 +++++ tests/test_filecollection.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/astropop/file_collection.py b/astropop/file_collection.py index 65786d8d..f64bd168 100644 --- a/astropop/file_collection.py +++ b/astropop/file_collection.py @@ -164,6 +164,11 @@ def summary(self): """Get a readonly table with summary of the fits files.""" return self._table.as_table() + @property + def keys(self): + """List the keywords of the headers table.""" + return self._table.column_names + def __copy__(self, indexes=None): """Copy the current instance to a new object.""" if indexes is None: diff --git a/tests/test_filecollection.py b/tests/test_filecollection.py index d2c5c764..be01e8ca 100644 --- a/tests/test_filecollection.py +++ b/tests/test_filecollection.py @@ -370,6 +370,16 @@ def test_fg_remove_file_str(self, tmpdir): fg.remove_file('NonExistingFile') +class Test_FitsFileGroup_Properties(): + def test_fg_keys(self, tmpdir): + tmpdir, flist = tmpdir + fg = FitsFileGroup(location=tmpdir/'fits', compression=False) + assert_equal(fg.keys, + ['__file', 'simple', 'bitpix', 'naxis', 'naxis1', + 'naxis2', 'obstype', 'exptime', 'observer', 'object', + 'filter', 'space key', 'image']) + + class Test_ListFitsFiles(): def test_list_custom_extension(self, tmpdir): tmpdir, flist = tmpdir From d2a1d0d6dd3af43ebe34d92b4de31774f2d54f8c Mon Sep 17 00:00:00 2001 From: Julio Campagnolo Date: Thu, 19 Oct 2023 15:43:33 -0300 Subject: [PATCH 6/7] file_collection: expose list_fits_files --- astropop/file_collection.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/astropop/file_collection.py b/astropop/file_collection.py index f64bd168..74bcc256 100644 --- a/astropop/file_collection.py +++ b/astropop/file_collection.py @@ -16,12 +16,30 @@ from .framedata import check_framedata from .logger import logger -__all__ = ['FitsFileGroup'] +__all__ = ['FitsFileGroup', 'list_fits_files'] def list_fits_files(location, fits_extensions=None, glob_include=None, glob_exclude=None): - """List all fist files in a directory, if compressed or not.""" + """List all fist files in a directory, if compressed or not. + + Parameters + ---------- + location : str + Main directory to look for the files. Files will be listed recursively. + fits_extensions : str or list, optional + FITS file name extension to be used. Default is None, wich means + that the default extensions will be used, like '.fits' and '.fit'. + glob_include : str, optional + Glob pattern to include files. Default is None. + glob_exclude : str, optional + Glob pattern to exclude files. Default is None. + + Returns + ------- + list + List of files found. + """ if fits_extensions is None: fits_extensions = _fits_extensions From e1ff1cb6eb02bc26b493e774140748340f2cb8ef Mon Sep 17 00:00:00 2001 From: Julio Campagnolo Date: Thu, 19 Oct 2023 15:46:44 -0300 Subject: [PATCH 7/7] fixed pep8 --- tests/test_filecollection.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/test_filecollection.py b/tests/test_filecollection.py index be01e8ca..a5fd5316 100644 --- a/tests/test_filecollection.py +++ b/tests/test_filecollection.py @@ -68,12 +68,12 @@ def create_test_files(tmpdir, extension='fits'): warnings.simplefilter('ignore', category=fits.verify.VerifyWarning) hdr = fits.Header({'obstype': 'bias', - 'exptime': 0.0001, - 'observer': 'Galileo Galileo', - 'object': 'bias', - 'filter': '', - 'space key': 1, - 'image': iname}) + 'exptime': 0.0001, + 'observer': 'Galileo Galileo', + 'object': 'bias', + 'filter': '', + 'space key': 1, + 'image': iname}) hdu = fits.PrimaryHDU(np.ones((8, 8), dtype=np.int16), hdr) hdu.writeto(fname) files_list.append(str(fname)) @@ -388,7 +388,8 @@ def test_list_custom_extension(self, tmpdir): assert_equal(sorted(found_files), sorted(flist['myfits'])) found_files = list_fits_files(tmpdir/'custom', - fits_extensions=['.myfits', '.otherfits']) + fits_extensions=['.myfits', + '.otherfits']) assert_equal(sorted(found_files), sorted(flist['custom'])) @pytest.mark.parametrize('ext', ['fits', 'fz', 'fit', 'fts'])