From e4b9b582cdb4e48430865f8739f341bc8017c1e4 Mon Sep 17 00:00:00 2001 From: Taj Khattra Date: Thu, 16 Nov 2023 22:22:05 -0800 Subject: [PATCH] Add more STRICT table support per https://github.com/simonw/sqlite-utils/issues/344#issuecomment-982014776. Make table.transform() preserve STRICT mode. --- docs/cli-reference.rst | 3 +++ docs/cli.rst | 21 +++++++++++++++++++- docs/python-api.rst | 18 ++++++++++++++++- sqlite_utils/cli.py | 19 ++++++++++++++++++ sqlite_utils/db.py | 38 ++++++++++++++++++++++++++++++++++-- tests/test_cli.py | 29 +++++++++++++++++++++++++++ tests/test_create.py | 43 +++++++++++++++++++++++++++++++++++++++++ tests/test_lookup.py | 6 ++++++ tests/test_transform.py | 9 +++++++++ 9 files changed, 182 insertions(+), 4 deletions(-) diff --git a/docs/cli-reference.rst b/docs/cli-reference.rst index 02f66cd12..81bf5e329 100644 --- a/docs/cli-reference.rst +++ b/docs/cli-reference.rst @@ -293,6 +293,7 @@ See :ref:`cli_inserting_data`, :ref:`cli_insert_csv_tsv`, :ref:`cli_insert_unstr --replace Replace records if pk already exists --truncate Truncate table before inserting records, if table already exists + --strict Apply STRICT mode to table -h, --help Show this message and exit. @@ -345,6 +346,7 @@ See :ref:`cli_upsert`. --analyze Run ANALYZE at the end of this operation --load-extension TEXT Path to SQLite extension, with optional :entrypoint --silent Do not show progress bar + --strict Apply STRICT mode to table -h, --help Show this message and exit. @@ -920,6 +922,7 @@ See :ref:`cli_create_table`. --replace If table already exists, replace it --transform If table already exists, try to transform the schema --load-extension TEXT Path to SQLite extension, with optional :entrypoint + --strict Apply STRICT mode to table -h, --help Show this message and exit. diff --git a/docs/cli.rst b/docs/cli.rst index 5d9deb084..fbc0a5b06 100644 --- a/docs/cli.rst +++ b/docs/cli.rst @@ -1972,6 +1972,25 @@ You can specify foreign key relationships between the tables you are creating us [author_id] INTEGER REFERENCES [authors]([id]) ) +You can create a table in STRICT mode using ``--strict``: + +.. code-block:: bash + + sqlite-utils create-table mydb.db mytable id integer name text --strict + +.. code-block:: bash + + sqlite-utils tables mydb.db --schema -t + +.. code-block:: output + + table schema + ------- ------------------------ + mytable CREATE TABLE [mytable] ( + [id] INTEGER, + [name] TEXT + ) STRICT + If a table with the same name already exists, you will get an error. You can choose to silently ignore this error with ``--ignore``, or you can replace the existing table with a new, empty table using ``--replace``. You can also pass ``--transform`` to transform the existing table to match the new schema. See :ref:`python_api_explicit_create` in the Python library documentation for details of how this option works. @@ -2018,7 +2037,7 @@ Use ``--ignore`` to ignore the error if the table does not exist. Transforming tables =================== -The ``transform`` command allows you to apply complex transformations to a table that cannot be implemented using a regular SQLite ``ALTER TABLE`` command. See :ref:`python_api_transform` for details of how this works. +The ``transform`` command allows you to apply complex transformations to a table that cannot be implemented using a regular SQLite ``ALTER TABLE`` command. See :ref:`python_api_transform` for details of how this works. The ``transform`` command preserves a table's ``STRICT`` mode. .. code-block:: bash diff --git a/docs/python-api.rst b/docs/python-api.rst index 9d396c655..2a9caf2cf 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -117,6 +117,12 @@ By default, any :ref:`sqlite-utils plugins ` that implement the :ref:`p db = Database(memory=True, execute_plugins=False) +You can pass ``strict=True`` to enable ``STRICT`` mode for all tables created from a database object: + +.. code-block:: python + + db = Database("my_database.db", strict=True) + .. _python_api_attach: Attaching additional databases @@ -581,6 +587,15 @@ The ``transform=True`` option will update the table schema if any of the followi Changes to ``foreign_keys=`` are not currently detected and applied by ``transform=True``. +You can pass ``strict=True`` to create a table in ``STRICT`` mode: + +.. code-block:: python + + db["cats"].create({ + "id": int, + "name": str, + }, strict=True) + .. _python_api_compound_primary_keys: Compound primary keys @@ -661,7 +676,7 @@ You can set default values for these methods by accessing the table through the # Now you can call .insert() like so: table.insert({"id": 1, "name": "Tracy", "score": 5}) -The configuration options that can be specified in this way are ``pk``, ``foreign_keys``, ``column_order``, ``not_null``, ``defaults``, ``batch_size``, ``hash_id``, ``hash_id_columns``, ``alter``, ``ignore``, ``replace``, ``extracts``, ``conversions``, ``columns``. These are all documented below. +The configuration options that can be specified in this way are ``pk``, ``foreign_keys``, ``column_order``, ``not_null``, ``defaults``, ``batch_size``, ``hash_id``, ``hash_id_columns``, ``alter``, ``ignore``, ``replace``, ``extracts``, ``conversions``, ``columns``, ``strict``. These are all documented below. .. _python_api_defaults_not_null: @@ -1011,6 +1026,7 @@ The first time this is called the record will be created for ``name="Palm"``. An - ``extracts`` - ``conversions`` - ``columns`` +- ``strict`` .. _python_api_extracts: diff --git a/sqlite_utils/cli.py b/sqlite_utils/cli.py index 5821db64c..2dab83bbb 100644 --- a/sqlite_utils/cli.py +++ b/sqlite_utils/cli.py @@ -900,6 +900,12 @@ def inner(fn): ), load_extension_option, click.option("--silent", is_flag=True, help="Do not show progress bar"), + click.option( + "--strict", + is_flag=True, + default=False, + help="Apply STRICT mode to table", + ), ) ): fn = decorator(fn) @@ -942,6 +948,7 @@ def insert_upsert_implementation( silent=False, bulk_sql=None, functions=None, + strict=False, ): db = sqlite_utils.Database(path) _load_extensions(db, load_extension) @@ -1057,6 +1064,7 @@ def insert_upsert_implementation( "replace": replace, "truncate": truncate, "analyze": analyze, + "strict": strict, } if not_null: extra_kwargs["not_null"] = set(not_null) @@ -1177,6 +1185,7 @@ def insert( truncate, not_null, default, + strict, ): """ Insert records from FILE into a table, creating the table if it @@ -1255,6 +1264,7 @@ def insert( silent=silent, not_null=not_null, default=default, + strict=strict, ) except UnicodeDecodeError as ex: raise click.ClickException(UNICODE_ERROR.format(ex)) @@ -1290,6 +1300,7 @@ def upsert( analyze, load_extension, silent, + strict, ): """ Upsert records based on their primary key. Works like 'insert' but if @@ -1334,6 +1345,7 @@ def upsert( analyze=analyze, load_extension=load_extension, silent=silent, + strict=strict, ) except UnicodeDecodeError as ex: raise click.ClickException(UNICODE_ERROR.format(ex)) @@ -1502,6 +1514,11 @@ def create_database(path, enable_wal, init_spatialite, load_extension): help="If table already exists, try to transform the schema", ) @load_extension_option +@click.option( + "--strict", + is_flag=True, + help="Apply STRICT mode to table", +) def create_table( path, table, @@ -1514,6 +1531,7 @@ def create_table( replace, transform, load_extension, + strict, ): """ Add a table with the specified columns. Columns should be specified using @@ -1561,6 +1579,7 @@ def create_table( ignore=ignore, replace=replace, transform=transform, + strict=strict, ) diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index 1baa32e5b..5e6a72a35 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -300,6 +300,7 @@ class Database: ``sql, parameters`` every time a SQL query is executed :param use_counts_table: set to ``True`` to use a cached counts table, if available. See :ref:`python_api_cached_table_counts` + :param strict: Apply STRICT mode to all created tables (unless overridden) """ _counts_table_name = "_counts" @@ -315,6 +316,7 @@ def __init__( tracer: Optional[Callable] = None, use_counts_table: bool = False, execute_plugins: bool = True, + strict: bool = False, ): assert (filename_or_conn is not None and (not memory and not memory_name)) or ( filename_or_conn is None and (memory or memory_name) @@ -348,6 +350,7 @@ def __init__( self.use_counts_table = use_counts_table if execute_plugins: pm.hook.prepare_connection(conn=self.conn) + self.strict = strict def close(self): "Close the SQLite connection, and the underlying database file" @@ -534,7 +537,11 @@ def table(self, table_name: str, **kwargs) -> Union["Table", "View"]: :param table_name: Name of the table """ - klass = View if table_name in self.view_names() else Table + if table_name in self.view_names(): + klass = View + else: + klass = Table + kwargs.setdefault("strict", self.strict) return klass(self, table_name, **kwargs) def quote(self, value: str) -> str: @@ -821,6 +828,7 @@ def create_table_sql( hash_id_columns: Optional[Iterable[str]] = None, extracts: Optional[Union[Dict[str, str], List[str]]] = None, if_not_exists: bool = False, + strict: bool = False, ) -> str: """ Returns the SQL ``CREATE TABLE`` statement for creating the specified table. @@ -836,6 +844,7 @@ def create_table_sql( :param hash_id_columns: List of columns to be used when calculating the hash ID for a row :param extracts: List or dictionary of columns to be extracted during inserts, see :ref:`python_api_extracts` :param if_not_exists: Use ``CREATE TABLE IF NOT EXISTS`` + :param strict: Apply STRICT mode to table """ if hash_id_columns and (hash_id is None): hash_id = "id" @@ -932,12 +941,13 @@ def sort_key(p): columns_sql = ",\n".join(column_defs) sql = """CREATE TABLE {if_not_exists}[{table}] ( {columns_sql}{extra_pk} -); +){strict}; """.format( if_not_exists="IF NOT EXISTS " if if_not_exists else "", table=name, columns_sql=columns_sql, extra_pk=extra_pk, + strict=" STRICT" if strict and self.supports_strict else "", ) return sql @@ -957,6 +967,7 @@ def create_table( replace: bool = False, ignore: bool = False, transform: bool = False, + strict: bool = False, ) -> "Table": """ Create a table with the specified name and the specified ``{column_name: type}`` columns. @@ -977,6 +988,7 @@ def create_table( :param replace: Drop and replace table if it already exists :param ignore: Silently do nothing if table already exists :param transform: If table already exists transform it to fit the specified schema + :param strict: Apply STRICT mode to table """ # Transform table to match the new definition if table already exists: if self[name].exists(): @@ -1048,6 +1060,7 @@ def create_table( hash_id_columns=hash_id_columns, extracts=extracts, if_not_exists=if_not_exists, + strict=strict, ) self.execute(sql) created_table = self.table( @@ -1416,6 +1429,7 @@ class Table(Queryable): :param extracts: Dictionary or list of column names to extract into a separate table on inserts :param conversions: Dictionary of column names and conversion functions :param columns: Dictionary of column names to column types + :param strict: If True, apply STRICT mode to table """ #: The ``rowid`` of the last inserted, updated or selected row. @@ -1441,6 +1455,7 @@ def __init__( extracts: Optional[Union[Dict[str, str], List[str]]] = None, conversions: Optional[dict] = None, columns: Optional[Dict[str, Any]] = None, + strict: bool = False, ): super().__init__(db, name) self._defaults = dict( @@ -1458,6 +1473,7 @@ def __init__( extracts=extracts, conversions=conversions or {}, columns=columns, + strict=strict, ) def __repr__(self) -> str: @@ -1639,6 +1655,7 @@ def create( replace: bool = False, ignore: bool = False, transform: bool = False, + strict: bool = False, ) -> "Table": """ Create a table with the specified columns. @@ -1658,6 +1675,7 @@ def create( :param replace: Drop and replace table if it already exists :param ignore: Silently do nothing if table already exists :param transform: If table already exists transform it to fit the specified schema + :param strict: Apply STRICT mode to table """ columns = {name: value for (name, value) in columns.items()} with self.db.conn: @@ -1676,6 +1694,7 @@ def create( replace=replace, ignore=ignore, transform=transform, + strict=strict, ) return self @@ -1909,6 +1928,7 @@ def transform_sql( defaults=create_table_defaults, foreign_keys=create_table_foreign_keys, column_order=column_order, + strict=self.strict, ).strip() ) @@ -3111,6 +3131,7 @@ def insert( extracts: Optional[Union[Dict[str, str], List[str], Default]] = DEFAULT, conversions: Optional[Union[Dict[str, str], Default]] = DEFAULT, columns: Optional[Union[Dict[str, Any], Default]] = DEFAULT, + strict: Optional[Union[bool, Default]] = DEFAULT, ) -> "Table": """ Insert a single record into the table. The table will be created with a schema that matches @@ -3143,6 +3164,7 @@ def insert( is being inserted, for example ``{"name": "upper(?)"}``. See :ref:`python_api_conversions`. :param columns: Dictionary over-riding the detected types used for the columns, for example ``{"age": int, "weight": float}``. + :param strict: Boolean, apply STRICT mode if creating the table. """ return self.insert_all( [record], @@ -3159,6 +3181,7 @@ def insert( extracts=extracts, conversions=conversions, columns=columns, + strict=strict, ) def insert_all( @@ -3181,6 +3204,7 @@ def insert_all( columns=DEFAULT, upsert=False, analyze=False, + strict=DEFAULT, ) -> "Table": """ Like ``.insert()`` but takes a list of records and ensures that the table @@ -3202,6 +3226,7 @@ def insert_all( extracts = self.value_or_default("extracts", extracts) conversions = self.value_or_default("conversions", conversions) or {} columns = self.value_or_default("columns", columns) + strict = self.value_or_default("strict", strict) if hash_id_columns and hash_id is None: hash_id = "id" @@ -3257,6 +3282,7 @@ def insert_all( hash_id=hash_id, hash_id_columns=hash_id_columns, extracts=extracts, + strict=strict, ) all_columns_set = set() for record in chunk: @@ -3307,6 +3333,7 @@ def upsert( extracts=DEFAULT, conversions=DEFAULT, columns=DEFAULT, + strict=DEFAULT, ) -> "Table": """ Like ``.insert()`` but performs an ``UPSERT``, where records are inserted if they do @@ -3327,6 +3354,7 @@ def upsert( extracts=extracts, conversions=conversions, columns=columns, + strict=strict, ) def upsert_all( @@ -3345,6 +3373,7 @@ def upsert_all( conversions=DEFAULT, columns=DEFAULT, analyze=False, + strict=DEFAULT, ) -> "Table": """ Like ``.upsert()`` but can be applied to a list of records. @@ -3365,6 +3394,7 @@ def upsert_all( columns=columns, upsert=True, analyze=analyze, + strict=strict, ) def add_missing_columns(self, records: Iterable[Dict[str, Any]]) -> "Table": @@ -3387,6 +3417,7 @@ def lookup( extracts: Optional[Union[Dict[str, str], List[str]]] = None, conversions: Optional[Dict[str, str]] = None, columns: Optional[Dict[str, Any]] = None, + strict: Optional[bool] = False, ): """ Create or populate a lookup table with the specified values. @@ -3409,6 +3440,7 @@ def lookup( :param lookup_values: Dictionary specifying column names and values to use for the lookup :param extra_values: Additional column values to be used only if creating a new record + :param strict: Boolean, apply STRICT mode if creating the table. """ assert isinstance(lookup_values, dict) if extra_values is not None: @@ -3440,6 +3472,7 @@ def lookup( extracts=extracts, conversions=conversions, columns=columns, + strict=strict, ).last_pk else: pk = self.insert( @@ -3452,6 +3485,7 @@ def lookup( extracts=extracts, conversions=conversions, columns=columns, + strict=strict, ).last_pk self.create_index(lookup_values.keys(), unique=True) return pk diff --git a/tests/test_cli.py b/tests/test_cli.py index 0a9350444..18d54dbc6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2391,3 +2391,32 @@ def test_load_extension(entrypoint, should_pass, should_fail): catch_exceptions=False, ) assert result.exit_code == 1 + + +@pytest.mark.parametrize("strict", (False, True)) +def test_create_table_strict(strict): + runner = CliRunner() + with runner.isolated_filesystem(): + db = Database("test.db") + result = runner.invoke( + cli.cli, + ["create-table", "test.db", "items", "id", "integer"] + + (["--strict"] if strict else []), + ) + assert result.exit_code == 0 + assert db["items"].strict == strict or not db.supports_strict + + +@pytest.mark.parametrize("method", ("insert", "upsert")) +@pytest.mark.parametrize("strict", (False, True)) +def test_insert_upsert_strict(tmpdir, method, strict): + db_path = str(tmpdir / "test.db") + result = CliRunner().invoke( + cli.cli, + [method, db_path, "items", "-", "--csv", "--pk", "id"] + + (["--strict"] if strict else []), + input="id\n1", + ) + assert result.exit_code == 0 + db = Database(db_path) + assert db["items"].strict == strict or not db.supports_strict diff --git a/tests/test_create.py b/tests/test_create.py index a88374f65..3bf1004a7 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1316,3 +1316,46 @@ def test_rename_table(fresh_db): # Should error if table does not exist: with pytest.raises(sqlite3.OperationalError): fresh_db.rename_table("does_not_exist", "renamed") + + +@pytest.mark.parametrize("strict", (False, True)) +def test_database_strict(strict): + db = Database(memory=True, strict=strict) + table = db.table("t", columns={"id": int}) + table.insert({"id": 1}) + assert table.strict == strict or not db.supports_strict + + +@pytest.mark.parametrize("strict", (False, True)) +def test_database_strict_override(strict): + db = Database(memory=True, strict=strict) + table = db.table("t", columns={"id": int}, strict=not strict) + table.insert({"id": 1}) + assert table.strict != strict or not db.supports_strict + + +@pytest.mark.parametrize( + "method_name", ("insert", "upsert", "insert_all", "upsert_all") +) +@pytest.mark.parametrize("strict", (False, True)) +def test_insert_upsert_strict(fresh_db, method_name, strict): + table = fresh_db["t"] + method = getattr(table, method_name) + record = {"id": 1} + if method_name.endswith("_all"): + record = [record] + method(record, pk="id", strict=strict) + assert table.strict == strict or not fresh_db.supports_strict + + +@pytest.mark.parametrize("strict", (False, True)) +def test_create_table_strict(fresh_db, strict): + table = fresh_db.create_table("t", {"id": int}, strict=strict) + assert table.strict == strict or not fresh_db.supports_strict + + +@pytest.mark.parametrize("strict", (False, True)) +def test_create_strict(fresh_db, strict): + table = fresh_db["t"] + table.create({"id": int}, strict=strict) + assert table.strict == strict or not fresh_db.supports_strict diff --git a/tests/test_lookup.py b/tests/test_lookup.py index 31be414ce..c3855d935 100644 --- a/tests/test_lookup.py +++ b/tests/test_lookup.py @@ -151,3 +151,9 @@ def test_lookup_with_extra_insert_parameters(fresh_db): columns=["name", "type"], ) ] + + +@pytest.mark.parametrize("strict", (False, True)) +def test_lookup_new_table(fresh_db, strict): + fresh_db["species"].lookup({"name": "Palm"}, strict=strict) + assert fresh_db["species"].strict == strict or not fresh_db.supports_strict diff --git a/tests/test_transform.py b/tests/test_transform.py index 1894494a5..111236ddd 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -530,3 +530,12 @@ def test_transform_preserves_rowids(fresh_db, table_type): tuple(row) for row in fresh_db.execute("select rowid, id, name from places") ) assert previous_rows == next_rows + + +@pytest.mark.parametrize("strict", (False, True)) +def test_transform_strict(fresh_db, strict): + dogs = fresh_db.table("dogs", strict=strict) + dogs.insert({"id": 1, "name": "Cleo"}) + assert dogs.strict == strict or not fresh_db.supports_strict + dogs.transform(not_null={"name"}) + assert dogs.strict == strict or not fresh_db.supports_strict