diff --git a/.travis.yml b/.travis.yml index d1ec304e..b64b6f28 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,3 +6,7 @@ install: - pip install -r requirements.txt - pip install -r tests-requirements.txt script: py.test + +before_script: + - psql -c 'create database eloquent_test;' -U postgres + - mysql -e 'create database eloquent_test;' diff --git a/CHANGELOG.md b/CHANGELOG.md index 442b25d6..5130c784 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +### 0.4 + +(April 28th, 2015) + +- Adds Schema Builder +- Adds scopes support +- Adds support for related name in relationships declaration + +### 0.3.1 + +(April 19th, 2015) + +- Fix MySQLdb compatibiity issues +- Fix wrong default key value for Builder.lists() method + ### 0.3 (April 3th, 2015) diff --git a/README.rst b/README.rst index efff956b..66eb114a 100644 --- a/README.rst +++ b/README.rst @@ -30,7 +30,7 @@ The different dbapi packages are not part of the package dependencies, so you must install them in order to connect to corresponding databases: * Postgres: ``pyscopg2`` -* MySQL: ``PyMySQL`` or ``MySQL-python`` +* MySQL: ``PyMySQL`` or ``mysqlclient`` * Sqlite: The ``sqlite3`` module is bundled with Python by default diff --git a/docs/_static/theme_overrides.css b/docs/_static/theme_overrides.css index d6b51f48..7316ff2e 100644 --- a/docs/_static/theme_overrides.css +++ b/docs/_static/theme_overrides.css @@ -4,6 +4,7 @@ body { font: 400 14px/28px "Roboto",sans-serif; + color: #7A7A7A; } .wy-body-for-nav { @@ -110,6 +111,40 @@ article ul li { background: none; } +.rst-content table.docutils, +.rst-content table.field-list { + border: medium none; + width: 100%; +} + +.rst-content table.docutils thead th, +.rst-content table.field-list thead th +{ + font-family: "Montserrat", serif; + text-transform: uppercase; + color: #7A7A7A; + font-weight: 500; +} + +.rst-content table.docutils thead th, +.rst-content table.field-list thead th, +.rst-content table.docutils tbody td, +.rst-content table.field-list tbody td +{ + border-top: 0; + border-bottom: 1px solid rgba(230, 230, 230, 0.7); + border-left: 0; + border-right: 0; + padding: 20px; + white-space: pre-wrap; +} + +.wy-table-odd td, +.wy-table-striped tr:nth-child(2n-1) td, +.rst-content table.docutils:not(.field-list) tr:nth-child(2n-1) td { + background: none; +} + .rst-content .highlighted { background: rgba(205, 220, 57, 0.7); border: 1px solid rgba(205, 220, 57, 0.9); @@ -448,7 +483,24 @@ green #859900 operators, other keywords .highlight .o { color: #7A7A7A } /* Operator */ .highlight .n { color: #7A7A7A } .highlight .c { color: #B8AFAD; font-style: normal; } +.highlight .sd { color: #B8AFAD; font-style: normal; } .highlight .nl { color: #A89BB9 } /* Name.Label */ .highlight .nn { color: #A89BB9 } /* Name.Namespace */ .highlight .nx { color: #A89BB9 } /* Name.Other */ .highlight .py { color: #A89BB9 } /* Name.Property */ + +/* Inline code */ + +tt.code, code.code { + border: 1px solid #F2F2F2; + font-weight: normal; + display: inline; + padding: 0.3em 0.5em; + background-color: #FAFAFA; +} +tt.code .name, code.code .name { color: #7A7A7A } +tt.code .operator, code.code .operator { color: #7A7A7A } +tt.code .punctuation, code.code .punctuation { color: #7A7A7A } +tt.code .string, code.code .string { color: #7BBDA4 } +tt.code .number, code.code .number { color: #F4BC87 } +tt.code .integer, code.code .integer { color: #F4BC87 } diff --git a/docs/conf.py b/docs/conf.py index 8e758a58..2938df1a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -53,9 +53,9 @@ # built documents. # # The short X.Y version. -version = '0.3' +version = '0.4' # The full version, including alpha/beta/rc tags. -release = '0.3.1' +release = '0.4' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/index.rst b/docs/index.rst index 292a14cb..5ba39136 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,3 +13,4 @@ but modified to be more pythonic. basic_usage query_builder orm + schema_builder diff --git a/docs/installation.rst b/docs/installation.rst index 9acf7a53..598ef98b 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -17,5 +17,5 @@ You can install Eloquent in 2 different ways: so you must install them in order to connect to corresponding databases: * PostgreSQL: ``pyscopg2`` - * MySQL: ``PyMySQL`` or ``MySQL-python`` + * MySQL: ``PyMySQL`` or ``mysqlclient`` * SQLite: The ``sqlite3`` module is bundled with Python by default diff --git a/docs/orm.rst b/docs/orm.rst index f3384398..3628ef3a 100644 --- a/docs/orm.rst +++ b/docs/orm.rst @@ -280,8 +280,15 @@ You can also run updates as queries against a set of models: affected_rows = User.where('votes', '>', 100).update(status=2) -.. - TODO: push method +Saving a model and relationships +-------------------------------- + +Sometimes you may wish to save not only a model, but also all of its relationships. +To do so, you can use the ``push`` method: + +.. code-block:: python + + user.push() Deleting an existing model @@ -322,6 +329,48 @@ If you want to only update the timestamps on a model, you can use the ``touch`` user.touch() +Soft deleting +============= + +When soft deleting a model, it is not actually removed from your database. +Instead, a ``deleted_at`` timestamp is set on the record. +To enable soft deletes for a model, make it inherit from the ``SoftDeletes`` mixin: + +.. code-block:: python + + from eloquent import Model, SoftDeletes + + + class User(Model, SoftDeletes): + + __dates__ = ['deleted_at'] + +To add a ``deleted_at`` column to your table, you may use the ``soft_deletes`` method from a migration (see :ref:`SchemaBuilder`): + +.. code-block:: python + + table.soft_deletes() + +Now, when you call the ``delete`` method on the model, the ``deleted_at`` column will be +set to the current timestamp. When querying a model that uses soft deletes, +the "deleted" models will not be included in query results. + +Forcing soft deleted models into results +---------------------------------------- + +To force soft deleted models to appear in a result set, use the ``with_trashed`` method on the query: + +.. code-block:: python + + users = User.with_trashed().where('account_id', 1).get() + +The ``with_trashed`` method may be used on a defined relationship: + +.. code-block:: python + + user.posts().with_trashed().get() + + Relationships ============= @@ -515,19 +564,19 @@ The tables for this relationship would look like this: .. code-block:: yaml - countries: - id: integer - name: string + countries + id - integer + name - string users: - id: integer - country_id: integer - name: string + id - integer + country_id - integer + name - string posts: - id: integer - user_id: integer - title: string + id - integer + user_id - integer + title - string Even though the ``posts`` table does not contain a ``country_id`` column, the ``has_many_through`` relation will allow access a country's posts via ``country.posts``: @@ -911,11 +960,9 @@ You can also pass conditions: .. code-block:: python - books.load( - { - 'author': Author.query().where('name', 'like', '%foo%') - } - ) + books.load({ + 'author': Author.query().where('name', 'like', '%foo%') + }) Inserting related models @@ -1138,10 +1185,137 @@ or the ``to_json`` methods, you can override the ``get_date_format`` method: class User(Model): - def get_date_format(): + def get_date_format(self): return 'DD-MM-YY' +Query Scopes +============ + +Defining a query scope +---------------------- + +Scopes allow you to easily re-use query logic in your models. +To define a scope, simply prefix a model method with ``scope``: + +.. code-block:: python + + class User(Model): + + def scope_popular(self, query): + return query.where('votes', '>', 100) + + def scope_women(self, query): + return query.where_gender('W') + +Using a query scope +----------------------- + +.. code-block:: python + + users = User.popular().women().order_by('created_at').get() + +Dynamic scopes +-------------- + +Sometimes you may wish to define a scope that accepts parameters. +Just add your parameters to your scope function: + +.. code-block:: python + + class User(Model): + + def scope_of_type(self, query, type): + return query.where_type(type) + +Then pass the parameter into the scope call: + +.. code-block:: python + + users = User.of_type('member').get() + + +Global Scopes +============= + +Sometimes you may wish to define a scope that applies to all queries performed on a model. +In essence, this is how Eloquent's own "soft delete" feature works. +Global scopes are defined using a combination of mixins and an implementation of the ``Scope`` class. + +First, let's define a mixin. For this example, we'll use the ``SoftDeletes`` that ships with Eloquent: + +.. code-block:: python + + from eloquent import SoftDeletingScope + + + class SoftDeletes(object): + + @classmethod + def boot_soft_deletes(cls, model_class): + """ + Boot the soft deleting mixin for a model. + """ + model_class.add_global_scope(SoftDeletingScope()) + + +If an Eloquent model inherits from a mixin that has a method matching the ``boot_name_of_trait`` +naming convention, that mixin method will be called when the Eloquent model is booted, +giving you an opportunity to register a global scope, or do anything else you want. +A scope must be an instance of the ``Scope`` class, which specifies two methods: ``apply`` and ``remove``. + +The apply method receives an ``Builder`` query builder object and the ``Model`` it's applied to, +and is responsible for adding any additional ``where`` clauses that the scope wishes to add. +The ``remove`` method also receives a ``Builder`` object and ``Model`` and is responsible +for reversing the action taken by ``apply``. +In other words, ``remove`` should remove the ``where`` clause (or any other clause) that was added. +So, for our ``SoftDeletingScope``, it would look something like this: + +.. code-block:: python + + from eloquent import Scope + + + class SoftDeletingScope(Scope): + + def apply(self, builder, model): + """ + Apply the scope to a given query builder. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + + :param model: The model + :type model: eloquent.orm.Model + """ + builder.where_null(model.get_qualified_deleted_at_column()) + + def remove(self, builder, model): + """ + Remove the scope from a given query builder. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + + :param model: The model + :type model: eloquent.orm.Model + """ + column = model.get_qualified_deleted_at_column() + + query = builder.get_query() + + wheres = [] + for where in query.wheres: + # If the where clause is a soft delete date constraint, + # we will remove it from the query and reset the keys + # on the wheres. This allows the developer to include + # deleted model in a relationship result set that is lazy loaded. + if not self._is_soft_delete_constraint(where, column): + wheres.append(where) + + query.wheres = wheres + + Date mutators ============= @@ -1161,7 +1335,7 @@ by completely overriding the ``get_dates`` method: class User(Model): - def get_dates(): + def get_dates(self): return ['created_at'] When a column is considered a date, you can set its value to a UNIX timestamp, a date string ``YYYY-MM-DD``, @@ -1173,7 +1347,7 @@ To completely disable date mutations, simply return an empty list from the ``get class User(Model): - def get_dates(): + def get_dates(self): return [] diff --git a/docs/schema_builder.rst b/docs/schema_builder.rst new file mode 100644 index 00000000..30e16ec1 --- /dev/null +++ b/docs/schema_builder.rst @@ -0,0 +1,303 @@ +.. _SchemaBuilder: + +Schema Builder +############## + +.. role:: python(code) + :language: python + +Introduction +============ + +The ``Schema`` class provides a database agnostic way of manipulating tables. + +Before getting started, be sure to have configured a ``DatabaseManager`` as seen in the :ref:`BasicUsage` section. + +.. code-block:: python + + from eloquent import DatabaseManager, Schema + + config = { + 'mysql': { + 'driver': 'mysql', + 'host': 'localhost', + 'database': 'database', + 'username': 'root', + 'password': '', + 'prefix': '' + } + } + + db = DatabaseManager(config) + schema = Schema(db) + + +Creating and dropping tables +============================ + +To create a new database table, the ``create`` method is used: + +.. code-block:: python + + with schema.create('users') as table: + table.increments('id') + +The ``table`` variable is a ``Blueprint`` instance which can be used to define the new table. + +To rename an existing database table, the ``rename`` method can be used: + +.. code-block:: python + + schema.rename('from', 'to') + +To specify which connection the schema operation should take place on, use the ``connection`` method: + +.. code-block:: python + + with schema.connection('foo').create('users') as table: + table.increments('id') + +To drop a table, you can use the ``drop`` or ``drop_if_exists`` methods: + +.. code-block:: python + + schema.drop('users') + + schema.drop_if_exists('users') + + +Adding columns +============== + +To update an existing table, you can use the ``table`` method: + +.. code-block:: python + + with schema.table('users') as table: + table.string('email') + +The table builder contains a variety of column types that you may use when building your tables: + +======================================================== ================================================= +Command Description +======================================================== ================================================= +:python:`table.big_increments('id')` Incrementing ID using a "big integer" equivalent +:python:`table.big_integer('votes')` BIGINT equivalent to the table +:python:`table.binary('data')` BLOB equivalent to the table +:python:`table.boolean('confirmed')` BOOLEAN equivalent to the table +:python:`table.char('name', 4)` CHAR equivalent with a length +:python:`table.date('created_on')` DATE equivalent to the table +:python:`table.datetime('created_at')` DATETIME equivalent to the table +:python:`table.decimal('amount', 5, 2)` DECIMAL equivalent to the table with a precision and scale +:python:`table.double('column', 15, 8)` DOUBLE equivalent to the table with precision, 15 digits in total and 8 after the decimal point +:python:`table.enum('choices', ['foo', 'bar'])` ENUM equivalent to the table +:python:`table.float('amount')` FLOAT equivalent to the table +:python:`table.increments('id')` Incrementing ID to the table (primary key) +:python:`table.integer('votes')` INTEGER equivalent to the table +:python:`table.json('options')` JSON equivalent to the table +:python:`table.long_text('description')` LONGTEXT equivalent to the table +:python:`table.medium_integer('votes')` MEDIUMINT equivalent to the table +:python:`table.medium_text('description')` MEDIUMTEXT equivalent to the table +:python:`table.morphs('taggable')` Adds INTEGER :python:`taggable_id` and STRING :python:`taggable_type` +:python:`table.nullable_timestamps()` Same as :python:`timestamps()`, except allows NULLs +:python:`table.small_integer('votes')` SMALLINT equivalent to the table +:python:`table.soft_deletes()` Adds **deleted_at** column for soft deletes +:python:`table.string('email')` VARCHAR equivalent column +:python:`table.string('votes', 100)` VARCHAR equivalent with a length +:python:`table.text('description')` TEXT equivalent to the table +:python:`table.time('sunrise')` TIME equivalent to the table +:python:`table.timestamp('added_at')` TIMESTAMP equivalent to the table +:python:`table.timestamps()` Adds **created_at** and **updated_at** columns +:python:`.nullable()` Designate that the column allows NULL values +:python:`.default(value)` Declare a default value for a column +:python:`.unsigned()` Set INTEGER to UNSIGNED +======================================================== ================================================= + + +Changing columns +================ + +Sometimes you may need to modify an existing column. +For example, you may wish to increase the size of a string column. +To do so, you can use the ``change`` method. +For example, let's increase the size of the ``name`` column from 25 to 50: + +.. code-block:: python + + with schema.table('users') as table: + table.string('name', 50).change() + +You could also modify the column to be nullable: + +.. code-block:: python + + with schema.table('user') as table: + table.string('name', 50).nullable().change() + + +.. warning:: + + The column change feature, while tested, is still considered in **beta** stage. + Please report any encountered issue or bug on the `Github project `_ + + +Renaming columns +================ + +To rename a column, you can use use the ``rename_column`` method on the Schema builder: + +.. code-block:: python + + with schema.table('users') as table: + table.rename('from', 'to') + +.. warning:: + + Prior to **MySQL 5.6.6**, foreign keys are **NOT** automatically updated when renaming columns. + Therefore, you will need to **drop** the foreign key constraint, **rename** the column and **recreate** + the constraint to avoid an error. + + .. code-block:: python + + with schema.table('posts') as table: + table.drop_foreign('posts_user_id_foreign') + table.rename('user_id', 'author_id') + table.foreign('author_id').references('id').on('users') + + In future versions, Eloquent **might** handle this automatically. + +.. warning:: + + The rename column feature, while tested, is still considered in **beta** stage (especially for SQLite). + Please report any encountered issue or bug on the `Github project `_ + + +Dropping columns +================ + +To drop a column, you can use use the ``drop_column`` method on the Schema builder: + +Dropping a column from a database table +--------------------------------------- + +.. code-block:: python + + with schema.table('users') as table: + table.drop_column + +Dropping multiple columns from a database table +------------------------------------------------ + +.. code-block:: python + + with schema.table('users') as table: + table.drop_column('votes', 'avatar', 'location') + + +Checking existence +================== + +You can easily check for the existence of a table or column using the ``has_table`` and ``has_column`` methods: + +Checking for existence of a table +--------------------------------- + +.. code-block:: python + + if schema.has_table('users'): + # ... + +Checking for existence of a column: + +.. code-block:: python + + if schema.has_column('users', 'email'): + # ... + + +Adding indexes +============== + +The schema builder supports several types of indexes. There are two ways to add them. +First, you may fluently define them on a column definition: + +.. code-block:: python + + table.string('email').unique() + +Or, you may choose to add the indexes on separate lines. Below is a list of all available index types: + +======================================================== ================================================= +Command Description +======================================================== ================================================= +:python:`table.primary('id')` Adds a primary key +:python:`table.primary(['first', 'last'])` Adds composite keys +:python:`table.unique('email')` Adds a unique index +:python:`table.index('state')` Adds a basic index +======================================================== ================================================= + + +Dropping indexes +================ + +To drop an index you must specify the index's name. +Eloquent assigns a reasonable name to the indexes by default. +Simply concatenate the table name, the names of the column in the index, and the index type. +Here are some examples: + +======================================================== ================================================= +Command Description +======================================================== ================================================= +:python:`table.drop_primary('user_id_primary')` Drops a primary key from the "users" table +:python:`table.drop_unique('user_email_unique')` Drops a unique index from the "users" table +:python:`table.drop_index('geo_state_index')` Drops a basic index from the "geo" table +======================================================== ================================================= + + +Foreign keys +============ + +Eloquent also provides support for adding foreign key constraints to your tables: + +.. code-block:: python + + table.integer('user_id').unsigned() + table.foreign('user_id').references('id').on('users') + +In this example, we are stating that the ``user_id`` +column references the ``id`` column on the ``users`` table. +Make sure to create the foreign key column first! + +You may also specify options for the "on delete" and "on update" actions of the constraint: + +.. code-block:: python + + table.foreign('user_id')\ + .references('id').on('users')\ + .on_delete('cascade') + +To drop a foreign key, you may use the ``drop_foreign`` method. +A similar naming convention is used for foreign keys as is used for other indexes: + +.. code-block:: python + + table.drop_foreign('posts_user_id_foreign') + +.. note:: + + When creating a foreign key that references an incrementing integer, + remember to always make the foreign key column ``unsigned``. + + +Dropping timestamps and soft deletes +==================================== + +To drop the ``timestamps``, ``nullable_timestamps`` or ``soft_deletes`` column types, +you may use the following methods: + +======================================================== ================================================= +Command Description +======================================================== ================================================= +:python:`table.drop_timestamps()` Drops the **created_at** and **deleted_at** columns +:python:`table.drop_soft_deletes()` Drops the **deleted_at** column +======================================================== ================================================= diff --git a/eloquent/__init__.py b/eloquent/__init__.py index 881d4c82..1b3be71e 100644 --- a/eloquent/__init__.py +++ b/eloquent/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -from .orm import Model +from .orm import Model, SoftDeletes, Collection from .database_manager import DatabaseManager from .query.expression import QueryExpression +from .schema import Schema diff --git a/eloquent/connections/connection.py b/eloquent/connections/connection.py index b2c48584..9ce214d2 100644 --- a/eloquent/connections/connection.py +++ b/eloquent/connections/connection.py @@ -8,6 +8,8 @@ from ..query.builder import QueryBuilder from ..query.expression import QueryExpression from ..query.processors.processor import QueryProcessor +from ..schema.builder import SchemaBuilder +from ..dbal.schema_manager import SchemaManager from ..exceptions.query import QueryException @@ -66,12 +68,29 @@ def use_default_query_grammar(self): def get_default_query_grammar(self): return QueryGrammar() + def use_default_schema_grammar(self): + self._schema_grammar = self.get_default_schema_grammar() + + def get_default_schema_grammar(self): + pass + def use_default_post_processor(self): self._post_processor = self.get_default_post_processor() def get_default_post_processor(self): return QueryProcessor() + def get_schema_builder(self): + """ + Retturn the underlying schema builder. + + :rtype: eloquent.schema.SchemaBuilder + """ + if not self._schema_grammar: + self.use_default_schema_grammar() + + return SchemaBuilder(self) + def table(self, table): """ Begin a fluent query against a database table @@ -368,6 +387,9 @@ def set_reconnector(self, reconnector): def get_name(self): return self._config.get('name') + def get_config(self, option): + return self._config.get(option) + def get_query_grammar(self): return self._query_grammar @@ -426,6 +448,14 @@ def with_table_prefix(self, grammar): return grammar + def get_column(self, table, column): + schema = self.get_schema_manager() + + return schema.list_table_details(table).get_column(column) + + def get_schema_manager(self): + return SchemaManager(self) + def __enter__(self): self.begin_transaction() diff --git a/eloquent/connections/mysql_connection.py b/eloquent/connections/mysql_connection.py index 8c4133eb..03320ba7 100644 --- a/eloquent/connections/mysql_connection.py +++ b/eloquent/connections/mysql_connection.py @@ -1,8 +1,13 @@ # -*- coding: utf-8 -*- +from ..utils import PY2 from .connection import Connection from ..query.grammars.mysql_grammar import MySqlQueryGrammar from ..query.processors.mysql_processor import MySqlQueryProcessor +from ..schema.grammars import MySqlSchemaGrammar +from ..schema import MySqlSchemaBuilder +from ..dbal.platforms.mysql_platform import MySqlPlatform +from ..dbal.mysql_schema_manager import MySqlSchemaManager class MySqlConnection(Connection): @@ -13,6 +18,26 @@ def get_default_query_grammar(self): def get_default_post_processor(self): return MySqlQueryProcessor() + def get_schema_builder(self): + """ + Retturn the underlying schema builder. + + :rtype: eloquent.schema.SchemaBuilder + """ + if not self._schema_grammar: + self.use_default_schema_grammar() + + return MySqlSchemaBuilder(self) + + def get_default_schema_grammar(self): + return self.with_table_prefix(MySqlSchemaGrammar()) + + def get_database_platform(self): + return MySqlPlatform() + + def get_schema_manager(self): + return MySqlSchemaManager(self) + def begin_transaction(self): self._connection.autocommit(False) @@ -38,4 +63,7 @@ def _get_cursor_query(self, query, bindings): if not hasattr(self._cursor, '_last_executed'): return super(MySqlConnection, self)._get_cursor_query(query, bindings) - return self._cursor._last_executed + if PY2: + return self._cursor._last_executed + + return self._cursor._last_executed.decode() diff --git a/eloquent/connections/postgres_connection.py b/eloquent/connections/postgres_connection.py index f9477c9b..80a7046c 100644 --- a/eloquent/connections/postgres_connection.py +++ b/eloquent/connections/postgres_connection.py @@ -4,6 +4,9 @@ from .connection import Connection from ..query.grammars.postgres_grammar import PostgresQueryGrammar from ..query.processors.postgres_processor import PostgresQueryProcessor +from ..schema.grammars import PostgresSchemaGrammar +from ..dbal.platforms.postgres_platform import PostgresPlatform +from ..dbal.postgres_schema_manager import PostgresSchemaManager class PostgresConnection(Connection): @@ -14,6 +17,15 @@ def get_default_query_grammar(self): def get_default_post_processor(self): return PostgresQueryProcessor() + def get_default_schema_grammar(self): + return self.with_table_prefix(PostgresSchemaGrammar()) + + def get_database_platform(self): + return PostgresPlatform() + + def get_schema_manager(self): + return PostgresSchemaManager(self) + def begin_transaction(self): self._connection.autocommit = False diff --git a/eloquent/connections/sqlite_connection.py b/eloquent/connections/sqlite_connection.py index 429724c7..7c02a193 100644 --- a/eloquent/connections/sqlite_connection.py +++ b/eloquent/connections/sqlite_connection.py @@ -3,13 +3,29 @@ from ..utils import PY2, decode from .connection import Connection from ..query.processors.sqlite_processor import SQLiteQueryProcessor +from ..query.grammars.sqlite_grammar import SQLiteQueryGrammar +from ..schema.grammars.sqlite_grammar import SQLiteSchemaGrammar +from ..dbal.platforms.sqlite_platform import SQLitePlatform +from ..dbal.sqlite_schema_manager import SQLiteSchemaManager class SQLiteConnection(Connection): + def get_default_query_grammar(self): + return self.with_table_prefix(SQLiteQueryGrammar()) + def get_default_post_processor(self): return SQLiteQueryProcessor() + def get_default_schema_grammar(self): + return self.with_table_prefix(SQLiteSchemaGrammar()) + + def get_database_platform(self): + return SQLitePlatform() + + def get_schema_manager(self): + return SQLiteSchemaManager(self) + def begin_transaction(self): self._connection.isolation_level = 'DEFERRED' diff --git a/eloquent/connectors/connection_factory.py b/eloquent/connectors/connection_factory.py index 78241086..6346f49e 100644 --- a/eloquent/connectors/connection_factory.py +++ b/eloquent/connectors/connection_factory.py @@ -28,7 +28,7 @@ def _create_single_connection(self, config): config['driver'], conn, config['database'], - config.get('prefix'), + config.get('prefix', ''), config ) diff --git a/eloquent/connectors/connector.py b/eloquent/connectors/connector.py index d4c99ccf..ab37bfce 100644 --- a/eloquent/connectors/connector.py +++ b/eloquent/connectors/connector.py @@ -4,7 +4,7 @@ class Connector(object): RESERVED_KEYWORDS = [ - 'log_queries', 'driver' + 'log_queries', 'driver', 'prefix' ] def get_api(self): diff --git a/eloquent/connectors/mysql_connector.py b/eloquent/connectors/mysql_connector.py index 1acb6ed7..ef15d0fa 100644 --- a/eloquent/connectors/mysql_connector.py +++ b/eloquent/connectors/mysql_connector.py @@ -20,6 +20,11 @@ class MySqlConnector(Connector): + RESERVED_KEYWORDS = [ + 'log_queries', 'driver', 'prefix', + 'engine', 'charset', 'collation' + ] + def connect(self, config): config = dict(config.items()) for key, value in keys_fix.items(): diff --git a/eloquent/dbal/__init__.py b/eloquent/dbal/__init__.py new file mode 100644 index 00000000..633f8661 --- /dev/null +++ b/eloquent/dbal/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- + diff --git a/eloquent/dbal/column.py b/eloquent/dbal/column.py new file mode 100644 index 00000000..33541471 --- /dev/null +++ b/eloquent/dbal/column.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- + +from ..utils import basestring + + +class Column(object): + + def __init__(self, name, type, options=None): + self._name = name + self._type = type + + self._length = None + self._precision = 10 + self._scale = 0 + self._unsigned = False + self._fixed = False + self._notnull = True + self._default = None + self._autoincrement = False + self._platform_options = {} + + self.set_options(options or {}) + + def set_options(self, options): + for key, value in options.items(): + method = 'set_%s' % key + if hasattr(self, method): + getattr(self, method)(value) + + return self + + def set_platform_options(self, platform_options): + self._platform_options = platform_options + + return self + + def set_platform_option(self, name, value): + self._platform_options[name] = value + + return self + + def get_platform_options(self): + return self._platform_options + + def has_platform_option(self, option): + return option in self._platform_options + + def get_platform_option(self, option): + return self._platform_options[option] + + def set_length(self, length): + if length is not None: + self._length = int(length) + else: + self._length = None + + return self + + def set_precision(self, precision): + if precision is None or isinstance(precision, basestring) and not precision.isdigit(): + precision = 10 + + self._precision = int(precision) + + return self + + def set_scale(self, scale): + if scale is None or isinstance(scale, basestring) and not scale.isdigit(): + scale = 0 + + self._scale = int(scale) + + return self + + def set_unsigned(self, unsigned): + self._unsigned = bool(unsigned) + + def set_fixed(self, fixed): + self._fixed = bool(fixed) + + def set_notnull(self, notnull): + self._notnull = bool(notnull) + + def set_default(self, default): + self._default = default + + def set_autoincrement(self, flag): + self._autoincrement = flag + + return self + + def set_type(self, type): + self._type = type + + def get_name(self): + return self._name + + def get_type(self): + return self._type + + def get_autoincrement(self): + return self._autoincrement + + def get_notnull(self): + return self._notnull + + def get_default(self): + return self._default + + def to_dict(self): + d = { + 'name': self._name, + 'type': self._type, + 'default': self._default, + 'notnull': self._notnull, + 'length': self._length, + 'precision': self._precision, + 'scale': self._scale, + 'fixed': self._fixed, + 'unsigned': self._unsigned, + 'autoincrement': self._autoincrement + } + + d.update(self._platform_options) + + return d + + diff --git a/eloquent/dbal/column_diff.py b/eloquent/dbal/column_diff.py new file mode 100644 index 00000000..4ff5f431 --- /dev/null +++ b/eloquent/dbal/column_diff.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + + +class ColumnDiff(object): + + def __init__(self, old_column_name, column, changed_properties=None, from_column=None): + self.old_column_name = old_column_name + self.column = column + self.changed_properties = changed_properties + self.from_column = from_column + + def has_changed(self, property_name): + return property_name in self.changed_properties + + def get_old_column_name(self): + return self.old_column_name diff --git a/eloquent/dbal/comparator.py b/eloquent/dbal/comparator.py new file mode 100644 index 00000000..1db9f130 --- /dev/null +++ b/eloquent/dbal/comparator.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- + +from .table_diff import TableDiff +from .column_diff import ColumnDiff + + +class Comparator(object): + """ + Compares two Schemas and return an instance of SchemaDiff. + """ + + def diff_table(self, table1, table2): + """ + Returns the difference between the tables table1 and table2. + + :type table1: Table + :type table2: Table + + :rtype: TableDiff + """ + changes = 0 + table_differences = TableDiff(table1.get_name()) + table_differences.from_table = table1 + + table1_columns = table1.get_columns() + table2_columns = table2.get_columns() + + # See if all the fields in table1 exist in table2 + for column_name, column in table2_columns.items(): + if not table1.has_column(column_name): + table_differences.added_columns[column_name] = column + changes += 1 + + # See if there are any removed fields in table2 + for column_name, column in table1_columns.items(): + if not table2.has_column(column_name): + table_differences.removed_columns[column_name] = column + changes += 1 + continue + + # See if column has changed properties in table2 + changed_properties = self.diff_column(column, table2.get_column(column_name)) + + if changed_properties: + column_diff = ColumnDiff(column.get_name(), + table2.get_column(column_name), + changed_properties) + column_diff.from_column = column + table_differences.changed_columns[column.get_name()] = column_diff + changes += 1 + + self.detect_column_renamings(table_differences) + + # table1_indexes = table1.get_indexes() + # table2_indexes = table2.get_indexes() + # + # # See if all the fields in table1 exist in table2 + # for index_name, index in table2_indexes.items(): + # if (index.is_primary() and table1.has_primary_key()) or table1.has_index(index_name): + # continue + # + # table_differences.added_indexes[index_name] = index + # changes += 1 + # + # # See if there are any removed fields in table2 + # for index_name, index in table1_indexes.items(): + # if (index.is_primary() and not table2.has_primary_key())\ + # or (not index.is_primary() and not table2.has_index(index_name)): + # table_differences.removed_indexes[index_name] = index + # changes += 1 + # continue + # + # if index.is_primary(): + # table2_index = table2.get_primary_key() + # else: + # table2_index = table2.get_index(index_name) + # + # if self.diff_index(index, table2_index): + # table_differences.changed_indexes[index_name] = index + # changes += 1 + # + # self.detect_index_renamings(table_differences) + # + # from_fkeys = table1.get_foreign_keys() + # to_fkeys = table2.get_foreign_keys() + # + # for key1, constraint1 in from_fkeys.items(): + # for key2, constraint2 in to_fkeys.items(): + # if self.diff_foreign_key(constraint1, constraint2) is False: + # del from_fkeys[key1] + # del to_fkeys[key2] + # else: + # if constraint1.get_name().lower() == constraint2.get_name().lower(): + # table_differences.changed_foreign_keys.append(constraint2) + # changes += 1 + # del from_fkeys[key1] + # del to_fkeys[key2] + # + # for constraint1 in from_fkeys.values(): + # table_differences.removed_foreign_keys.append(constraint1) + # changes += 1 + # + # for constraint2 in to_fkeys.values(): + # table_differences.added_foreign_keys.append(constraint2) + # changes += 1 + + if changes: + return table_differences + + return False + + def detect_column_renamings(self, table_differences): + """ + Try to find columns that only changed their names. + + :type table_differences: TableDiff + """ + rename_candidates = {} + + for added_column_name, added_column in table_differences.added_columns.items(): + for removed_column in table_differences.removed_columns.values(): + if len(self.diff_column(added_column, removed_column)) == 0: + if added_column.get_name() not in rename_candidates: + rename_candidates[added_column.get_name()] = [] + + rename_candidates[added_column.get_name()] = (removed_column, added_column, added_column_name) + + for candidate_columns in rename_candidates.values(): + if len(candidate_columns) == 1: + removed_column, added_column, _ = candidate_columns[0] + removed_column_name = removed_column.get_name().lower() + added_column_name = added_column.get_name().lower() + + if removed_column_name not in table_differences.renamed_columns: + table_differences.renamed_columns[removed_column_name] = added_column + del table_differences.added_columns[added_column_name] + del table_differences.removed_columns[removed_column_name] + + def diff_column(self, column1, column2): + """ + Returns the difference between column1 and column2 + + :type column1: eloquent.dbal.column.Column + :type column2: eloquent.dbal.column.Column + + :rtype: list + """ + properties1 = column1.to_dict() + properties2 = column2.to_dict() + + changed_properties = [] + + for prop in ['type', 'notnull', 'unsigned', 'autoincrement']: + if properties1[prop] != properties2[prop]: + changed_properties.append(prop) + + if properties1['default'] != properties2['default']\ + or (properties1['default'] is None and properties2['default'] is not None)\ + or (properties2['default'] is None and properties1['default'] is not None): + changed_properties.append('default') + + if properties1['type'] == 'string' and properties1['type'] != 'guid'\ + or properties1['type'] in ['binary', 'blob']: + length1 = properties1['length'] or 255 + length2 = properties2['length'] or 255 + + if length1 != length2: + changed_properties.append('length') + + if properties1['fixed'] != properties2['fixed']: + changed_properties.append('fixed') + elif properties1['type'] in ['decimal', 'float', 'double precision']: + precision1 = properties1['precision'] or 10 + precision2 = properties2['precision'] or 10 + + if precision1 != precision2: + changed_properties.append('precision') + + if properties1['scale'] != properties2['scale']: + changed_properties.append('scale') + + return list(set(changed_properties)) + + diff --git a/eloquent/dbal/mysql_schema_manager.py b/eloquent/dbal/mysql_schema_manager.py new file mode 100644 index 00000000..9e837dd4 --- /dev/null +++ b/eloquent/dbal/mysql_schema_manager.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- + +import re +from .column import Column +from .schema_manager import SchemaManager +from .platforms.mysql_platform import MySqlPlatform + + +class MySqlSchemaManager(SchemaManager): + + def _get_portable_table_column_definition(self, table_column): + db_type = table_column['type'].lower() + match = re.match('(.+)\((.*)\).*', db_type) + if match: + db_type = match.group(1) + + if 'length' in table_column: + length = table_column['length'] + else: + if match and match.group(2) and ',' not in match.group(2): + length = int(match.group(2)) + else: + length = 0 + + fixed = None + + if 'name' not in table_column: + table_column['name'] = '' + + precision = None + scale = None + + type = self._platform.get_type_mapping(db_type) + + if db_type in ['char', 'binary']: + fixed = True + elif db_type in ['float', 'double', 'real', 'decimal', 'numeric']: + match = re.match('([A-Za-z]+\(([0-9]+),([0-9]+)\))', table_column['type']) + if match: + precision = match.group(1) + scale = match.group(2) + length = None + elif db_type == 'tinytext': + length = MySqlPlatform.LENGTH_LIMIT_TINYTEXT + elif db_type == 'text': + length = MySqlPlatform.LENGTH_LIMIT_TEXT + elif db_type == 'mediumtext': + length = MySqlPlatform.LENGTH_LIMIT_MEDIUMTEXT + elif db_type == 'tinyblob': + length = MySqlPlatform.LENGTH_LIMIT_TINYBLOB + elif db_type == 'blob': + length = MySqlPlatform.LENGTH_LIMIT_BLOB + elif db_type == 'mediumblob': + length = MySqlPlatform.LENGTH_LIMIT_MEDIUMBLOB + elif db_type in ['tinyint', 'smallint', 'mediumint', 'int', 'bigint', 'year']: + length = None + + if length is None or length == 0: + length = None + + options = { + 'length': length, + 'unsigned': table_column['type'].find('unsigned') != -1, + 'fixed': fixed, + 'notnull': table_column['null'] != 'YES', + 'default': table_column.get('default'), + 'precision': None, + 'scale': None, + 'autoincrement': table_column['extra'].find('auto_increment') != -1 + } + + if scale is not None and precision is not None: + options['scale'] = scale + options['precision'] = precision + + column = Column(table_column['field'], type, options) + + if 'collation' in table_column: + column.set_platform_option('collation', table_column['collation']) + + return column + + def get_database_platform(self): + return MySqlPlatform() diff --git a/eloquent/dbal/platforms/__init__.py b/eloquent/dbal/platforms/__init__.py new file mode 100644 index 00000000..633f8661 --- /dev/null +++ b/eloquent/dbal/platforms/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- + diff --git a/eloquent/dbal/platforms/mysql_platform.py b/eloquent/dbal/platforms/mysql_platform.py new file mode 100644 index 00000000..fae4907e --- /dev/null +++ b/eloquent/dbal/platforms/mysql_platform.py @@ -0,0 +1,236 @@ +# -*- coding: utf-8 -*- + +from .platform import Platform +from ..table import Table +from ..column import Column + + +class MySqlPlatform(Platform): + + LENGTH_LIMIT_TINYTEXT = 255 + LENGTH_LIMIT_TEXT = 65535 + LENGTH_LIMIT_MEDIUMTEXT = 16777215 + + LENGTH_LIMIT_TINYBLOB = 255 + LENGTH_LIMIT_BLOB = 65535 + LENGTH_LIMIT_MEDIUMBLOB = 16777215 + + INTERNAL_TYPE_MAPPING = { + 'tinyint': 'boolean', + 'smallint': 'smallint', + 'mediumint': 'integer', + 'int': 'integer', + 'integer': 'integer', + 'bigint': 'bigint', + 'int8': 'bigint', + 'bool': 'boolean', + 'boolean': 'boolean', + 'tinytext': 'text', + 'mediumtext': 'text', + 'longtext': 'text', + 'text': 'text', + 'varchar': 'string', + 'string': 'string', + 'char': 'string', + 'date': 'date', + 'datetime': 'datetime', + 'timestamp': 'datetime', + 'time': 'time', + 'float': 'float', + 'double': 'float', + 'real': 'float', + 'decimal': 'decimal', + 'numeric': 'decimal', + 'year': 'date', + 'longblob': 'blob', + 'blob': 'blob', + 'mediumblob': 'blob', + 'tinyblob': 'blob', + 'binary': 'binary', + 'varbinary': 'binary', + 'set': 'simple_array' + } + + def get_list_table_columns_sql(self, table, database=None): + if database: + database = "'%s'" % database + else: + database = 'DATABASE()' + + return 'SELECT COLUMN_NAME AS field, COLUMN_TYPE AS type, IS_NULLABLE AS `null`, ' \ + 'COLUMN_KEY AS `key`, COLUMN_DEFAULT AS `default`, EXTRA AS extra, COLUMN_COMMENT AS comment, ' \ + 'CHARACTER_SET_NAME AS character_set, COLLATION_NAME AS collation ' \ + 'FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = \'%s\''\ + % (database, table) + + def get_list_table_indexes_sql(self, table, current_database=None): + return 'SHOW INDEX FROM %s' % table + + def get_list_table_foreign_keys_sql(self, table, database=None): + sql = ("SELECT DISTINCT k.`CONSTRAINT_NAME` AS `name`, k.`COLUMN_NAME`, k.`REFERENCED_TABLE_NAME`, " + "k.`REFERENCED_COLUMN_NAME` /*!50116 , c.update_rule, c.delete_rule */ " + "FROM information_schema.key_column_usage k /*!50116 " + "INNER JOIN information_schema.referential_constraints c ON " + " c.constraint_name = k.constraint_name AND " + " c.table_name = '%s' */ WHERE k.table_name = '%s'" % (table, table)) + + if database: + sql += " AND k.table_schema = '%s' /*!50116 AND c.constraint_schema = '%s' */"\ + % (database, database) + + sql += " AND k.`REFERENCED_COLUMN_NAME` IS NOT NULL" + + return sql + + def get_alter_table_sql(self, diff): + """ + Get the ALTER TABLE SQL statement + + :param diff: The table diff + :type diff: eloquent.dbal.table_diff.TableDiff + + :rtype: list + """ + column_sql = [] + query_parts = [] + + if diff.new_name is not False: + query_parts.append('RENAME TO %s' % diff.new_name) + + # Added columns? + + # Removed columns? + + for column_diff in diff.changed_columns.values(): + column = column_diff.column + column_dict = column.to_dict() + + # Don't propagate default value changes for unsupported column types. + if column_diff.has_changed('default') \ + and len(column_diff.changed_properties) == 1 \ + and (column_dict['type'] == 'text' or column_dict['type'] == 'blob'): + continue + + query_parts.append('CHANGE %s %s' + % (column_diff.get_old_column_name(), + self.get_column_declaration_sql(column.get_name(), column_dict))) + + for old_column_name, column in diff.renamed_columns.items(): + column_dict = column.to_dict() + query_parts.append('CHANGE %s %s' + % (self.quote(old_column_name), + self.get_column_declaration_sql(self.quote(column.get_name()), column_dict))) + + sql = [] + + if len(query_parts) > 0: + sql.append('ALTER TABLE %s %s' % (diff.name, ', '.join(query_parts))) + + return sql + + def convert_booleans(self, item): + if isinstance(item, list): + for i, value in enumerate(item): + if isinstance(value, bool): + item[i] = str(value).lower() + elif isinstance(item, bool): + item = str(item).lower() + + return item + + def get_boolean_type_sql_declaration(self, column): + return 'TINYINT(1)' + + def get_integer_type_sql_declaration(self, column): + return 'INT ' + self._get_common_integer_type_declaration_sql(column) + + def get_bigint_type_sql_declaration(self, column): + return 'BIGINT ' + self._get_common_integer_type_declaration_sql(column) + + def get_smallint_type_sql_declaration(self, column): + return 'SMALLINT ' + self._get_common_integer_type_declaration_sql(column) + + def get_guid_type_sql_declaration(self, column): + return 'UUID' + + def get_datetime_type_sql_declaration(self, column): + if 'version' in column and column['version'] == True: + return 'TIMESTAMP' + + return 'DATETIME' + + def get_date_type_sql_declaration(self, column): + return 'DATE' + + def get_time_type_sql_declaration(self, column): + return 'TIME' + + def get_varchar_type_declaration_sql_snippet(self, length, fixed): + if fixed: + return 'CHAR(%s)' % length if length else 'CHAR(255)' + else: + return 'VARCHAR(%s)' % length if length else 'VARCHAR(255)' + + def get_binary_type_declaration_sql_snippet(self, length, fixed): + if fixed: + return 'BINARY(%s)' % (length or 255) + else: + return 'VARBINARY(%s)' % (length or 255) + + def get_text_type_sql_declaration(self, column): + length = column.get('length') + if length: + if length <= self.LENGTH_LIMIT_TINYTEXT: + return 'TINYTEXT' + + if length <= self.LENGTH_LIMIT_TEXT: + return 'TEXT' + + if length <= self.LENGTH_LIMIT_MEDIUMTEXT: + return 'MEDIUMTEXT' + + return 'LONGTEXT' + + def get_blob_type_sql_declaration(self, column): + length = column.get('length') + if length: + if length <= self.LENGTH_LIMIT_TINYBLOB: + return 'TINYBLOB' + + if length <= self.LENGTH_LIMIT_BLOB: + return 'BLOB' + + if length <= self.LENGTH_LIMIT_MEDIUMBLOB: + return 'MEDIUMBLOB' + + return 'LONGBLOB' + + def get_decimal_type_sql_declaration(self, column): + decl = super(MySqlPlatform, self).get_decimal_type_sql_declaration(column) + + return decl + self.get_unsigned_declaration(column) + + def get_unsigned_declaration(self, column): + if column.get('unsigned'): + return ' UNSIGNED' + + return '' + + def _get_common_integer_type_declaration_sql(self, column): + autoinc = '' + if column.get('autoincrement'): + autoinc = ' AUTO_INCREMENT' + + return self.get_unsigned_declaration(column) + autoinc + + def get_float_type_sql_declaration(self, column): + return 'DOUBLE PRECISION' + self.get_unsigned_declaration(column) + + def supports_foreign_key_constraints(self): + return True + + def supports_column_collation(self): + return False + + def quote(self, name): + return '`%s`' % name.replace('`', '``') diff --git a/eloquent/dbal/platforms/platform.py b/eloquent/dbal/platforms/platform.py new file mode 100644 index 00000000..3133ca9b --- /dev/null +++ b/eloquent/dbal/platforms/platform.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- + + +class Platform(object): + + INTERNAL_TYPE_MAPPING = {} + + def get_default_value_declaration_sql(self, field): + default = '' + + if not field.get('notnull'): + default = ' DEFAULT NULL' + + if 'default' in field and field['default'] is not None: + default = ' DEFAULT \'%s\'' % field['default'] + + if 'type' in field: + type = field['type'] + + if type in ['integer', 'bigint', 'smallint']: + default = ' DEFAULT %s' % field['default'] + elif type in ['datetime', 'datetimetz'] \ + and field['default'] in [self.get_current_timestamp_sql(), 'NOW', 'now']: + default = ' DEFAULT %s' % self.get_current_timestamp_sql() + elif type in ['time'] \ + and field['default'] in [self.get_current_time_sql(), 'NOW', 'now']: + default = ' DEFAULT %s' % self.get_current_time_sql() + elif type in ['date'] \ + and field['default'] in [self.get_current_date_sql(), 'NOW', 'now']: + default = ' DEFAULT %s' % self.get_current_date_sql() + elif type in ['boolean']: + default = ' DEFAULT \'%s\'' % self.convert_booleans(field['default']) + + return default + + def convert_booleans(self, item): + if isinstance(item, list): + for i, value in enumerate(item): + if isinstance(value, bool): + item[i] = int(value) + elif isinstance(item, bool): + item = int(item) + + return item + + def get_current_date_sql(self): + return 'CURRENT_DATE' + + def get_current_time_sql(self): + return 'CURRENT_TIME' + + def get_current_timestamp_sql(self): + return 'CURRENT_TIMESTAMP' + + def get_sql_type_declaration(self, column): + internal_type = column['type'] + + return getattr(self, 'get_%s_type_sql_declaration' % internal_type)(column) + + def get_column_declaration_sql(self, name, field): + if 'column_definition' in field: + column_def = self.get_custom_type_declaration_sql(field) + else: + default = self.get_default_value_declaration_sql(field) + + charset = field.get('charset', '') + if charset: + charset = ' ' + self.get_column_charset_declaration_sql(charset) + + collation = field.get('collation', '') + if charset: + charset = ' ' + self.get_column_collation_declaration_sql(charset) + + notnull = field.get('notnull', '') + if notnull: + notnull = ' NOT NULL' + else: + notnull = '' + + unique = field.get('unique', '') + if unique: + unique = ' ' + self.get_unique_field_declaration_sql() + else: + unique = '' + + check = field.get('check', '') + + type_decl = self.get_sql_type_declaration(field) + column_def = type_decl + charset + default + notnull + unique + check + collation + + return name + ' ' + column_def + + def get_custom_type_declaration_sql(self, column_def): + return column_def['column_definition'] + + def get_column_charset_declaration_sql(self, charset): + return '' + + def get_column_collation_declaration_sql(self, collation): + if self.supports_column_collation(): + return 'COLLATE %s' % collation + + return '' + + def supports_column_collation(self): + return False + + def get_unique_field_declaration_sql(self): + return 'UNIQUE' + + def get_string_type_sql_declaration(self, column): + if 'length' not in column: + column['length'] = self.get_varchar_default_length() + + fixed = column.get('fixed', False) + + if column['length'] > self.get_varchar_max_length(): + return self.get_text_type_sql_declaration(column) + + return self.get_varchar_type_declaration_sql_snippet(column['length'], fixed) + + def get_binary_type_sql_declaration(self, column): + if 'length' not in column: + column['length'] = self.get_binary_default_length() + + fixed = column.get('fixed', False) + + if column['length'] > self.get_binary_max_length(): + return self.get_blob_type_sql_declaration(column) + + return self.get_binary_type_declaration_sql_snippet(column['length'], fixed) + + def get_varchar_type_declaration_sql_snippet(self, length, fixed): + raise NotImplementedError('VARCHARS not supported by Platform') + + def get_binary_type_declaration_sql_snippet(self, length, fixed): + raise NotImplementedError('BINARY/VARBINARY not supported by Platform') + + def get_decimal_type_sql_declaration(self, column): + if 'precision' not in column or not column['precision']: + column['precision'] = 10 + + if 'scale' not in column or not column['scale']: + column['precision'] = 0 + + return 'NUMERIC(%s, %s)' % (column['precision'], column['scale']) + + def get_varchar_default_length(self): + return 255 + + def get_varchar_max_length(self): + return 4000 + + def get_binary_default_length(self): + return 255 + + def get_binary_max_length(self): + return 4000 + + def get_column_options(self): + return [] + + def get_type_mapping(self, db_type): + return self.INTERNAL_TYPE_MAPPING[db_type] diff --git a/eloquent/dbal/platforms/postgres_platform.py b/eloquent/dbal/platforms/postgres_platform.py new file mode 100644 index 00000000..d7b5107a --- /dev/null +++ b/eloquent/dbal/platforms/postgres_platform.py @@ -0,0 +1,280 @@ +# -*- coding: utf-8 -*- + +from .platform import Platform +from ..table import Table +from ..column import Column + + +class PostgresPlatform(Platform): + + INTERNAL_TYPE_MAPPING = { + 'smallint': 'smallint', + 'int2': 'smallint', + 'serial': 'integer', + 'serial4': 'integer', + 'int': 'integer', + 'int4': 'integer', + 'integer': 'integer', + 'bigserial': 'bigint', + 'serial8': 'bigint', + 'bigint': 'bigint', + 'int8': 'bigint', + 'bool': 'boolean', + 'boolean': 'boolean', + 'text': 'text', + 'tsvector': 'text', + 'varchar': 'string', + 'interval': 'string', + '_varchar': 'string', + 'char': 'string', + 'bpchar': 'string', + 'inet': 'string', + 'date': 'date', + 'datetime': 'datetime', + 'timestamp': 'datetime', + 'timestamptz': 'datetimez', + 'time': 'time', + 'timetz': 'time', + 'float': 'float', + 'float4': 'float', + 'float8': 'float', + 'double': 'float', + 'double precision': 'float', + 'real': 'float', + 'decimal': 'decimal', + 'money': 'decimal', + 'numeric': 'decimal', + 'year': 'date', + 'uuid': 'guid', + 'bytea': 'blob' + } + + def get_list_table_columns_sql(self, table): + sql = """SELECT + a.attnum, + quote_ident(a.attname) AS field, + t.typname AS type, + format_type(a.atttypid, a.atttypmod) AS complete_type, + (SELECT t1.typname FROM pg_catalog.pg_type t1 WHERE t1.oid = t.typbasetype) AS domain_type, + (SELECT format_type(t2.typbasetype, t2.typtypmod) FROM + pg_catalog.pg_type t2 WHERE t2.typtype = 'd' AND t2.oid = a.atttypid) AS domain_complete_type, + a.attnotnull AS isnotnull, + (SELECT 't' + FROM pg_index + WHERE c.oid = pg_index.indrelid + AND pg_index.indkey[0] = a.attnum + AND pg_index.indisprimary = 't' + ) AS pri, + (SELECT pg_get_expr(adbin, adrelid) + FROM pg_attrdef + WHERE c.oid = pg_attrdef.adrelid + AND pg_attrdef.adnum=a.attnum + ) AS default, + (SELECT pg_description.description + FROM pg_description WHERE pg_description.objoid = c.oid AND a.attnum = pg_description.objsubid + ) AS comment + FROM pg_attribute a, pg_class c, pg_type t, pg_namespace n + WHERE %s + AND a.attnum > 0 + AND a.attrelid = c.oid + AND a.atttypid = t.oid + AND n.oid = c.relnamespace + ORDER BY a.attnum""" % self.get_table_where_clause(table) + + return sql + + def get_list_table_indexes_sql(self, table): + table = table.replace('.', '__') + + return 'PRAGMA index_list(\'%s\')' % table + + def get_list_table_foreign_keys_sql(self, table): + return 'SELECT quote_ident(r.conname) AS name, ' \ + 'pg_catalog.pg_get_constraintdef(r.oid, true) AS condef ' \ + 'FROM pg_catalog.pg_constraint r ' \ + 'WHERE r.conrelid = ' \ + '(' \ + 'SELECT c.oid ' \ + 'FROM pg_catalog.pg_class c, pg_catalog.pg_namespace n ' \ + 'WHERE ' + self.get_table_where_clause(table) + ' AND n.oid = c.relnamespace' \ + ')' \ + ' AND r.contype = \'f\'' + + def get_table_where_clause(self, table, class_alias='c', namespace_alias='n'): + where_clause = namespace_alias + '.nspname NOT IN (\'pg_catalog\', \'information_schema\', \'pg_toast\') AND ' + if table.find('.') >= 0: + split = table.split('.') + schema, table = split[0], split[1] + schema = "'%s'" % schema + else: + schema = 'ANY(string_to_array((select replace(replace(setting, \'"$user"\', user), \' \', \'\')' \ + ' from pg_catalog.pg_settings where name = \'search_path\'),\',\'))' + + where_clause += '%s.relname = \'%s\' AND %s.nspname = %s' % (class_alias, table, namespace_alias, schema) + + return where_clause + + def get_alter_table_sql(self, diff): + """ + Get the ALTER TABLE SQL statement + + :param diff: The table diff + :type diff: eloquent.dbal.table_diff.TableDiff + + :rtype: list + """ + sql = [] + + for column_diff in diff.changed_columns.values(): + if self.is_unchanged_binary_column(column_diff): + continue + + old_column_name = column_diff.old_column_name + column = column_diff.column + + if any([column_diff.has_changed('type'), + column_diff.has_changed('precision'), + column_diff.has_changed('scale'), + column_diff.has_changed('fixed')]): + query = 'ALTER ' + old_column_name + ' TYPE ' + self.get_sql_type_declaration(column.to_dict()) + sql.append('ALTER TABLE ' + diff.name + ' ' + query) + + if column_diff.has_changed('default') or column_diff.has_changed('type'): + if column.get_default() is None: + default_clause = ' DROP DEFAULT' + else: + default_clause = ' SET' + self.get_default_value_declaration_sql(column.to_dict()) + + query = 'ALTER ' + old_column_name + default_clause + sql.append('ALTER TABLE ' + diff.name + ' ' + query) + + if column_diff.has_changed('notnull'): + op = 'DROP' + if column.get_notnull(): + op = 'SET' + + query = 'ALTER ' + old_column_name + ' ' + op + ' NOT NULL' + sql.append('ALTER TABLE ' + diff.name + ' ' + query) + + if column_diff.has_changed('autoincrement'): + if column.get_autoincrement(): + seq_name = self.get_identity_sequence_name(diff.name, old_column_name) + + sql.append('CREATE SEQUENCE ' + seq_name) + sql.append('SELECT setval(\'' + seq_name + '\', ' + '(SELECT MAX(' + old_column_name + ') FROM ' + diff.name + '))') + query = 'ALTER ' + old_column_name + ' SET DEFAULT nextval(\'' + seq_name + '\')' + sql.append('ALTER TABLE ' + diff.name + ' ' + query) + else: + query = 'ALTER ' + old_column_name + ' DROP DEFAULT' + sql.append('ALTER TABLE ' + diff.name + ' ' + query) + + if column_diff.has_changed('length'): + query = 'ALTER ' + old_column_name + ' TYPE ' + self.get_sql_type_declaration(column.to_dict()) + sql.append('ALTER TABLE ' + diff.name + ' ' + query) + + for old_column_name, column in diff.renamed_columns.items(): + sql.append('ALTER TABLE ' + diff.name + ' ' + 'RENAME COLUMN ' + old_column_name + ' TO ' + column.get_name()) + + return sql + + def is_unchanged_binary_column(self, column_diff): + column_type = column_diff.column.get_type() + + if column_type not in ['blob', 'binary']: + return False + + if isinstance(column_diff.from_column, Column): + from_column = column_diff.from_column + else: + from_column = None + + if from_column: + from_column_type = self.INTERNAL_TYPE_MAPPING[from_column.get_type()] + + if from_column_type in ['blob', 'binary']: + return False + + return len([x for x in column_diff.changed_properties if x not in ['type', 'length', 'fixed']]) == 0 + + if column_diff.has_changed('type'): + return False + + return len([x for x in column_diff.changed_properties if x not in ['length', 'fixed']]) == 0 + + def convert_booleans(self, item): + if isinstance(item, list): + for i, value in enumerate(item): + if isinstance(value, bool): + item[i] = str(value).lower() + elif isinstance(item, bool): + item = str(item).lower() + + return item + + def get_boolean_type_sql_declaration(self, column): + return 'BOOLEAN' + + def get_integer_type_sql_declaration(self, column): + if column.get('autoincrement'): + return 'SERIAL' + + return 'INT' + + def get_bigint_type_sql_declaration(self, column): + if column.get('autoincrement'): + return 'BIGSERIAL' + + return 'BIGINT' + + def get_smallint_type_sql_declaration(self, column): + return 'SMALLINT' + + def get_guid_type_sql_declaration(self, column): + return 'UUID' + + def get_datetime_type_sql_declaration(self, column): + return 'TIMESTAMP(0) WITHOUT TIME ZONE' + + def get_datetimetz_type_sql_declaration(self, column): + return 'TIMESTAMP(0) WITH TIME ZONE' + + def get_date_type_sql_declaration(self, column): + return 'DATE' + + def get_time_type_sql_declaration(self, column): + return 'TIME(0) WITHOUT TIME ZONE' + + def get_string_type_sql_declaration(self, column): + length = column.get('length', '255') + fixed = column.get('fixed') + + if fixed: + return 'CHAR(%s)' % length + else: + return 'VARCHAR(%s)' % length + + def get_binary_type_sql_declaration(self, column): + return 'BYTEA' + + def get_blob_type_sql_declaration(self, column): + return 'BYTEA' + + def get_text_type_sql_declaration(self, column): + return 'TEXT' + + def get_decimal_type_sql_declaration(self, column): + if 'precision' not in column or not column['precision']: + column['precision'] = 10 + + if 'scale' not in column or not column['scale']: + column['precision'] = 0 + + return 'DECIMAL(%s, %s)' % (column['precision'], column['scale']) + + def get_float_type_sql_declaration(self, column): + return 'DOUBLE PRECISION' + + def supports_foreign_key_constraints(self): + return True diff --git a/eloquent/dbal/platforms/sqlite_platform.py b/eloquent/dbal/platforms/sqlite_platform.py new file mode 100644 index 00000000..f432b8c4 --- /dev/null +++ b/eloquent/dbal/platforms/sqlite_platform.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- + +from .platform import Platform +from ..table import Table +from ..column import Column + + +class SQLitePlatform(Platform): + + TYPE_MAPPING = { + 'boolean': 'boolean', + 'tinyint': 'boolean', + 'smallint': 'small_integer', + 'mediumint': 'integer', + 'int': 'integer', + 'integer': 'integer', + 'serial': 'integer', + 'bigint': 'big_integer', + 'bigserial': 'big_integer', + 'clob': 'text', + 'tinytext': 'text', + 'mediumtext': 'text', + 'longtext': 'text', + 'text': 'text', + 'varchar': 'string', + 'longvarchar': 'string', + 'varchar2': 'string', + 'nvarchar': 'string', + 'image': 'string', + 'ntext': 'string', + 'char': 'string', + 'date': 'date', + 'datetime': 'datetime', + 'timestamp': 'datetime', + 'time': 'time', + 'float': 'float', + 'double': 'float', + 'double precision': 'float', + 'real': 'float', + 'decimal': 'decimal', + 'numeric': 'decimal', + 'blob': 'binary' + } + + def get_list_table_columns_sql(self, table): + table = table.replace('.', '__') + + return 'PRAGMA table_info(\'%s\')' % table + + def get_list_table_indexes_sql(self, table): + table = table.replace('.', '__') + + return 'PRAGMA index_list(\'%s\')' % table + + def get_list_table_foreign_keys_sql(self, table): + table = table.replace('.', '__') + + return 'PRAGMA foreign_key_list(\'%s\')' % table + + def get_alter_table_sql(self, diff): + """ + Get the ALTER TABLE SQL statement + + :param diff: The table diff + :type diff: eloquent.dbal.table_diff.TableDiff + + :rtype: list + """ + #sql = self._get_simple_alter_table_sql(diff) + + from_table = diff.from_table + if not isinstance(from_table, Table): + raise Exception('SQLite platform requires for the alter table the table diff ' + 'referencing the original table') + + table = from_table.clone() + columns = {} + old_column_names = {} + new_column_names = {} + column_sql = [] + for column_name, column in table.get_columns().items(): + column_name = column_name.lower() + columns[column_name] = column + old_column_names[column_name] = column.get_name() + new_column_names[column_name] = column.get_name() + + for column_name, column in diff.removed_columns.items(): + column_name = column_name.lower() + if column_name in columns: + del columns[column_name] + del old_column_names[column_name] + del new_column_names[column_name] + + for old_column_name, column in diff.renamed_columns.items(): + old_column_name = old_column_name.lower() + if old_column_name in columns: + del columns[old_column_name] + + columns[column.get_name().lower()] = column + + if old_column_name in new_column_names: + new_column_names[old_column_name] = column.get_name() + + for old_column_name, column_diff in diff.changed_columns.items(): + if old_column_name in columns: + del columns[old_column_name] + + columns[column_diff.column.get_name().lower()] = column_diff.column + + if old_column_name in new_column_names: + new_column_names[old_column_name] = column_diff.column.get_name() + + for column_name, column in diff.added_columns.items(): + columns[column_name.lower()] = column + + sql = [] + table_sql = [] + + data_table = Table('__temp__' + table.get_name()) + + new_table = Table(table.get_name(), columns, + self.get_primary_index_in_altered_table(diff), + self.get_foreign_keys_in_altered_table(diff)) + new_table.add_option('alter', True) + + sql = self.get_pre_alter_table_index_foreign_key_sql(diff) + sql.append('CREATE TEMPORARY TABLE %s AS SELECT %s FROM %s' + % (data_table.get_name(), ', '.join(old_column_names.values()), table.get_name())) + sql.append(self.get_drop_table_sql(from_table)) + + sql += self.get_create_table_sql(new_table) + sql.append('INSERT INTO %s (%s) SELECT %s FROM %s' + % (new_table.get_name(), + ', '.join(new_column_names.values()), + ', '.join(old_column_names.values()), + data_table.get_name())) + sql.append(self.get_drop_table_sql(data_table)) + + sql += self.get_post_alter_table_index_foreign_key_sql(diff) + + return sql + + def _get_simple_alter_table_sql(self, diff): + for old_column_name, column_diff in diff.changed_columns.items(): + if not isinstance(column_diff.from_column, Column)\ + or not isinstance(column_diff.column, Column)\ + or not column_diff.column.get_autoincrement()\ + or column_diff.column.get_type().lower() != 'integer': + continue + + if not column_diff.has_changed('type') and not column_diff.has_changed('unsigned'): + del diff.changed_columns[old_column_name] + + continue + + from_column_type = column_diff.column.get_type() + + if from_column_type == 'smallint' or from_column_type == 'bigint': + del diff.changed_columns[old_column_name] + + if any([not diff.renamed_columns, not diff.added_foreign_keys, not diff.added_indexes, + not diff.changed_columns, not diff.changed_foreign_keys, not diff.changed_indexes, + not diff.removed_columns, not diff.removed_foreign_keys, not diff.removed_indexes, + not diff.renamed_indexes]): + return False + + table = Table(diff.name) + + sql = [] + table_sql = [] + column_sql = [] + + for column in diff.added_columns.values(): + field = { + 'unique': None, + 'autoincrement': None, + 'default': None + } + field.update(column.to_dict()) + + def get_foreign_keys_in_altered_table(self, diff): + """ + :param diff: The table diff + :type diff: eloquent.dbal.table_diff.TableDiff + + :rtype: list + """ + foreign_keys = diff.from_table.get_foreign_keys() + column_names = self.get_column_names_in_altered_table(diff) + + for key, constraint in foreign_keys.items(): + changed = False + local_columns = [] + for column_name in constraint.get_local_columns(): + normalized_column_name = column_name.lower() + if normalized_column_name not in column_names: + del foreign_keys[key] + break + else: + local_columns.append(column_names[normalized_column_name]) + if column_name != column_names[normalized_column_name]: + changed = True + + if changed: + pass + + return foreign_keys + + def supports_foreign_key_constraints(self): + return False + + def get_column_options(self): + return ['pk'] diff --git a/eloquent/dbal/postgres_schema_manager.py b/eloquent/dbal/postgres_schema_manager.py new file mode 100644 index 00000000..e024a411 --- /dev/null +++ b/eloquent/dbal/postgres_schema_manager.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- + +import re +from .table import Table +from .column import Column +from .schema_manager import SchemaManager +from .platforms.postgres_platform import PostgresPlatform + + +class PostgresSchemaManager(SchemaManager): + + def _get_portable_table_column_definition(self, table_column): + if table_column['type'].lower() == 'varchar' or table_column['type'] == 'bpchar': + length = re.sub('.*\(([0-9]*)\).*', '\\1', table_column['complete_type']) + table_column['length'] = length + + autoincrement = False + match = re.match("^nextval\('?(.*)'?(::.*)?\)$", str(table_column['default'])) + if match: + table_column['sequence'] = match.group(1) + table_column['default'] = None + autoincrement = True + + match = re.match("^'?([^']*)'?::.*$", str(table_column['default'])) + if match: + table_column['default'] = match.group(1) + + if str(table_column['default']).find('NULL') == 0: + table_column['default'] = None + + if 'length' in table_column: + length = table_column['length'] + else: + length = None + + if length == '-1' and 'atttypmod' in table_column: + length = table_column['atttypmod'] - 4 + + if length is None or int(length) <= 0: + length = None + + fixed = None + + if 'name' not in table_column: + table_column['name'] = '' + + precision = None + scale = None + + db_type = table_column['type'].lower() + + type = self._platform.get_type_mapping(db_type) + + if db_type in ['smallint', 'int2']: + length = None + elif db_type in ['int', 'int4', 'integer']: + length = None + elif db_type in ['int8', 'bigint']: + length = None + elif db_type in ['bool', 'boolean']: + if table_column['default'] == 'true': + table_column['default'] = True + + if table_column['default'] == 'false': + table_column['default'] = False + + length = None + elif db_type == 'text': + fixed = False + elif db_type in ['varchar', 'interval', '_varchar']: + fixed = False + elif db_type in ['char', 'bpchar']: + fixed = True + elif db_type in ['float', 'float4', 'float8', + 'double', 'double precision', + 'real', 'decimal', 'money', 'numeric']: + match = re.match('([A-Za-z]+\(([0-9]+),([0-9]+)\))', table_column['complete_type']) + if match: + precision = match.group(1) + scale = match.group(2) + length = None + elif db_type == 'year': + length = None + + if table_column['default']: + match = re.match("('?([^']+)'?::)", str(table_column['default'])) + if match: + table_column['default'] = match.group(1) + + options = { + 'length': length, + 'notnull': table_column['isnotnull'], + 'default': table_column['default'], + 'primary': table_column['pri'] == 't', + 'precision': precision, + 'scale': scale, + 'fixed': fixed, + 'unsigned': False, + 'autoincrement': autoincrement + } + + column = Column(table_column['field'], type, options) + + return column + + def get_database_platform(self): + return PostgresPlatform() diff --git a/eloquent/dbal/schema_manager.py b/eloquent/dbal/schema_manager.py new file mode 100644 index 00000000..32b22455 --- /dev/null +++ b/eloquent/dbal/schema_manager.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +from .table import Table +from .column import Column + + +class SchemaManager(object): + + def __init__(self, connection, platform=None): + """ + :param connection: The connection to use + :type connection: eloquent.connection.Connection + + :param platform: The platform + :type platform: eloquent.dbal.platforms.Platform + """ + self._connection = connection + if not platform: + self._platform = self._connection.get_database_platform() + else: + self._platform = platform + + def list_table_columns(self, table): + sql = self._platform.get_list_table_columns_sql(table) + + cursor = self._connection.get_connection().cursor() + cursor.execute(sql) + table_columns = map(lambda x: dict(x.items()), cursor.fetchall()) + + return self._get_portable_table_columns_list(table, table_columns) + + def list_table_indexes(self, table): + sql = self._platform.get_list_table_indexes_sql(table) + + cursor = self._connection.get_connection().cursor() + table_indexes = cursor.execute(sql).fetchall() + + return table_indexes + + def list_table_foreign_keys(self, table): + sql = self._platform.get_list_table_foreign_keys_sql(table) + + cursor = self._connection.get_connection().cursor() + cursor.execute(sql) + table_foreign_keys = cursor.fetchall() + + return table_foreign_keys + + def list_table_details(self, table_name): + columns = self.list_table_columns(table_name) + + foreign_keys = {} + if self._platform.supports_foreign_key_constraints(): + foreign_keys = self.list_table_foreign_keys(table_name) + + #indexes = self.list_table_indexes(table_name) + + return Table(table_name, columns, [], foreign_keys) + + def _get_portable_table_columns_list(self, table, table_columns): + columns_list = {} + + for table_column in table_columns: + column = self._get_portable_table_column_definition(table_column) + + if column: + name = column.get_name().lower() + columns_list[name] = column + + return columns_list + + def _get_portable_table_column_definition(self, table_column): + raise NotImplementedError + + def get_database_platform(self): + raise NotImplementedError diff --git a/eloquent/dbal/sqlite_schema_manager.py b/eloquent/dbal/sqlite_schema_manager.py new file mode 100644 index 00000000..136a7743 --- /dev/null +++ b/eloquent/dbal/sqlite_schema_manager.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- + +from .schema_manager import SchemaManager +from .platforms.sqlite_platform import SQLitePlatform +from .column import Column + + +class SQLiteSchemaManager(SchemaManager): + + def list_table_columns(self, table): + sql = self._platform.get_list_table_columns_sql(table) + + cursor = self._connection.get_connection().cursor() + options = self._platform.get_column_options() + table_columns = [] + for column_info in cursor.execute(sql).fetchall(): + column_info = dict(column_info.items()) + column_info['default'] = column_info['dflt_value'] + + column = Column(column_info['name'], column_info['type'], column_info) + + column.set_platform_options({x: column_info[x] for x in options}) + + table_columns.append(column) + + return table_columns + + def list_table_indexes(self, table): + sql = self._platform.get_list_table_indexes_sql(table) + + cursor = self._connection.get_connection().cursor() + table_indexes = cursor.execute(sql).fetchall() + + indexes = [] + for index in table_indexes: + table_index = dict(index.items()) + index_info = cursor.execute('PRAGMA index_info(%s)' % index['name']).fetchall() + columns = [] + for column in index_info: + columns.append(column['name']) + + table_index['columns'] = columns + + indexes.append(table_index) + + return indexes + + def get_database_platform(self): + return SQLitePlatform() diff --git a/eloquent/dbal/table.py b/eloquent/dbal/table.py new file mode 100644 index 00000000..122fe0ec --- /dev/null +++ b/eloquent/dbal/table.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +from collections import OrderedDict +from .column import Column + + +class Table(object): + + def __init__(self, table_name, columns=None, indexes=None, fk_constraints=None): + self._name = table_name + self._columns = OrderedDict() + self._indexes = {} + self._fk_constraints = {} + + columns = columns or [] + indexes = indexes or [] + fk_constraints = fk_constraints or [] + + columns = columns.values() if isinstance(columns, dict) else columns + for column in columns: + self._add_column(column) + + for index in indexes: + self._add_index(index) + + for constraint in fk_constraints: + self._add_foreign_key_constraint(constraint) + + def get_columns(self): + columns = self._columns + + return columns + + def has_column(self, column): + return column in self._columns + + def get_column(self, column): + if self.has_column(column): + return self._columns[column] + + def change_column(self, column_name, options): + column = self.get_column(column_name) + column.set_options(options) + + return self + + def _add_column(self, column): + column_name = column.get_name() + + if column_name in self._columns: + raise Exception('Column %s already exists.' % column_name) + + self._columns[column_name] = column + + def _add_index(self, index): + index_name = index['name'] + + self._indexes[index_name] = index + + def _add_foreign_key_constraint(self, constraint): + name = constraint['name'] + + self._fk_constraints[name] = constraint + + def get_name(self): + return self._name + + def clone(self): + columns = [] + + for column in self._columns.values(): + columns.append(Column(column.get_name(), column.get_type(), column.to_dict())) + + table = Table(self._name, columns) + + return table diff --git a/eloquent/dbal/table_diff.py b/eloquent/dbal/table_diff.py new file mode 100644 index 00000000..d896940f --- /dev/null +++ b/eloquent/dbal/table_diff.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + + +class TableDiff(object): + + def __init__(self, table_name, added_columns=None, + changed_columns=None, removed_columns=None, added_indexes=None, + changed_indexes=None, removed_indexes=None, from_table=None): + self.name = table_name + self.new_name = False + self.added_columns = added_columns or {} + self.changed_columns = changed_columns or {} + self.removed_columns = removed_columns or {} + self.added_indexes = added_indexes or [] + self.changed_indexes = changed_indexes or [] + self.removed_indexes = removed_indexes or [] + self.added_foreign_keys = [] + self.changed_foreign_keys = [] + self.removed_foreign_keys = [] + self.renamed_columns = {} + self.renamed_indexes = {} + self.from_table = from_table diff --git a/eloquent/dbal/types/__init__.py b/eloquent/dbal/types/__init__.py new file mode 100644 index 00000000..633f8661 --- /dev/null +++ b/eloquent/dbal/types/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- + diff --git a/eloquent/exceptions/orm.py b/eloquent/exceptions/orm.py index 853ee8bd..e52dbf4b 100644 --- a/eloquent/exceptions/orm.py +++ b/eloquent/exceptions/orm.py @@ -14,3 +14,14 @@ def __str__(self): class MassAssignmentError(RuntimeError): pass + + +class RelatedClassNotFound(RuntimeError): + + def __init__(self, related): + self._related = related + + self.message = 'The related class for "%s" does not exists' % related + + def __str__(self): + return self.message diff --git a/eloquent/orm/__init__.py b/eloquent/orm/__init__.py index 78a85f0b..2ba8e248 100644 --- a/eloquent/orm/__init__.py +++ b/eloquent/orm/__init__.py @@ -2,3 +2,5 @@ from .builder import Builder from .model import Model +from .mixins import SoftDeletes +from .collection import Collection diff --git a/eloquent/orm/builder.py b/eloquent/orm/builder.py index fea4953b..85726e34 100644 --- a/eloquent/orm/builder.py +++ b/eloquent/orm/builder.py @@ -23,7 +23,7 @@ def __init__(self, query): self._model = None self._eager_load = {} - self._macros = [] + self._macros = {} self._on_delete = None @@ -755,6 +755,17 @@ def _parse_nested(self, name, results): return results + def _call_scope(self, scope, *args, **kwargs): + """ + Call the given model scope. + + :param scope: The scope to call + :type scope: str + """ + result = getattr(self._model, scope)(self, *args, **kwargs) + + return result or self + def get_query(self): """ Get the underlying query instance. @@ -816,10 +827,47 @@ def set_model(self, model): return self + def macro(self, name, callback): + """ + Extend the builder with the given callback. + + :param name: The extension name + :type name: str + + :param callback: The callback + :type callback: callable + """ + self._macros[name] = callback + + def get_macro(self, name): + """ + Get the given macro by name + + :param name: The macro name + :type name: str + :return: + """ + return self._macros.get(name) + def __dynamic(self, method): - attribute = getattr(self._query, method) + scope = 'scope_%s' % method + is_scope = False + is_macro = False + if hasattr(self._model, scope): + is_scope = True + attribute = getattr(self._model, scope) + elif method in self._macros: + is_macro = True + attribute = self._macros[method] + else: + attribute = getattr(self._query, method) def call(*args, **kwargs): + if is_scope: + return self._call_scope(scope, *args, **kwargs) + if is_macro: + return attribute(self, *args, **kwargs) + result = attribute(*args, **kwargs) if method in self._passthru: @@ -836,5 +884,5 @@ def __getattr__(self, item, *args): try: object.__getattribute__(self, item) except AttributeError: - # TODO: macros and scopes + # TODO: macros return self.__dynamic(item) diff --git a/eloquent/orm/mixins/__init__.py b/eloquent/orm/mixins/__init__.py new file mode 100644 index 00000000..ae958d67 --- /dev/null +++ b/eloquent/orm/mixins/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +from .soft_deletes import SoftDeletes diff --git a/eloquent/orm/mixins/soft_deletes.py b/eloquent/orm/mixins/soft_deletes.py new file mode 100644 index 00000000..55049f84 --- /dev/null +++ b/eloquent/orm/mixins/soft_deletes.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- + +from ..scopes import SoftDeletingScope + + +class SoftDeletes(object): + + _force_deleting = False + + @classmethod + def boot_soft_deletes(cls, klass): + """ + Boot the soft deleting mixin for a model. + """ + klass.add_global_scope(SoftDeletingScope()) + + def force_delete(self): + """ + Force a hard delete on a soft deleted model. + """ + self._force_deleting = True + + self.delete() + + self._force_deleting = False + + def _do_perform_delete_on_model(self): + """ + Perform the actual delete query on this model instance. + """ + if self._force_deleting: + return self.with_trashed().where(self.get_key_name(), self.get_key()).force_delete() + + return self._run_soft_delete() + + def _run_soft_delete(self): + """ + Perform the actual delete query on this model instance. + """ + query = self.new_query().where(self.get_key_name(), self.get_key()) + + time = self.fresh_timestamp() + setattr(self, self.get_deleted_at_column(), time) + + query.update({ + self.get_deleted_at_column(): self.from_datetime(time) + }) + + def restore(self): + """ + Restore a soft-deleted model instance. + """ + setattr(self, self.get_deleted_at_column(), None) + + self.set_exists(True) + + result = self.save() + + return result + + def trashed(self): + """ + Determine if the model instance has been soft-deleted + + :rtype: bool + """ + return getattr(self, self.get_deleted_at_column()) is not None + + @classmethod + def with_trashed(cls): + """ + Get a new query builder that includes soft deletes. + + :rtype: eloquent.orm.builder.Builder + """ + return cls().new_query_without_scope(SoftDeletingScope()) + + @classmethod + def only_trashed(cls): + """ + Get a new query builder that only includes soft deletes + + :type cls: eloquent.orm.model.Model + + :rtype: eloquent.orm.builder.Builder + """ + instance = cls() + + column = instance.get_qualified_deleted_at_column() + + return instance.new_query_without_scope(SoftDeletingScope()).where_not_null(column) + + def get_deleted_at_column(self): + """ + Get the name of the "deleted at" column. + + :rtype: str + """ + return getattr(self, 'DELETED_AT', 'deleted_at') + + def get_qualified_deleted_at_column(self): + """ + Get the fully qualified "deleted at" column. + + :rtype: str + """ + return '%s.%s' % (self.get_table(), self.get_deleted_at_column()) diff --git a/eloquent/orm/model.py b/eloquent/orm/model.py index 78e41eb5..62ab958e 100644 --- a/eloquent/orm/model.py +++ b/eloquent/orm/model.py @@ -5,7 +5,8 @@ import inflection import inspect from six import add_metaclass -from ..exceptions.orm import MassAssignmentError +from ..utils import basestring +from ..exceptions.orm import MassAssignmentError, RelatedClassNotFound from ..query import QueryBuilder from .builder import Builder from .collection import Collection @@ -46,6 +47,7 @@ class Model(object): __visible__ = [] __timestamps__ = True + __dates__ = [] __casts__ = {} @@ -56,6 +58,7 @@ class Model(object): _with = [] _booted = {} + _global_scopes = {} _registered = [] __resolver = None @@ -70,7 +73,6 @@ def __init__(self, **attributes): :param attributes: The instance attributes """ self.__exists = False - self.__dates = [] self.__original = {} self.__attributes = {} self.__relations = {} @@ -97,7 +99,62 @@ def _boot(cls): """ The booting method of the model. """ - # TODO + # TODO: mutators + + cls._boot_mixins() + + @classmethod + def _boot_mixins(cls): + """ + Boot the mixins + """ + for mixin in cls.__bases__: + method = 'boot_%s' % inflection.underscore(mixin.__name__) + if hasattr(mixin, method): + getattr(mixin, method)(cls) + + @classmethod + def add_global_scope(cls, scope): + """ + Register a new global scope on the model. + + :param scope: The scope to register + :type scope: eloquent.orm.scopes.scope.Scope + """ + if cls not in cls._global_scopes: + cls._global_scopes[cls] = {} + + cls._global_scopes[cls][scope.__class__] = scope + + @classmethod + def has_global_scope(cls, scope): + """ + Determine if a model has a global scope. + + :param scope: The scope to register + :type scope: eloquent.orm.scopes.scope.Scope + """ + return cls.get_global_scope(scope) is not None + + @classmethod + def get_global_scope(cls, scope): + """ + Get a global scope registered with the model. + + :param scope: The scope to register + :type scope: eloquent.orm.scopes.scope.Scope + """ + for key, value in cls._global_scopes[cls].items(): + if isinstance(scope, key): + return value + + def get_global_scopes(self): + """ + Get the global scopes for this class instance. + + :rtype: dict + """ + return self.__class__._global_scopes.get(self.__class__, {}) def fill(self, **attributes): """ @@ -496,7 +553,7 @@ def has_one(self, related, foreign_key=None, local_key=None): if not foreign_key: foreign_key = self.get_foreign_key() - instance = related() + instance = self._get_related(related)() if not local_key: local_key = self.get_key_name() @@ -526,7 +583,7 @@ def morph_one(self, related, name, type_column=None, id_column=None, local_key=N if name in self.__relations: return self.__relations[name] - instance = related() + instance = self._get_related(related)() type_column, id_column = self.get_morphs(name, type_column, id_column) @@ -565,7 +622,7 @@ def belongs_to(self, related, foreign_key=None, other_key=None, relation=None): if foreign_key is None: foreign_key = '%s_id' % inflection.underscore(relation) - instance = related() + instance = self._get_related(related)() query = instance.new_query() @@ -635,7 +692,7 @@ def has_many(self, related, foreign_key=None, local_key=None): if not foreign_key: foreign_key = self.get_foreign_key() - instance = related() + instance = self._get_related(related)() if not local_key: local_key = self.get_key_name() @@ -673,7 +730,8 @@ def has_many_through(self, related, through, first_key=None, second_key=None): if not second_key: second_key = through.get_foreign_key() - return HasManyThrough(related().new_query(), self, through, first_key, second_key) + return HasManyThrough(self._get_related(related)().new_query(), + self, through, first_key, second_key) def morph_many(self, related, name, type_column=None, id_column=None, local_key=None): """ @@ -693,7 +751,7 @@ def morph_many(self, related, name, type_column=None, id_column=None, local_key= :rtype: MorphMany """ - instance = related() + instance = self._get_related(related)() if name in self.__relations: return self.__relations[name] @@ -738,7 +796,7 @@ def belongs_to_many(self, related, table=None, foreign_key=None, other_key=None, if not foreign_key: foreign_key = self.get_foreign_key() - instance = related() + instance = self._get_related(related)() if not other_key: other_key = instance.get_foreign_key() @@ -779,7 +837,7 @@ def morph_to_many(self, related, name, table=None, foreign_key=None, other_key=N if not foreign_key: foreign_key = name + '_id' - instance = related() + instance = self._get_related(related)() if not other_key: other_key = instance.get_foreign_key() @@ -821,6 +879,25 @@ def morphed_by_many(self, related, name, table=None, foreign_key=None, other_key return self.morph_to_many(related, name, table, foreign_key, other_key, True) + def _get_related(self, related): + """ + Get the related class. + + :param related: The related model or table + :type related: Model or str + + :rtype: Model class + """ + if not isinstance(related, basestring) and issubclass(related, Model): + return related + + for cls in Model.__subclasses__(): + table = cls.__table__ or inflection.tableize(cls.__name__) + if table == related: + return cls + + raise RelatedClassNotFound(related) + def joining_table(self, related): """ Get the joining table name for a many-to-many relation @@ -896,6 +973,9 @@ def _perform_delete_on_model(self): """ Perform the actual delete query on this model instance. """ + if hasattr(self, '_do_perform_delete_on_model'): + return self._do_perform_delete_on_model() + return self.new_query().where(self.get_key_name(), self.get_key()).delete() # TODO: events @@ -1234,6 +1314,29 @@ def new_query(self): """ Get a new query builder for the model's table + :return: A Builder instance + :rtype: Builder + """ + builder = self.new_query_without_scopes() + + return self.apply_global_scopes(builder) + + def new_query_without_scope(self, scope): + """ + Get a new query builder for the model's table without a given scope + + :return: A Builder instance + :rtype: Builder + """ + builder = self.new_query() + self.get_global_scope(scope).remove(builder, self) + + return builder + + def new_query_without_scopes(self): + """ + Get a new query builder without any scopes. + :return: A Builder instance :rtype: Builder """ @@ -1243,6 +1346,34 @@ def new_query(self): return builder.set_model(self).with_(*self._with) + def apply_global_scopes(self, builder): + """ + Apply all of the global scopes to a builder. + + :param builder: A Builder instance + :type builder: Builder + + :rtype: Builder + """ + for scope in self.get_global_scopes().values(): + scope.apply(builder, self) + + return builder + + def remove_global_scopes(self, builder): + """ + Remove all of the global scopes from a builder. + + :param builder: A Builder instance + :type builder: Builder + + :rtype: Builder + """ + for scope in self.get_global_scopes().values(): + scope.remove(builder, self) + + return builder + @classmethod def query(cls): return cls().new_query() @@ -1737,7 +1868,12 @@ def _get_relationship_from_method(self, method, relations=None): if not isinstance(relations, Relation): raise RuntimeError('Relationship method must return an object of type Relation') - self.__relations[method] = DynamicProperty(relations.get_results, relations) + def results_getter(): + relations() + + return relations.get_results() + + self.__relations[method] = DynamicProperty(results_getter, relations) return self.__relations[method] @@ -1827,7 +1963,7 @@ def get_dates(self): """ defaults = [self.CREATED_AT, self.UPDATED_AT] - return self.__dates + defaults + return self.__dates__ + defaults def from_datetime(self, value): """ @@ -2117,7 +2253,7 @@ def set_connection_resolver(cls, resolver): cls.__resolver = resolver @classmethod - def unset_connection_resolver(cls, resolver): + def unset_connection_resolver(cls): """ Unset the connection resolver instance. """ diff --git a/eloquent/orm/relations/belongs_to_many.py b/eloquent/orm/relations/belongs_to_many.py index 676891d5..4f37be71 100644 --- a/eloquent/orm/relations/belongs_to_many.py +++ b/eloquent/orm/relations/belongs_to_many.py @@ -225,7 +225,7 @@ def get_relation_count_hash(self): :rtype: str """ - return 'self_%s' % (hashlib.md5(time.time()).hexdigest()) + return 'self_%s' % (hashlib.md5(str(time.time()).encode()).hexdigest()) def _get_select_columns(self, columns=None): """ diff --git a/eloquent/orm/relations/dynamic_property.py b/eloquent/orm/relations/dynamic_property.py index 6e2edab2..f13ae01d 100644 --- a/eloquent/orm/relations/dynamic_property.py +++ b/eloquent/orm/relations/dynamic_property.py @@ -20,7 +20,16 @@ def __init__(self, results_getter, relation): self._results = None self._relation = relation - def get_results(self): + def refresh(self): + self._results = self._results_getter() + + return self._results + + @property + def instance(self): + if not self._results: + self._results = self._results_getter() + return self._results def __getitem__(self, item): @@ -35,6 +44,12 @@ def __iter__(self): return iter(self._results) + def __len__(self): + if not self._results: + self._results = self._results_getter() + + return len(self._results) + def __getattr__(self, item): if not self._results: self._results = self._results_getter() diff --git a/eloquent/orm/scopes/__init__.py b/eloquent/orm/scopes/__init__.py new file mode 100644 index 00000000..f0249ba4 --- /dev/null +++ b/eloquent/orm/scopes/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from .scope import Scope +from .soft_deleting import SoftDeletingScope diff --git a/eloquent/orm/scopes/scope.py b/eloquent/orm/scopes/scope.py new file mode 100644 index 00000000..c39554dc --- /dev/null +++ b/eloquent/orm/scopes/scope.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- + + +class Scope(object): + + def apply(self, builder, model): + """ + Apply the scope to a given query builder. + + :param builder: The query builder + :type builder: eloquent.orm.Builder + + :param model: The model + :type model: eloquent.orm.Model + """ + raise NotImplementedError + + def remove(self, builder, model): + """ + Remove the scope from a given query builder. + + :param builder: The query builder + :type builder: eloquent.orm.Builder + + :param model: The model + :type model: eloquent.orm.Model + """ + raise NotImplementedError diff --git a/eloquent/orm/scopes/soft_deleting.py b/eloquent/orm/scopes/soft_deleting.py new file mode 100644 index 00000000..6460c8d0 --- /dev/null +++ b/eloquent/orm/scopes/soft_deleting.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- + +from .scope import Scope + + +class SoftDeletingScope(Scope): + + _extensions = ['force_delete', 'restore', 'with_trashed', 'only_trashed'] + + def apply(self, builder, model): + """ + Apply the scope to a given query builder. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + + :param model: The model + :type model: eloquent.orm.Model + """ + builder.where_null(model.get_qualified_deleted_at_column()) + + self.extend(builder) + + def remove(self, builder, model): + """ + Remove the scope from a given query builder. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + + :param model: The model + :type model: eloquent.orm.Model + """ + column = model.get_qualified_deleted_at_column() + + query = builder.get_query() + + wheres = [] + for where in query.wheres: + # If the where clause is a soft delete date constraint, + # we will remove it from the query and reset the keys + # on the wheres. This allows the developer to include + # deleted model in a relationship result set that is lazy loaded. + if not self._is_soft_delete_constraint(where, column): + wheres.append(where) + + query.wheres = wheres + + def extend(self, builder): + """ + Extend the query builder with the needed functions. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + """ + for extension in self._extensions: + getattr(self, '_add_%s' % extension)(builder) + + builder.on_delete(self._on_delete) + + def _on_delete(self, builder): + """ + The delete replacement function. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + """ + column = self._get_deleted_at_column(builder) + + return builder.update({ + column: builder.get_model().fresh_timestamp() + }) + + def _get_deleted_at_column(self, builder): + """ + Get the "deleted at" column for the builder. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + + :rtype: str + """ + if len(builder.get_query().joins) > 0: + return builder.get_model().get_qualified_deleted_at_column() + else: + return builder.get_model().get_deleted_at_column() + + def _add_force_delete(self, builder): + """ + Add the force delete extension to the builder. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + """ + builder.macro('force_delete', self._force_delete) + + def _force_delete(self, builder): + """ + The forece delete extension. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + """ + return builder.get_query().delete() + + def _add_restore(self, builder): + """ + Add the restore extension to the builder. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + """ + builder.macro('restore', self._restore) + + def _restore(self, builder): + """ + The restore extension. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + """ + builder.with_trashed() + + return builder.update({ + builder.get_model().get_deleted_at_column(): None + }) + + def _add_with_trashed(self, builder): + """ + Add the with-trashed extension to the builder. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + """ + builder.macro('with_trashed', self._with_trashed) + + def _with_trashed(self, builder): + """ + The with-trashed extension. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + """ + self.remove(builder, builder.get_model()) + + return builder + + def _add_only_trashed(self, builder): + """ + Add the only-trashed extension to the builder. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + """ + builder.macro('only_trashed', self._only_trashed) + + def _only_trashed(self, builder): + """ + The only-trashed extension. + + :param builder: The query builder + :type builder: eloquent.orm.builder.Builder + """ + model = builder.get_model() + + self.remove(builder, model) + + builder.get_query().where_not_null(model.get_qualified_deleted_at_column()) + + def _is_soft_delete_constraint(self, where, column): + """ + Determine if the given where clause is a soft delete constraint. + + :param where: The where clause + :type where: dict + + :param column: The column + :type column: str + + :rtype: bool + """ + return where['type'] == 'null' and where['column'] == column diff --git a/eloquent/query/grammars/__init__.py b/eloquent/query/grammars/__init__.py index 3c452fa0..71dd7096 100644 --- a/eloquent/query/grammars/__init__.py +++ b/eloquent/query/grammars/__init__.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from .base import BaseQueryGrammar from .grammar import QueryGrammar from .postgres_grammar import PostgresQueryGrammar from .mysql_grammar import MySqlQueryGrammar diff --git a/eloquent/query/grammars/grammar.py b/eloquent/query/grammars/grammar.py index 0c915f0c..1afdfe0b 100644 --- a/eloquent/query/grammars/grammar.py +++ b/eloquent/query/grammars/grammar.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- import re -from .base import BaseQueryGrammar +from ...support.grammar import Grammar from ..builder import QueryBuilder from ...utils import basestring -class QueryGrammar(BaseQueryGrammar): +class QueryGrammar(Grammar): _select_components = [ 'aggregate_', diff --git a/eloquent/schema/__init__.py b/eloquent/schema/__init__.py new file mode 100644 index 00000000..50239fad --- /dev/null +++ b/eloquent/schema/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .builder import SchemaBuilder +from .mysql_builder import MySqlSchemaBuilder +from .blueprint import Blueprint +from .schema import Schema diff --git a/eloquent/schema/blueprint.py b/eloquent/schema/blueprint.py new file mode 100644 index 00000000..5dbeea00 --- /dev/null +++ b/eloquent/schema/blueprint.py @@ -0,0 +1,782 @@ +# -*- coding: utf-8 -*- + +from ..support.fluent import Fluent + + +class Blueprint(object): + + def __init__(self, table): + """ + :param table: The table to operate on + :type table: str + """ + self._table = table + self._columns = [] + self._commands = [] + self.engine = None + self.charset = None + self.collation = None + + def build(self, connection, grammar): + """ + Execute the blueprint against the database. + + :param connection: The connection to use + :type connection: eloquent.connections.Connection + + :param grammar: The grammar to user + :type grammar: eloquent.query.grammars.QueryGrammar + """ + for statement in self.to_sql(connection, grammar): + connection.statement(statement) + + def to_sql(self, connection, grammar): + """ + Get the raw SQL statements for the blueprint. + + :param connection: The connection to use + :type connection: eloquent.connections.Connection + + :param grammar: The grammar to user + :type grammar: eloquent.schema.grammars.SchemaGrammar + + :rtype: list + """ + self._add_implied_commands() + + statements = [] + + for command in self._commands: + method = 'compile_%s' % command.name + + if hasattr(grammar, method): + sql = getattr(grammar, method)(self, command, connection) + if sql is not None: + if isinstance(sql, list): + statements += sql + else: + statements.append(sql) + + return statements + + def _add_implied_commands(self): + """ + Add the commands that are implied by the blueprint. + """ + if len(self.get_added_columns()) and not self._creating(): + self._commands.insert(0, self._create_command('add')) + + if len(self.get_changed_columns()) and not self._creating(): + self._commands.insert(0, self._create_command('change')) + + return self._add_fluent_indexes() + + def _add_fluent_indexes(self): + """ + Add the index commands fluently specified on columns: + """ + for column in self._columns: + for index in ['primary', 'unique', 'index']: + column_index = column.get(index) + + if column_index is True: + getattr(self, index)(column.name) + + break + elif column_index: + getattr(self, index)(column.name, column_index) + + break + + def _creating(self): + """ + Determine if the blueprint has a create command. + + :rtype: bool + """ + for command in self._commands: + if command.name == 'create': + return True + + return False + + def create(self): + """ + Indicates that the table needs to be created. + + :rtype: Fluent + """ + return self._add_command('create') + + def drop(self): + """ + Indicates that the table needs to be dropped. + + :rtype: Fluent + """ + self._add_command('drop') + + return self + + def drop_if_exists(self): + """ + Indicates that the table should be dropped if it exists. + + :rtype: Fluent + """ + return self._add_command('drop_if_exists') + + def drop_column(self, *columns): + """ + Indicates that the given columns should be dropped. + + :param columns: The columns to drop + :type columns: tuple + + :rtype: Fluent + """ + columns = list(columns) + + return self._add_command('drop_column', columns=columns) + + def rename_column(self, from_, to): + """ + Indicates that the given columns should be renamed. + + :param from_: The original column name + :type from_: str + :param to: The new name of the column + :type to: str + + :rtype: Fluent + """ + return self._add_command('rename_column', **{'from_': from_, 'to': to}) + + def drop_primary(self, index=None): + """ + Indicate that the given primary key should be dropped. + + :param index: The index + :type index: str + + :rtype: dict + """ + return self._drop_index_command('drop_primary', 'primary', index) + + def drop_unique(self, index): + """ + Indicate that the given unique key should be dropped. + + :param index: The index + :type index: str + + :rtype: Fluent + """ + return self._drop_index_command('drop_unique', 'unique', index) + + def drop_index(self, index): + """ + Indicate that the given index should be dropped. + + :param index: The index + :type index: str + + :rtype: Fluent + """ + return self._drop_index_command('drop_index', 'index', index) + + def drop_foreign(self, index): + """ + Indicate that the given foreign key should be dropped. + + :param index: The index + :type index: str + + :rtype: dict + """ + return self._drop_index_command('drop_foreign', 'foreign', index) + + def drop_timestamps(self): + """ + Indicate that the timestamp columns should be dropped. + + :rtype: Fluent + """ + return self.drop_column('created_at', 'updated_at') + + def drop_soft_deletes(self): + """ + Indicate that the soft delete column should be dropped + + :rtype: Fluent + """ + return self.drop_column('deleted_at') + + def rename(self, to): + """ + Rename the table to a given name + + :param to: The new table name + :type to: str + + :rtype: Fluent + """ + return self._add_command('rename', to=to) + + def primary(self, columns, name=None): + """ + Specify the primary key(s) for the table + + :param columns: The primary key(s) columns + :type columns: str or list + + :param name: The name of the primary key + :type name: str + + :rtype: Fluent + """ + return self._index_command('primary', columns, name) + + def unique(self, columns, name=None): + """ + Specify a unique index on the table + + :param columns: The primary key(s) columns + :type columns: str or list + + :param name: The name of the primary key + :type name: str + + :rtype: Fluent + """ + return self._index_command('unique', columns, name) + + def index(self, columns, name=None): + """ + Specify an index on the table + + :param columns: The primary key(s) columns + :type columns: str or list + + :param name: The name of the primary key + :type name: str + + :rtype: Fluent + """ + return self._index_command('index', columns, name) + + def foreign(self, columns, name=None): + """ + Specify an foreign key on the table + + :param columns: The foreign key(s) columns + :type columns: str or list + + :param name: The name of the foreign key + :type name: str + + :rtype: Fluent + """ + return self._index_command('foreign', columns, name) + + def increments(self, column): + """ + Create a new auto-incrementing integer column on the table. + + :param column: The auto-incrementing column + :type column: str + + :rtype: Fluent + """ + return self.unsigned_integer(column, True) + + def big_increments(self, column): + """ + Create a new auto-incrementing big integer column on the table. + + :param column: The auto-incrementing column + :type column: str + + :rtype: Fluent + """ + return self.unsigned_big_integer(column, True) + + def char(self, column, length=255): + """ + Create a new char column on the table. + + :param column: The column + :type column: str + + :rtype: Fluent + """ + return self._add_command('char', column, length=length) + + def string(self, column, length=255): + """ + Create a new string column on the table. + + :param column: The column + :type column: str + + :rtype: Fluent + """ + return self._add_column('string', column, length=length) + + def text(self, column): + """ + Create a new text column on the table. + + :param column: The column + :type column: str + + :rtype: Fluent + """ + return self._add_column('text', column) + + def medium_text(self, column): + """ + Create a new medium text column on the table. + + :param column: The column + :type column: str + + :rtype: Fluent + """ + return self._add_column('medium_text', column) + + def long_text(self, column): + """ + Create a new long text column on the table. + + :param column: The column + :type column: str + + :rtype: Fluent + """ + return self._add_column('long_text', column) + + def integer(self, column, auto_increment=False, unsigned=False): + """ + Create a new integer column on the table. + + :param column: The column + :type column: str + + :type auto_increment: bool + + :type unsigned: bool + + :rtype: Fluent + """ + return self._add_column('integer', column, + auto_increment=auto_increment, + unsigned=unsigned) + + def big_integer(self, column, auto_increment=False, unsigned=False): + """ + Create a new big integer column on the table. + + :param column: The column + :type column: str + + :type auto_increment: bool + + :type unsigned: bool + + :rtype: Fluent + """ + return self._add_column('big_integer', column, + auto_increment=auto_increment, + unsigned=unsigned) + + def medium_integer(self, column, auto_increment=False, unsigned=False): + """ + Create a new medium integer column on the table. + + :param column: The column + :type column: str + + :type auto_increment: bool + + :type unsigned: bool + + :rtype: Fluent + """ + return self._add_column('medium_integer', column, + auto_increment=auto_increment, + unsigned=unsigned) + + def tiny_integer(self, column, auto_increment=False, unsigned=False): + """ + Create a new tiny integer column on the table. + + :param column: The column + :type column: str + + :type auto_increment: bool + + :type unsigned: bool + + :rtype: Fluent + """ + return self._add_column('tiny_integer', column, + auto_increment=auto_increment, + unsigned=unsigned) + + def small_integer(self, column, auto_increment=False, unsigned=False): + """ + Create a new small integer column on the table. + + :param column: The column + :type column: str + + :type auto_increment: bool + + :type unsigned: bool + + :rtype: Fluent + """ + return self._add_column('small_integer', column, + auto_increment=auto_increment, + unsigned=unsigned) + + def unsigned_integer(self, column, auto_increment=False): + """ + Create a new unisgned integer column on the table. + + :param column: The column + :type column: str + + :type auto_increment: bool + + :rtype: Fluent + """ + return self.integer(column, auto_increment, True) + + def unsigned_big_integer(self, column, auto_increment=False): + """ + Create a new unsigned big integer column on the table. + + :param column: The column + :type column: str + + :type auto_increment: bool + + :rtype: Fluent + """ + return self.big_integer(column, auto_increment, True) + + def float(self, column, total=8, places=2): + """ + Create a new float column on the table. + + :param column: The column + :type column: str + + :type total: int + + :type places: 2 + + :rtype: Fluent + """ + return self._add_column('float', column, total=total, places=places) + + def double(self, column, total=None, places=None): + """ + Create a new double column on the table. + + :param column: The column + :type column: str + + :type total: int + + :type places: 2 + + :rtype: Fluent + """ + return self._add_column('double', column, total=total, places=places) + + def decimal(self, column, total=8, places=2): + """ + Create a new decimal column on the table. + + :param column: The column + :type column: str + + :type total: int + + :type places: 2 + + :rtype: Fluent + """ + return self._add_column('decimal', column, total=total, places=places) + + def boolean(self, column): + """ + Create a new decimal column on the table. + + :param column: The column + :type column: str + + :rtype: Fluent + """ + return self._add_column('boolean', column) + + def enum(self, column, allowed): + """ + Create a new enum column on the table. + + :param column: The column + :type column: str + + :type allowed: list + + :rtype: Fluent + """ + return self._add_column('enum', column, allowed=allowed) + + def json(self, column): + """ + Create a new json column on the table. + + :param column: The column + :type column: str + + :rtype: Fluent + """ + return self._add_column('json', column) + + def date(self, column): + """ + Create a new date column on the table. + + :param column: The column + :type column: str + + :rtype: Fluent + """ + return self._add_column('date', column) + + def datetime(self, column): + """ + Create a new datetime column on the table. + + :param column: The column + :type column: str + + :rtype: Fluent + """ + return self._add_column('datetime', column) + + def time(self, column): + """ + Create a new time column on the table. + + :param column: The column + :type column: str + + :rtype: Fluent + """ + return self._add_column('time', column) + + def timestamp(self, column): + """ + Create a new timestamp column on the table. + + :param column: The column + :type column: str + + :rtype: Fluent + """ + return self._add_column('timestamp', column) + + def nullable_timestamps(self): + """ + Create nullable creation and update timestamps to the table. + + :rtype: Fluent + """ + self.timestamp('created_at').nullable() + self.timestamp('updated_at').nullable() + + def timestamps(self): + """ + Create creation and update timestamps to the table. + + :rtype: Fluent + """ + self.timestamp('created_at') + self.timestamp('updated_at') + + def soft_deletes(self): + """ + Add a "deleted at" timestamp to the table. + + :rtype: Fluent + """ + return self.timestamp('deleted_at').nullable() + + def binary(self, column): + """ + Create a new binary column on the table. + + :param column: The column + :type column: str + + :rtype: Fluent + """ + return self._add_column('binary', column) + + def morphs(self, name, index_name=None): + """ + Add the proper columns for a polymorphic table. + + :type name: str + + :type index_name: str + """ + self.unsigned_integer('%s_id' % name) + self.string('%s_type' % name) + self.index(['%s_id' % name, '%s_type' % name], index_name) + + def _drop_index_command(self, command, type, index): + """ + Create a new drop index command on the blueprint. + + :param command: The command + :type command: str + + :param type: The index type + :type type: str + + :param index: The index name + :type index: str + + :rtype: Fluent + """ + columns = [] + + if isinstance(index, list): + columns = index + + index = self._create_index_name(type, columns) + + return self._index_command(command, columns, index) + + def _index_command(self, type, columns, index): + """ + Add a new index command to the blueprint. + + :param type: The index type + :type type: str + + :param columns: The index columns + :type columns: list or str + + :param index: The index name + :type index: str + + :rtype: Fluent + """ + if not isinstance(columns, list): + columns = [columns] + + if not index: + index = self._create_index_name(type, columns) + + return self._add_command(type, index=index, columns=columns) + + def _create_index_name(self, type, columns): + if not isinstance(columns, list): + columns = [columns] + + index = '%s_%s_%s' % (self._table, '_'.join(columns), type) + + return index.lower().replace('-', '_').replace('.', '_') + + def _add_column(self, type, name, **parameters): + """ + Add a new column to the blueprint. + + :param type: The column type + :type type: str + + :param name: The column name + :type name: str + + :param parameters: The column parameters + :type parameters: dict + + :rtype: Fluent + """ + parameters.update({ + 'type': type, + 'name': name + }) + + column = Fluent(**parameters) + self._columns.append(column) + + return column + + def _remove_column(self, name): + """ + Removes a column from the blueprint. + + :param name: The column name + :type name: str + + :rtype: Blueprint + """ + self._columns = filter(lambda c: c.name != name, self._columns) + + return self + + def _add_command(self, name, **parameters): + """ + Add a new command to the blueprint. + + :param name: The command name + :type name: str + + :param parameters: The command parameters + :type parameters: dict + + :rtype: Fluent + """ + command = self._create_command(name, **parameters) + self._commands.append(command) + + return command + + def _create_command(self, name, **parameters): + """ + Create a new command. + + :param name: The command name + :type name: str + + :param parameters: The command parameters + :type parameters: dict + + :rtype: Fluent + """ + parameters.update({'name': name}) + + return Fluent(**parameters) + + def get_table(self): + return self._table + + def get_columns(self): + return self._columns + + def get_commands(self): + return self._commands + + def get_added_columns(self): + return list(filter(lambda column: not column.get('change'), self._columns)) + + def get_changed_columns(self): + return list(filter(lambda column: column.get('change'), self._columns)) diff --git a/eloquent/schema/builder.py b/eloquent/schema/builder.py new file mode 100644 index 00000000..6bf8b32e --- /dev/null +++ b/eloquent/schema/builder.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- + +from contextlib import contextmanager +from .blueprint import Blueprint + + +class SchemaBuilder(object): + + def __init__(self, connection): + """ + :param connection: The schema connection + :type connection: eloquent.connections.Connection + """ + self._connection = connection + self._grammar = connection.get_schema_grammar() + + def has_table(self, table): + """ + Determine if the given table exists. + + :param table: The table + :type table: str + + :rtype: bool + """ + sql = self._grammar.compile_table_exists() + + table = self._connection.get_table_prefix() + table + + return len(self._connection.select(sql, [table])) > 0 + + def has_column(self, table, column): + """ + Determine if the given table has a given column. + + :param table: The table + :type table: str + + :type column: str + + :rtype: bool + """ + column = column.lower() + + return column in list(map(lambda x: x.lower(), self.get_colummn_listing())) + + def get_column_listing(self, table): + """ + Get the column listing for a given table. + + :param table: The table + :type table: str + + :rtype: list + """ + table = self._connection.get_table_prefix() + table + + results = self._connection.select(self._grammar.compile_column_exists(table)) + + return self._connection.get_post_processor().process_column_listing(results) + + @contextmanager + def table(self, table): + """ + Modify a table on the schema. + + :param table: The table + """ + try: + blueprint = self._create_blueprint(table) + + yield blueprint + except Exception as e: + raise + + try: + self._build(blueprint) + except Exception: + raise + + @contextmanager + def create(self, table): + """ + Create a new table on the schema. + + :param table: The table + :type table: str + + :rtype: Blueprint + """ + try: + blueprint = self._create_blueprint(table) + blueprint.create() + + yield blueprint + except Exception as e: + raise + + try: + self._build(blueprint) + except Exception: + raise + + def drop(self, table): + """ + Drop a table from the schema. + + :param table: The table + :type table: str + """ + blueprint = self._create_blueprint(table) + + blueprint.drop() + + self._build(blueprint) + + def drop_if_exists(self, table): + """ + Drop a table from the schema. + + :param table: The table + :type table: str + """ + blueprint = self._create_blueprint(table) + + blueprint.drop_if_exists() + + self._build(blueprint) + + def rename(self, from_, to): + """ + Rename a table on the schema. + """ + blueprint = self._create_blueprint(from_) + + blueprint.rename(to) + + self._build(blueprint) + + def _build(self, blueprint): + """ + Execute the blueprint to build / modify the table. + + :param blueprint: The blueprint + :type blueprint: eloquent.schema.Blueprint + """ + blueprint.build(self._connection, self._grammar) + + def _create_blueprint(self, table): + return Blueprint(table) + + def get_connection(self): + return self._connection + + def set_connection(self, connection): + self._connection = connection + + return self diff --git a/eloquent/schema/grammars/__init__.py b/eloquent/schema/grammars/__init__.py new file mode 100644 index 00000000..99d16e91 --- /dev/null +++ b/eloquent/schema/grammars/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .grammar import SchemaGrammar +from .sqlite_grammar import SQLiteSchemaGrammar +from .postgres_grammar import PostgresSchemaGrammar +from .mysql_grammar import MySqlSchemaGrammar diff --git a/eloquent/schema/grammars/grammar.py b/eloquent/schema/grammars/grammar.py new file mode 100644 index 00000000..1943c03b --- /dev/null +++ b/eloquent/schema/grammars/grammar.py @@ -0,0 +1,317 @@ +# -*- coding: utf-8 -*- + +from ...support.grammar import Grammar +from ...support.fluent import Fluent +from ...query.expression import QueryExpression +from ...dbal.column import Column +from ...dbal.table_diff import TableDiff +from ...dbal.comparator import Comparator +from ..blueprint import Blueprint + + +class SchemaGrammar(Grammar): + + def compile_rename_column(self, blueprint, command, connection): + """ + Compile a rename column command. + + :param blueprint: The blueprint + :type blueprint: Blueprint + + :param command: The command + :type command: Fluent + + :param connection: The connection + :type connection: eloquent.connections.Connection + + :rtype: list + """ + schema = connection.get_schema_manager() + + table = self.get_table_prefix() + blueprint.get_table() + + column = connection.get_column(table, command.from_) + + table_diff = self._get_renamed_diff(blueprint, command, column, schema) + + return schema.get_database_platform().get_alter_table_sql(table_diff) + + def _get_renamed_diff(self, blueprint, command, column, schema): + """ + Get a new column instance with the new column name. + + :param blueprint: The blueprint + :type blueprint: Blueprint + + :param command: The command + :type command: Fluent + + :param column: The column + :type column: eloquent.dbal.Column + + :param schema: The schema + :type schema: eloquent.dbal.SchemaManager + + :rtype: eloquent.dbal.TableDiff + """ + table_diff = self._get_table_diff(blueprint, schema) + + return self._set_renamed_columns(table_diff, command, column) + + def _set_renamed_columns(self, table_diff, command, column): + """ + Set the renamed columns on the table diff. + + :rtype: eloquent.dbal.TableDiff + """ + new_column = Column(command.to, column.get_type(), column.to_dict()) + + table_diff.renamed_columns = {command.from_: new_column} + + return table_diff + + def compile_foreign(self, blueprint, command, _): + """ + Compile a foreign key command. + + :param blueprint: The blueprint + :type blueprint: Blueprint + + :param command: The command + :type command: Fluent + + :rtype: str + """ + table = self.wrap_table(blueprint) + + on = self.wrap_table(command.on) + + columns = self.columnize(command.columns) + + on_columns = self.columnize(command.references + if isinstance(command.references, list) + else [command.references]) + + sql = 'ALTER TABLE %s ADD CONSTRAINT %s ' % (table, command.index) + + sql += 'FOREIGN KEY (%s) REFERENCES %s (%s)' % (columns, on, on_columns) + + if command.get('on_delete'): + sql += ' ON DELETE %s' % command.on_delete + + if command.get('on_update'): + sql += ' ON UPDATE %s' % command.on_update + + return sql + + def _get_columns(self, blueprint): + """ + Get the blueprint's columns definitions. + + :param blueprint: The blueprint + :type blueprint: Blueprint + + :rtype: list + """ + columns = [] + + for column in blueprint.get_added_columns(): + sql = self.wrap(column) + ' ' + self._get_type(column) + + columns.append(self._add_modifiers(sql, blueprint, column)) + + return columns + + def _add_modifiers(self, sql, blueprint, column): + """ + Add the column modifiers to the deifinition + """ + for modifier in self._modifiers: + method = '_modify_%s' % modifier + + if hasattr(self, method): + sql += getattr(self, method)(blueprint, column) + + return sql + + def _get_command_by_name(self, blueprint, name): + """ + Get the primary key command it it exists. + """ + commands = self._get_commands_by_name(blueprint, name) + + if len(commands): + return commands[0] + + def _get_commands_by_name(self, blueprint, name): + """ + Get all of the commands with a given name. + """ + return list(filter(lambda value: value.name == name, blueprint.get_commands())) + + def _get_type(self, column): + """ + Get the SQL for the column data type. + + :param column: The column + :type column: Fluent + + :rtype sql + """ + return getattr(self, '_type_%s' % column.type)(column) + + def prefix_list(self, prefix, values): + """ + Add a prefix to a list of values. + """ + return list(map(lambda value: prefix + ' ' + value, values)) + + def wrap_table(self, table): + if isinstance(table, Blueprint): + table = table.get_table() + + return super(SchemaGrammar, self).wrap_table(table) + + def wrap(self, value, prefix_alias=False): + if isinstance(value, Fluent): + value = value.name + + return super(SchemaGrammar, self).wrap(value, prefix_alias) + + def _get_default_value(self, value): + """ + Format a value so that it can be used in "default" clauses. + """ + if isinstance(value, QueryExpression): + return value + + if isinstance(value, bool): + return "'%s'" % int(value) + + return "'%s'" % value + + def _get_table_diff(self, blueprint, schema): + table = self.get_table_prefix() + blueprint.get_table() + + table_diff = TableDiff(table) + + table_diff.from_table = schema.list_table_details(table) + + return table_diff + + def compile_change(self, blueprint, command, connection): + """ + Compile a change column command into a series of SQL statement. + + :param blueprint: The blueprint + :type blueprint: Blueprint + + :param command: The command + :type command: Fluent + + :param connection: The connection + :type connection: eloquent.connections.Connection + + :rtype: list + """ + schema = connection.get_schema_manager() + + table_diff = self._get_changed_diff(blueprint, schema) + + if table_diff: + sql = schema.get_database_platform().get_alter_table_sql(table_diff) + + if isinstance(sql, list): + return sql + + return [sql] + + return [] + + def _get_changed_diff(self, blueprint, schema): + """ + Get the table diffrence for the given changes. + + :param blueprint: The blueprint + :type blueprint: Blueprint + + :param schema: The schema + :type schema: eloquent.dbal.SchemaManager + + :rtype: eloquent.dbal.TableDiff + """ + table = schema.list_table_details(self.get_table_prefix() + blueprint.get_table()) + + return Comparator().diff_table(table, self._get_table_with_column_changes(blueprint, table)) + + def _get_table_with_column_changes(self, blueprint, table): + """ + Get a copy of the given table after making the column changes. + + :param blueprint: The blueprint + :type blueprint: Blueprint + + :type table: eloquent.dbal.table.Table + + :rtype: eloquent.dbal.table.Table + """ + table = table.clone() + + for fluent in blueprint.get_changed_columns(): + column = self._get_column_for_change(table, fluent) + + for key, value in fluent.get_attributes().items(): + option = self._map_fluent_option(key) + + if option is not None: + method = 'set_%s' % option + + if hasattr(column, method): + getattr(column, method)(self._map_fluent_value(option, value)) + + return table + + def _get_column_for_change(self, table, fluent): + """ + Get the column instance for a column change. + + :type table: eloquent.dbal.table.Table + + :rtype: eloquent.dbal.column.Column + """ + return table.change_column( + fluent.name, self._get_column_change_options(fluent) + ).get_column(fluent.name) + + def _get_column_change_options(self, fluent): + """ + Get the column change options. + """ + options = { + 'name': fluent.name, + 'type': fluent.type, + 'default': fluent.get('default') + } + + if fluent.type in ['string']: + options['length'] = fluent.length + + return options + + def _map_fluent_option(self, attribute): + if attribute in ['type', 'name']: + return + elif attribute == 'nullable': + return 'notnull' + elif attribute == 'total': + return 'precision' + elif attribute == 'places': + return 'scale' + else: + return + + def _map_fluent_value(self, option, value): + if option == 'notnull': + return not value + + return value diff --git a/eloquent/schema/grammars/mysql_grammar.py b/eloquent/schema/grammars/mysql_grammar.py new file mode 100644 index 00000000..0d627bec --- /dev/null +++ b/eloquent/schema/grammars/mysql_grammar.py @@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- + +from .grammar import SchemaGrammar +from ..blueprint import Blueprint +from ...query.expression import QueryExpression +from ...support.fluent import Fluent + + +class MySqlSchemaGrammar(SchemaGrammar): + + _modifiers = [ + 'unsigned', 'charset', 'collate', 'nullable', + 'default', 'increment', 'comment', 'after' + ] + + _serials = ['big_integer', 'integer', + 'medium_integer', 'small_integer', 'tiny_integer'] + + def compile_table_exists(self): + """ + Compile the query to determine if a table exists + + :rtype: str + """ + return 'SELECT * FROM information_schema.tables WHERE table_name = %s' + + def compile_column_exists(self, table): + """ + Compile the query to determine the list of columns. + """ + return 'SELECT column_name FROM information_schema.columns WHERE table_name = %s' % table + + def compile_create(self, blueprint, command, connection): + """ + Compile a create table command. + """ + columns = ', '.join(self._get_columns(blueprint)) + + sql = 'CREATE TABLE %s (%s)' % (self.wrap_table(blueprint), columns) + + sql = self._compile_create_encoding(sql, connection, blueprint) + + if blueprint.engine: + sql += ' ENGINE = %s' % blueprint.engine + + return sql + + def _compile_create_encoding(self, sql, connection, blueprint): + """ + Append the character set specifications to a command. + + :type sql: str + :type connection: eloquent.connections.Connection + :type blueprint: Blueprint + + :rtype: str + """ + charset = blueprint.charset or connection.get_config('charset') + if charset: + sql += ' DEFAULT CHARACTER SET %s' % charset + + collation = blueprint.collation or connection.get_config('collation') + if collation: + sql += ' COLLATE %s' % collation + + return sql + + def compile_add(self, blueprint, command, _): + table = self.wrap_table(blueprint) + + columns = self.prefix_list('ADD', self._get_columns(blueprint)) + + return 'ALTER TABLE %s %s' % (table, ', '.join(columns)) + + def compile_primary(self, blueprint, command, _): + command.name = None + + return self._compile_key(blueprint, command, 'PRIMARY KEY') + + def compile_unique(self, blueprint, command, _): + return self._compile_key(blueprint, command, 'UNIQUE') + + def compile_index(self, blueprint, command, _): + return self._compile_key(blueprint, command, 'INDEX') + + def _compile_key(self, blueprint, command, type): + columns = self.columnize(command.columns) + + table = self.wrap_table(blueprint) + + return 'ALTER TABLE %s ADD %s %s(%s)' % (table, type, command.index, columns) + + def compile_drop(self, blueprint, command, _): + return 'DROP TABLE %s' % self.wrap_table(blueprint) + + def compile_drop_if_exists(self, blueprint, command, _): + return 'DROP TABLE IF EXISTS %s' % self.wrap_table(blueprint) + + def compile_drop_column(self, blueprint, command, connection): + columns = self.prefix_list('DROP', self.wrap_list(command.columns)) + + table = self.wrap_table(blueprint) + + return 'ALTER TABLE %s %s' % (table, ', '.join(columns)) + + def compile_drop_primary(self, blueprint, command, _): + return 'ALTER TABLE %s DROP PRIMARY KEY'\ + % self.wrap_table(blueprint) + + def compile_drop_unique(self, blueprint, command, _): + table = self.wrap_table(blueprint) + + return 'ALTER TABLE %s DROP INDEX %s' % (table, command.index) + + def compile_drop_index(self, blueprint, command, _): + table = self.wrap_table(blueprint) + + return 'ALTER TABLE %s DROP INDEX %s' % (table, command.index) + + def compile_drop_foreign(self, blueprint, command, _): + table = self.wrap_table(blueprint) + + return 'ALTER TABLE %s DROP FOREIGN KEY %s' % (table, command.index) + + def compile_rename(self, blueprint, command, _): + from_ = self.wrap_table(blueprint) + + return 'RENAME TABLE %s TO %s' % (from_, self.wrap_table(command.to)) + + def _type_char(self, column): + return "CHAR(%s)" % column.length + + def _type_string(self, column): + return "VARCHAR(%s)" % column.length + + def _type_text(self, column): + return 'TEXT' + + def _type_medium_text(self, column): + return 'MEDIUMTEXT' + + def _type_long_text(self, column): + return 'LONGTEXT' + + def _type_integer(self, column): + return 'INT' + + def _type_big_integer(self, column): + return 'BIGINT' + + def _type_medium_integer(self, column): + return 'MEDIUMINT' + + def _type_tiny_integer(self, column): + return 'TINYINT' + + def _type_small_integer(self, column): + return 'SMALLINT' + + def _type_float(self, column): + return self._type_double(column) + + def _type_double(self, column): + if column.total and column.places: + return 'DOUBLE(%s, %s)' % (column.total, column.places) + + return 'DOUBLE' + + def _type_decimal(self, column): + return 'DECIMAL(%s, %s)' % (column.total, column.places) + + def _type_boolean(self, column): + return 'TINYINT(1)' + + def _type_enum(self, column): + return 'ENUM(\'%s\')' % '\', \''.join(column.allowed) + + def _type_json(self, column): + return 'TEXT' + + def _type_date(self, column): + return 'DATE' + + def _type_datetime(self, column): + return 'DATETIME' + + def _type_time(self, column): + return 'TIME' + + def _type_timestamp(self, column): + if getattr(column, 'nullable', False): + return 'TIMESTAMP DEFAULT 0' + + return 'TIMESTAMP' + + def _type_binary(self, column): + return 'BLOB' + + def _modify_unsigned(self, blueprint, column): + if column.get('unsigned', False): + return ' UNSIGNED' + + return '' + + def _modify_charset(self, blueprint, column): + if column.get('charset'): + return ' CHARACTER SET ' + column.charset + + return '' + + def _modify_collate(self, blueprint, column): + if column.get('collation'): + return ' COLLATE ' + column.collation + + return '' + + def _modify_nullable(self, blueprint, column): + if column.get('nullable'): + return ' NULL' + + return ' NOT NULL' + + def _modify_default(self, blueprint, column): + if column.get('default') is not None: + return ' DEFAULT %s' % self._get_default_value(column.default) + + return '' + + def _modify_increment(self, blueprint, column): + if column.type in self._serials and column.auto_increment: + return ' AUTO_INCREMENT PRIMARY KEY' + + return '' + + def _modify_after(self, blueprint, column): + if column.get('after') is not None: + return ' AFTER ' + self.wrap(column.after) + + return '' + + def _modify_comment(self, blueprint, column): + if column.get('comment') is not None: + return ' COMMENT "%s"' % column.comment + + return '' + + def _wrap_value(self, value): + if value == '*': + return value + + return '`%s`' % value.replace('`', '``') diff --git a/eloquent/schema/grammars/postgres_grammar.py b/eloquent/schema/grammars/postgres_grammar.py new file mode 100644 index 00000000..c8dbf119 --- /dev/null +++ b/eloquent/schema/grammars/postgres_grammar.py @@ -0,0 +1,206 @@ +# -*- coding: utf-8 -*- + +from .grammar import SchemaGrammar +from ..blueprint import Blueprint +from ...query.expression import QueryExpression +from ...support.fluent import Fluent + + +class PostgresSchemaGrammar(SchemaGrammar): + + _modifiers = ['increment', 'nullable', 'default'] + + _serials = ['big_integer', 'integer', + 'medium_integer', 'small_integer', 'tiny_integer'] + + def compile_rename_column(self, blueprint, command, connection): + """ + Compile a rename column command. + + :param blueprint: The blueprint + :type blueprint: Blueprint + + :param command: The command + :type command: Fluent + + :param connection: The connection + :type connection: eloquent.connections.Connection + + :rtype: list + """ + table = self.get_table_prefix() + blueprint.get_table() + + column = self.wrap(command.from_) + + return 'ALTER TABLE %s RENAME COLUMN %s TO %s'\ + % (table, column, self.wrap(command.to)) + + def compile_table_exists(self): + """ + Compile the query to determine if a table exists + + :rtype: str + """ + return 'SELECT * FROM information_schema.tables WHERE table_name = %s' + + def compile_column_exists(self, table): + """ + Compile the query to determine the list of columns. + """ + return 'SELECT column_name FROM information_schema.columns WHERE table_name = %s' % table + + def compile_create(self, blueprint, command, _): + """ + Compile a create table command. + """ + columns = ', '.join(self._get_columns(blueprint)) + + return 'CREATE TABLE %s (%s)' % (self.wrap_table(blueprint), columns) + + def compile_add(self, blueprint, command, _): + table = self.wrap_table(blueprint) + + columns = self.prefix_list('ADD COLUMN', self._get_columns(blueprint)) + + return 'ALTER TABLE %s %s' % (table, ', '.join(columns)) + + def compile_primary(self, blueprint, command, _): + columns = self.columnize(command.columns) + + return 'ALTER TABLE %s ADD PRIMARY KEY (%s)'\ + % (self.wrap_table(blueprint), columns) + + def compile_unique(self, blueprint, command, _): + columns = self.columnize(command.columns) + + table = self.wrap_table(blueprint) + + return 'ALTER TABLE %s ADD CONSTRAINT %s UNIQUE (%s)'\ + % (table, command.index, columns) + + def compile_index(self, blueprint, command, _): + columns = self.columnize(command.columns) + + table = self.wrap_table(blueprint) + + return 'CREATE INDEX %s ON %s (%s)' % (command.index, table, columns) + + def compile_drop(self, blueprint, command, _): + return 'DROP TABLE %s' % self.wrap_table(blueprint) + + def compile_drop_if_exists(self, blueprint, command, _): + return 'DROP TABLE IF EXISTS %s' % self.wrap_table(blueprint) + + def compile_drop_column(self, blueprint, command, connection): + columns = self.prefix_list('DROP COLUMN', self.wrap_list(command.columns)) + + table = self.wrap_table(blueprint) + + return 'ALTER TABLE %s %s' % (table, ', '.join(columns)) + + def compile_drop_primary(self, blueprint, command, _): + table = blueprint.get_table() + + return 'ALTER TABLE %s DROP CONSTRAINT %s_pkey'\ + % (self.wrap_table(blueprint), table) + + def compile_drop_unique(self, blueprint, command, _): + table = self.wrap_table(blueprint) + + return 'ALTER TABLE %s DROP CONSTRAINT %s' % (table, command.index) + + def compile_drop_index(self, blueprint, command, _): + return 'DROP INDEX %s' % command.index + + def compile_drop_foreign(self, blueprint, command, _): + table = self.wrap_table(blueprint) + + return 'ALTER TABLE %s DROP CONSTRAINT %s' % (table, command.index) + + def compile_rename(self, blueprint, command, _): + from_ = self.wrap_table(blueprint) + + return 'ALTER TABLE %s RENAME TO %s' % (from_, self.wrap_table(command.to)) + + def _type_char(self, column): + return "CHAR(%s)" % column.length + + def _type_string(self, column): + return "VARCHAR(%s)" % column.length + + def _type_text(self, column): + return 'TEXT' + + def _type_medium_text(self, column): + return 'TEXT' + + def _type_long_text(self, column): + return 'TEXT' + + def _type_integer(self, column): + return 'SERIAL' if column.auto_increment else 'INTEGER' + + def _type_big_integer(self, column): + return 'BIGSERIAL' if column.auto_increment else 'BIGINT' + + def _type_medium_integer(self, column): + return 'SERIAL' if column.auto_increment else 'INTEGER' + + def _type_tiny_integer(self, column): + return 'SMALLSERIAL' if column.auto_increment else 'SMALLINT' + + def _type_small_integer(self, column): + return 'SMALLSERIAL' if column.auto_increment else 'SMALLINT' + + def _type_float(self, column): + return self._type_double(column) + + def _type_double(self, column): + return 'DOUBLE PRECISION' + + def _type_decimal(self, column): + return 'DECIMAL(%s, %s)' % (column.total, column.places) + + def _type_boolean(self, column): + return 'BOOLEAN' + + def _type_enum(self, column): + allowed = list(map(lambda a: "'%s'" % a, column.allowed)) + + return 'VARCHAR(255) CHECK ("%s" IN (%s))' % (column.name, ', '.join(allowed)) + + def _type_json(self, column): + return 'JSON' + + def _type_date(self, column): + return 'DATE' + + def _type_datetime(self, column): + return 'TIMESTAMP(0) WITHOUT TIME ZONE' + + def _type_time(self, column): + return 'TIME(0) WITHOUT TIME ZONE' + + def _type_timestamp(self, column): + return 'TIMESTAMP(0) WITHOUT TIME ZONE' + + def _type_binary(self, column): + return 'BYTEA' + + def _modify_nullable(self, blueprint, column): + if column.get('nullable'): + return ' NULL' + + return ' NOT NULL' + + def _modify_default(self, blueprint, column): + if column.get('default') is not None: + return ' DEFAULT %s' % self._get_default_value(column.default) + + return '' + + def _modify_increment(self, blueprint, column): + if column.type in self._serials and column.auto_increment: + return ' PRIMARY KEY' + + return '' diff --git a/eloquent/schema/grammars/sqlite_grammar.py b/eloquent/schema/grammars/sqlite_grammar.py new file mode 100644 index 00000000..b1d2782d --- /dev/null +++ b/eloquent/schema/grammars/sqlite_grammar.py @@ -0,0 +1,561 @@ +# -*- coding: utf-8 -*- + +from .grammar import SchemaGrammar +from ..blueprint import Blueprint +from ...query.expression import QueryExpression +from ...support.fluent import Fluent + + +class SQLiteSchemaGrammar(SchemaGrammar): + + _modifiers = ['nullable', 'default', 'increment'] + + _serials = ['big_integer', 'integer'] + + def compile_rename_column(self, blueprint, command, connection): + """ + Compile a rename column command. + + :param blueprint: The blueprint + :type blueprint: Blueprint + + :param command: The command + :type command: Fluent + + :param connection: The connection + :type connection: eloquent.connections.Connection + + :rtype: list + """ + # The code is a little complex. It will propably change + # if we support complete diffs in dbal + sql = [] + + schema = connection.get_schema_manager() + table = self.get_table_prefix() + blueprint.get_table() + + column = connection.get_column(table, command.from_) + + columns = schema.list_table_columns(table) + indexes = schema.list_table_indexes(table) + foreign_keys = schema.list_table_foreign_keys(table) + + diff = self._get_renamed_diff(blueprint, command, column, schema) + renamed_columns = diff.renamed_columns + + old_column_names = list(map(lambda x: x.get_name(), columns)) + + # We build the new column names + new_column_names = [] + for column_name in old_column_names: + if column_name in renamed_columns: + new_column_names.append(renamed_columns[column_name].get_name()) + else: + new_column_names.append(column_name) + + # We create a temporary table and insert the data into it + temp_table = '__temp__' + self.get_table_prefix() + blueprint.get_table() + sql.append('CREATE TEMPORARY TABLE %s AS SELECT %s FROM %s' + % (temp_table, self.columnize(old_column_names), table)) + + # We drop the current table + sql += Blueprint(table).drop().to_sql(None, self) + + # Building the list a new columns + new_columns = [] + for column in columns: + for column_name, changed_column in renamed_columns.items(): + if column_name == column.get_name(): + new_columns.append(changed_column) + + # Here we will try to rebuild a new blueprint to create a new table + # with the original name + new_blueprint = Blueprint(table) + new_blueprint.create() + primary = [] + for column in columns: + # Mapping the database type to the blueprint type + type = schema.get_database_platform().TYPE_MAPPING[column.get_type().lower()] + + # If the column is a primary, we will add it to the blueprint later + if column.get_platform_option('pk'): + primary.append(column.get_name()) + + # If the column is not one that's been renamed we reinsert it into the blueprint + if column.get_name() not in renamed_columns.keys(): + col = getattr(new_blueprint, type)(column.get_name()) + + # If the column is nullable, we flag it + if not column.get_notnull(): + col.nullable() + + # If the column has a default value, we add it + if column.get_default() is not None: + col.default(QueryExpression(column.get_default())) + + # Inserting the renamed columns into the blueprint + for column in new_columns: + type = schema.get_database_platform().TYPE_MAPPING[column.get_type().lower()] + + col = getattr(new_blueprint, type)(column.get_name()) + if not column.get_notnull(): + col.nullable() + + if column.get_default() is not None: + col.default(QueryExpression(column.get_default())) + + # We add the primary keys + if primary: + new_blueprint.primary(primary) + + # We rebuild the indexes + for index in indexes: + index_columns = index['columns'] + new_index_columns = [] + index_name = index['name'] + + for column_name in index_columns: + if column_name in renamed_columns: + new_index_columns.append(renamed_columns[column_name].get_name()) + else: + new_index_columns.append(column_name) + + if index_columns != new_index_columns: + index_name = None + + if index['unique']: + new_blueprint.unique(new_index_columns, index_name) + else: + new_blueprint.index(index['columns'], index_name) + + for foreign_key in foreign_keys: + fkey_from = foreign_key['from'] + if fkey_from in renamed_columns: + fkey_from = renamed_columns[fkey_from].get_name() + + new_blueprint.foreign(fkey_from)\ + .references(foreign_key['to'])\ + .on(foreign_key['table'])\ + .on_delete(foreign_key['on_delete'])\ + .on_update(foreign_key['on_update']) + + # We create the table + sql += new_blueprint.to_sql(None, self) + + # We reinsert the data into the new table + sql.append('INSERT INTO %s (%s) SELECT %s FROM %s' + % (self.wrap_table(table), + ', '.join(new_column_names), + self.columnize(old_column_names), + self.wrap_table(temp_table) + )) + + # Finally we drop the temporary table + sql += Blueprint(temp_table).drop().to_sql(None, self) + + return sql + + def compile_change(self, blueprint, command, connection): + """ + Compile a change column command into a series of SQL statement. + + :param blueprint: The blueprint + :type blueprint: eloquent.schema.Blueprint + + :param command: The command + :type command: Fluent + + :param connection: The connection + :type connection: eloquent.connections.Connection + + :rtype: list + """ + sql = [] + + schema = connection.get_schema_manager() + table = self.get_table_prefix() + blueprint.get_table() + + columns = schema.list_table_columns(table) + indexes = schema.list_table_indexes(table) + foreign_keys = schema.list_table_foreign_keys(table) + + diff = self._get_changed_diff(blueprint, schema) + blueprint_changed_columns = blueprint.get_changed_columns() + changed_columns = diff.changed_columns + + temp_table = '__temp__' + self.get_table_prefix() + blueprint.get_table() + sql.append('CREATE TEMPORARY TABLE %s AS SELECT %s FROM %s' + % (temp_table, self.columnize(list(map(lambda x: x.get_name(), columns))), table)) + sql += Blueprint(table).drop().to_sql(None, self) + + new_columns = [] + for column in columns: + for column_name, changed_column in changed_columns.items(): + if column_name == column.get_name(): + for blueprint_column in blueprint_changed_columns: + if blueprint_column.name == column_name: + new_columns.append(blueprint_column) + break + + break + + new_blueprint = Blueprint(table) + new_blueprint.create() + primary = [] + new_column_names = [] + for column in columns: + type = schema.get_database_platform().TYPE_MAPPING[column.get_type().lower()] + + if column.get_platform_option('pk'): + primary.append(column.get_name()) + + if column.get_name() not in changed_columns: + col = getattr(new_blueprint, type)(column.get_name()) + if not column.get_notnull(): + col.nullable() + + new_column_names.append(column.get_name()) + + for column in new_columns: + column.change = False + new_blueprint._add_column(**column.get_attributes()) + new_column_names.append(column.name) + + if primary: + new_blueprint.primary(primary) + + for index in indexes: + if index['unique']: + new_blueprint.unique(index['columns'], index['name']) + else: + new_blueprint.index(index['columns'], index['name']) + + for foreign_key in foreign_keys: + new_blueprint.foreign(foreign_key['from'])\ + .references(foreign_key['to'])\ + .on(foreign_key['table'])\ + .on_delete(foreign_key['on_delete'])\ + .on_update(foreign_key['on_update']) + + sql += new_blueprint.to_sql(None, self) + sql.append('INSERT INTO %s (%s) SELECT %s FROM %s' + % (self.wrap_table(table), + ', '.join(sorted(new_column_names)), + self.columnize(sorted(list(map(lambda x: x.get_name(), columns)))), + self.wrap_table(temp_table) + )) + sql += Blueprint(temp_table).drop().to_sql(None, self) + + return sql + + def compile_table_exists(self): + """ + Compile the query to determine if a table exists + + :rtype: str + """ + return "SELECT * FROM sqlite_master WHERE type = 'table' AND name = ?" + + def compile_column_exists(self, table): + """ + Compile the query to determine the list of columns. + """ + return 'PRAGMA table_info(%s)' % table.replace('.', '__') + + def compile_create(self, blueprint, command, _): + """ + Compile a create table command. + """ + columns = ', '.join(self._get_columns(blueprint)) + + sql = 'CREATE TABLE %s (%s' % (self.wrap_table(blueprint), columns) + + sql += self._add_foreign_keys(blueprint) + + sql += self._add_primary_keys(blueprint) + + return sql + ')' + + def _add_foreign_keys(self, blueprint): + sql = '' + + foreigns = self._get_commands_by_name(blueprint, 'foreign') + + for foreign in foreigns: + sql += self._get_foreign_key(foreign) + + if foreign.get('on_delete'): + sql += ' ON DELETE %s' % foreign.on_delete + + if foreign.get('on_update'): + sql += ' ON UPDATE %s' % foreign.on_delete + + return sql + + def _get_foreign_key(self, foreign): + on = self.wrap_table(foreign.on) + + columns = self.columnize(foreign.columns) + + references = foreign.references + if not isinstance(references, list): + references = [references] + + on_columns = self.columnize(references) + + return ', FOREIGN KEY(%s) REFERENCES %s(%s)' % (columns, on, on_columns) + + def _add_primary_keys(self, blueprint): + primary = self._get_command_by_name(blueprint, 'primary') + + if primary: + columns = self.columnize(primary.columns) + + return ', PRIMARY KEY (%s)' % columns + + return '' + + def compile_add(self, blueprint, command, _): + table = self.wrap_table(blueprint) + + columns = self.prefix_list('ADD COLUMN', self._get_columns(blueprint)) + + statements = [] + + for column in columns: + statements.append('ALTER TABLE %s %s' % (table, column)) + + return statements + + def compile_unique(self, blueprint, command, _): + columns = self.columnize(command.columns) + + table = self.wrap_table(blueprint) + + return 'CREATE UNIQUE INDEX %s ON %s (%s)' % (command.index, table, columns) + + def compile_index(self, blueprint, command, _): + columns = self.columnize(command.columns) + + table = self.wrap_table(blueprint) + + return 'CREATE INDEX %s ON %s (%s)' % (command.index, table, columns) + + def compile_foreign(self, blueprint, command, _): + pass + + def compile_drop(self, blueprint, command, _): + return 'DROP TABLE %s' % self.wrap_table(blueprint) + + def compile_drop_if_exists(self, blueprint, command, _): + return 'DROP TABLE IF EXISTS %s' % self.wrap_table(blueprint) + + def compile_drop_column(self, blueprint, command, connection): + # The code is a little complex. It will propably change + # if we support complete diffs in dbal + sql = [] + + schema = connection.get_schema_manager() + table = self.get_table_prefix() + blueprint.get_table() + + columns = schema.list_table_columns(table) + indexes = schema.list_table_indexes(table) + foreign_keys = schema.list_table_foreign_keys(table) + + diff = self._get_table_diff(blueprint, schema) + + for name in command.columns: + column = connection.get_column(blueprint.get_table(), name) + + diff.removed_columns[name] = column + + removed_columns = diff.removed_columns + + old_column_names = list(map(lambda x: x.get_name(), columns)) + + # We build the new column names + new_column_names = [] + for column_name in old_column_names: + if column_name not in removed_columns: + new_column_names.append(column_name) + + # We create a temporary table and insert the data into it + temp_table = '__temp__' + self.get_table_prefix() + blueprint.get_table() + sql.append('CREATE TEMPORARY TABLE %s AS SELECT %s FROM %s' + % (temp_table, self.columnize(old_column_names), table)) + + # We drop the current table + sql += Blueprint(table).drop().to_sql(None, self) + + # Here we will try to rebuild a new blueprint to create a new table + # with the original name + new_blueprint = Blueprint(table) + new_blueprint.create() + primary = [] + for column in columns: + # If the column is not one that's been removed we reinsert it into the blueprint + if column.get_name() in new_column_names: + # Mapping the database type to the blueprint type + type = schema.get_database_platform().TYPE_MAPPING[column.get_type().lower()] + + # If the column is a primary, we will add it to the blueprint later + if column.get_platform_option('pk'): + primary.append(column.get_name()) + + col = getattr(new_blueprint, type)(column.get_name()) + + # If the column is nullable, we flag it + if not column.get_notnull(): + col.nullable() + + # If the column has a default value, we add it + if column.get_default() is not None: + col.default(QueryExpression(column.get_default())) + + # We add the primary keys + if primary: + new_blueprint.primary(primary) + + # We rebuild the indexes + for index in indexes: + index_columns = index['columns'] + new_index_columns = [] + index_name = index['name'] + + removed = False + for column_name in index_columns: + if column_name not in removed_columns: + new_index_columns.append(column_name) + else: + removed = True + break + + if removed: + continue + + if index_columns != new_index_columns: + index_name = None + + if index['unique']: + new_blueprint.unique(new_index_columns, index_name) + else: + new_blueprint.index(index['columns'], index_name) + + for foreign_key in foreign_keys: + fkey_from = foreign_key['from'] + if fkey_from in removed_columns: + continue + + new_blueprint.foreign(fkey_from)\ + .references(foreign_key['to'])\ + .on(foreign_key['table'])\ + .on_delete(foreign_key['on_delete'])\ + .on_update(foreign_key['on_update']) + + # We create the table + sql += new_blueprint.to_sql(None, self) + + # We reinsert the data into the new table + sql.append('INSERT INTO %s (%s) SELECT %s FROM %s' + % (self.wrap_table(table), + self.columnize(new_column_names), + self.columnize(new_column_names), + self.wrap_table(temp_table) + )) + + # Finally we drop the temporary table + sql += Blueprint(temp_table).drop().to_sql(None, self) + + return sql + + def compile_drop_unique(self, blueprint, command, _): + return 'DROP INDEX %s' % command.index + + def compile_drop_index(self, blueprint, command, _): + return 'DROP INDEX %s' % command.index + + def compile_rename(self, blueprint, command, _): + from_ = self.wrap_table(blueprint) + + return 'ALTER TABLE %s RENAME TO %s' % (from_, self.wrap_table(command.to)) + + def _type_char(self, column): + return 'VARCHAR' + + def _type_string(self, column): + return 'VARCHAR' + + def _type_text(self, column): + return 'TEXT' + + def _type_medium_text(self, column): + return 'TEXT' + + def _type_long_text(self, column): + return 'TEXT' + + def _type_integer(self, column): + return 'INTEGER' + + def _type_big_integer(self, column): + return 'INTEGER' + + def _type_medium_integer(self, column): + return 'INTEGER' + + def _type_tiny_integer(self, column): + return 'TINYINT' + + def _type_small_integer(self, column): + return 'INTEGER' + + def _type_float(self, column): + return 'FLOAT' + + def _type_double(self, column): + return 'FLOAT' + + def _type_decimal(self, column): + return 'NUMERIC' + + def _type_boolean(self, column): + return 'TINYINT' + + def _type_enum(self, column): + return 'VARCHAR' + + def _type_json(self, column): + return 'TEXT' + + def _type_date(self, column): + return 'DATE' + + def _type_datetime(self, column): + return 'DATETIME' + + def _type_time(self, column): + return 'TIME' + + def _type_timestamp(self, column): + return 'DATETIME' + + def _type_binary(self, column): + return 'BLOB' + + def _modify_nullable(self, blueprint, column): + if column.get('nullable'): + return ' NULL' + + return ' NOT NULL' + + def _modify_default(self, blueprint, column): + if column.get('default') is not None: + return ' DEFAULT %s' % self._get_default_value(column.default) + + return '' + + def _modify_increment(self, blueprint, column): + if column.type in self._serials and column.auto_increment: + return ' PRIMARY KEY AUTOINCREMENT' + + return '' diff --git a/eloquent/schema/mysql_builder.py b/eloquent/schema/mysql_builder.py new file mode 100644 index 00000000..e439ff2e --- /dev/null +++ b/eloquent/schema/mysql_builder.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +from .builder import SchemaBuilder + + +class MySqlSchemaBuilder(SchemaBuilder): + + def has_table(self, table): + """ + Determine if the given table exists. + + :param table: The table + :type table: str + + :rtype: bool + """ + sql = self._grammar.compile_table_exists() + database = self._connection.get_database_name() + table = self._connection.get_table_prefix() + table + + return len(self._connection.select(sql, [database, table])) > 0 + + def get_column_listing(self, table): + """ + Get the column listing for a given table. + + :param table: The table + :type table: str + + :rtype: list + """ + sql = self._grammar.compile_column_exists() + database = self._connection.get_database_name() + table = self._connection.get_table_prefix() + table + + results = self._connection.select(sql, [database, table]) + + return self._connection.get_post_processor().process_column_listing(results) diff --git a/eloquent/schema/schema.py b/eloquent/schema/schema.py new file mode 100644 index 00000000..e8cbd40a --- /dev/null +++ b/eloquent/schema/schema.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- + + +class Schema(object): + + def __init__(self, manager): + """ + :param manager: The database manager + :type manager: eloquent.DatabaseManager + """ + self.db = manager + + def connection(self, connection=None): + """ + Get a schema builder instance for a connection. + + :param connection: The connection to user + :type connection: str + + :rtype: eloquent.schema.SchemaBuilder + """ + return self.db.connection(connection).get_schema_builder() + + def __getattr__(self, item): + return getattr(self.db.connection().get_schema_builder(), item) diff --git a/eloquent/support/fluent.py b/eloquent/support/fluent.py new file mode 100644 index 00000000..f31322f5 --- /dev/null +++ b/eloquent/support/fluent.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +import simplejson as json + + +class Fluent(object): + + def __init__(self, **attributes): + self.__attributes = {} + + for key, value in attributes.items(): + self.__attributes[key] = value + + def get(self, key, default=None): + return self.__attributes.get(key, default) + + def get_attributes(self): + return self.__attributes + + def to_dict(self): + return self.__attributes + + def to_json(self): + return json.dumps(self.to_dict()) + + def __contains__(self, item): + return item in self.__attributes + + def __getitem__(self, item): + return self.__attributes[item] + + def __setitem__(self, key, value): + self.__attributes[key] = value + + def __delitem__(self, key): + del self.__attributes[key] + + def __dynamic(self, method): + def call(*args, **kwargs): + if len(args): + self.__attributes[method] = args[0] + else: + self.__attributes[method] = True + + return self + + return call + + def __getattr__(self, item): + if item in self.__attributes: + return self.__attributes[item] + + return self.__dynamic(item) + + def __setattr__(self, key, value): + if key.startswith(('_Fluent__', '_%s__' % self.__class__.__name__, '__')): + super(Fluent, self).__setattr__(key, value) + elif callable(getattr(self, key, None)): + return super(Fluent, self).__setattr__(key, value) + else: + self.__attributes[key] = value + + def __delattr__(self, item): + del self.__attributes[item] diff --git a/eloquent/query/grammars/base.py b/eloquent/support/grammar.py similarity index 93% rename from eloquent/query/grammars/base.py rename to eloquent/support/grammar.py index 65cdfa6c..73572c4e 100644 --- a/eloquent/query/grammars/base.py +++ b/eloquent/support/grammar.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- -from ..expression import QueryExpression +from ..query.expression import QueryExpression -class BaseQueryGrammar(object): +class Grammar(object): def __init__(self): self._table_prefix = '' @@ -76,8 +76,13 @@ def is_expression(self, value): def get_date_format(self): return 'Y-m-d H:i:s' + def get_table_prefix(self): + return self._table_prefix + def set_table_prefix(self, prefix): self._table_prefix = prefix + return self + def get_marker(self): return '?' diff --git a/setup.py b/setup.py index 0f6562fb..eec9fe44 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages -__version__ = '0.3.1' +__version__ = '0.4' setup( name='eloquent', diff --git a/tests-requirements.txt b/tests-requirements.txt index d72d9896..ec05da93 100644 --- a/tests-requirements.txt +++ b/tests-requirements.txt @@ -1,3 +1,5 @@ pytest pytest-mock flexmock +psycopg2 +mysqlclient diff --git a/tests/orm/mixins/__init__.py b/tests/orm/mixins/__init__.py new file mode 100644 index 00000000..633f8661 --- /dev/null +++ b/tests/orm/mixins/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- + diff --git a/tests/orm/mixins/test_soft_deletes.py b/tests/orm/mixins/test_soft_deletes.py new file mode 100644 index 00000000..2791cd00 --- /dev/null +++ b/tests/orm/mixins/test_soft_deletes.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- + +import datetime +import arrow +from flexmock import flexmock, flexmock_teardown +from eloquent import Model, SoftDeletes +from eloquent.orm import Builder +from eloquent.query import QueryBuilder +from ... import EloquentTestCase + + +t = arrow.get().naive + + +class SoftDeletesTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_delete_sets_soft_deleted_column(self): + model = flexmock(SoftDeleteModelStub()) + model.set_exists(True) + builder = flexmock(Builder) + query_builder = flexmock(QueryBuilder(None, None, None)) + query = Builder(query_builder) + model.should_receive('new_query').and_return(query) + builder.should_receive('where').once().with_args('id', 1).and_return(query) + builder.should_receive('update').once().with_args({'deleted_at': t}) + model.delete() + + self.assertIsInstance(model.deleted_at, datetime.datetime) + + def test_restore(self): + model = flexmock(SoftDeleteModelStub()) + model.set_exists(True) + model.should_receive('save').once() + + model.restore() + + self.assertIsNone(model.deleted_at) + + +class SoftDeleteModelStub(Model, SoftDeletes): + + def get_key(self): + return 1 + + def get_key_name(self): + return 'id' + + def from_datetime(self, value): + return t diff --git a/tests/orm/scopes/__init__.py b/tests/orm/scopes/__init__.py new file mode 100644 index 00000000..633f8661 --- /dev/null +++ b/tests/orm/scopes/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- + diff --git a/tests/orm/scopes/test_soft_deleting.py b/tests/orm/scopes/test_soft_deleting.py new file mode 100644 index 00000000..7efc0515 --- /dev/null +++ b/tests/orm/scopes/test_soft_deleting.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- + +from flexmock import flexmock, flexmock_teardown +from eloquent.orm.scopes import SoftDeletingScope +from eloquent.orm import Builder, Model +from eloquent.query import QueryBuilder +from ... import EloquentTestCase + + +class SoftDeletingScopeTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_apply_scope_to_a_builder(self): + scope = SoftDeletingScope() + query = flexmock(QueryBuilder(None, None, None)) + builder = Builder(query) + model = flexmock(ModelStub()) + model.should_receive('get_qualified_deleted_at_column').once().and_return('table.deleted_at') + builder.get_query().should_receive('where_null').once().with_args('table.deleted_at') + + scope.apply(builder, model) + + def test_scope_can_remove_deleted_at_constraint(self): + scope = SoftDeletingScope() + query = flexmock(QueryBuilder(None, None, None)) + model = flexmock(ModelStub()) + builder = Builder(query) + builder.set_model(model) + model.should_receive('get_qualified_deleted_at_column').once().and_return('table.deleted_at') + query.wheres = [{ + 'type': 'null', + 'column': 'foo' + }, { + 'type': 'null', + 'column': 'table.deleted_at' + }] + scope.remove(builder, model) + + self.assertEqual( + query.wheres, + [{ + 'type': 'null', + 'column': 'foo' + }] + ) + + def test_force_delete_extension(self): + scope = SoftDeletingScope() + builder = Builder(None) + scope.extend(builder) + callback = builder.get_macro('force_delete') + query = flexmock(QueryBuilder(None, None, None)) + given_builder = Builder(query) + query.should_receive('delete').once() + + callback(given_builder) + + def test_restore_extension(self): + scope = SoftDeletingScope() + builder = Builder(None) + scope.extend(builder) + callback = builder.get_macro('restore') + query = flexmock(QueryBuilder(None, None, None)) + builder_mock = flexmock(BuilderWithTrashedStub) + given_builder = BuilderWithTrashedStub(query) + builder_mock.should_receive('with_trashed').once() + builder_mock.should_receive('get_model').once().and_return(ModelStub()) + builder_mock.should_receive('update').once().with_args({'deleted_at': None}) + + callback(given_builder) + + def test_with_trashed_extension(self): + scope = flexmock(SoftDeletingScope()) + builder = Builder(None) + scope.extend(builder) + callback = builder.get_macro('with_trashed') + query = flexmock(QueryBuilder(None, None, None)) + given_builder = Builder(query) + model = flexmock(ModelStub()) + given_builder.set_model(model) + scope.should_receive('remove').once().with_args(given_builder, model) + result = callback(given_builder) + + self.assertEqual(given_builder, result) + + +class ModelStub(Model): + + def get_qualified_deleted_at_column(self): + return 'table.deleted_at' + + def get_deleted_at_column(self): + return 'deleted_at' + + +class BuilderWithTrashedStub(Builder): + + def with_trashed(self): + pass diff --git a/tests/orm/test_builder.py b/tests/orm/test_builder.py index 8e793cce..7c76360e 100644 --- a/tests/orm/test_builder.py +++ b/tests/orm/test_builder.py @@ -10,6 +10,8 @@ from eloquent.orm.model import Model from eloquent.exceptions.orm import ModelNotFound from eloquent.orm.collection import Collection +from eloquent.connections import Connection +from eloquent.query.processors import QueryProcessor class BuilderTestCase(EloquentTestCase): @@ -265,6 +267,24 @@ def test_get_models_hydrates_models(self): records, 'foo_connection' ) + def test_macros_are_called_on_builder(self): + builder = Builder(QueryBuilder( + flexmock(Connection), + flexmock(QueryGrammar), + flexmock(QueryProcessor) + )) + + def foo_bar(builder): + builder.foobar = True + + return builder + + builder.macro('foo_bar', foo_bar) + result = builder.foo_bar() + + self.assertEqual(result, builder) + self.assertTrue(builder.foobar) + def test_eager_load_relations_load_top_level_relationships(self): flexmock(Builder) builder = Builder(flexmock(QueryBuilder(None, None, None))) @@ -347,7 +367,15 @@ def test_query_passthru(self): builder.get_query().insert.assert_called_once_with(['bar']) - # TODO: test query scopes + def test_query_scopes(self): + builder = self.get_builder() + builder.get_query().from_ = mock.MagicMock() + builder.get_query().where = mock.MagicMock() + model = OrmBuilderTestModelScopeStub() + builder.set_model(model) + result = builder.approved() + + self.assertEqual(result, builder) def test_simple_where(self): builder = self.get_builder() @@ -423,6 +451,12 @@ class OrmBuilderTestModelFarRelatedStub(Model): pass +class OrmBuilderTestModelScopeStub(Model): + + def scope_approved(self, query): + query.where('foo', 'bar') + + class OrmBuilderTestModelCloseRelated(Model): @property diff --git a/tests/schema/__init__.py b/tests/schema/__init__.py new file mode 100644 index 00000000..40a96afc --- /dev/null +++ b/tests/schema/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/tests/schema/grammars/__init__.py b/tests/schema/grammars/__init__.py new file mode 100644 index 00000000..633f8661 --- /dev/null +++ b/tests/schema/grammars/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- + diff --git a/tests/schema/grammars/test_mysql_grammar.py b/tests/schema/grammars/test_mysql_grammar.py new file mode 100644 index 00000000..c16ab609 --- /dev/null +++ b/tests/schema/grammars/test_mysql_grammar.py @@ -0,0 +1,538 @@ +# -*- coding: utf-8 -*- + +from flexmock import flexmock, flexmock_teardown +from eloquent.connections import Connection +from eloquent.schema.grammars import MySqlSchemaGrammar +from eloquent.schema.blueprint import Blueprint +from ... import EloquentTestCase + + +class MySqlSchemaGrammarTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_basic_create(self): + blueprint = Blueprint('users') + blueprint.create() + blueprint.increments('id') + blueprint.string('email') + + conn = self.get_connection() + conn.should_receive('get_config').once().with_args('charset').and_return('utf8') + conn.should_receive('get_config').once().with_args('collation').and_return('utf8_unicode_ci') + + statements = blueprint.to_sql(conn, self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'CREATE TABLE `users` (' + '`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, ' + '`email` VARCHAR(255) NOT NULL) ' + 'DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.create() + blueprint.increments('id') + blueprint.string('email') + + conn = self.get_connection() + conn.should_receive('get_config').and_return(None) + + statements = blueprint.to_sql(conn, self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'CREATE TABLE `users` (' + '`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, ' + '`email` VARCHAR(255) NOT NULL)', + statements[0] + ) + + def test_charset_collation_create(self): + blueprint = Blueprint('users') + blueprint.create() + blueprint.increments('id') + blueprint.string('email') + blueprint.charset = 'utf8mb4' + blueprint.collation = 'utf8mb4_unicode_ci' + + conn = self.get_connection() + + statements = blueprint.to_sql(conn, self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'CREATE TABLE `users` (' + '`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, ' + '`email` VARCHAR(255) NOT NULL) ' + 'DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci', + statements[0] + ) + + def test_basic_create_with_prefix(self): + blueprint = Blueprint('users') + blueprint.create() + blueprint.increments('id') + blueprint.string('email') + grammar = self.get_grammar() + grammar.set_table_prefix('prefix_') + + conn = self.get_connection() + conn.should_receive('get_config').and_return(None) + + statements = blueprint.to_sql(conn, grammar) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'CREATE TABLE `prefix_users` (' + '`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, ' + '`email` VARCHAR(255) NOT NULL)', + statements[0] + ) + + def test_drop_table(self): + blueprint = Blueprint('users') + blueprint.drop() + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('DROP TABLE `users`', statements[0]) + + def test_drop_table_if_exists(self): + blueprint = Blueprint('users') + blueprint.drop_if_exists() + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('DROP TABLE IF EXISTS `users`', statements[0]) + + def test_drop_column(self): + blueprint = Blueprint('users') + blueprint.drop_column('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE `users` DROP `foo`', statements[0]) + + blueprint = Blueprint('users') + blueprint.drop_column('foo', 'bar') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE `users` DROP `foo`, DROP `bar`', statements[0]) + + def test_drop_primary(self): + blueprint = Blueprint('users') + blueprint.drop_primary('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE `users` DROP PRIMARY KEY', statements[0]) + + def test_drop_unique(self): + blueprint = Blueprint('users') + blueprint.drop_unique('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE `users` DROP INDEX foo', statements[0]) + + def test_drop_index(self): + blueprint = Blueprint('users') + blueprint.drop_index('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE `users` DROP INDEX foo', statements[0]) + + def test_drop_foreign(self): + blueprint = Blueprint('users') + blueprint.drop_foreign('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE `users` DROP FOREIGN KEY foo', statements[0]) + + def test_drop_timestamps(self): + blueprint = Blueprint('users') + blueprint.drop_timestamps() + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE `users` DROP `created_at`, DROP `updated_at`', statements[0]) + + def test_rename_table(self): + blueprint = Blueprint('users') + blueprint.rename('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('RENAME TABLE `users` TO `foo`', statements[0]) + + def test_adding_primary_key(self): + blueprint = Blueprint('users') + blueprint.primary('foo', 'bar') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE `users` ADD PRIMARY KEY bar(`foo`)', statements[0]) + + def test_adding_foreign_key(self): + blueprint = Blueprint('users') + blueprint.foreign('order_id').references('id').on('orders') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + expected = [ + 'ALTER TABLE `users` ADD CONSTRAINT users_order_id_foreign ' + 'FOREIGN KEY (`order_id`) REFERENCES `orders` (`id`)' + ] + self.assertEqual(expected, statements) + + def test_adding_unique_key(self): + blueprint = Blueprint('users') + blueprint.unique('foo', 'bar') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD UNIQUE bar(`foo`)', + statements[0] + ) + + def test_adding_index(self): + blueprint = Blueprint('users') + blueprint.index(['foo', 'bar'], 'baz') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD INDEX baz(`foo`, `bar`)', + statements[0] + ) + + def test_adding_incrementing_id(self): + blueprint = Blueprint('users') + blueprint.increments('id') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY', + statements[0] + ) + + def test_adding_big_incrementing_id(self): + blueprint = Blueprint('users') + blueprint.big_increments('id') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY', + statements[0] + ) + + def test_adding_column_after_another(self): + blueprint = Blueprint('users') + blueprint.string('name').after('foo') + + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL AFTER `foo`', + statements[0] + ) + + def test_adding_string(self): + blueprint = Blueprint('users') + blueprint.string('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` VARCHAR(255) NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.string('foo', 100) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` VARCHAR(100) NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.string('foo', 100).nullable().default('bar') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` VARCHAR(100) NULL DEFAULT \'bar\'', + statements[0] + ) + + def test_adding_text(self): + blueprint = Blueprint('users') + blueprint.text('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` TEXT NOT NULL', + statements[0] + ) + + def test_adding_big_integer(self): + blueprint = Blueprint('users') + blueprint.big_integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` BIGINT NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.big_integer('foo', True) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY', + statements[0] + ) + + def test_adding_integer(self): + blueprint = Blueprint('users') + blueprint.integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` INT NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.integer('foo', True) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` INT NOT NULL AUTO_INCREMENT PRIMARY KEY', + statements[0] + ) + + def test_adding_medium_integer(self): + blueprint = Blueprint('users') + blueprint.medium_integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` MEDIUMINT NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.medium_integer('foo', True) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` MEDIUMINT NOT NULL AUTO_INCREMENT PRIMARY KEY', + statements[0] + ) + + def test_adding_tiny_integer(self): + blueprint = Blueprint('users') + blueprint.tiny_integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` TINYINT NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.tiny_integer('foo', True) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` TINYINT NOT NULL AUTO_INCREMENT PRIMARY KEY', + statements[0] + ) + + def test_adding_small_integer(self): + blueprint = Blueprint('users') + blueprint.small_integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` SMALLINT NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.small_integer('foo', True) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` SMALLINT NOT NULL AUTO_INCREMENT PRIMARY KEY', + statements[0] + ) + + def test_adding_float(self): + blueprint = Blueprint('users') + blueprint.float('foo', 5, 2) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` DOUBLE(5, 2) NOT NULL', + statements[0] + ) + + def test_adding_double(self): + blueprint = Blueprint('users') + blueprint.double('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` DOUBLE NOT NULL', + statements[0] + ) + + def test_adding_double_with_precision(self): + blueprint = Blueprint('users') + blueprint.double('foo', 15, 8) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` DOUBLE(15, 8) NOT NULL', + statements[0] + ) + + def test_adding_decimal(self): + blueprint = Blueprint('users') + blueprint.decimal('foo', 5, 2) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` DECIMAL(5, 2) NOT NULL', + statements[0] + ) + + def test_adding_boolean(self): + blueprint = Blueprint('users') + blueprint.boolean('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` TINYINT(1) NOT NULL', + statements[0] + ) + + def test_adding_enum(self): + blueprint = Blueprint('users') + blueprint.enum('foo', ['bar', 'baz']) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` ENUM(\'bar\', \'baz\') NOT NULL', + statements[0] + ) + + def test_adding_date(self): + blueprint = Blueprint('users') + blueprint.date('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` DATE NOT NULL', + statements[0] + ) + + def test_adding_datetime(self): + blueprint = Blueprint('users') + blueprint.datetime('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` DATETIME NOT NULL', + statements[0] + ) + + def test_adding_time(self): + blueprint = Blueprint('users') + blueprint.time('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` TIME NOT NULL', + statements[0] + ) + + def test_adding_timestamp(self): + blueprint = Blueprint('users') + blueprint.timestamp('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` TIMESTAMP DEFAULT 0 NOT NULL', + statements[0] + ) + + def test_adding_timestamps(self): + blueprint = Blueprint('users') + blueprint.timestamps() + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + expected = [ + 'ALTER TABLE `users` ADD `created_at` TIMESTAMP DEFAULT 0 NOT NULL, ' + 'ADD `updated_at` TIMESTAMP DEFAULT 0 NOT NULL' + ] + self.assertEqual( + expected[0], + statements[0] + ) + + def test_adding_binary(self): + blueprint = Blueprint('users') + blueprint.binary('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE `users` ADD `foo` BLOB NOT NULL', + statements[0] + ) + + def get_connection(self): + return flexmock(Connection(None)) + + def get_grammar(self): + return MySqlSchemaGrammar() diff --git a/tests/schema/grammars/test_postgres_grammar.py b/tests/schema/grammars/test_postgres_grammar.py new file mode 100644 index 00000000..aee3493f --- /dev/null +++ b/tests/schema/grammars/test_postgres_grammar.py @@ -0,0 +1,464 @@ +# -*- coding: utf-8 -*- + +from flexmock import flexmock, flexmock_teardown +from eloquent.connections import Connection +from eloquent.schema.grammars import PostgresSchemaGrammar +from eloquent.schema.blueprint import Blueprint +from ... import EloquentTestCase + + +class PostgresSchemaGrammarTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_basic_create(self): + blueprint = Blueprint('users') + blueprint.create() + blueprint.increments('id') + blueprint.string('email') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'CREATE TABLE "users" ("id" SERIAL PRIMARY KEY NOT NULL, "email" VARCHAR(255) NOT NULL)', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.increments('id') + blueprint.string('email') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + expected = [ + 'ALTER TABLE "users" ADD COLUMN "id" SERIAL PRIMARY KEY NOT NULL, ' + 'ADD COLUMN "email" VARCHAR(255) NOT NULL' + ] + self.assertEqual(expected[0], statements[0]) + + def test_drop_table(self): + blueprint = Blueprint('users') + blueprint.drop() + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('DROP TABLE "users"', statements[0]) + + def test_drop_table_if_exists(self): + blueprint = Blueprint('users') + blueprint.drop_if_exists() + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('DROP TABLE IF EXISTS "users"', statements[0]) + + def test_drop_column(self): + blueprint = Blueprint('users') + blueprint.drop_column('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE "users" DROP COLUMN "foo"', statements[0]) + + blueprint = Blueprint('users') + blueprint.drop_column('foo', 'bar') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE "users" DROP COLUMN "foo", DROP COLUMN "bar"', statements[0]) + + def test_drop_primary(self): + blueprint = Blueprint('users') + blueprint.drop_primary('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE "users" DROP CONSTRAINT users_pkey', statements[0]) + + def test_drop_unique(self): + blueprint = Blueprint('users') + blueprint.drop_unique('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE "users" DROP CONSTRAINT foo', statements[0]) + + def test_drop_index(self): + blueprint = Blueprint('users') + blueprint.drop_index('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('DROP INDEX foo', statements[0]) + + def test_drop_foreign(self): + blueprint = Blueprint('users') + blueprint.drop_unique('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE "users" DROP CONSTRAINT foo', statements[0]) + + def test_drop_timestamps(self): + blueprint = Blueprint('users') + blueprint.drop_timestamps() + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE "users" DROP COLUMN "created_at", DROP COLUMN "updated_at"', statements[0]) + + def test_rename_table(self): + blueprint = Blueprint('users') + blueprint.rename('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE "users" RENAME TO "foo"', statements[0]) + + def test_adding_primary_key(self): + blueprint = Blueprint('users') + blueprint.primary('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE "users" ADD PRIMARY KEY ("foo")', statements[0]) + + def test_adding_foreign_key(self): + blueprint = Blueprint('users') + blueprint.create() + blueprint.string('foo').primary() + blueprint.string('order_id') + blueprint.foreign('order_id').references('id').on('orders') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(3, len(statements)) + expected = [ + 'CREATE TABLE "users" ("foo" VARCHAR(255) NOT NULL, "order_id" VARCHAR(255) NOT NULL)', + 'ALTER TABLE "users" ADD CONSTRAINT users_order_id_foreign' + ' FOREIGN KEY ("order_id") REFERENCES "orders" ("id")', + 'ALTER TABLE "users" ADD PRIMARY KEY ("foo")' + ] + self.assertEqual(expected, statements) + + def test_adding_unique_key(self): + blueprint = Blueprint('users') + blueprint.unique('foo', 'bar') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD CONSTRAINT bar UNIQUE ("foo")', + statements[0] + ) + + def test_adding_index(self): + blueprint = Blueprint('users') + blueprint.index(['foo', 'bar'], 'baz') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'CREATE INDEX baz ON "users" ("foo", "bar")', + statements[0] + ) + + def test_adding_incrementing_id(self): + blueprint = Blueprint('users') + blueprint.increments('id') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "id" SERIAL PRIMARY KEY NOT NULL', + statements[0] + ) + + def test_adding_big_incrementing_id(self): + blueprint = Blueprint('users') + blueprint.big_increments('id') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "id" BIGSERIAL PRIMARY KEY NOT NULL', + statements[0] + ) + + def test_adding_string(self): + blueprint = Blueprint('users') + blueprint.string('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(255) NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.string('foo', 100) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(100) NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.string('foo', 100).nullable().default('bar') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(100) NULL DEFAULT \'bar\'', + statements[0] + ) + + def test_adding_text(self): + blueprint = Blueprint('users') + blueprint.text('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL', + statements[0] + ) + + def test_adding_big_integer(self): + blueprint = Blueprint('users') + blueprint.big_integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" BIGINT NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.big_integer('foo', True) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" BIGSERIAL PRIMARY KEY NOT NULL', + statements[0] + ) + + def test_adding_integer(self): + blueprint = Blueprint('users') + blueprint.integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.integer('foo', True) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" SERIAL PRIMARY KEY NOT NULL', + statements[0] + ) + + def test_adding_medium_integer(self): + blueprint = Blueprint('users') + blueprint.medium_integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.medium_integer('foo', True) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" SERIAL PRIMARY KEY NOT NULL', + statements[0] + ) + + def test_adding_tiny_integer(self): + blueprint = Blueprint('users') + blueprint.tiny_integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" SMALLINT NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.tiny_integer('foo', True) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" SMALLSERIAL PRIMARY KEY NOT NULL', + statements[0] + ) + + def test_adding_small_integer(self): + blueprint = Blueprint('users') + blueprint.small_integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" SMALLINT NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.small_integer('foo', True) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" SMALLSERIAL PRIMARY KEY NOT NULL', + statements[0] + ) + + def test_adding_float(self): + blueprint = Blueprint('users') + blueprint.float('foo', 5, 2) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" DOUBLE PRECISION NOT NULL', + statements[0] + ) + + def test_adding_double(self): + blueprint = Blueprint('users') + blueprint.double('foo', 15, 8) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" DOUBLE PRECISION NOT NULL', + statements[0] + ) + + def test_adding_decimal(self): + blueprint = Blueprint('users') + blueprint.decimal('foo', 5, 2) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" DECIMAL(5, 2) NOT NULL', + statements[0] + ) + + def test_adding_boolean(self): + blueprint = Blueprint('users') + blueprint.boolean('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" BOOLEAN NOT NULL', + statements[0] + ) + + def test_adding_enum(self): + blueprint = Blueprint('users') + blueprint.enum('foo', ['bar', 'baz']) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(255) CHECK ("foo" IN (\'bar\', \'baz\')) NOT NULL', + statements[0] + ) + + def test_adding_date(self): + blueprint = Blueprint('users') + blueprint.date('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" DATE NOT NULL', + statements[0] + ) + + def test_adding_datetime(self): + blueprint = Blueprint('users') + blueprint.datetime('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL', + statements[0] + ) + + def test_adding_time(self): + blueprint = Blueprint('users') + blueprint.time('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" TIME(0) WITHOUT TIME ZONE NOT NULL', + statements[0] + ) + + def test_adding_timestamp(self): + blueprint = Blueprint('users') + blueprint.timestamp('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL', + statements[0] + ) + + def test_adding_timestamps(self): + blueprint = Blueprint('users') + blueprint.timestamps() + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + expected = [ + 'ALTER TABLE "users" ADD COLUMN "created_at" TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL, ' + 'ADD COLUMN "updated_at" TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL' + ] + self.assertEqual( + expected[0], + statements[0] + ) + + def test_adding_binary(self): + blueprint = Blueprint('users') + blueprint.binary('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" BYTEA NOT NULL', + statements[0] + ) + + def get_connection(self): + return flexmock(Connection(None)) + + def get_grammar(self): + return PostgresSchemaGrammar() diff --git a/tests/schema/grammars/test_sqlite_grammar.py b/tests/schema/grammars/test_sqlite_grammar.py new file mode 100644 index 00000000..227bfc4e --- /dev/null +++ b/tests/schema/grammars/test_sqlite_grammar.py @@ -0,0 +1,383 @@ +# -*- coding: utf-8 -*- + +from flexmock import flexmock, flexmock_teardown +from eloquent.connections import Connection +from eloquent.schema.grammars import SQLiteSchemaGrammar +from eloquent.schema.blueprint import Blueprint +from ... import EloquentTestCase + + +class SqliteSchemaGrammarTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_basic_create(self): + blueprint = Blueprint('users') + blueprint.create() + blueprint.increments('id') + blueprint.string('email') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'CREATE TABLE "users" ("id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, "email" VARCHAR NOT NULL)', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.increments('id') + blueprint.string('email') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(2, len(statements)) + expected = [ + 'ALTER TABLE "users" ADD COLUMN "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT', + 'ALTER TABLE "users" ADD COLUMN "email" VARCHAR NOT NULL' + ] + self.assertEqual(expected, statements) + + def test_drop_table(self): + blueprint = Blueprint('users') + blueprint.drop() + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('DROP TABLE "users"', statements[0]) + + def test_drop_table_if_exists(self): + blueprint = Blueprint('users') + blueprint.drop_if_exists() + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('DROP TABLE IF EXISTS "users"', statements[0]) + + def test_drop_unique(self): + blueprint = Blueprint('users') + blueprint.drop_unique('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('DROP INDEX foo', statements[0]) + + def test_drop_index(self): + blueprint = Blueprint('users') + blueprint.drop_index('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('DROP INDEX foo', statements[0]) + + def test_rename_table(self): + blueprint = Blueprint('users') + blueprint.rename('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual('ALTER TABLE "users" RENAME TO "foo"', statements[0]) + + def test_adding_foreign_key(self): + blueprint = Blueprint('users') + blueprint.create() + blueprint.string('foo').primary() + blueprint.string('order_id') + blueprint.foreign('order_id').references('id').on('orders') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + expected = 'CREATE TABLE "users" ("foo" VARCHAR NOT NULL, "order_id" VARCHAR NOT NULL, ' \ + 'FOREIGN KEY("order_id") REFERENCES "orders"("id"), PRIMARY KEY ("foo"))' + self.assertEqual(expected, statements[0]) + + def test_adding_unique_key(self): + blueprint = Blueprint('users') + blueprint.unique('foo', 'bar') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'CREATE UNIQUE INDEX bar ON "users" ("foo")', + statements[0] + ) + + def test_adding_index(self): + blueprint = Blueprint('users') + blueprint.index('foo', 'bar') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'CREATE INDEX bar ON "users" ("foo")', + statements[0] + ) + + def test_adding_incrementing_id(self): + blueprint = Blueprint('users') + blueprint.increments('id') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT', + statements[0] + ) + + def test_adding_big_incrementing_id(self): + blueprint = Blueprint('users') + blueprint.big_increments('id') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT', + statements[0] + ) + + def test_adding_string(self): + blueprint = Blueprint('users') + blueprint.string('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.string('foo', 100) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.string('foo', 100).nullable().default('bar') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NULL DEFAULT \'bar\'', + statements[0] + ) + + def test_adding_text(self): + blueprint = Blueprint('users') + blueprint.text('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL', + statements[0] + ) + + def test_adding_big_integer(self): + blueprint = Blueprint('users') + blueprint.big_integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.big_integer('foo', True) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT', + statements[0] + ) + + def test_adding_integer(self): + blueprint = Blueprint('users') + blueprint.integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', + statements[0] + ) + + blueprint = Blueprint('users') + blueprint.integer('foo', True) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT', + statements[0] + ) + + def test_adding_medium_integer(self): + blueprint = Blueprint('users') + blueprint.integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', + statements[0] + ) + + def test_adding_tiny_integer(self): + blueprint = Blueprint('users') + blueprint.integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', + statements[0] + ) + + def test_adding_small_integer(self): + blueprint = Blueprint('users') + blueprint.integer('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', + statements[0] + ) + + def test_adding_float(self): + blueprint = Blueprint('users') + blueprint.float('foo', 5, 2) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" FLOAT NOT NULL', + statements[0] + ) + + def test_adding_double(self): + blueprint = Blueprint('users') + blueprint.double('foo', 15, 8) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" FLOAT NOT NULL', + statements[0] + ) + + def test_adding_decimal(self): + blueprint = Blueprint('users') + blueprint.decimal('foo', 5, 2) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" NUMERIC NOT NULL', + statements[0] + ) + + def test_adding_boolean(self): + blueprint = Blueprint('users') + blueprint.boolean('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" TINYINT NOT NULL', + statements[0] + ) + + def test_adding_enum(self): + blueprint = Blueprint('users') + blueprint.enum('foo', ['bar', 'baz']) + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL', + statements[0] + ) + + def test_adding_date(self): + blueprint = Blueprint('users') + blueprint.date('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" DATE NOT NULL', + statements[0] + ) + + def test_adding_datetime(self): + blueprint = Blueprint('users') + blueprint.datetime('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" DATETIME NOT NULL', + statements[0] + ) + + def test_adding_time(self): + blueprint = Blueprint('users') + blueprint.time('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" TIME NOT NULL', + statements[0] + ) + + def test_adding_timestamp(self): + blueprint = Blueprint('users') + blueprint.timestamp('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" DATETIME NOT NULL', + statements[0] + ) + + def test_adding_timestamps(self): + blueprint = Blueprint('users') + blueprint.timestamps() + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(2, len(statements)) + expected = [ + 'ALTER TABLE "users" ADD COLUMN "created_at" DATETIME NOT NULL', + 'ALTER TABLE "users" ADD COLUMN "updated_at" DATETIME NOT NULL' + ] + self.assertEqual( + expected, + statements + ) + + def test_adding_binary(self): + blueprint = Blueprint('users') + blueprint.binary('foo') + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + + self.assertEqual(1, len(statements)) + self.assertEqual( + 'ALTER TABLE "users" ADD COLUMN "foo" BLOB NOT NULL', + statements[0] + ) + + def get_connection(self): + return flexmock(Connection(None)) + + def get_grammar(self): + return SQLiteSchemaGrammar() diff --git a/tests/schema/integrations/__init__.py b/tests/schema/integrations/__init__.py new file mode 100644 index 00000000..633f8661 --- /dev/null +++ b/tests/schema/integrations/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- + diff --git a/tests/schema/integrations/test_mysql.py b/tests/schema/integrations/test_mysql.py new file mode 100644 index 00000000..2da82407 --- /dev/null +++ b/tests/schema/integrations/test_mysql.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- + +import os +from ... import EloquentTestCase +from eloquent import Model +from eloquent.connections import MySqlConnection +from eloquent.connectors.mysql_connector import MySqlConnector +from eloquent.query.expression import QueryExpression +from eloquent.exceptions.query import QueryException + + +class SchemaBuilderMySqlIntegrationTestCase(EloquentTestCase): + + @classmethod + def setUpClass(cls): + Model.set_connection_resolver(DatabaseIntegrationConnectionResolver()) + + @classmethod + def tearDownClass(cls): + Model.unset_connection_resolver() + + def setUp(self): + self.schema().drop_if_exists('photos') + self.schema().drop_if_exists('posts') + self.schema().drop_if_exists('friends') + self.schema().drop_if_exists('users') + + with self.schema().create('users') as table: + table.increments('id') + table.string('email').unique() + table.timestamps() + + with self.schema().create('friends') as table: + table.unsigned_integer('user_id') + table.unsigned_integer('friend_id') + + table.foreign('user_id').references('id').on('users') + table.foreign('friend_id').references('id').on('users') + + with self.schema().create('posts') as table: + table.increments('id') + table.unsigned_integer('user_id') + table.string('name').unique() + table.timestamps() + + table.foreign('user_id').references('id').on('users') + + with self.schema().create('photos') as table: + table.increments('id') + table.morphs('imageable') + table.string('name') + table.timestamps() + + self.connection().commit() + + for i in range(10): + user = User.create(email='user%d@foo.com' % (i + 1)) + + for j in range(10): + post = Post(name='User %d Post %d' % (user.id, j + 1)) + user.posts().save(post) + + def tearDown(self): + post = Post.first() + if hasattr(post, 'user_id'): + with self.schema().table('posts') as table: + table.drop_foreign('posts_user_id_foreign') + elif hasattr(post, 'my_user_id'): + with self.schema().table('posts') as table: + table.drop_foreign('posts_my_user_id_foreign') + + with self.schema().table('friends') as table: + table.drop_foreign('friends_user_id_foreign') + table.drop_foreign('friends_friend_id_foreign') + + self.schema().drop('users') + self.schema().drop('friends') + self.schema().drop('posts') + self.schema().drop('photos') + + def test_add_columns(self): + with self.schema().table('posts') as table: + table.text('content') + table.integer('votes').default(QueryExpression(0)) + + user = User.find(1) + post = user.posts().order_by('id', 'asc').first() + + self.assertEqual('User 1 Post 1', post.name) + self.assertEqual('', post.content) + self.assertEqual(0, post.votes) + + def test_remove_columns(self): + with self.schema().table('posts') as table: + table.drop_column('name') + + self.assertIsNone(self.connection().get_column('posts', 'name')) + + user = User.find(1) + post = user.posts().order_by('id', 'asc').first() + + self.assertFalse(hasattr(post, 'name')) + + def test_rename_columns(self): + with self.schema().table('posts') as table: + table.rename_column('name', 'title') + + self.assertIsNone(self.connection().get_column('posts', 'name')) + self.assertIsNotNone(self.connection().get_column('posts', 'title')) + + user = User.find(1) + post = user.posts().order_by('id', 'asc').first() + + self.assertEqual('User 1 Post 1', post.title) + + def test_rename_columns_with_index(self): + with self.schema().table('users') as table: + table.rename_column('email', 'email_address') + + self.assertIsNone(self.connection().get_column('users', 'email')) + self.assertIsNotNone(self.connection().get_column('users', 'email_address')) + + def test_rename_columns_with_foreign_keys(self): + with self.schema().table('posts') as table: + table.drop_foreign('posts_user_id_foreign') + table.rename_column('user_id', 'my_user_id') + table.foreign('my_user_id').references('id').on('users') + + self.assertIsNone(self.connection().get_column('posts', 'user_id')) + self.assertIsNotNone(self.connection().get_column('posts', 'my_user_id')) + + def test_change_columns(self): + with self.schema().table('posts') as table: + table.integer('votes').default(0) + + post = Post.find(1) + self.assertEqual(0, post.votes) + + with self.schema().table('posts') as table: + table.string('name').nullable().change() + table.string('votes').default('0').change() + + name_column = self.connection().get_column('posts', 'name') + votes_column = self.connection().get_column('posts', 'votes') + self.assertFalse(name_column.get_notnull()) + self.assertTrue(votes_column.get_notnull()) + self.assertEqual('0', votes_column.get_default()) + + post = Post.find(1) + self.assertEqual('0', post.votes) + + def connection(self): + return Model.get_connection_resolver().connection() + + def schema(self): + """ + :rtype: eloquent.schema.SchemaBuilder + """ + return self.connection().get_schema_builder() + + +class User(Model): + + __guarded__ = [] + + @property + def friends(self): + return self.belongs_to_many(User, 'friends', 'user_id', 'friend_id') + + @property + def posts(self): + return self.has_many(Post, 'user_id') + + @property + def post(self): + return self.has_one(Post, 'user_id') + + @property + def photos(self): + return self.morph_many(Photo, 'imageable') + + +class Post(Model): + + __guarded__ = [] + + @property + def user(self): + return self.belongs_to(User, 'user_id') + + @property + def photos(self): + return self.morph_many(Photo, 'imageable') + + +class Photo(Model): + + __guarded__ = [] + + @property + def imageable(self): + return self.morph_to() + + +class DatabaseIntegrationConnectionResolver(object): + + _connection = None + + def connection(self, name=None): + if self._connection: + return self._connection + + database = os.environ.get('ELOQUENT_MYSQL_TEST_DATABASE', 'eloquent_test') + user = os.environ.get('ELOQUENT_MYSQL_TEST_USER', 'root') + password = os.environ.get('ELOQUENT_MYSQL_TEST_PASSWORD', '') + + self._connection = MySqlConnection( + MySqlConnector().connect({ + 'database': database, + 'user': user, + 'password': password + }) + ) + + return self._connection + + def get_default_connection(self): + return 'default' + + def set_default_connection(self, name): + pass diff --git a/tests/schema/integrations/test_postgres.py b/tests/schema/integrations/test_postgres.py new file mode 100644 index 00000000..f6499c6a --- /dev/null +++ b/tests/schema/integrations/test_postgres.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- + +import os +from ... import EloquentTestCase +from eloquent import Model +from eloquent.connections import PostgresConnection +from eloquent.connectors.postgres_connector import PostgresConnector +from eloquent.query.expression import QueryExpression + + +class SchemaBuilderPostgresIntegrationTestCase(EloquentTestCase): + + @classmethod + def setUpClass(cls): + Model.set_connection_resolver(DatabaseIntegrationConnectionResolver()) + + @classmethod + def tearDownClass(cls): + Model.unset_connection_resolver() + + def setUp(self): + self.schema().drop_if_exists('photos') + self.schema().drop_if_exists('posts') + self.schema().drop_if_exists('friends') + self.schema().drop_if_exists('users') + + with self.schema().create('users') as table: + table.increments('id') + table.string('email').unique() + table.timestamps() + + with self.schema().create('friends') as table: + table.integer('user_id') + table.integer('friend_id') + + table.foreign('user_id').references('id').on('users') + table.foreign('friend_id').references('id').on('users') + + with self.schema().create('posts') as table: + table.increments('id') + table.integer('user_id') + table.string('name').unique() + table.timestamps() + + table.foreign('user_id').references('id').on('users') + + with self.schema().create('photos') as table: + table.increments('id') + table.morphs('imageable') + table.string('name') + table.timestamps() + + self.connection().commit() + + for i in range(10): + user = User.create(email='user%d@foo.com' % (i + 1)) + + for j in range(10): + post = Post(name='User %d Post %d' % (user.id, j + 1)) + user.posts().save(post) + + def tearDown(self): + with self.schema().table('posts') as table: + table.drop_foreign('posts_user_id_foreign') + + with self.schema().table('friends') as table: + table.drop_foreign('friends_user_id_foreign') + table.drop_foreign('friends_friend_id_foreign') + + self.schema().drop('users') + self.schema().drop('friends') + self.schema().drop('posts') + self.schema().drop('photos') + + def test_add_columns(self): + with self.schema().table('posts') as table: + table.text('content').default('Test') + table.integer('votes').default(QueryExpression(0)) + + user = User.find(1) + post = user.posts().order_by('id', 'asc').first() + + self.assertEqual('User 1 Post 1', post.name) + self.assertEqual('Test', post.content) + self.assertEqual(0, post.votes) + + def test_remove_columns(self): + with self.schema().table('posts') as table: + table.drop_column('name') + + self.assertIsNone(self.connection().get_column('posts', 'name')) + + user = User.find(1) + post = user.posts().order_by('id', 'asc').first() + + self.assertFalse(hasattr(post, 'name')) + + def test_rename_columns(self): + with self.schema().table('posts') as table: + table.rename_column('name', 'title') + + self.assertIsNone(self.connection().get_column('posts', 'name')) + self.assertIsNotNone(self.connection().get_column('posts', 'title')) + + user = User.find(1) + post = user.posts().order_by('id', 'asc').first() + + self.assertEqual('User 1 Post 1', post.title) + + def test_rename_columns_with_index(self): + with self.schema().table('users') as table: + table.rename_column('email', 'email_address') + + self.assertIsNone(self.connection().get_column('users', 'email')) + self.assertIsNotNone(self.connection().get_column('users', 'email_address')) + + def test_rename_columns_with_foreign_keys(self): + with self.schema().table('posts') as table: + table.rename_column('user_id', 'my_user_id') + + self.assertIsNone(self.connection().get_column('posts', 'user_id')) + self.assertIsNotNone(self.connection().get_column('posts', 'my_user_id')) + + def test_change_columns(self): + with self.schema().table('posts') as table: + table.integer('votes').default(0) + + post = Post.find(1) + self.assertEqual(0, post.votes) + + with self.schema().table('posts') as table: + table.string('name').nullable().change() + table.string('votes').default('0').change() + + name_column = self.connection().get_column('posts', 'name') + votes_column = self.connection().get_column('posts', 'votes') + self.assertFalse(name_column.get_notnull()) + self.assertTrue(votes_column.get_notnull()) + self.assertEqual('0', votes_column.get_default()) + + post = Post.find(1) + self.assertEqual('0', post.votes) + + def connection(self): + return Model.get_connection_resolver().connection() + + def schema(self): + """ + :rtype: eloquent.schema.SchemaBuilder + """ + return self.connection().get_schema_builder() + + +class User(Model): + + __guarded__ = [] + + @property + def friends(self): + return self.belongs_to_many(User, 'friends', 'user_id', 'friend_id') + + @property + def posts(self): + return self.has_many(Post, 'user_id') + + @property + def post(self): + return self.has_one(Post, 'user_id') + + @property + def photos(self): + return self.morph_many(Photo, 'imageable') + + +class Post(Model): + + __guarded__ = [] + + @property + def user(self): + return self.belongs_to(User, 'user_id') + + @property + def photos(self): + return self.morph_many(Photo, 'imageable') + + +class Photo(Model): + + __guarded__ = [] + + @property + def imageable(self): + return self.morph_to() + + +class DatabaseIntegrationConnectionResolver(object): + + _connection = None + + def connection(self, name=None): + if self._connection: + return self._connection + + database = os.environ.get('ELOQUENT_POSTGRES_TEST_DATABASE', 'eloquent_test') + user = os.environ.get('ELOQUENT_POSTGRES_TEST_USER', 'postgres') + password = os.environ.get('ELOQUENT_POSTGRES_TEST_PASSWORD', None) + + self._connection = PostgresConnection( + PostgresConnector().connect({ + 'database': database, + 'user': user, + 'password': password + }) + ) + + return self._connection + + def get_default_connection(self): + return 'default' + + def set_default_connection(self, name): + pass diff --git a/tests/schema/integrations/test_sqlite.py b/tests/schema/integrations/test_sqlite.py new file mode 100644 index 00000000..1f86ddae --- /dev/null +++ b/tests/schema/integrations/test_sqlite.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- + +from ... import EloquentTestCase +from eloquent import Model +from eloquent.connections import SQLiteConnection +from eloquent.connectors.sqlite_connector import SQLiteConnector +from eloquent.query.expression import QueryExpression + + +class SchemaBuilderSQLiteIntegrationTestCase(EloquentTestCase): + + @classmethod + def setUpClass(cls): + Model.set_connection_resolver(DatabaseIntegrationConnectionResolver()) + + @classmethod + def tearDownClass(cls): + Model.unset_connection_resolver() + + def setUp(self): + with self.schema().create('users') as table: + table.increments('id') + table.string('email').unique() + table.timestamps() + + with self.schema().create('friends') as table: + table.integer('user_id') + table.integer('friend_id') + + table.foreign('user_id').references('id').on('users') + table.foreign('friend_id').references('id').on('users') + + with self.schema().create('posts') as table: + table.increments('id') + table.integer('user_id') + table.string('name').unique() + table.timestamps() + + table.foreign('user_id').references('id').on('users') + + with self.schema().create('photos') as table: + table.increments('id') + table.morphs('imageable') + table.string('name') + table.timestamps() + + for i in range(10): + user = User.create(email='user%d@foo.com' % (i + 1)) + + for j in range(10): + post = Post(name='User %d Post %d' % (user.id, j + 1)) + user.posts().save(post) + + def tearDown(self): + self.schema().drop('users') + self.schema().drop('friends') + self.schema().drop('posts') + self.schema().drop('photos') + + def test_add_columns(self): + with self.schema().table('posts') as table: + table.text('content').default('Test') + table.integer('votes').default(QueryExpression(0)) + + user = User.find(1) + post = user.posts().order_by('id', 'asc').first() + + self.assertEqual('User 1 Post 1', post.name) + self.assertEqual('Test', post.content) + self.assertEqual(0, post.votes) + + def test_remove_columns(self): + with self.schema().table('posts') as table: + table.drop_column('name') + + self.assertIsNone(self.connection().get_column('posts', 'name')) + + user = User.find(1) + post = user.posts().order_by('id', 'asc').first() + + self.assertFalse(hasattr(post, 'name')) + + def test_rename_columns(self): + old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + + with self.schema().table('posts') as table: + table.rename_column('name', 'title') + + self.assertIsNone(self.connection().get_column('posts', 'name')) + self.assertIsNotNone(self.connection().get_column('posts', 'title')) + + foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + + self.assertEqual(len(foreign_keys), len(old_foreign_keys)) + + user = User.find(1) + post = user.posts().order_by('id', 'asc').first() + + self.assertEqual('User 1 Post 1', post.title) + + def test_rename_columns_with_index(self): + indexes = self.connection().get_schema_manager().list_table_indexes('users') + old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + + self.assertEqual('users_email_unique', indexes[0]['name']) + self.assertEqual(['email'], indexes[0]['columns']) + self.assertTrue(indexes[0]['unique']) + + with self.schema().table('users') as table: + table.rename_column('email', 'email_address') + + self.assertIsNone(self.connection().get_column('users', 'email')) + self.assertIsNotNone(self.connection().get_column('users', 'email_address')) + foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + + self.assertEqual(len(foreign_keys), len(old_foreign_keys)) + + indexes = self.connection().get_schema_manager().list_table_indexes('users') + + self.assertEqual('users_email_address_unique', indexes[0]['name']) + self.assertEqual(['email_address'], indexes[0]['columns']) + self.assertTrue(indexes[0]['unique']) + + def test_rename_columns_with_foreign_keys(self): + old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + + self.assertEqual('user_id', old_foreign_keys[0]['from']) + self.assertEqual('id', old_foreign_keys[0]['to']) + self.assertEqual('users', old_foreign_keys[0]['table']) + + with self.schema().table('posts') as table: + table.rename_column('user_id', 'my_user_id') + + self.assertIsNone(self.connection().get_column('posts', 'user_id')) + self.assertIsNotNone(self.connection().get_column('posts', 'my_user_id')) + foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + + self.assertEqual(len(foreign_keys), len(old_foreign_keys)) + + self.assertEqual('my_user_id', foreign_keys[0]['from']) + self.assertEqual('id', foreign_keys[0]['to']) + self.assertEqual('users', foreign_keys[0]['table']) + + def test_change_columns(self): + with self.schema().table('posts') as table: + table.integer('votes').default(0) + + indexes = self.connection().get_schema_manager().list_table_indexes('posts') + old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + + self.assertEqual('posts_name_unique', indexes[0]['name']) + self.assertEqual(['name'], indexes[0]['columns']) + self.assertTrue(indexes[0]['unique']) + + post = Post.find(1) + self.assertEqual(0, post.votes) + + with self.schema().table('posts') as table: + table.string('name').nullable().change() + table.string('votes').default('0').change() + + name_column = self.connection().get_column('posts', 'name') + votes_column = self.connection().get_column('posts', 'votes') + self.assertFalse(name_column.get_notnull()) + self.assertTrue(votes_column.get_notnull()) + self.assertEqual("'0'", votes_column.get_default()) + foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + + self.assertEqual(len(foreign_keys), len(old_foreign_keys)) + + indexes = self.connection().get_schema_manager().list_table_indexes('posts') + + self.assertEqual('posts_name_unique', indexes[0]['name']) + self.assertEqual(['name'], indexes[0]['columns']) + self.assertTrue(indexes[0]['unique']) + + post = Post.find(1) + self.assertEqual('0', post.votes) + + def connection(self): + return Model.get_connection_resolver().connection() + + def schema(self): + return self.connection().get_schema_builder() + + +class User(Model): + + __guarded__ = [] + + @property + def friends(self): + return self.belongs_to_many(User, 'friends', 'user_id', 'friend_id') + + @property + def posts(self): + return self.has_many(Post, 'user_id') + + @property + def post(self): + return self.has_one(Post, 'user_id') + + @property + def photos(self): + return self.morph_many(Photo, 'imageable') + + +class Post(Model): + + __guarded__ = [] + + @property + def user(self): + return self.belongs_to(User, 'user_id') + + @property + def photos(self): + return self.morph_many(Photo, 'imageable') + + +class Photo(Model): + + __guarded__ = [] + + @property + def imageable(self): + return self.morph_to() + + +class DatabaseIntegrationConnectionResolver(object): + + _connection = None + + def connection(self, name=None): + if self._connection: + return self._connection + + self._connection = SQLiteConnection(SQLiteConnector().connect({'database': ':memory:'})) + + return self._connection + + def get_default_connection(self): + return 'default' + + def set_default_connection(self, name): + pass diff --git a/tests/schema/test_blueprint.py b/tests/schema/test_blueprint.py new file mode 100644 index 00000000..aece5f6d --- /dev/null +++ b/tests/schema/test_blueprint.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +from flexmock import flexmock, flexmock_teardown +from eloquent.schema import Blueprint +from eloquent.schema.grammars import SchemaGrammar +from eloquent.connections import Connection +from .. import EloquentTestCase + + +class SchemaBuilderTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_to_sql_runs_commands_from_blueprint(self): + conn = flexmock(Connection(None)) + conn.should_receive('statement').once().with_args('foo') + conn.should_receive('statement').once().with_args('bar') + grammar = flexmock(SchemaGrammar()) + blueprint = flexmock(Blueprint('table')) + blueprint.should_receive('to_sql').once().with_args(conn, grammar).and_return(['foo', 'bar']) + + blueprint.build(conn, grammar) + + def test_index_default_names(self): + blueprint = Blueprint('users') + blueprint.unique(['foo', 'bar']) + commands = blueprint.get_commands() + self.assertEqual('users_foo_bar_unique', commands[0].index) + + blueprint = Blueprint('users') + blueprint.index('foo') + commands = blueprint.get_commands() + self.assertEqual('users_foo_index', commands[0].index) + + def test_drop_index_default_names(self): + blueprint = Blueprint('users') + blueprint.drop_unique(['foo', 'bar']) + commands = blueprint.get_commands() + self.assertEqual('users_foo_bar_unique', commands[0].index) + + blueprint = Blueprint('users') + blueprint.drop_index(['foo']) + commands = blueprint.get_commands() + self.assertEqual('users_foo_index', commands[0].index) diff --git a/tests/schema/test_builder.py b/tests/schema/test_builder.py new file mode 100644 index 00000000..595918e7 --- /dev/null +++ b/tests/schema/test_builder.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- + +from flexmock import flexmock, flexmock_teardown +from eloquent.connections import Connection +from eloquent.schema import SchemaBuilder +from .. import EloquentTestCase + + +class SchemaBuilderTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_has_table_correctly_calls_grammar(self): + connection = flexmock(Connection(None)) + grammar = flexmock() + connection.should_receive('get_schema_grammar').and_return(grammar) + builder = SchemaBuilder(connection) + grammar.should_receive('compile_table_exists').once().and_return('sql') + connection.should_receive('get_table_prefix').once().and_return('prefix_') + connection.should_receive('select').once().with_args('sql', ['prefix_table']).and_return(['prefix_table']) + + self.assertTrue(builder.has_table('table')) diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 00000000..f2dd7e7a --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- + +import arrow +from . import EloquentTestCase +from eloquent import Model, Collection +from eloquent.connections import SQLiteConnection +from eloquent.connectors.sqlite_connector import SQLiteConnector +from eloquent.exceptions.orm import ModelNotFound + + +class EloquentIntegrationTestCase(EloquentTestCase): + + @classmethod + def setUpClass(cls): + Model.set_connection_resolver(DatabaseIntegrationConnectionResolver()) + + @classmethod + def tearDownClass(cls): + Model.unset_connection_resolver() + + def setUp(self): + with self.schema().create('users') as table: + table.increments('id') + table.string('email').unique() + table.timestamps() + + with self.schema().create('friends') as table: + table.integer('user_id') + table.integer('friend_id') + + with self.schema().create('posts') as table: + table.increments('id') + table.integer('user_id') + table.string('name') + table.timestamps() + + with self.schema().create('photos') as table: + table.increments('id') + table.morphs('imageable') + table.string('name') + table.timestamps() + + def tearDown(self): + self.schema().drop('users') + self.schema().drop('friends') + self.schema().drop('posts') + self.schema().drop('photos') + + def test_basic_model_retrieval(self): + EloquentTestUser.create(email='john@doe.com') + model = EloquentTestUser.where('email', 'john@doe.com').first() + self.assertEqual('john@doe.com', model.email) + + def test_basic_model_collection_retrieval(self): + EloquentTestUser.create(id=1, email='john@doe.com') + EloquentTestUser.create(id=2, email='jane@doe.com') + + models = EloquentTestUser.oldest('id').get() + + self.assertEqual(2, len(models)) + self.assertIsInstance(models, Collection) + self.assertIsInstance(models[0], EloquentTestUser) + self.assertIsInstance(models[1], EloquentTestUser) + self.assertEqual('john@doe.com', models[0].email) + self.assertEqual('jane@doe.com', models[1].email) + + def test_lists_retrieval(self): + EloquentTestUser.create(id=1, email='john@doe.com') + EloquentTestUser.create(id=2, email='jane@doe.com') + + simple = EloquentTestUser.oldest('id').lists('email') + keyed = EloquentTestUser.oldest('id').lists('email', 'id') + + self.assertEqual(['john@doe.com', 'jane@doe.com'], simple) + self.assertEqual({1: 'john@doe.com', 2: 'jane@doe.com'}, keyed) + + def test_find_or_fail(self): + EloquentTestUser.create(id=1, email='john@doe.com') + EloquentTestUser.create(id=2, email='jane@doe.com') + + single = EloquentTestUser.find_or_fail(1) + multiple = EloquentTestUser.find_or_fail([1, 2]) + + self.assertIsInstance(single, EloquentTestUser) + self.assertEqual('john@doe.com', single.email) + self.assertIsInstance(multiple, Collection) + self.assertIsInstance(multiple[0], EloquentTestUser) + self.assertIsInstance(multiple[1], EloquentTestUser) + + def test_find_or_fail_with_single_id_raises_model_not_found_exception(self): + self.assertRaises( + ModelNotFound, + EloquentTestUser.find_or_fail, + 1 + ) + + def test_find_or_fail_with_multiple_ids_raises_model_not_found_exception(self): + self.assertRaises( + ModelNotFound, + EloquentTestUser.find_or_fail, + [1, 2] + ) + + def test_one_to_one_relationship(self): + user = EloquentTestUser.create(email='john@doe.com') + user.post().create(name='First Post') + + post = user.post + user = post.user + + self.assertEqual('john@doe.com', user.email) + self.assertEqual('First Post', post.name) + + def test_one_to_many_relationship(self): + user = EloquentTestUser.create(email='john@doe.com') + user.posts().create(name='First Post') + user.posts().create(name='Second Post') + + posts = user.posts + post2 = user.posts().where('name', 'Second Post').first() + + self.assertEqual(2, len(posts)) + #self.assertIsInstance(posts[0], EloquentTestPost) + #self.assertIsInstance(posts[1], EloquentTestPost) + #self.assertIsInstance(post2, EloquentTestPost) + self.assertEqual('Second Post', post2.name) + #self.assertIsInstance(post2.user.instance, EloquentTestUser) + self.assertEqual('john@doe.com', post2.user.email) + + def test_basic_model_hydrate(self): + EloquentTestUser.create(id=1, email='john@doe.com') + EloquentTestUser.create(id=2, email='jane@doe.com') + + models = EloquentTestUser.hydrate_raw( + 'SELECT * FROM users WHERE email = ?', + ['jane@doe.com'], + 'foo_connection' + ) + self.assertIsInstance(models, Collection) + self.assertIsInstance(models[0], EloquentTestUser) + self.assertEqual('jane@doe.com', models[0].email) + self.assertEqual('foo_connection', models[0].get_connection_name()) + self.assertEqual(1, len(models)) + + def test_has_on_self_referencing_belongs_to_many_relationship(self): + user = EloquentTestUser.create(id=1, email='john@doe.com') + friend = user.friends().create(email='jane@doe.com') + + results = EloquentTestUser.has('friends').get() + + self.assertEqual(1, len(results)) + self.assertEqual('john@doe.com', results.first().email) + + def test_basic_has_many_eager_loading(self): + user = EloquentTestUser.create(id=1, email='john@doe.com') + user.posts().create(name='First Post') + user = EloquentTestUser.with_('posts').where('email', 'john@doe.com').first() + + self.assertEqual('First Post', user.posts.first().name) + + post = EloquentTestPost.with_('user').where('name', 'First Post').get() + self.assertEqual('john@doe.com', post.first().user.email) + + def test_basic_morph_many_relationship(self): + user = EloquentTestUser.create(id=1, email='john@doe.com') + user.photos().create(name='Avatar 1') + user.photos().create(name='Avatar 2') + post = user.posts().create(name='First Post') + post.photos().create(name='Hero 1') + post.photos().create(name='Hero 2') + + self.assertIsInstance(user.photos.instance, Collection) + #self.assertIsInstance(user.photos[0], EloquentTestPhoto) + self.assertIsInstance(post.photos.instance, Collection) + #self.assertIsInstance(post.photos[0], EloquentTestPhoto) + self.assertEqual(2, len(user.photos)) + self.assertEqual(2, len(post.photos)) + self.assertEqual('Avatar 1', user.photos[0].name) + self.assertEqual('Avatar 2', user.photos[1].name) + self.assertEqual('Hero 1', post.photos[0].name) + self.assertEqual('Hero 2', post.photos[1].name) + + photos = EloquentTestPhoto.order_by('name').get() + + self.assertIsInstance(photos, Collection) + self.assertEqual(4, len(photos)) + #self.assertIsInstance(photos[0].imageable.instance, EloquentTestUser) + #self.assertIsInstance(photos[2].imageable.instance, EloquentTestPost) + self.assertEqual('john@doe.com', photos[1].imageable.email) + self.assertEqual('First Post', photos[3].imageable.name) + + def test_multi_insert_with_different_values(self): + date = arrow.utcnow().naive + result = EloquentTestPost.insert([ + { + 'user_id': 1, 'name': 'Post', 'created_at': date, 'updated_at': date + }, { + 'user_id': 2, 'name': 'Post', 'created_at': date, 'updated_at': date + } + ]) + + self.assertTrue(result) + self.assertEqual(2, EloquentTestPost.count()) + + def test_multi_insert_with_same_values(self): + date = arrow.utcnow().naive + result = EloquentTestPost.insert([ + { + 'user_id': 1, 'name': 'Post', 'created_at': date, 'updated_at': date + }, { + 'user_id': 1, 'name': 'Post', 'created_at': date, 'updated_at': date + } + ]) + + self.assertTrue(result) + self.assertEqual(2, EloquentTestPost.count()) + + def connection(self): + return Model.get_connection_resolver().connection() + + def schema(self): + return self.connection().get_schema_builder() + + +class EloquentTestUser(Model): + + __table__ = 'users' + __guarded__ = [] + + @property + def friends(self): + return self.belongs_to_many(EloquentTestUser, 'friends', 'user_id', 'friend_id') + + @property + def posts(self): + return self.has_many('posts', 'user_id') + + @property + def post(self): + return self.has_one(EloquentTestPost, 'user_id') + + @property + def photos(self): + return self.morph_many('photos', 'imageable') + + +class EloquentTestPost(Model): + + __table__ = 'posts' + __guarded__ = [] + + @property + def user(self): + return self.belongs_to(EloquentTestUser, 'user_id') + + @property + def photos(self): + return self.morph_many('photos', 'imageable') + + +class EloquentTestPhoto(Model): + + __table__ = 'photos' + __guarded__ = [] + + @property + def imageable(self): + return self.morph_to() + + +class DatabaseIntegrationConnectionResolver(object): + + _connection = None + + def connection(self, name=None): + if self._connection: + return self._connection + + self._connection = SQLiteConnection(SQLiteConnector().connect({'database': ':memory:'})) + + return self._connection + + def get_default_connection(self): + return 'default' + + def set_default_connection(self, name): + pass