diff --git a/.gitignore b/.gitignore index 7b548069..92345a9c 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ nosetests.xml .idea/* test.py +test_models.py diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..a876a348 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,14 @@ +### 0.2 + +(March 31th, 2015) + +- Adds actual ORM with relationships and eager loading +- Adds chunk support for QueryBuilder +- Properly close connections when using reconnect() and disconnect() methods + + +### 0.1 + +(March 18th, 2015) + +- Initial release diff --git a/docs/_static/theme_overrides.css b/docs/_static/theme_overrides.css new file mode 100644 index 00000000..3d446ab8 --- /dev/null +++ b/docs/_static/theme_overrides.css @@ -0,0 +1,408 @@ +@import url(http://fonts.googleapis.com/css?family=Roboto:400,400italic,500,700); +@import url(http://fonts.googleapis.com/css?family=Montserrat:400,700); + +body { + font-family: "Roboto", sans-serif; +} + +.wy-body-for-nav { + background: #FFFFFF; +} + +.wy-nav-content { + background: #FFFFFF; +} + +.wy-nav-content-wrap { + background: #FFFFFF; +} + +.rst-content h1:hover .headerlink, +.rst-content h2:hover .headerlink, +.rst-content h3:hover .headerlink, +.rst-content h4:hover .headerlink, +.rst-content h5:hover .headerlink, +.rst-content h6:hover .headerlink, +.rst-content dl dt:hover .headerlink { + display: inline; +} + +.wy-plain-list-disc li, +.rst-content .section ul li, +.rst-content .toctree-wrapper ul li, +article ul li { + font: 400 14px/28px "Roboto",sans-serif; + color: #7A7A7A; +} + +.wy-nav-side { + background: #2B303B; +} + +.wy-side-nav-search { + background-color: rgb(39, 43, 53); +} + +.wy-menu-vertical li.current { + background: rgb(33, 37, 45); +} + +.wy-menu-vertical li.on a, .wy-menu-vertical li.current > a { + background: #2B303B; + border: 0; +} + +.wy-menu-vertical li.current a { + font-size: 13px; + border-right: 0; +} + +.wy-menu-vertical a, +.wy-menu-vertical a:visited { + color: #788195; + line-height: 24px; +} + +.wy-menu-vertical a:hover { + color: #FFFFFF; + background: none; +} + +.wy-menu-vertical li.current a, +.wy-menu-vertical li.current a:visited { + color: #788195; +} + +.wy-menu-vertical li.current a:hover { + color: #FFFFFF; + background: none; +} + +.wy-menu-vertical li.current > a { + color: #FFFFFF !important; +} + +.wy-menu-vertical li.current > a:hover { + background: #2B303B !important; +} + +.wy-side-nav-search input[type="text"] { + border: 1px solid #2B303B; + background: rgb(23, 27, 35); + box-shadow: none; + color: #FFFFFF; + border-radius: 4px; +} + +.wy-side-nav-search > a, +.wy-side-nav-search > a:visited { + color: #FFFFFF; +} + +.wy-side-nav-search > a:hover { + background: none; +} + +.rst-content .highlighted { + background: rgba(205, 220, 57, 0.7); + border: 1px solid rgba(205, 220, 57, 0.9); + padding: 1px 3px; + border-radius: 3px; + color: rgba(135, 150, 0, 1); +} + +h1, h2, h3, h4, h5, h6 { + color: #7A7A7A; + padding: 10px 0px; + font-family: "Montserrat",serif; + line-height: 1; + margin: 40px 0 30px 0; + font-weight: 400; +} + +h1 { + font-size: 42px; +} + +h2 { + font-size: 32px; + color: #2C3E50; + position: relative; +} + +h2:before { + content: "#"; + margin-left: -30px; + font-size: 32px; + top: 10px; + color: #03A9F4; + opacity: 0.6; + position: absolute; +} + +h3 { + font-size: 24px; +} + +h4 { + font-size: 20px; +} + + +h1, +h1 a, +.header h1, +.header h1 a { + font: 700 42px/66px "Roboto", "Montserrat", sans-serif; + margin-bottom: 30px; + font-weight: 700; + color: #2C3E50; +} + +a, a:visited { + color: #00BCD4; +} + +a > em { + font-style: normal; +} + +p { + font: 400 14px/28px "Roboto",sans-serif; + color: #7A7A7A; +} + +pre, pre code, .rst-content tt { + color: #03A9F4; + background-color: #FAFAFA; + font-size: 14px; +} + +pre, code, .rst-content tt { + font-family: "Consolas", "Menlo", "Monaco", "Courier New", Courier, monospace; + font-weight: 400; + border-radius: 3px; + font-size: 1em; + margin: 0px 2px; +} + +div[class^="highlight"] { + border: 0; +} + +.highlight > pre, div[class^="highlight"] pre { + border: 1px solid #F2F2F2; + font-size: 13px; + padding: 4%; +} + +code { + background: #FAFAFA; +} + +p > code, +.rst-content tt { + color: #03A9F4; + padding: 0.3em 0.5em; + display: inline; +} + +.rst-content .admonition-title { + display: none; +} + +.rst-content .note, +.rst-content .warning { + padding: 0px 40px; + margin: 30px 0px; + background: #FFFFFF; +} + +.rst-content .note { + border-left: 1px solid rgba(60, 147, 206, 0.7); + color: rgb(99, 169, 216) !important; +} + +.rst-content .warning { + border-left: 1px solid #FF9800; + color: #FFC107 !important; +} + +.rst-content .note > p { + padding: 12px 12px 12px 0; + color: rgb(99, 169, 216) !important; +} + +.rst-content .warning > p { + padding: 12px 12px 12px 0; + color: #FF9800 !important; +} + +.note code { + background-color: rgba(49, 112, 143, 0.07); + border-color: rgba(49, 112, 143, 0.1); + padding: 0.3em 0.5em; + display: inline; +} + +.note ul li, +.warning ul li { + color: inherit !important; +} + +.warning code, +.warning tt { + background-color: rgba(255, 152, 0, 0.07); + border-color: rgba(255, 152, 0, 0.1); + padding: 0.3em 0.5em; + display: inline; +} + +.btn-neutral, +.btn-neutral:visited { + color: #808080 !important; + font: 400 12px/20px Montserrat,Arial,sans-serif; + text-transform: uppercase; + padding: 6px 14px; + border-radius: 4px; + box-shadow: none; + border: 1px solid rgba(0, 0, 0, 0.05); + transition: all 0.3s; +} + +.btn-neutral:hover { + background-color: rgba(0, 188, 212, 0.6) !important; + color: rgba(255, 255, 255, 0.9) !important; + /*color: rgba(0, 168, 192, 1) !important;*/ +} + +.btn-neutral:focus { + background-color: rgba(0, 188, 212, 0.6) !important; + color: rgba(255, 255, 255, 0.7) !important; + box-shadow: none; + padding: 6px 14px; +} + +.btn-neutral span.fa { + transition: all 0.3s; +} + +.btn-neutral:hover span.fa-arrow-circle-left { + padding-right: 5px; +} + +.btn-neutral:hover span.fa-arrow-circle-right { + padding-left: 5px; +} + +.btn-neutral span.fa-arrow-circle-left:before { + content: '' +} + +.btn-neutral span.fa-arrow-circle-right:before { + content: '' +} + + +/* Solarized Light + +For use with Jekyll and Pygments + +http://ethanschoonover.com/solarized + +SOLARIZED HEX ROLE +--------- -------- ------------------------------------------ +base01 #586e75 body text / default code / primary content +base1 #93a1a1 comments / secondary content +base3 #fdf6e3 background +orange #cb4b16 constants +red #dc322f regex, special keywords +blue #268bd2 reserved keywords +cyan #2aa198 strings, numbers +green #859900 operators, other keywords +*/ + +.highlight { background-color: #fdf6e3; color: #586e75 } +.highlight .c { color: #93a1a1 } /* Comment */ +.highlight .err { color: #586e75 } /* Error */ +.highlight .g { color: #586e75 } /* Generic */ +.highlight .k { color: #859900 } /* Keyword */ +.highlight .l { color: #586e75 } /* Literal */ +.highlight .n { color: #586e75 } /* Name */ +.highlight .o { color: #859900 } /* Operator */ +.highlight .x { color: #cb4b16 } /* Other */ +.highlight .p { color: #586e75 } /* Punctuation */ +.highlight .cm { color: #93a1a1 } /* Comment.Multiline */ +.highlight .cp { color: #859900 } /* Comment.Preproc */ +.highlight .c1 { color: #93a1a1 } /* Comment.Single */ +.highlight .cs { color: #859900 } /* Comment.Special */ +.highlight .gd { color: #2aa198 } /* Generic.Deleted */ +.highlight .ge { color: #586e75; font-style: italic } /* Generic.Emph */ +.highlight .gr { color: #dc322f } /* Generic.Error */ +.highlight .gh { color: #cb4b16 } /* Generic.Heading */ +.highlight .gi { color: #859900 } /* Generic.Inserted */ +.highlight .go { color: #586e75 } /* Generic.Output */ +.highlight .gp { color: #586e75 } /* Generic.Prompt */ +.highlight .gs { color: #586e75; font-weight: bold } /* Generic.Strong */ +.highlight .gu { color: #cb4b16 } /* Generic.Subheading */ +.highlight .gt { color: #586e75 } /* Generic.Traceback */ +.highlight .kc { color: #cb4b16 } /* Keyword.Constant */ +.highlight .kd { color: #268bd2 } /* Keyword.Declaration */ +.highlight .kn { color: #859900 } /* Keyword.Namespace */ +.highlight .kp { color: #859900 } /* Keyword.Pseudo */ +.highlight .kr { color: #268bd2 } /* Keyword.Reserved */ +.highlight .kt { color: #dc322f } /* Keyword.Type */ +.highlight .ld { color: #586e75 } /* Literal.Date */ +.highlight .m { color: #2aa198 } /* Literal.Number */ +.highlight .s { color: #2aa198 } /* Literal.String */ +.highlight .na { color: #586e75 } /* Name.Attribute */ +.highlight .nb { color: #B58900 } /* Name.Builtin */ +.highlight .nc { color: #268bd2 } /* Name.Class */ +.highlight .no { color: #cb4b16 } /* Name.Constant */ +/*.highlight .nd { color: #268bd2 }*/ /* Name.Decorator */ +.highlight .nd { color: #cb4b16 } +.highlight .ni { color: #cb4b16 } /* Name.Entity */ +.highlight .ne { color: #cb4b16 } /* Name.Exception */ +.highlight .nf { color: #268bd2 } /* Name.Function */ +.highlight .nl { color: #586e75 } /* Name.Label */ +.highlight .nn { color: #586e75 } /* Name.Namespace */ +.highlight .nx { color: #586e75 } /* Name.Other */ +.highlight .py { color: #586e75 } /* Name.Property */ +.highlight .nt { color: #268bd2 } /* Name.Tag */ +.highlight .nv { color: #268bd2 } /* Name.Variable */ +.highlight .ow { color: #859900 } /* Operator.Word */ +.highlight .w { color: #586e75 } /* Text.Whitespace */ +.highlight .mf { color: #ff9800 } /* Literal.Number.Float */ +.highlight .mh { color: #ff9800 } /* Literal.Number.Hex */ +.highlight .mi { color: #ff9800 } /* Literal.Number.Integer */ +.highlight .mo { color: #ff9800 } /* Literal.Number.Oct */ +.highlight .sb { color: #93a1a1 } /* Literal.String.Backtick */ +.highlight .sc { color: #2aa198 } /* Literal.String.Char */ +.highlight .sd { color: #586e75 } /* Literal.String.Doc */ +.highlight .s2 { color: #2aa198 } /* Literal.String.Double */ +.highlight .se { color: #cb4b16 } /* Literal.String.Escape */ +.highlight .sh { color: #586e75 } /* Literal.String.Heredoc */ +.highlight .si { color: #2aa198 } /* Literal.String.Interpol */ +.highlight .sx { color: #2aa198 } /* Literal.String.Other */ +.highlight .sr { color: #dc322f } /* Literal.String.Regex */ +.highlight .s1 { color: #2aa198 } /* Literal.String.Single */ +.highlight .ss { color: #2aa198 } /* Literal.String.Symbol */ +.highlight .bp { color: #268bd2 } /* Name.Builtin.Pseudo */ +.highlight .vc { color: #268bd2 } /* Name.Variable.Class */ +.highlight .vg { color: #268bd2 } /* Name.Variable.Global */ +.highlight .vi { color: #268bd2 } /* Name.Variable.Instance */ +.highlight .il { color: #2aa198 } /* Literal.Number.Integer.Long */ + +.highlight .s { color: rgb(0, 188, 212) } /* Literal.String */ +.highlight .s1 { color: rgb(0, 188, 212) } /* Literal.String */ +.highlight .k { color: #8bc34a; font-weight: normal; } /* Keyword */ +.highlight .kn { color: #8bc34a; font-weight: normal; } /* Keyword */ +.highlight .ow { color: #8bc34a; font-weight: normal } /* Operator.Word */ +.highlight .nc { color: rgb(3, 169, 244) } /* Name.Class */ +.highlight .nf { color: rgb(3, 169, 244); font-weight: normal; } /* Name.Function */ +.highlight .bp { color: rgb(3, 169, 244) } /* Name.Builtin.Pseudo */ +.highlight .nd { color: #ff5252 } /* Name.Decorator */ +.highlight .o { color: #7A7A7A } /* Operator */ +.highlight .n { color: #7A7A7A } + + diff --git a/docs/basic_usage.rst b/docs/basic_usage.rst index a37160d5..b6c84f01 100644 --- a/docs/basic_usage.rst +++ b/docs/basic_usage.rst @@ -1,3 +1,5 @@ +.. _BasicUsage: + Basic Usage =========== @@ -25,6 +27,8 @@ and passing it to a ``DatabaseManager`` instance. db = DatabaseManager(config) +.. _read_write_connections: + Read / Write connections ------------------------ diff --git a/docs/conf.py b/docs/conf.py index ba2685c5..d3e145e9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -53,9 +53,9 @@ # built documents. # # The short X.Y version. -version = '0.1' +version = '0.2' # The full version, including alpha/beta/rc tags. -release = '0.1' +release = '0.2' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -100,7 +100,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'sphinx_rtd_theme' +html_theme = 'default' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -258,3 +258,19 @@ # If true, do not generate a @detailmenu in the "Top" node's menu. #texinfo_no_detailmenu = False + + +on_rtd = os.environ.get('READTHEDOCS', None) == 'True' + +if not on_rtd: + html_theme = 'sphinx_rtd_theme' + def setup(app): + app.add_stylesheet('theme_overrides.css') +else: + html_context = { + 'css_files': [ + 'https://media.readthedocs.org/css/sphinx_rtd_theme.css', + 'https://media.readthedocs.org/css/readthedocs-doc-embed.css', + '_static/theme_overrides.css', + ], + } diff --git a/docs/index.rst b/docs/index.rst index c7fc2aa7..292a14cb 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,12 +12,4 @@ but modified to be more pythonic. installation basic_usage query_builder - - - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` + orm diff --git a/docs/installation.rst b/docs/installation.rst index a427908e..9acf7a53 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -16,6 +16,6 @@ You can install Eloquent in 2 different ways: The different dbapi packages are not part of the package dependencies, so you must install them in order to connect to corresponding databases: - * Postgres: ``pyscopg2`` + * PostgreSQL: ``pyscopg2`` * MySQL: ``PyMySQL`` or ``MySQL-python`` - * Sqlite: The ``sqlite3`` module is bundled with Python by default + * SQLite: The ``sqlite3`` module is bundled with Python by default diff --git a/docs/orm.rst b/docs/orm.rst new file mode 100644 index 00000000..3b0eb9dd --- /dev/null +++ b/docs/orm.rst @@ -0,0 +1,1052 @@ +The ORM +####### + +Introduction +============ + +The ORM provides a simple ActiveRecord implementation for working with your databases. +Each database table has a corresponding Model which is used to interact with that table. + +Before getting started, be sure to have configured a ``DatabaseManager`` as seen in the :ref:`BasicUsage` section. + +.. code-block:: python + + from eloquent import DatabaseManager + + config = { + 'mysql': { + 'driver': 'mysql', + 'host': 'localhost', + 'database': 'database', + 'username': 'root', + 'password': '', + 'prefix': '' + } + } + + db = DatabaseManager(config) + + +Basic Usage +=========== + +To actually get started, you need to tell the ORM to use the configured ``DatabaseManager`` for all models +inheriting from the ``Model`` class: + +.. code-block:: python + + from eloquent import Model + + Model.set_connection_resolver(db) + +And that's pretty much it. You can now define your models. + + +Defining a Model +---------------- + +.. code-block:: python + + class User(Model): + pass + +Note that we did not tell the ORM which table to use for the ``User`` model. The plural "snake case" name of the +class name will be used as the table name unless another name is explicitly specified. +In this case, the ORM will assume the ``User`` model stores records in the ``users`` table. +You can specify a custom table by defining a ``__table__`` property on your model: + +.. code-block:: python + + class User(Model): + + __table__ = 'my_users' + +.. note:: + + The ORM will also assume that each table has a primary key column named ``id``. + You can define a ``__primary_key__`` property to override this convention. + Likewise, you can define a ``__connection__`` property to override the name of the database + connection that should be used when using the model. + +Once a model is defined, you are ready to start retrieving and creating records in your table. +Note that you will need to place ``updated_at`` and ``created_at`` columns on your table by default. +If you do not wish to have these columns automatically maintained, +set the ``__timestamps__`` property on your model to ``False``. + + +Retrieving all models +--------------------- + +.. code-block:: python + + users = User.all() + + +Retrieving a record by primary key +---------------------------------- + +.. code-block:: python + + user = User.find(1) + + print(user.name) + +.. note:: + + All methods available on the :ref:`QueryBuilder` are also available when querying models. + + +Retrieving a Model by primary key or raise an exception +------------------------------------------------------- + +Sometimes it may be useful to throw an exception if a model is not found. +You can use the ``find_or_fail`` method for that, which will raise a ``ModelNotFound`` exception. + +.. code-block:: python + + model = User.find_or_fail(1) + + model = User.where('votes', '>', 100).first_or_fail() + + +Querying using models +--------------------- + +.. code-block:: python + + users = User.where('votes', '>', 100).take(10).get() + + for user in users: + print(user.name) + + +Aggregates +---------- + +You can also use the query builder aggregate functions: + +.. code-block:: python + + count = User.where('votes', '>', 100).count() + +If you feel limited by the builder's fluent interface, you can use the ``where_raw`` method: + +.. code-block:: python + + users = User.where_raw('age > ? and votes = 100', [25]).get() + + +Chunking Results +---------------- + +If you need to process a lot of records, you can use the ``chunk`` method to avoid +consuming a lot of RAM: + +.. code-block:: python + + for users in User.chunk(100): + for user in users: + # ... + + +Specifying the query connection +------------------------------- + +You can specify which database connection to use when querying a model by using the ``on`` method: + +.. code-block:: python + + user = User.on('connection-name').find(1) + +If you are using :ref:`read_write_connections`, you can force the query to use the "write" connection +with the following method: + +.. code-block:: python + + user = User.on_write_connection().find(1) + + +Mass assignment +=============== + +When creating a new model, you pass attributes to the model constructor. +These attributes are then assigned to the model via mass-assignment. +Though convenient, this can be a serious security concern when passing user input into a model, +since the user is then free to modify **any** and **all** of the model's attributes. +For this reason, all models protect against mass-assignment by default. + +To get started, set the ``__fillable__`` or ``__guarded__`` properties on your model. + + +Defining fillable attributes on a model +--------------------------------------- + +The ``__fillable__`` property specifies which attributes can be mass-assigned. + +.. code-block:: python + + class User(Model): + + __fillable__ = ['first_name', 'last_name', 'email'] + + +Defining guarded attributes on a model +-------------------------------------- + +The ``__guarded__`` is the inverse and acts as "blacklist". + +.. code-block:: python + + class User(Model): + + __guarded__ = ['id', 'password'] + +.. warning:: + + When using ``__guarded__``, you should still never pass any user input directly since + any attribute that is not guarded can be mass-assigned. + + +You can also block **all** attributes from mass-assignment: + +.. code-block:: python + + __guarded__ = ['*'] + + +Insert, update and delete +========================= + + +Saving a new model +------------------ + +To create a new record in the database, simply create a new model instance and call the ``save`` method. + +.. code-block:: python + + user = User() + + user.name = 'John' + + user.save() + +.. note:: + + Your models will probably have auto-incrementing primary keys. However, if you wish to maintain + your own primary keys, set the ``__autoincrementing__`` property to ``False``. + +You can also use the ``create`` method to save a model in a single line, but you will need to specify +either the ``__fillable__`` or ``__guarded__`` property on the model since all models are protected against +mass-assigment by default. + +After saving or creating a new model with auto-incrementing IDs, you can retrieve the ID by accessing +the object's ``id`` attribute: + +.. code-block:: python + + inserted_id = user.id + + +Using the create method +----------------------- + +.. code-block:: python + + # Create a new user in the database + user = User.create(name='John') + + # Retrieve the user by attributes, or create it if it does not exist + user = User.first_or_create(name='John') + + # Retrieve the user by attributes, or instantiate it if it does not exist + user = User.first_or_new(name='John') + + +Updating a retrieved model +-------------------------- + +.. code-block:: python + + user = User.find(1) + + user.name = 'Foo' + + user.save() + +You can also run updates as queries against a set of models: + +.. code-block:: python + + affected_rows = User.where('votes', '>', 100).update(status=2) + +.. + TODO: push method + + +Deleting an existing model +-------------------------- + +To delete a model, simply call the ``delete`` model: + +.. code-block:: python + + user = User.find(1) + + user.delete() + + +Deleting an existing model by key +--------------------------------- + +.. code-block:: python + + User.destroy(1) + + User.destroy(1, 2, 3) + +You can alsoe run a delete query on a set of models: + +.. code-block:: python + + affected_rows = User.where('votes', '>' 100).delete() + + +Updating only the model's timestamps +------------------------------------ + +If you want to only update the timestamps on a model, you can use the ``touch`` method: + +.. code-block:: python + + user.touch() + + +Relationships +============= + +Eloquent makes managing and working with relationships easy. It supports many types of relationships: + +* :ref:`OneToOne` +* :ref:`OneToMany` +* :ref:`ManyToMany` +* :ref:`HasManyThrough` + +.. _OneToOne: + +One To One +---------- + +Defining a One To One relationship +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A one-to-one relationship is a very basic relation. For instance, a ``User`` model might have a ``Phone``. +We can define this relation with the ORM: + +.. code-block:: python + + class User(Model): + + @property + def phone(self): + return self.has_one(Phone) + +The first argument passed to the ``has_one`` method is the class of the related model. +Once the relationship is defined, we can retrieve it using :ref:`dynamic_properties`: + +.. code-block:: python + + phone = User.find(1).phone + +The SQL performed by this statement will be as follow: + +.. code-block:: sql + + SELECT * FROM users WHERE id = 1 + + SELECT * FROM phones WHERE user_id = 1 + +The Eloquent ORM assumes the foreign key of the relationship based on the model name. In this case, +``Phone`` model is assumed to use a ``user_id`` foreign key. If you want to override this convention, +you can pass a second argument to the ``has_one`` method. Furthermore, you may pass a third argument +to the method to specify which local column should be used for the association: + +.. code-block:: python + + return self.has_one(Phone, 'foreign_key') + + return self.has_one(Phone, 'foreign_key', 'local_key') + + +Defining the inverse of the relation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To define the inverse of the relationship on the ``Phone`` model, you can use the ``belongs_to`` method: + +.. code-block:: python + + class Phone(Model): + + @property + def user(self): + return self.belongs_to(User) + +In the example above, the Eloquent ORM will look for a ``user_id`` column on the ``phones`` table. You can +define a different foreign key column, you can pass it as the second argument of the ``belongs_to`` method: + +.. code-block:: python + + return self.belongs_to(User, 'local_key') + +Additionally, you pass the third parameter which specifies the name of the associated column on the parent table: + +.. code-block:: python + + return self.belongs_to(User, 'local_key', 'parent_key') + + +.. _OneToMany: + +One To Many +----------- + +An example of a one-to-many relation is a blog post that has many comments: + +.. code-block:: python + + class Post(Model): + + @property + def comments(self): + return self.has_many(Comment) + +Now you can access the post's comments via :ref:`dynamic_properties`: + +.. code-block:: python + + comments = Post.find(1).comments + +Again, you may override the conventional foreign key by passing a second argument to the ``has_many`` method. +And, like the ``has_one`` relation, the local column may also be specified: + +.. code-block:: python + + return self.has_many(Comment, 'foreign_key') + + return self.has_many(Comment, 'foreign_key', 'local_key') + +Defining the inverse of the relation: +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To define the inverse of the relationship on the ``Comment`` model, we use the ``belongs_to`` method: + +.. code-block:: python + + class Comment(Model): + + @property + def post(self): + return self.belongs_to(Post) + + +.. _ManyToMany: + +Many To Many +------------ + +Many-to-many relations are a more complicated relationship type. +An example of such a relationship is a user with many roles, where the roles are also shared by other users. +For example, many users may have the role of "Admin". Three database tables are needed for this relationship: +``users``, ``roles``, and ``roles_users``. +The ``roles_users`` table is derived from the alphabetical order of the related table names, +and should have the ``user_id`` and ``role_id`` columns. + +We can define a many-to-many relation using the ``belongs_to_many`` method: + +.. code-block:: python + + class User(Model): + + @property + def roles(self): + return self.belongs_to_many(Role) + +Now, we can retrieve the roles through the ``User`` model: + +.. code-block:: python + + roles = User.find(1).roles + +If you want to use an unconventional table name for your pivot table, you can pass it as the second argument +to the ``belongs_to_many`` method: + +.. code-block:: python + + return self.belongs_to_many(Role, 'user_role') + +You can also override the conventional associated keys: + +.. code-block:: python + + return self.belongs_to_many(Role, 'user_role', 'user_id', 'foo_id') + +Of course, you also can define the inverse og the relationship on the ``Role`` model: + +.. code-block:: python + + class Role(Model): + + @property + def users(self): + return self.belongs_to_many(Role) + + +.. _HasManyThrough: + +Has Many Through +---------------- + +The "has many through" relation provides a convenient short-cut +for accessing distant relations via an intermediate relation. +For example, a ``Country`` model might have many ``Post`` through a ``User`` model. +The tables for this relationship would look like this: + +.. code-block:: yaml + + countries: + id: integer + name: string + + users: + id: integer + country_id: integer + name: string + + posts: + id: integer + user_id: integer + title: string + +Even though the ``posts`` table does not contain a ``country_id`` column, the ``has_many_through`` relation +will allow access a country's posts via ``country.posts``: + +.. code-block:: python + + class Country(Model): + + @property + def posts(self): + return self.has_many_through(Post, User) + +If you want to manually specify the keys of the relationship, +you can pass them as the third and fourth arguments to the method: + +.. code-block:: python + + return self.has_many_through(Post, User, 'country_id', 'user_id') + + +Querying relations +================== + +.. _dynamic_properties: + +Dynamic properties +------------------ + +The Eloquent ORM allows you to access your relations via dynamic properties. +It will automatically load the relationship for you. It will then be accessible via +a dynamic property by the same name as the relation. For example, with the following model ``Post``: + +.. code-block:: python + + class Phone(Model): + + @property + def user(self): + return self.belongs_to(User) + + phone = Phone.find(1) + + +You can then print the user's email like this: + +.. code-block:: python + + print(phone.user.email) + +Now, for one-to-many relationships: + +.. code-block:: python + + class Post(Model): + + @property + def comments(self): + return self.has_many(Comment) + + post = Post.find(1) + +You can then access the post's comments like this: + +.. code-block:: python + + comments = post.comments + +If you need to add further constraints to which comments are retrieved, +you may call the ``comments`` method and continue chaining conditions: + +.. code-block:: python + + comments = post.comments().where('title', 'foo').first() + +.. note:: + + Relationships that return many results will return an instance of the ``Collection`` class. + + +Eager loading +============= + +Eager loading exists to alleviate the N + 1 query problem. For example, consider a ``Book`` that is related +to an ``Author``: + +.. code-block:: python + + class Book(Model): + + @property + def author(self): + return self.belongs_to(Author) + +Now, consider the following code: + +.. code-block:: python + + for book in Book.all(): + print(book.author.name) + +This loop will execute 1 query to retrieve all the books on the table, then another query for each book +to retrieve the author. So, if we have 25 books, this loop will run 26 queries. + +To drastically reduce the number of queries you can use eager loading. The relationships that should be +eager loaded can be specified via the ``with_`` method. + +.. code-block:: python + + for book in Book.with_('author').get(): + print(book.author.name) + +In this loop, only two queries will be executed: + +.. code-block:: sql + + SELECT * FROM books + + SELECT * FROM authors WHERE id IN (1, 2, 3, 4, 5, ...) + +You can eager load multiple relationships at one time: + +.. code-block:: python + + books = Book.with_('author', 'publisher').get() + +You can even eager load nested relationships: + +.. code-block:: python + + books = Book.with_('author.contacts').get() + +In this example, the ``author`` relationship will be eager loaded as well as the author's ``contacts`` +relation. + +Eager load constraints +---------------------- + +Sometimes you may wish to eager load a relationship but also specify a condition for the eager load. +Here's an example: + +.. code-block:: python + + users = User.with_( + { + 'posts': Post.query().where('title', 'like', '%first%')) + } + ).get() + +In this example, we're eager loading the user's posts only if the post's title contains the word "first". + +Lazy eager loading +------------------ + +It is also possible to eagerly load related models directly from an already existing model collection. +This may be useful when dynamically deciding whether to load related models or not, or in combination with caching. + +.. code-block:: python + + books = Book.all() + + books.load('author', 'publisher') + +You can also pass conditions: + +.. code-block:: python + + books.load( + { + 'author': Author.query().where('name', 'like', '%foo%') + } + ) + + +Inserting related models +======================== + +You will often need to insert new related models, like inserting a new comment for a post. +Instead of manually setting the ``post_id`` foreign key, you can insert the new comment from its parent ``Post``model +directly: + +.. code-block:: python + + comment = Comment(message='A new comment') + + post = Post.find(1) + + comment = post.comments().save(comment) + +If you need to save multiple comments: + +.. code-block:: python + + comments = [ + Comment(message='Comment 1'), + Comment(message='Comment 2'), + Comment(message='Comment 3') + ] + + post = Post.find(1) + + post.comments().save_many(comments) + +Associating models (Belongs To) +------------------------------- + +When updatings a ``belongs_to`` relationship, you can use the associate method: + +.. code-block:: python + + account = Account.find(1) + + user.account().associate(account) + + user.save() + +Inserting related models (Many to Many) +--------------------------------------- + +You can also insert related models when working with many-to-many relationship. +For example, with ``User`` and ``Roles`` models: + +Attaching many to many models +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + user = User.find(1) + role = Roles.find(3) + + user.roles().attach(role) + + # or + user.roles().attach(3) + + +You can also pass a dictionary of attributes that should be stored on the pivot table for the relation: + +.. code-block:: python + + user.roles().attach(3, {'expires': expires}) + +The opposite of ``attach`` is ``detach``: + +.. code-block:: python + + user.roles().detach(3) + +Both ``attach`` and ``detach`` also take list of IDs as input: + +.. code-block:: python + + user = User.find(1) + + user.roles().detach([1, 2, 3]) + + user.roles().attach([{1: {'attribute1': 'value1'}}, 2, 3]) + + +Using sync to attach many to many models +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can also use the ``sync`` method to attach related models. The ``sync`` method accepts a list of IDs +to place on the pivot table. After this operation, only the IDs in the list will be on the pivot table: + +.. code-block:: python + + user.roles().sync([1, 2, 3]) + + +Adding pivot data when syncing +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can also associate other pivot table values with the given IDs: + +.. code-block:: python + + user.roles().sync([{1: {'expires': True}}]) + +Sometimes you might want to create a new related model and attach it in a single command. +For that, you can use the save method: + +.. code-block:: python + + role = Role(name='Editor') + + User.find(1).roles().save(role) + +You can also pass attributes to place on the pivot table: + +.. code-block:: python + + User.find(1).roles().save(role, {'expires': True}) + + +Touching parent timestamps +========================== + +When a model ``belongs_to`` another model, like a ``Comment`` belonging to a ``Post``, it is often helpful +to update the parent's timestamp when the chil model is updated. For instance, when a ``Comment`` model is updated, +you may want to automatically touch the ``updated_at`` timestamp of the owning ``Post``. For this to actually happen, +you just have to add a ``__touches__`` property containing the names of the relationships: + +.. code-block:: python + + class Comment(Model): + + __touches__ = ['posts'] + + @property + def post(self): + return self.belongs_to(Post) + +Now, when you update a ``Comment``, the owning ``Post`` will have its ``updated_at`` column updated. + + +Working with pivot table +======================== + +Working with many-to-many reationships requires the presence of an intermediate table. Eloquent makes it easy to +interact with this table. Let's take the ``User`` and ``Roles`` models and see how you can access the ``pivot`` table: + +.. code-block:: python + + user = User.find(1) + + for role in user.roles: + print(role.pivot.created_at) + +Note that each retrieved ``Role`` model is automatically assigned a ``pivot`` attribute. This attribute contains e model +instance representing the intermediate table, and can be used as any other model. + +By default, only the keys will be present on the ``pivot`` object. If your pivot table contains extra attributes, +you must specify them when defining the relationship: + +.. code-block:: python + + return self.belongs_to_many(Role).with_pivot('foo', 'bar') + +Now the ``foo`` and ``bar`` attributes will be accessible on the ``pivot`` object for the ``Role`` model. + +If you want your pivot table to have automatically maintained ``created_at`` and ``updated_at`` timestamps, +use the ``with_timestamps`` method on the relationship definition: + +.. code-block:: python + + return self.belongs_to_many(Role).with_timestamps() + + +Deleting records on a pivot table +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To delete all records on the pivot table for a model, you can use the ``detach`` method: + +.. code-block:: python + + User.find(1).roles().detach() + + +Updating a record on the pivot table +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Sometimes you may need to update your pivot table, but not detach it. +If you wish to update your pivot table in place you may use ``update_existing_pivot`` method like so: + +.. code-block:: python + + User.find(1).roles().update_existing_pivot(role_id, attributes) + + +Timestamps +========== + +By default, the ORM will maintain the ``created_at`` and ``updated_at`` columns on your database table +automatically. Simply add these ``timestamp`` columns to your table. If you do not wish for the ORM to maintain +these columns, just add the ``__timestamps__`` property: + +.. code-block:: python + + class User(Model): + + __timestamps__ = False + + +Providing a custom timestamp format +----------------------------------- + +If you whish to customize the format of your timestamps (the default is the ISO Format) that will be returned when using the ``to_dict`` +or the ``to_json`` methods, you can override the ``get_date_format`` method: + +.. code-block:: python + + class User(Model): + + def get_date_format(): + return 'DD-MM-YY' + + +Date mutators +============= + +By default, the ORM will convert the ``created_at`` and ``updated_at`` columns to instances of `Arrow `_, +which eases date and datetime manipulation while behaving pretty much like the native Python date and datetime. + +You can customize which fields are automatically mutated, by either adding them with the ``__dates__`` property or +by completely overriding the ``get_dates`` method: + +.. code-block:: python + + class User(Model): + + __dates__ = ['synchronized_at'] + +.. code-block:: python + + class User(Model): + + def get_dates(): + return ['created_at'] + +When a column is considered a date, you can set its value to a UNIX timestamp, a date string ``YYYY-MM-DD``, +a datetime string, a native ``date`` or ``datetime`` and of course an ``Arrow`` instance. + +To completely disable date mutations, simply return an empty list from the ``get_dates`` method. + +.. code-block:: python + + class User(Model): + + def get_dates(): + return [] + + +Attributes casting +================== + +If you have some attributes that you want to always convert to another data-type, +you may add the attribute to the ``__casts__`` property of your model. +Otherwise, you will have to define a mutator for each of the attributes, which can be time consuming. +Here is an example of using the ``__casts__`` property: + +.. code-block:: python + + __casts__ = { + 'is_admin': 'bool' + } + +Now the ``is_admin`` attribute will always be cast to a boolean when you access it, +even if the underlying value is stored in the database as an integer. +Other supported cast types are: ``int``, ``float``, ``str``, ``bool``, ``dict``, ``list``. + +The ``dict`` cast is particularly useful for working with columns that are stored as serialized JSON. +For example, if your database has a TEXT type field that contains serialized JSON, +adding the ``dict`` cast to that attribute will automatically deserialize the attribute +to a dictionary when you access it on your model: + +.. code-block:: python + + __casts__ = { + 'options': 'dict' + } + +Now, when you utilize the model: + +.. code-block:: python + + user = User.find(1) + + # options is a dict + options = user.options + + # options is automatically serialized back to JSON + user.options = {'foo': 'bar'} + + +Converting to dictionaries / JSON +================================= + +Converting a model to a dictionary +---------------------------------- + +When building JSON APIs, you may often need to convert your models and relationships to dictionaries or JSON. +So, Eloquent includes methods for doing so. To convert a model and its loaded relationship to a dictionary, +you may use the ``to_dict`` method: + +.. code-block:: python + + user = User.with('roles').first() + + return user.to_dict() + +Note that entire collections of models can also be converted to dictionaries: + +.. code-block:: python + + return User.all().to_dict() + + +Converting a model to JSON +-------------------------- + +To convert a model to JSON, you can use the ``to_json`` method! + +.. code-block:: python + + return User.find(1).to_json() + + +Hiding attributes from dictionary or JSON conversion +---------------------------------------------------- + +Sometimes you may wish to limit the attributes that are included in you model's dictionary or JSON form, +such as passwords. To do so, add a ``__hidden__`` property definition to you model: + +.. code-block:: python + + class User(model): + + __hidden__ = ['password'] + +Alternatively, you may use the ``__visible__`` property to define a whitelist: + +.. code-block:: python + + __visible__ = ['first_name', 'last_name'] diff --git a/docs/query_builder.rst b/docs/query_builder.rst index 9089a777..e94f2126 100644 --- a/docs/query_builder.rst +++ b/docs/query_builder.rst @@ -1,3 +1,5 @@ +.. _QueryBuilder: + Query Builder ============= @@ -33,6 +35,17 @@ Retrieving all row from a table for user in users: print(user['name']) + +Chunking results from a table +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + for users in db.table('users').chunk(100): + for user in users: + # ... + + Retrieving a single row from a table ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -108,7 +121,7 @@ Using Where Not Between .. code-block:: python - users = db.table('users').where_not_between('age', [25, 35]).get()~ + users = db.table('users').where_not_between('age', [25, 35]).get() Using Where In ~~~~~~~~~~~~~~ diff --git a/eloquent/__init__.py b/eloquent/__init__.py index 633f8661..0be683cf 100644 --- a/eloquent/__init__.py +++ b/eloquent/__init__.py @@ -1,2 +1,4 @@ # -*- coding: utf-8 -*- +from .orm import Model +from .database_manager import DatabaseManager diff --git a/eloquent/connections/connection.py b/eloquent/connections/connection.py index 1948dfd7..e1eb6222 100644 --- a/eloquent/connections/connection.py +++ b/eloquent/connections/connection.py @@ -285,6 +285,12 @@ def _caused_by_lost_connection(self, e): return True def disconnect(self): + if self._connection: + self._connection.close() + + if self._read_connection: + self._read_connection.close() + self.set_connection(None).set_read_connection(None) def reconnect(self): diff --git a/eloquent/database_manager.py b/eloquent/database_manager.py index 84be0ad4..149bfa3f 100644 --- a/eloquent/database_manager.py +++ b/eloquent/database_manager.py @@ -30,7 +30,7 @@ def connection(self, name=None): :type name: str :return: A Connection instance - :rtype: Connection + :rtype: eloquent.connections.Connection """ name, type = self._parse_connection_name(name) @@ -97,8 +97,8 @@ def _refresh_api_connections(self, name): fresh = self._make_connection(name) return self._connections[name]\ - .set_api(fresh.get_api())\ - .set_read_api(fresh.get_read_api()) + .set_connection(fresh.get_connection())\ + .set_read_connection(fresh.get_read_connection()) def _make_connection(self, name): config = self._get_config(name) diff --git a/eloquent/exceptions/orm.py b/eloquent/exceptions/orm.py new file mode 100644 index 00000000..853ee8bd --- /dev/null +++ b/eloquent/exceptions/orm.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + + +class ModelNotFound(RuntimeError): + + def __init__(self, model): + self._model = model + + self.message = 'No query results found for model [%s]' % self._model.__name__ + + def __str__(self): + return self.message + + +class MassAssignmentError(RuntimeError): + pass diff --git a/eloquent/orm/__init__.py b/eloquent/orm/__init__.py new file mode 100644 index 00000000..78a85f0b --- /dev/null +++ b/eloquent/orm/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from .builder import Builder +from .model import Model diff --git a/eloquent/orm/builder.py b/eloquent/orm/builder.py new file mode 100644 index 00000000..5665e983 --- /dev/null +++ b/eloquent/orm/builder.py @@ -0,0 +1,604 @@ +# -*- coding: utf-8 -*- + +from ..exceptions.orm import ModelNotFound +from ..utils import Null + + +class Builder(object): + + _passthru = [ + 'to_sql', 'lists', 'insert', 'insert_get_id', 'pluck', 'count', + 'min', 'max', 'avg', 'sum', 'exists', 'get_bindings' + ] + + def __init__(self, query): + """ + Constructor + + :param query: The underlying query builder + :type query: QueryBuilder + """ + self._query = query + + self._model = None + self._eager_load = {} + self._macros = [] + + self._on_delete = None + + def find(self, id, columns=None): + """ + Find a model by its primary key + + :param id: The primary key value + :type id: mixed + + :param columns: The columns to retrieve + :type columns: list + + :return: The found model + :rtype: eloquent.Model + """ + if columns is None: + columns = ['*'] + + if isinstance(id, list): + return self.find_many(id, columns) + + self._query.where(self._model.get_qualified_key_name(), '=', id) + + return self.first(columns) + + def find_many(self, id, columns=None): + """ + Find a model by its primary key + + :param id: The primary key values + :type id: list + + :param columns: The columns to retrieve + :type columns: list + + :return: The found model + :rtype: eloquent.Collection + """ + if columns is None: + columns = ['*'] + + if not id: + return self._model.new_collection() + + self._query.where_in(self._model.get_qualified_key_name(), id) + + return self.get(columns) + + def find_or_fail(self, id, columns=None): + """ + Find a model by its primary key or raise an exception + + :param id: The primary key value + :type id: mixed + + :param columns: The columns to retrieve + :type columns: list + + :return: The found model + :rtype: eloquent.Model + + :raises: ModelNotFound + """ + result = self.find(id, columns) + + if isinstance(id, list): + if len(result) == len(set(id)): + return result + elif result: + return result + + raise ModelNotFound(self._model.__class__) + + def first(self, columns=None): + """ + Execute the query and get the first result + + :param columns: The columns to get + :type columns: list + + :return: The result + :rtype: mixed + """ + if columns is None: + columns = ['*'] + + return self.take(1).get(columns).first() + + def first_or_fail(self, columns=None): + """ + Execute the query and get the first result or raise an exception + + :param columns: The columns to get + :type columns: list + + :return: The result + :rtype: mixed + """ + model = self.first(columns) + + if model is not None: + return model + + raise ModelNotFound(self._model.__class__) + + def get(self, columns=None): + """ + Execute the query as a "select" statement. + + :param columns: The columns to get + :type columns: list + + :rtype: eloquent.Collection + """ + models = self.get_models(columns) + + # If we actually found models we will also eager load any relationships that + # have been specified as needing to be eager loaded, which will solve the + # n+1 query issue for the developers to avoid running a lot of queries. + if len(models) > 0: + models = self.eager_load_relations(models) + + return self._model.new_collection(models) + + def pluck(self, column): + """ + Pluck a single column from the database. + + :param column: THe column to pluck + :type column: str + + :return: The column value + :rtype: mixed + """ + result = self.first([column]) + + if result: + return result[column] + + def chunk(self, count): + """ + Chunk the results of the query + + :param count: The chunk size + :type count: int + + :return: The current chunk + :rtype: list + """ + page = 1 + results = self.for_page(page, count).get() + + while results: + yield results + + page += 1 + + results = self.for_page(page, count).get() + + def lists(self, column, key=''): + """ + Get a list with the values of a given column + + :param column: The column to get the values for + :type column: str + + :param key: The key + :type key: str + + :return: The list of values + :rtype: list or dict + """ + results = self._query.lists(column, key) + + if self._model.has_get_mutator(column): + if isinstance(results, dict): + for key, value in results.items(): + fill = {column: value} + + results[key] = self._model.new_from_builder(fill).column + else: + for i, value in enumerate(results): + fill = {column: value} + + results[i] = self._model.new_from_builder(fill).column + + return results + + def paginate(self, per_page=None, columns=None): + """ + Paginate the given query. + + :param per_page: The number of records per page + :type per_page: int + + :param columns: The columns to return + :type columns: list + + :return: The paginator + """ + # TODO + + def simple_paginate(self, per_page=None, columns=None): + """ + Paginate the given query. + + :param per_page: The number of records per page + :type per_page: int + + :param columns: The columns to return + :type columns: list + + :return: The paginator + """ + # TODO + + def update(self, _values=None, **values): + """ + Update a record in the database + + :param values: The values of the update + :type values: dict + + :return: The number of records affected + :rtype: int + """ + if _values is not None: + values.update(_values) + + return self._query.update(self._add_updated_at_column(values)) + + def increment(self, column, amount=1, extras=None): + """ + Increment a column's value by a given amount + + :param column: The column to increment + :type column: str + + :param amount: The amount by which to increment + :type amount: int + + :param extras: Extra columns + :type extras: dict + + :return: The number of rows affected + :rtype: int + """ + extras = self._add_updated_at_column(extras) + + return self._query.increment(column, amount, extras) + + def decrement(self, column, amount=1, extras=None): + """ + Decrement a column's value by a given amount + + :param column: The column to increment + :type column: str + + :param amount: The amount by which to increment + :type amount: int + + :param extras: Extra columns + :type extras: dict + + :return: The number of rows affected + :rtype: int + """ + extras = self._add_updated_at_column(extras) + + return self._query.decrement(column, amount, extras) + + def _add_updated_at_column(self, values): + """ + Add the "updated_at" column to a dictionary of values. + + :param values: The values to update + :type values: dict + + :return: The new dictionary of values + :rtype: dict + """ + if not self._model.uses_timestamps(): + return values + + column = self._model.get_updated_at_column() + + values.update({column: self._model.fresh_timestamp()}) + + return values + + def delete(self): + """ + Delete a record from the database. + """ + if self._on_delete is not None: + return self._on_delete(self) + + return self._query.delete() + + def force_delete(self): + """ + Run the default delete function on the builder. + """ + return self._query.delete() + + def on_delete(self, callback): + """ + Register a replacement for the default delete function. + + :param callback: A replacement for the default delete function + :type callback: callable + """ + self._on_delete = callback + + def get_models(self, columns=None): + """ + Get the hydrated models without eager loading. + + :param columns: The columns to get + :type columns: list + + :return: A list of models + :rtype: list + """ + results = self._query.get(columns) + + connection = self._model.get_connection_name() + + return self._model.hydrate(results, connection).all() + + def eager_load_relations(self, models): + """ + Eager load the relationship of the models. + + :param models: + :type models: list + + :return: The models + :rtype: list + """ + for name, constraints in self._eager_load.items(): + if name.find('.') == -1: + models = self._load_relation(models, name, constraints) + + return models + + def _load_relation(self, models, name, constraints): + """ + Eagerly load the relationship on a set of models. + + :rtype: list + """ + relation = self.get_relation(name) + + relation.add_eager_constraints(models) + + relation.merge_query(constraints) + + models = relation.init_relation(models, name) + + results = relation.get_eager() + + return relation.match(models, results, name) + + def get_relation(self, relation): + """ + Get the relation instance for the given relation name. + + :rtype: eloquent.orm.relations.Relation + """ + from .relations import Relation + + query = Relation.no_constraints(lambda: getattr(self.get_model(), relation)()) + + nested = self._nested_relations(relation) + + if len(nested) > 0: + query.get_query().with_(nested) + + return query + + def _nested_relations(self, relation): + """ + Get the deeply nested relations for a given top-level relation. + + :rtype: dict + """ + nested = {} + + for name, constraints in self._eager_load.items(): + if self._is_nested(name, relation): + nested[name[len(relation + '.')]:] = constraints + + return nested + + def _is_nested(self, name, relation): + """ + Determine if the relationship is nested. + + :type name: str + :type relation: str + + :rtype: bool + """ + dots = name.find('.') + + return dots and name.startswith(relation + '.') + + def where(self, column, operator=Null(), value=None, boolean='and'): + """ + Add a where clause to the query + + :param column: The column of the where clause, can also be a QueryBuilder instance for sub where + :type column: str|Builder + + :param operator: The operator of the where clause + :type operator: str + + :param value: The value of the where clause + :type value: mixed + + :param boolean: The boolean of the where clause + :type boolean: str + + :return: The current Builder instance + :rtype: Builder + """ + if isinstance(column, Builder): + self._query.add_nested_where_query(column.get_query(), boolean) + else: + self._query.where(column, operator, value, boolean) + + return self + + def or_where(self, column, operator=None, value=None): + """ + Add an "or where" clause to the query. + + :param column: The column of the where clause, can also be a QueryBuilder instance for sub where + :type column: str or Builder + + :param operator: The operator of the where clause + :type operator: str + + :param value: The value of the where clause + :type value: mixed + + :return: The current Builder instance + :rtype: Builder + """ + return self.where(column, operator, value, 'or') + + def with_(self, *relations): + """ + Set the relationships that should be eager loaded. + + :return: The current Builder instance + :rtype: Builder + """ + if not relations: + return self + + eagers = self._parse_relations(list(relations)) + + self._eager_load.update(eagers) + + return self + + def _parse_relations(self, relations): + """ + Parse a list of relations into individuals. + + :param relations: The relation to parse + :type relations: list + + :rtype: dict + """ + results = {} + + for relation in relations: + if isinstance(relation, dict): + name = list(relation.keys())[0] + constraints = relation[name] + else: + name = relation + constraints = self.__class__(self.get_query().new_query()) + + results = self._parse_nested(name, results) + + results[name] = constraints + + return results + + def _parse_nested(self, name, results): + """ + Parse the nested relationship in a relation. + + :param name: The name of the relationship + :type name: str + + :type results: dict + + :rtype: dict + """ + progress = [] + + for segment in name.split('.'): + progress.append(segment) + + last = '.'.join(progress) + if last not in results: + results[last] = self.__class__(self.get_query().new_query()) + + return results + + def get_query(self): + """ + Get the underlying query instance. + + :rtype: QueryBuilder + """ + return self._query + + def set_query(self, query): + """ + Set the underlying query instance. + + :param query: A QueryBuilder instance + :type query: QueryBuilder + """ + self._query = query + + def get_model(self): + """ + Get the model instance of the model being queried + + :rtype: eloquent.Model + """ + return self._model + + def set_model(self, model): + """ + Set a model instance for the model being queried. + + :param model: The model instance + :type model: eloquent.orm.Model + + :return: The current Builder instance + :rtype: Builder + """ + self._model = model + + self._query.from_(model.get_table()) + + return self + + def __dynamic(self, method): + attribute = getattr(self._query, method) + + def call(*args, **kwargs): + result = attribute(*args, **kwargs) + + if method in self._passthru: + return result + else: + return self + + if not callable(attribute): + return attribute + + return call + + def __getattr__(self, item, *args): + try: + object.__getattribute__(self, item) + except AttributeError: + # TODO: macros and scopes + return self.__dynamic(item) diff --git a/eloquent/orm/collection.py b/eloquent/orm/collection.py new file mode 100644 index 00000000..a923fa84 --- /dev/null +++ b/eloquent/orm/collection.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +from ..support.collection import Collection as BaseCollection + + +class Collection(BaseCollection): + + def load(self, *relations): + """ + Load a set of relationships onto the collection. + """ + if len(self._items) > 0: + query = self.first().new_query().with_(*relations) + + self._items = query.eager_load_relations(self._items) + + return self + + def lists(self, value, key=None): + """ + Get a list with the values of a given key + + :rtype: list + """ + results = map(lambda x: getattr(x, value), self._items) + + return list(results) + + def model_keys(self): + """ + Get the list of primary keys. + + :rtype: list + """ + return map(lambda m: m.get_key(), self._items) diff --git a/eloquent/orm/model.py b/eloquent/orm/model.py new file mode 100644 index 00000000..6ed849e3 --- /dev/null +++ b/eloquent/orm/model.py @@ -0,0 +1,1921 @@ +# -*- coding: utf-8 -*- + +import simplejson as json +import arrow +import inflection +import inspect +from six import add_metaclass +from ..exceptions.orm import MassAssignmentError +from ..query import QueryBuilder +from .builder import Builder +from .collection import Collection +from .relations import Relation, HasOne, HasMany, BelongsTo, BelongsToMany, HasManyThrough +from .relations.dynamic_property import DynamicProperty + + +class MetaModel(type): + + def __getattr__(cls, item): + try: + return type.__getattribute__(cls, item) + except AttributeError: + query = cls.query() + + return getattr(query, item) + + +@add_metaclass(MetaModel) +class Model(object): + + __connection__ = None + + __table__ = None + + __primary_key__ = 'id' + + __incrementing__ = True + + __fillable__ = [] + __guarded__ = ['*'] + __unguarded__ = False + + __hidden__ = [] + __visible__ = [] + + __timestamps__ = True + + __casts__ = {} + + __touches__ = [] + + _with = [] + + _booted = {} + + __resolver = None + + many_methods = ['belongs_to_many', 'morph_to_many', 'morphed_by_many'] + + CREATED_AT = 'created_at' + UPDATED_AT = 'updated_at' + + def __init__(self, **attributes): + """ + :param attributes: The instance attributes + """ + self.__exists = False + self.__dates = [] + self.__original = {} + self.__attributes = {} + self.__relations = {} + + self._boot_if_not_booted() + + self.sync_original() + + self.fill(**attributes) + + def _boot_if_not_booted(self): + """ + Check if the model needs to be booted and if so, do it. + """ + klass = self.__class__ + + if not klass._booted.get(klass): + klass._booted[klass] = True + + klass._boot() + + @classmethod + def _boot(cls): + """ + The booting method of the model. + """ + # TODO + + def fill(self, **attributes): + """ + Fill the model with attributes. + + :param attributes: The instance attributes + :type attributes: dict + + :return: The model instance + :rtype: Model + + :raises: MassAssignmentError + """ + totally_guarded = self.totally_guarded() + + for key, value in self._fillable_from_dict(attributes).items(): + key = self._remove_table_from_key(key) + + if self.is_fillable(key): + self.set_attribute(key, value) + elif totally_guarded: + raise MassAssignmentError(key) + + return self + + def force_fill(self, **attributes): + """ + Fill the model with attributes. Force mass assignment. + + :param attributes: The instance attributes + :type attributes: dict + + :return: The model instance + :rtype: Model + """ + self.unguard() + + self.fill(**attributes) + + self.reguard() + + return self + + def _fillable_from_dict(self, attributes): + """ + Get the fillable attributes from a given dictionary. + + :type attributes: dict + + :return: The fillable attributes + :rtype: dict + """ + if self.__fillable__ and not self.__unguarded__: + return {x: attributes[x] for x in attributes if x in self.__fillable__} + + return attributes + + def new_instance(self, attributes=None, exists=False): + """ + Create a new instance for the given model. + + :param attributes: The instance attributes + :type attributes: dict + + :param exists: + :type exists: bool + + :return: A new instance for the current model + :rtype: Model + """ + if attributes is None: + attributes = {} + + model = self.__class__(**attributes) + + model.set_exists(exists) + + return model + + def new_from_builder(self, attributes=None, connection=None): + """ + Create a new model instance that is existing. + + :param attributes: The model attributes + :type attributes: dict + + :param connection: The connection name + :type connection: str + + :return: A new instance for the current model + :rtype: Model + """ + model = self.new_instance({}, True) + + if attributes is None: + attributes = {} + + model.set_raw_attributes(attributes, True) + + model.set_connection(connection or self.__connection__) + + return model + + @classmethod + def hydrate(cls, items, connection=None): + """ + Create a collection of models from plain lists. + + :param items: + :param connection: + :return: + """ + instance = cls().set_connection(connection) + + collection = instance.new_collection(items) + + return Collection(list(map(lambda item: instance.new_from_builder(item), collection))) + + @classmethod + def hydrate_raw(cls, query, bindings=None, connection=None): + """ + Create a collection of models from a raw query. + + :param query: The SQL query + :type query: str + + :param bindings: The query bindings + :type bindings: list + + :param connection: The connection name + + :rtype: Collection + """ + instance = cls().set_connection(connection) + + items = instance.get_connection().select(query, bindings) + + return cls.hydrate(items, connection) + + @classmethod + def create(cls, **attributes): + """ + Save a new model an return the instance. + + :param attributes: The instance attributes + :type attributes: dict + + :return: The new instance + :rtype: Model + """ + model = cls(**attributes) + + model.save() + + return model + + @classmethod + def force_create(cls, **attributes): + """ + Save a new model an return the instance. Allow mass assignment. + + :param attributes: The instance attributes + :type attributes: dict + + :return: The new instance + :rtype: Model + """ + cls.unguard() + + model = cls.create(**attributes) + + cls.reguard() + + return model + + @classmethod + def first_or_create(cls, **attributes): + """ + Get the first record matching the attributes or create it. + + :param attributes: The instance attributes + :type attributes: dict + + :return: The new instance + :rtype: Model + """ + instance = cls.where(attributes).first() + + if instance is not None: + return instance + + return cls.create(**attributes) + + @classmethod + def first_or_new(cls, **attributes): + """ + Get the first record matching the attributes or instantiate it. + + :param attributes: The instance attributes + :type attributes: dict + + :return: The new instance + :rtype: Model + """ + instance = cls.where(attributes).first() + + if instance is not None: + return instance + + return cls(**attributes) + + @classmethod + def update_or_create(cls, attributes, values=None): + """ + Create or update a record matching the attributes, and fill it with values. + + :param attributes: The instance attributes + :type attributes: dict + + :param values: The values + :type values: dict + + :return: The new instance + :rtype: Model + """ + instance = cls.first_or_new(**attributes) + + if values is None: + values = {} + + instance.fill(**values).save() + + return instance + + @classmethod + def query(cls): + """ + Begin querying the model. + + :return: A Builder instance + :rtype: eloquent.orm.Builder + """ + return cls().new_query() + + @classmethod + def on(cls, connection=None): + """ + Begin querying the model on a given connection. + + :param connection: The connection name + :type connection: str + + :return: A Builder instance + :rtype: eloquent.orm.Builder + """ + instance = cls() + + instance.set_connection(connection) + + return instance.new_query() + + @classmethod + def on_write_connection(cls): + """ + Begin querying the model on the write connection. + + :return: A Builder instance + :rtype: QueryBuilder + """ + instance = cls() + + return instance.new_query().use_write_connection() + + @classmethod + def all(cls, columns=None): + """ + Get all og the models from the database. + + :param columns: The columns to retrieve + :type columns: list + + :return: A Collection instance + :rtype: Collection + """ + instance = cls() + + return instance.new_query().get(columns) + + @classmethod + def find(cls, id, columns=None): + """ + Find a model by its primary key. + + :param id: The id of the model + :type id: mixed + + :param columns: The columns to retrieve + :type columns: list + + :return: Either a Model instance or a Collection + :rtype: Model or Collection + """ + instance = cls() + + if isinstance(id, list) and not id: + return instance.new_collection() + + if columns is None: + columns = ['*'] + + return instance.new_query().find(id, columns) + + @classmethod + def find_or_new(cls, id, columns=None): + """ + Find a model by its primary key or return new instance. + + :param id: The id of the model + :type id: mixed + + :param columns: The columns to retrieve + :type columns: list + + :return: A Model instance + :rtype: Model + """ + instance = cls.find(id, columns) + + if instance is not None: + return instance + + return cls() + + def fresh(self, with_=None): + """ + Reload a fresh instance from the database. + + :param with_: The list of relations to eager load + :type with_: list + + :return: The current model instance + :rtype: Model + """ + key = self.get_key_name() + + if self.exists: + return self.__class__.with_(with_).where(key, self.get_key()).first() + + def load(self, relations): + """ + Eager load relations on the model + + :param relations: The relations to eager load + :type relations: str or list + + :return: The current model instance + :rtype: Model + """ + # TODO + + @classmethod + def with_(cls, *relations): + """ + Begin querying a model with eager loading + + :param relations: The relations to eager load + :type relations: str or list + + :return: A Builder instance + :rtype: Builder + """ + instance = cls() + + return instance.new_query().with_(*relations) + + def has_one(self, related, foreign_key=None, local_key=None): + """ + Define a one to one relationship. + + :param related: The related model: + :type related: Model class + + :param foreign_key: The foreign key + :type foreign_key: str + + :param local_key: The local key + :type local_key: str + + :rtype: HasOne + """ + if not foreign_key: + foreign_key = self.get_foreign_key() + + instance = related() + + if not local_key: + local_key = self.get_key_name() + + return HasOne(instance.new_query(), self, '%s.%s' % (instance.get_table(), foreign_key), local_key) + + def belongs_to(self, related, foreign_key=None, other_key=None, relation=None): + """ + Define an inverse one to one or many relationship. + + :param related: The related model: + :type related: Model class + + :param foreign_key: The foreign key + :type foreign_key: str + + :param other_key: The other key + :type other_key: str + + :type relation: str + + :rtype: BelongsTo + """ + if relation is None: + relation = inspect.stack()[1][3] + + if foreign_key is None: + foreign_key = '%s_id' % inflection.underscore(relation) + + instance = related() + + query = instance.new_query() + + if not other_key: + other_key = instance.get_key_name() + + return BelongsTo(query, self, foreign_key, other_key, relation) + + def has_many(self, related, foreign_key=None, local_key=None): + """ + Define a one to many relationship. + + :param related: The related model + :type related: Model class + + :param foreign_key: The foreign key + :type foreign_key: str + + :param local_key: The local key + :type local_key: str + + :rtype: HasOne + """ + if not foreign_key: + foreign_key = self.get_foreign_key() + + instance = related() + + if not local_key: + local_key = self.get_key_name() + + return HasMany(instance.new_query(), self, '%s.%s' % (instance.get_table(), foreign_key), local_key) + + def has_many_through(self, related, through, first_key=None, second_key=None): + """ + Define a has-many-through relationship. + + :param related: The related model + :type related: Model class + + :param through: The through model + :type through: Model class + + :param first_key: The first key + :type first_key: str + + :param second_key: The second_key + :type second_key: str + + :rtype: HasManyThrough + """ + through = through() + + if not first_key: + first_key = self.get_foreign_key() + + if not second_key: + second_key = through.get_foreign_key() + + return HasManyThrough(related().new_query(), self, through, first_key, second_key) + + def belongs_to_many(self, related, table=None, foreign_key=None, other_key=None, relation=None): + """ + Define a many-to-many relationship. + + :param related: The related model: + :type related: Model + + :param table: The pivot table + :type table: str + + :param foreign_key: The foreign key + :type foreign_key: str + + :param other_key: The other key + :type other_key: str + + :type relation: str + + :rtype: BelongsToMany + """ + if relation is None: + relation = inspect.stack()[1][3] + + if not foreign_key: + foreign_key = self.get_foreign_key() + + instance = related() + + if not other_key: + other_key = instance.get_foreign_key() + + if table is None: + table = self.joining_table(instance) + + query = instance.new_query() + + return BelongsToMany(query, self, table, foreign_key, other_key, relation) + + def joining_table(self, related): + """ + Get the joining table name for a many-to-many relation + + :param related: The related model + :type related: Model + + :rtype: str + """ + base = self.get_table() + + related = related.get_table() + + models = sorted([related, base]) + + return '_'.join(models) + + @classmethod + def destroy(cls, *ids): + """ + Destroy the models for the given IDs + + :param ids: The ids of the models to destroy + :type ids: tuple + + :return: The number of models destroyed + :rtype: int + """ + count = 0 + + if len(ids) == 1 and isinstance(ids[0], list): + ids = ids[0] + + ids = list(ids) + + instance = cls() + + key = instance.get_key_name() + + for model in instance.new_query().where_in(key, ids).get(): + if model.delete(): + count += 1 + + return count + + def delete(self): + """ + Delete the model from the database. + + :rtype: bool or None + + :raises: Exception + """ + if self.__primary_key__ is None: + raise Exception('No primary key defined on the model.') + + if self.__exists: + self._touch_owners() + + self._perform_delete_on_model() + + self.__exists = False + + return True + + def force_delete(self): + """ + Force a hard delete on a soft deleted model. + """ + return self.delete() + + def _perform_delete_on_model(self): + """ + Perform the actual delete query on this model instance. + """ + return self.new_query().where(self.get_key_name(), self.get_key()).delete() + + # TODO: events + + def _increment(self, column, amount=1): + """ + Increment a column's value + + :param column: The column to increment + :type column: str + + :param amount: The amount by which to increment + :type amount: int + + :return: The new column value + :rtype: int + """ + return self._increment_or_decrement(column, amount, 'increment') + + def _decrement(self, column, amount=1): + """ + Decrement a column's value + + :param column: The column to increment + :type column: str + + :param amount: The amount by which to increment + :type amount: int + + :return: The new column value + :rtype: int + """ + return self._increment_or_decrement(column, amount, 'decrement') + + def _increment_or_decrement(self, column, amount, method): + """ + Runthe increment or decrement method on the model + + :param column: The column to increment or decrement + :type column: str + + :param amount: The amount by which to increment or decrement + :type amount: int + + :param method: The method + :type method: str + + :return: The new column value + :rtype: int + """ + query = self.new_query() + + if not self.__exists: + return getattr(query, method)(column, amount) + + self._increment_or_decrement_attribute_value(column, amount, method) + + query = query.where(self.get_key_name(), self.get_key()) + + return getattr(query, method)(column, amount) + + def _increment_or_decrement_attribute_value(self, column, amount, method): + """ + Increment the underlying attribute value and sync with original. + + :param column: The column to increment or decrement + :type column: str + + :param amount: The amount by which to increment or decrement + :type amount: int + + :param method: The method + :type method: str + + :return: None + """ + setattr(self, column, getattr(self, column) + (amount if method == 'increment' else amount * -1)) + + self.sync_original_attribute(column) + + def update(self, **attributes): + """ + Update the model in the database. + + :param attributes: The model attributes + :type attributes: dict + + :return: The number of rows affected + :rtype: int + """ + if not self.__exists: + return self.new_query().update(**attributes) + + return self.fill(**attributes).save() + + def push(self): + """ + Save the model and all of its relationship. + """ + if not self.save(): + return False + + for models in self.__relations.values(): + if isinstance(models, Collection): + models = models.all() + else: + models = [models] + + for model in models: + if not model: + continue + + if not model.push(): + return False + + return True + + def save(self, options=None): + """ + Save the model to the database. + """ + if options is None: + options = {} + + query = self.new_query() + + if self.__exists: + saved = self._perform_update(query, options) + else: + saved = self._perform_insert(query, options) + + if saved: + self._finish_save(options) + + return saved + + def _finish_save(self, options): + """ + Finish processing on a successful save operation. + """ + self.sync_original() + + if options.get('touch', True): + self._touch_owners() + + def _perform_update(self, query, options=None): + """ + Perform a model update operation. + + :param query: A Builder instance + :type query: Builder + + :param options: Extra options + :type options: dict + """ + if options is None: + options = {} + + dirty = self.get_dirty() + + if len(dirty): + # TODO: "updating" event + if self.__timestamps__ and options.get('timestamps', True): + self._update_timestamps() + + dirty = self.get_dirty() + + if len(dirty): + self._set_keys_for_save_query(query).update(dirty) + + # TODO: "updated" event + + return True + + def _perform_insert(self, query, options=None): + """ + Perform a model update operation. + + :param query: A Builder instance + :type query: Builder + + :param options: Extra options + :type options: dict + """ + if options is None: + options = {} + + # TODO: "creating" event + + if self.__timestamps__ and options.get('timestamps', True): + self._update_timestamps() + + attributes = self.__attributes + + if self.__incrementing__: + self._insert_and_set_id(query, attributes) + else: + query.insert(attributes) + + self.__exists = True + + # TODO: "created" event + + return True + + def _insert_and_set_id(self, query, attributes): + """ + Insert the given attributes and set the ID on the model. + + :param query: A Builder instance + :type query: Builder + + :param attributes: The attributes to insert + :type attributes: dict + """ + key_name = self.get_key_name() + + id = query.insert_get_id(attributes, key_name) + + self.set_attribute(key_name, id) + + def _touch_owners(self): + """ + Touch the owning relations of the model. + """ + for relation in self.__touches__: + if hasattr(self, relation): + _relation = getattr(self, relation) + _relation().touch() + + if _relation is not None: + _relation.touch_owners() + + def touches(self, relation): + """ + Determine if a model touches a given relation. + + :param relation: The relation to check. + :type relation: str + + :rtype: bool + """ + return relation in self.__touches__ + + def _set_keys_for_save_query(self, query): + """ + Set the keys for a save update query. + + :param query: A Builder instance + :type query: Builder + + :return: The Builder instance + :rtype: Builder + """ + query.where(self.get_key_name(), self._get_key_for_save_query()) + + return query + + def _get_key_for_save_query(self): + """ + Get the primary key value for a save query. + """ + if self.get_key_name() in self.__original: + return self.__original[self.get_key_name()] + + return self.__attributes[self.get_key_name()] + + def touch(self): + """ + Update the model's timestamps. + + :rtype: bool + """ + if not self.__timestamps__: + return False + + self._update_timestamps() + + return self.save() + + def _update_timestamps(self): + """ + Update the model's timestamps. + """ + time = self.fresh_timestamp() + + if not self.is_dirty(self.UPDATED_AT): + self.set_updated_at(time) + + if not self.__exists and not self.is_dirty(self.CREATED_AT): + self.set_created_at(time) + + def set_created_at(self, value): + """ + Set the value of the "created at" attribute. + + :param value: The value + :type value: datetime + """ + setattr(self, self.CREATED_AT, value) + + def set_updated_at(self, value): + """ + Set the value of the "updated at" attribute. + + :param value: The value + :type value: datetime + """ + setattr(self, self.UPDATED_AT, value) + + def get_created_at_column(self): + """ + Get the name of the "created at" column. + + :rtype: str + """ + return self.CREATED_AT + + def get_updated_at_column(self): + """ + Get the name of the "updated at" column. + + :rtype: str + """ + return self.UPDATED_AT + + def fresh_timestamp(self): + """ + Get a fresh timestamp for the model. + + :return: arrow.Arrow + """ + return arrow.get().naive + + def new_query(self): + """ + Get a new query builder for the model's table + + :return: A Builder instance + :rtype: Builder + """ + builder = self.new_orm_builder( + self._new_base_query_builder() + ) + + return builder.set_model(self).with_(*self._with) + + @classmethod + def query(cls): + return cls().new_query() + + def new_orm_builder(self, query): + """ + Create a new orm query builder for the model + + :param query: A QueryBuilder instance + :type query: QueryBuilder + + :return: A Builder instance + :rtype: Builder + """ + return Builder(query) + + def _new_base_query_builder(self): + """ + Get a new query builder instance for the connection. + + :return: A QueryBuilder instance + :rtype: QueryBuilder + """ + conn = self.get_connection() + + grammar = conn.get_query_grammar() + + return QueryBuilder(conn, grammar, conn.get_post_processor()) + + def new_collection(self, models=None): + """ + Create a new Collection instance. + + :param models: A list of models + :type models: list + + :return: A new Collection instance + :rtype: Collection + """ + if models is None: + models = [] + + return Collection(models) + + def new_pivot(self, parent, attributes, table, exists): + """ + Create a new pivot model instance. + + :param parent: The parent model + :type parent: Model + + :param attributes: The pivot attributes + :type attributes: dict + + :param table: the pivot table + :type table: str + + :param exists: Whether the pivot exists or not + :type exists: bool + + :rtype: Pivot + """ + from .relations.pivot import Pivot + + return Pivot(parent, attributes, table, exists) + + def get_table(self): + """ + Get the table associated with the model. + + :return: The name of the table + :rtype: str + """ + if self.__table__ is not None: + return self.__table__ + + return inflection.tableize(self.__class__.__name__) + + def set_table(self, table): + """ + Set the table associated with the model. + + :param table: The table name + :type table: str + """ + self.__table__ = table + + def get_key(self): + """ + Get the value of the model's primary key. + """ + return self.get_attribute(self.get_key_name()) + + def get_key_name(self): + """ + Get the primary key for the model. + + :return: The primary key name + :rtype: str + """ + return self.__primary_key__ + + def set_key_name(self, name): + """ + Set the primary key for the model. + + :param name: The primary key name + :type name: str + """ + self.__primary_key__ = name + + def get_qualified_key_name(self): + """ + Get the table qualified key name. + + :rtype: str + """ + return '%s.%s' % (self.get_table(), self.get_key_name()) + + def uses_timestamps(self): + """ + Determine if the model uses timestamps. + + :rtype: bool + """ + return self.__timestamps__ + + def get_foreign_key(self): + """ + Get the default foreign key name for the model + + :rtype: str + """ + return '%s_id' % inflection.singularize(inflection.tableize(self.__class__.__name__)) + + def get_hidden(self): + """ + Get the hidden attributes for the model. + """ + return self.__hidden__ + + def set_hidden(self, hidden): + """ + Set the hidden attributes for the model. + + :param hidden: The attributes to add + :type hidden: list + """ + self.__hidden__ = hidden + + return self + + def add_hidden(self, *attributes): + """ + Add hidden attributes to the model. + + :param attributes: The attributes to hide + :type attributes: list + """ + self.__hidden__ += attributes + + def get_visible(self): + """ + Get the visible attributes for the model. + """ + return self.__visible__ + + def set_visible(self, visible): + """ + Set the visible attributes for the model. + + :param visible: The attributes to make visible + :type visible: list + """ + self.__visible__ = visible + + return self + + def add_visible(self, *attributes): + """ + Add visible attributes to the model. + + :param attributes: The attributes to make visible + :type attributes: list + """ + self.__visible__ += attributes + + def get_fillable(self): + """ + Get the fillable attributes for the model. + + :rtype: list + """ + return self.__fillable__ + + def fillable(self, fillable): + """ + Set the fillable attributes for the model. + + :param fillable: The fillable attributes + :type fillable: list + + :return: The current Model instance + :rtype: Model + """ + self.__fillable__ = fillable + + return self + + def get_guarded(self): + """ + Get the guarded attributes. + """ + return self.__guarded__ + + def guard(self, guarded): + """ + Set the guarded attributes. + + :param guarded: The guarded attributes + :type guarded: list + + :return: The current Model instance + :rtype: Model + """ + self.__guarded__ = guarded + + return self + + @classmethod + def unguard(cls): + """ + Disable the mass assigment restrictions. + """ + cls.__unguarded__ = True + + @classmethod + def reguard(cls): + """ + Enable the mass assignment restrictions. + :return: + """ + cls.__unguarded__ = False + + def is_fillable(self, key): + """ + Determine if the given attribute can be mass assigned. + + :param key: The attribute to check + :type key: str + + :return: Whether the attribute can be mass assigned or not + :rtype: bool + """ + if self.__unguarded__: + return True + + if key in self.__fillable__: + return True + + if self.is_guarded(key): + return False + + return not self.__fillable__ and not key.startswith('_') + + def is_guarded(self, key): + """ + Determine if the given attribute is guarded. + + :param key: The attribute to check + :type key: str + + :return: Whether the attribute is guarded or not + :rtype: bool + """ + return key in self.__guarded__ or self.__guarded__ == ['*'] + + def totally_guarded(self): + """ + Determine if the model is totally guarded. + + :rtype: bool + """ + return len(self.__fillable__) == 0 and self.__guarded__ == ['*'] + + def _remove_table_from_key(self, key): + """ + Remove the table name from a given key. + + :param key: The key to remove the table name from. + :type key: str + + :rtype: str + """ + if '.' not in key: + return key + + return key.split('.')[-1] + + def get_incrementing(self): + return self.__incrementing__ + + def set_incrementing(self, value): + self.__incrementing__ = value + + def to_json(self, **options): + """ + Convert the model instance to JSON. + + :param options: The JSON options + :type options: dict + + :return: The JSON encoded model instance + :rtype: str + """ + return json.dumps(self.to_dict(), **options) + + def json_serialize(self): + """ + Convert the object into something JSON serializable. + + :rtype: dict + """ + return self.to_dict() + + def to_dict(self): + """ + Convert the model instance to a dictionary. + + :return: The dictionary version of the model instance + :rtype: dict + """ + attributes = self.attributes_to_dict() + + attributes.update(self.relations_to_dict()) + + return attributes + + def attributes_to_dict(self): + """ + Convert the model's attributes to a dictionary. + + :rtype: dict + """ + attributes = self._get_dictable_attributes() + + for key in self.get_dates(): + if not key in attributes: + continue + + attributes[key] = self._format_date(self.as_datetime(attributes[key])) + + # TODO: mutators + + for key, value in self.__casts__.items(): + if key not in attributes: # TODO: check mutators + continue + + attributes[key] = self._cast_attribute(key, attributes[key]) + + # TODO: appends + + return attributes + + def _get_dictable_attributes(self): + """ + Get an attribute dictionary of all dictable attributes. + + :rtype: dict + """ + return self._get_dictable_items(self.__attributes) + + def relations_to_dict(self): + """ + Get the model's relationships in dictionary form. + + :rtype: dict + """ + attributes = {} + + for key, value in self._get_dictable_relations().items(): + if key in self.get_hidden(): + continue + + relation = None + if hasattr(value, 'to_dict'): + relation = value.to_dict() + elif value is None: + relation = value + + if relation or value is None: + attributes[key] = relation + + return attributes + + def _get_dictable_relations(self): + """ + Get an attribute dict of all dictable relations. + """ + return self._get_dictable_items(self.__relations) + + def _get_dictable_items(self, values): + """ + Get an attribute dictionary of all dictable values. + + :param values: The values to check + :type values: dict + + :rtype: dict + """ + if len(self.get_visible()) > 0: + return {x: values[x] for x in values.keys() if x in self.get_visible()} + + return {x: values[x] for x in values.keys() if x not in self.get_hidden() and not x.startswith('_')} + + def get_attribute(self, key): + """ + Get an attribute from the model. + + :param key: The attribute to get + :type key: str + """ + in_attributes = key in self.__attributes + + if in_attributes: + return self._get_attribute_value(key) + + if key in self.__relations: + return self.__relations[key] + + relation = super(Model, self).__getattribute__(key) + + if relation: + return self._get_relationship_from_method(key) + + raise AttributeError(key) + + def _get_attribute_value(self, key): + """ + Get a plain attribute. + + :param key: The attribute to get + :type key: str + """ + value = self._get_attribute_from_dict(key) + + # TODO: mutators + + if self._has_cast(key): + value = self._cast_attribute(key, value) + elif key in self.get_dates(): + if value is not None: + return self.as_datetime(value) + + return value + + def _get_attribute_from_dict(self, key): + return self.__attributes.get(key) + + def _get_relationship_from_method(self, method): + """ + Get a relationship value from a method. + + :param method: The method name + :type method: str + + :rtype: mixed + """ + relations = super(Model, self).__getattribute__(method) + + if not isinstance(relations, Relation): + raise RuntimeError('Relationship method must return an object of type Relation') + + self.__relations[method] = DynamicProperty(relations.get_results, relations) + + return self.__relations[method] + + def has_get_mutator(self, key): + """ + Determine if a get mutator exists for an attribute. + + :param key: The attribute name + :type key: str + + :rtype: bool + """ + return hasattr(self, 'get_%s_attribute' % inflection.underscore(key)) + + def _has_cast(self, key): + """ + Determine whether an attribute should be casted to a native type. + + :param key: The attribute to check + :type key: str + + :rtype: bool + """ + return key in self.__casts__ + + def _is_json_castable(self, key): + """ + Determine whether a value is JSON castable. + + :param key: The key to check + :type key: str + + :rtype: bool + """ + if self._has_cast(key): + type = self._get_cast_type(key) + + return type in ['list', 'dict', 'json', 'object'] + + return False + + def _get_cast_type(self, key): + """ + Get the type of the cast for a model attribute. + + :param key: The attribute to get the cast for + :type key: str + + :rtype: str + """ + return self.__casts__[key].lower().strip() + + def _cast_attribute(self, key, value): + """ + Cast an attribute to a native Python type + + :param key: The attribute key + :type key: str + + :param value: The attribute value + :type value: The attribute value + + :rtype: mixed + """ + if value is None: + return None + + type = self._get_cast_type(key) + if type in ['int', 'integer']: + return int(value) + elif type in ['real', 'float', 'double']: + return float(value) + elif type in ['string', 'str']: + return str(value) + elif type in ['bool', 'boolean']: + return bool(value) + elif type in ['dict', 'list', 'json']: + return json.loads(value) + else: + return value + + def get_dates(self): + """ + Get the attributes that should be converted to dates. + + :rtype: list + """ + defaults = [self.CREATED_AT, self.UPDATED_AT] + + return self.__dates + defaults + + def from_datetime(self, value): + """ + Convert datetime to a datetime object + + :rtype: datetime.datetime + """ + if isinstance(value, arrow.Arrow): + return value.naive + + return arrow.get(value).naive + + def as_datetime(self, value): + """ + Return a timestamp as a datetime. + + :rtype: arrow.Arrow + """ + return arrow.get(value) + + def get_date_format(self): + """ + Get the format to use for timestamps and dates. + + :rtype: str + """ + return 'iso' + + def _format_date(self, date): + """ + Format a date or timestamp. + + :param date: The date or timestamp + :type date: datetime.datetime or datetime.date or arrow.Arrow + + :rtype: str + """ + format = self.get_date_format() + + if format == 'iso': + return date.isoformat() + else: + if isinstance(date, arrow.Arrow): + return date.format(format) + + return date.strftime(format) + + def set_attribute(self, key, value): + """ + Set a given attribute on the model. + """ + # TODO: Set mutators + + if key in self.get_dates() and value: + value = self.from_datetime(value) + + if self._is_json_castable(key): + value = json.dumps(value) + + self.__attributes[key] = value + + def replicate(self, except_=None): + """ + Clone the model into a new, non-existing instance. + + :param except_: The attributes that should not be cloned + :type except_: list + + :rtype: Model + """ + if except_ is None: + except_ = [ + self.get_key_name(), + self.get_created_at_column(), + self.get_updated_at_column() + ] + + attributes = {x: self.__attributes[x] for x in self.__attributes if x not in except_} + + instance = self.new_instance(attributes) + + instance.set_relations(dict(**self.__relations)) + + return instance + + def get_attributes(self): + """ + Get all of the current attributes on the model. + + :rtype: dict + """ + return self.__attributes + + def set_raw_attributes(self, attributes, sync=False): + """ + Set the dictionary of model attributes. No checking is done. + + :param attributes: The model attributes + :type attributes: dict + + :param sync: Whether to sync the attributes or not + :type sync: bool + """ + self.__attributes = dict(attributes) + + if sync: + self.sync_original() + + def get_original(self, key=None, default=None): + """ + Get the original values + + :param key: The original key to get + :type key: str + + :param default: The default value if the key does not exist + :type default: mixed + + :rtype: mixed + """ + if key is None: + return self.__original + + return self.__original.get(key, default) + + def sync_original(self): + """ + Sync the original attributes with the current. + + :rtype: Builder + """ + self.__original = dict(self.__attributes.items()) + + return self + + def sync_original_attribute(self, attribute): + """ + Sync a single original attribute with its current value. + + :param attribute: The attribute to sync + :type attribute: str + + :rtype: Model + """ + self.__original[attribute] = self.__attributes[attribute] + + return self + + def is_dirty(self, *attributes): + """ + Determine if the model or given attributes have been modified. + + :param attributes: The attributes to check + :type attributes: list + + :rtype: boolean + """ + dirty = self.get_dirty() + + if not attributes: + return len(dirty) > 0 + + for attribute in attributes: + if attribute in dirty: + return True + + return False + + def get_dirty(self): + """ + Get the attribute that have been change since last sync. + + :rtype: list + """ + dirty = {} + + for key, value in self.__attributes.items(): + if key not in self.__original: + dirty[key] = value + elif value != self.__original[key]: + dirty[key] = value + + return dirty + + @property + def exists(self): + return self.__exists + + def set_exists(self, exists): + self.__exists = exists + + def get_relations(self): + """ + Get all the loaded relations for the instance. + + :rtype: dict + """ + return self.__relations + + def get_relation(self, relation): + """ + Get a specific relation. + + :param relation: The name of the relation. + :type relation: str + + :rtype: mixed + """ + return self.__relations[relation] + + def set_relation(self, relation, value): + """ + Set the specific relation in the model. + + :param relation: The name of the relation + :type relation: str + + :param value: The relation + :type value: mixed + + :return: The current Model instance + :rtype: Model + """ + self.__relations[relation] = value + + return self + + def set_relations(self, relations): + self.__relations = relations + + return self + + def get_connection(self): + """ + Get the database connection for the model + + :rtype: eloquent.connections.Connection + """ + return self.resolve_connection(self.__connection__) + + def get_connection_name(self): + """ + Get the database connection name for the model. + + :rtype: str + """ + return self.__connection__ + + def set_connection(self, name): + """ + Set the connection associated with the model. + + :param name: The connection name + :type name: str + + :return: The current model instance + :rtype: Model + """ + self.__connection__ = name + + return self + + @classmethod + def resolve_connection(cls, connection=None): + """ + Resolve a connection instance. + + :param connection: The connection name + :type connection: str + + :rtype: eloquent.connections.Connection + """ + return cls.__resolver.connection(connection) + + @classmethod + def get_connection_resolver(cls): + """ + Get the connection resolver instance. + """ + return cls.__resolver + + @classmethod + def set_connection_resolver(cls, resolver): + """ + Set the connection resolver instance. + """ + cls.__resolver = resolver + + @classmethod + def unset_connection_resolver(cls, resolver): + """ + Unset the connection resolver instance. + """ + cls._resolver = None + + def __getattribute__(self, item): + try: + attr = super(Model, self).__getattribute__(item) + if isinstance(attr, Relation): + return self.get_attribute(item) + + return attr + except AttributeError: + return self.get_attribute(item) + + def __setattr__(self, key, value): + if key.startswith(('_Model__', '_%s__' % self.__class__.__name__, '__')): + super(Model, self).__setattr__(key, value) + elif callable(getattr(self, key, None)): + return super(Model, self).__setattr__(key, value) + else: + self.set_attribute(key, value) + + def __delattr__(self, item): + try: + super(Model, self).__delattr__(item) + except AttributeError: + del self.__attributes[item] diff --git a/eloquent/orm/relations/__init__.py b/eloquent/orm/relations/__init__.py new file mode 100644 index 00000000..b9f665dd --- /dev/null +++ b/eloquent/orm/relations/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- + +from .relation import Relation +from .has_one import HasOne +from .has_many import HasMany +from .belongs_to import BelongsTo +from .belongs_to_many import BelongsToMany +from .has_many_through import HasManyThrough diff --git a/eloquent/orm/relations/belongs_to.py b/eloquent/orm/relations/belongs_to.py new file mode 100644 index 00000000..4f09ea44 --- /dev/null +++ b/eloquent/orm/relations/belongs_to.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- + +from ...query.expression import QueryExpression +from .relation import Relation + + +class BelongsTo(Relation): + + def __init__(self, query, parent, foreign_key, other_key, relation): + """ + :param query: A Builder instance + :type query: Builder + + :param parent: The parent model + :type parent: Model + + :param foreign_key: The foreign key + :type foreign_key: str + + :param other_key: The other key + :type other_key: str + + :param relation: The relation name + :type relation: str + """ + self._other_key = other_key + self._relation = relation + self._foreign_key = foreign_key + + super(BelongsTo, self).__init__(query, parent) + + def get_results(self): + """ + Get the results of the relationship. + """ + return self._query.first() + + def add_constraints(self): + """ + Set the base constraints on the relation query. + + :rtype: None + """ + if self._constraints: + table = self._related.get_table() + + self._query.where('%s.%s' % (table, self._other_key), '=', getattr(self._parent, self._foreign_key)) + + def get_relation_count_query(self, query, parent): + """ + Add the constraints for a relationship count query. + + :type query: eloquent.orm.Builder + :type parent: eloquent.orm.Builder + + :rtype: Builder + """ + query.select(QueryExpression('COUNT(*)')) + + other_key = self.wrap('%s.%s' % (query.get_model().get_table(), self._other_key)) + + return query.where(self.get_qualified_foreign_key(), '=', QueryExpression(other_key)) + + def add_eager_constraints(self, models): + """ + Set the constraints for an eager load of the relation. + + :type models: list + """ + key = '%s.%s' % (self._related.get_table(), self._other_key) + + self._query.where_in(key, self._get_eager_model_keys(models)) + + def _get_eager_model_keys(self, models): + """ + Gather the keys from a list of related models. + + :type models: list + + :rtype: list + """ + keys = [] + + for model in models: + value = getattr(model, self._foreign_key) + + if value is not None and value not in keys: + keys.append(value) + + if not len(keys): + return [0] + + return keys + + def init_relation(self, models, relation): + """ + Initialize the relation on a set of models. + + :type models: list + :type relation: str + """ + for model in models: + model.set_relation(relation, None) + + return models + + def match(self, models, results, relation): + """ + Match the eagerly loaded results to their parents. + + :type models: list + :type results: Collection + :type relation: str + """ + foreign = self._foreign_key + + other = self._other_key + + dictionary = {} + + for result in results: + dictionary[result.get_attribute(other)] = result + + for model in models: + value = model.get_attribute(foreign) + + if value in dictionary: + model.set_relation(relation, dictionary[value]) + + return models + + def associate(self, model): + """ + Associate the model instance to the given parent. + + :type model: eloquent.Model + + :rtype: eloquent.Model + """ + self._parent.set_attribute(self._foreign_key, model.get_attribute(self._other_key)) + + return self._parent.set_relation(self._relation, model) + + def dissociate(self): + """ + Dissociate previously associated model from the given parent. + + :rtype: eloquent.Model + """ + self._parent.set_attribute(self._foreign_key, None) + + return self._parent.set_relation(self._relation, None) + + def update(self, _attributes=None, **attributes): + """ + Update the parent model on the relationship. + + :param attributes: The update attributes + :type attributes: dict + + :rtype: mixed + """ + if _attributes is not None: + attributes.update(_attributes) + + instance = self.get_results() + + return instance.fill(attributes).save() + + def get_foreign_key(self): + return self._foreign_key + + def get_qualified_foreign_key(self): + return '%s.%s' % (self._parent.get_table(), self._foreign_key) + + def get_other_key(self): + return self._other_key + + def get_qualified_other_key_name(self): + return '%s.%s' % (self._related.get_table(), self._other_key) diff --git a/eloquent/orm/relations/belongs_to_many.py b/eloquent/orm/relations/belongs_to_many.py new file mode 100644 index 00000000..676891d5 --- /dev/null +++ b/eloquent/orm/relations/belongs_to_many.py @@ -0,0 +1,837 @@ +# -*- coding: utf-8 -*- + +import hashlib +import time +import inflection +from ...exceptions.orm import ModelNotFound +from ...query.expression import QueryExpression +from ..collection import Collection +import eloquent.orm.model +from .relation import Relation + + +class BelongsToMany(Relation): + + def __init__(self, query, parent, table, foreign_key, other_key, relation_name=None): + """ + :param query: A Builder instance + :type query: Builder + + :param parent: The parent model + :type parent: Model + + :param table: The pivot table + :type table: str + + :param foreign_key: The foreign key + :type foreign_key: str + + :param other_key: The other key + :type other_key: str + + :param relation_name: The relation name + :type relation_name: str + """ + self._table = table + self._other_key = other_key + self._foreign_key = foreign_key + self._relation_name = relation_name + + self._pivot_columns = [] + self._pivot_wheres = [] + + super(BelongsToMany, self).__init__(query, parent) + + def get_results(self): + """ + Get the results of the relationship. + """ + return self.get() + + def where_pivot(self, column, operator=None, value=None, boolean='and'): + """ + Set a where clause for a pivot table column. + + :param column: The column of the where clause, can also be a QueryBuilder instance for sub where + :type column: str|Builder + + :param operator: The operator of the where clause + :type operator: str + + :param value: The value of the where clause + :type value: mixed + + :param boolean: The boolean of the where clause + :type boolean: str + + :return: self + :rtype: self + """ + self._pivot_wheres.append([column, operator, value, boolean]) + + return self.where('%s.%s' % (self._table, column), operator, value, boolean) + + def or_where_pivot(self, column, operator=None, value=None): + """ + Set an or where clause for a pivot table column. + + :param column: The column of the where clause, can also be a QueryBuilder instance for sub where + :type column: str|Builder + + :param operator: The operator of the where clause + :type operator: str + + :param value: The value of the where clause + :type value: mixed + + :return: self + :rtype: BelongsToMany + """ + return self.where_pivot(column, operator, value, 'or') + + def first(self, columns=None): + """ + Execute the query and get the first result. + + :type columns: list + """ + results = self.take(1).get(columns) + + if len(results) > 0: + return results.first() + + return + + def first_or_fail(self, columns=None): + """ + Execute the query and get the first result or raise an exception. + + :type columns: list + + :raises: ModelNotFound + """ + model = self.first(columns) + if model is not None: + return model + + raise ModelNotFound(self._parent.__class__) + + def get(self, columns=None): + """ + Execute the query as a "select" statement. + + :type columns: list + + :rtype: eloquent.Collection + """ + if columns is None: + columns = ['*'] + + if self._query.get_query().columns: + columns = [] + + select = self._get_select_columns(columns) + + models = self._query.add_select(*select).get_models() + + self._hydrate_pivot_relation(models) + + if len(models) > 0: + models = self._query.eager_load_relations(models) + + return self._related.new_collection(models) + + def _hydrate_pivot_relation(self, models): + """ + Hydrate the pivot table relationship on the models. + + :type models: list + """ + for model in models: + pivot = self.new_existing_pivot(self._clean_pivot_attributes(model)) + + model.set_relation('pivot', pivot) + + def _clean_pivot_attributes(self, model): + """ + Get the pivot attributes from a model. + + :type model: eloquent.Model + """ + values = {} + delete_keys = [] + + for key, value in model.get_attributes().items(): + if key.find('pivot_') == 0: + values[key[6:]] = value + + delete_keys.append(key) + + for key in delete_keys: + delattr(model, key) + + return values + + def add_constraints(self): + """ + Set the base constraints on the relation query. + + :rtype: None + """ + self._set_join() + + if self._constraints: + self._set_where() + + def get_relation_count_query(self, query, parent): + """ + Add the constraints for a relationship count query. + + :type query: eloquent.orm.Builder + :type parent: eloquent.orm.Builder + + :rtype: eloquent.orm.Builder + """ + if parent.get_query().from__ == query.get_query().from__: + return self.get_relation_count_query_for_self_join(query, parent) + + self._set_join(query) + + return super(BelongsToMany, self).get_relation_count_query(query, parent) + + def get_relation_count_query_for_self_join(self, query, parent): + """ + Add the constraints for a relationship count query on the same table. + + :type query: eloquent.orm.Builder + :type parent: eloquent.orm.Builder + + :rtype: eloquent.orm.Builder + """ + query.select(QueryExpression('COUNT(*)')) + + table_prefix = self._query.get_query().get_connection().get_table_prefix() + + hash_ = self.get_relation_count_hash() + query.from_('%s AS %s%s' % (self._table, table_prefix, hash_)) + + key = self.wrap(self.get_qualified_parent_key_name()) + + return query.where('%s.%s' % (hash_, self._foreign_key), '=', QueryExpression(key)) + + def get_relation_count_hash(self): + """ + Get a relationship join table hash. + + :rtype: str + """ + return 'self_%s' % (hashlib.md5(time.time()).hexdigest()) + + def _get_select_columns(self, columns=None): + """ + Set the select clause for the relation query. + + :param columns: The columns + :type columns: list + + :rtype: list + """ + if columns == ['*'] or columns is None: + columns = ['%s.*' % self._related.get_table()] + + return columns + self._get_aliased_pivot_columns() + + def _get_aliased_pivot_columns(self): + """ + Get the pivot columns for the relation. + + :rtype: list + """ + defaults = [self._foreign_key, self._other_key] + + columns = [] + + for column in defaults + self._pivot_columns: + value = '%s.%s AS pivot_%s' % (self._table, column, column) + if value not in columns: + columns.append('%s.%s AS pivot_%s' % (self._table, column, column)) + + return columns + + def _has_pivot_column(self, column): + """ + Determine whether the given column is defined as a pivot column. + + :param column: The column to check + :type column: str + + :rtype: bool + """ + return column in self._pivot_columns + + def _set_join(self, query=None): + """ + Set the join clause for the relation query. + + :param query: The query builder + :type query: eloquent.orm.Builder + + :return: self + :rtype: BelongsToMany + """ + if not query: + query = self._query + + base_table = self._related.get_table() + + key = '%s.%s' % (base_table, self._related.get_key_name()) + + query.join(self._table, key, '=', self.get_other_key()) + + return self + + def _set_where(self): + """ + Set the where clause for the relation query. + + :return: self + :rtype: BelongsToMany + """ + foreign = self.get_foreign_key() + + self._query.where(foreign, '=', self._parent.get_key()) + + return self + + def add_eager_constraints(self, models): + """ + Set the constraints for an eager load of the relation. + + :type models: list + """ + self._query.where_in(self.get_foreign_key(), self.get_keys(models)) + + def init_relation(self, models, relation): + """ + Initialize the relation on a set of models. + + :type models: list + :type relation: str + """ + for model in models: + model.set_relation(relation, self._related.new_collection()) + + return models + + def match(self, models, results, relation): + """ + Match the eagerly loaded results to their parents. + + :type models: list + :type results: Collection + :type relation: str + """ + dictionary = self._build_dictionary(results) + + for model in models: + key = model.get_key() + + if key in dictionary: + collection = self._related.new_collection(dictionary[key]) + + model.set_relation(relation, collection) + + return models + + def _build_dictionary(self, results): + """ + Build model dictionary keyed by the relation's foreign key. + + :param results: The results + :type results: Collection + + :rtype: dict + """ + foreign = self._foreign_key + + dictionary = {} + + for result in results: + key = getattr(result.pivot, foreign) + if key not in dictionary: + dictionary[key] = [] + + dictionary[key].append(result) + + return dictionary + + def touch(self): + """ + Touch all of the related models of the relationship. + """ + key = self.get_related().get_key_name() + + columns = self.get_related_fresh_update() + + ids = self.get_related_ids() + + if len(ids) > 0: + self.get_related().new_query().where_in(key, ids).update(columns) + + def get_related_ids(self): + """ + Get all of the IDs for the related models. + + :rtype: list + """ + related = self.get_related() + + full_key = related.get_qualified_key_name() + + return self.get_query().select(full_key).lists(related.get_key_name()) + + def save(self, model, joining=None, touch=True): + """ + Save a new model and attach it to the parent model. + + :type model: eloquent.Model + :type joining: dict + :type touch: bool + + :rtype: eloquent.Model + """ + if joining is None: + joining = {} + + model.save({'touch': False}) + + self.attach(model.get_key(), joining, touch) + + return model + + def save_many(self, models, joinings=None): + """ + Save a list of new models and attach them to the parent model + + :type models: list + :type joinings: list + + :rtype: list + """ + for key, model in enumerate(models): + self.save(model, joinings[key], False) + + self.touch_if_touching() + + return models + + def find_or_new(self, id, columns=None): + """ + Find a model by its primary key or return new instance of the related model. + + :param id: The primary key + :type id: mixed + + :param columns: The columns to retrieve + :type columns: list + + :rtype: Collection or Model + """ + instance = self.find(id, columns) + if instance is None: + instance = self.get_related().new_instance() + + return instance + + def first_or_new(self, _attributes=None, **attributes): + """ + Get the first related model record matching the attributes or instantiate it. + + :param attributes: The attributes + :type attributes: dict + + :rtype: Model + """ + if _attributes is not None: + attributes.update(_attributes) + + instance = self.where(attributes).first() + if instance is None: + instance = self._related.new_instance() + + return instance + + def first_or_create(self, _attributes=None, _joining=None, _touch=True, **attributes): + """ + Get the first related model record matching the attributes or create it. + + :param attributes: The attributes + :type attributes: dict + + :rtype: Model + """ + if _attributes is not None: + attributes.update(_attributes) + + instance = self.where(attributes).first() + if instance is None: + instance = self.create(attributes, _joining or {}, _touch) + + return instance + + def update_or_create(self, attributes, values=None, joining=None, touch=True): + """ + Create or update a related record matching the attributes, and fill it with values. + + :param attributes: The attributes + :type attributes: dict + + :param values: The values + :type values: dict + + :rtype: Model + """ + if values is None: + values = {} + + instance = self.where(attributes).first() + + if instance is None: + return self.create(values, joining, touch) + + instance.fill(**values) + + instance.save({'touch': False}) + + return instance + + def create(self, _attributes=None, _joining=None, _touch=True, **attributes): + """ + Create a new instance of the related model. + + :param attributes: The attributes + :type attributes: dict + + :rtype: eloquent.orm.Model + """ + if _attributes is not None: + attributes.update(_attributes) + + instance = self._related.new_instance(attributes) + + instance.save({'touch': False}) + + self.attach(instance.get_key(), _joining, _touch) + + return instance + + def create_many(self, records, joinings=None): + """ + Create a list of new instances of the related model. + """ + if joinings is None: + joinings = [] + + instances = [] + + for key, record in enumerate(records): + instances.append(self.create(record), joinings[key], False) + + self.touch_if_touching() + + return instances + + def sync(self, ids, detaching=True): + """ + Sync the intermediate tables with a list of IDs or collection of models + """ + changes = { + 'attached': [], + 'detached': [], + 'updated': [] + } + + if isinstance(ids, Collection): + ids = ids.model_keys() + + current = self._new_pivot_query().lists(self._other_key) + + records = self._format_sync_list(ids) + + detach = [x for x in current if x not in records.keys()] + + if detaching and len(detach) > 0: + self.detach(detach) + + changes['detached'] = detach + + changes.update(self._attach_new(records, current, False)) + + if len(changes['attached']) or len(changes['updated']): + self.touch_if_touching() + + return changes + + def _format_sync_list(self, records): + """ + Format the sync list so that it is keyed by ID. + """ + results = {} + + for attributes in records: + if not isinstance(attributes, dict): + id, attributes = attributes, {} + else: + id = list(attributes.keys())[0] + attributes = attributes[id] + + results[id] = attributes + + return results + + def _attach_new(self, records, current, touch=True): + """ + Attach all of the IDs that aren't in the current dict. + """ + changes = { + 'attached': [], + 'updated': [] + } + + for id, attributes in records.items(): + if id not in current: + self.attach(id, attributes, touch) + + changes['attached'].append(id) + elif len(attributes) > 0 and self.update_existing_pivot(id, attributes, touch): + changes['updated'].append(id) + + return changes + + def update_existing_pivot(self, id, attributes, touch=True): + """ + Update an existing pivot record on the table. + """ + if self.updated_at() in self._pivot_columns: + attributes = self.set_timestamps_on_attach(attributes, True) + + updated = self._new_picot_statement_for_id(id).update(attributes) + + if touch: + self.touch_if_touching() + + return updated + + def attach(self, id, attributes=None, touch=True): + """ + Attach a model to the parent. + """ + if isinstance(id, eloquent.orm.Model): + id = id.get_key() + + query = self.new_pivot_statement() + + if not isinstance(id, list): + id = [id] + + query.insert(self._create_attach_records(id, attributes)) + + if touch: + self.touch_if_touching() + + def _create_attach_records(self, ids, attributes): + """ + Create a list of records to insert into the pivot table. + """ + records = [] + + timed = (self._has_pivot_column(self.created_at()) + or self._has_pivot_column(self.updated_at())) + + for key, value in enumerate(ids): + records.append(self._attacher(key, value, attributes, timed)) + + return records + + def _attacher(self, key, value, attributes, timed): + """ + Create a full attachment record payload. + """ + id, extra = self._get_attach_id(key, value, attributes) + + record = self._create_attach_record(id, timed) + + if extra: + record.update(extra) + + return record + + def _get_attach_id(self, key, value, attributes): + """ + Get the attach record ID and extra attributes. + """ + if isinstance(value, dict): + key = list(value.keys())[0] + attributes.update(value[key]) + + return [key, attributes] + + return value, attributes + + def _create_attach_record(self, id, timed): + """ + Create a new pivot attachement record. + """ + record = {} + + record[self._foreign_key] = self._parent.get_key() + + record[self._other_key] = id + + if timed: + record = self._set_timestamps_on_attach(record) + + return record + + def _set_timestamps_on_attach(self, record, exists=False): + """ + Set the creation an update timestamps on an attach record. + """ + fresh = self._parent.fresh_timestamp() + + if not exists and self._has_pivot_column(self.created_at()): + record[self.created_at()] = fresh + + if self._has_pivot_column(self.updated_at()): + record[self.updated_at()] = fresh + + return record + + def detach(self, ids=None, touch=True): + """ + Detach models from the relationship. + """ + if isinstance(ids, eloquent.orm.model.Model): + ids = ids.get_key() + + if ids is None: + ids = [] + + query = self._new_pivot_query() + + if not isinstance(ids, list): + ids = [ids] + + if len(ids) > 0: + query.where_in(self._other_key, ids) + + if touch: + self.touch_if_touching() + + results = query.delete() + + return results + + def touch_if_touching(self): + """ + Touch if the parent model is being touched. + """ + if self._touching_parent(): + self.get_parent().touch() + + if self.get_parent().touches(self._relation_name): + self.touch() + + def _touching_parent(self): + """ + Determine if we should touch the parent on sync. + """ + return self.get_related().touches(self._guess_inverse_relation()) + + def _guess_inverse_relation(self): + return inflection.camelize(inflection.pluralize(self.get_parent().__class__.__name__)) + + def _new_pivot_query(self): + """ + Create a new query builder for the pivot table. + + :rtype: eloquent.orm.Builder + """ + query = self.new_pivot_statement() + + for where_args in self._pivot_wheres: + query.where(*where_args) + + return query.where(self._foreign_key, self._parent.get_key()) + + def new_pivot_statement(self): + """ + Get a new plain query builder for the pivot table. + """ + return self._query.get_query().new_query().from_(self._table) + + def new_pivot_statement_for_id(self, id): + """ + Get a new pivot statement for a given "other" id. + """ + return self._new_pivot_query().where(self._other_key, id) + + def new_pivot(self, attributes=None, exists=False): + """ + Create a new pivot model instance. + """ + pivot = self._related.new_pivot(self._parent, attributes, self._table, exists) + + return pivot.set_pivot_keys(self._foreign_key, self._other_key) + + def new_existing_pivot(self, attributes): + """ + Create a new existing pivot model instance. + """ + return self.new_pivot(attributes, True) + + def with_pivot(self, *columns): + """ + Set the columns on the pivot table to retrieve. + """ + columns = list(columns) + + self._pivot_columns += columns + + return self + + def with_timestamps(self, created_at=None, updated_at=None): + """ + Specify that the pivot table has creation and update columns. + """ + if not created_at: + created_at = self.created_at() + + if not updated_at: + updated_at = self.updated_at() + + return self.with_pivot(created_at, updated_at) + + def get_related_fresh_update(self): + """ + Get the related model's update at column at + """ + return {self._related.get_updated_at_column(): self._related.fresh_timestamp()} + + def get_has_compare_key(self): + """ + Get the key for comparing against the parent key in "has" query. + """ + return self.get_foreign_key() + + def get_foreign_key(self): + return '%s.%s' % (self._table, self._foreign_key) + + def get_other_key(self): + return '%s.%s' % (self._table, self._other_key) + + def get_table(self): + return self._table + + def get_relation_name(self): + return self._relation_name diff --git a/eloquent/orm/relations/dynamic_property.py b/eloquent/orm/relations/dynamic_property.py new file mode 100644 index 00000000..6e2edab2 --- /dev/null +++ b/eloquent/orm/relations/dynamic_property.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + + +class DynamicProperty(object): + """ + Relationship dynamic property. + + It provides a simple way to access a property as is, returning the results, + or has a method whihch will start a query on the relation. + + Example: + + >>> user = User.find(1) + >>> user.roles # will return the roles associated with the user + >>> user.roles().first() # Will return the first role + """ + + def __init__(self, results_getter, relation): + self._results_getter = results_getter + self._results = None + self._relation = relation + + def get_results(self): + return self._results + + def __getitem__(self, item): + if not self._results: + self._results = self._results_getter() + + return self._results[item] + + def __iter__(self): + if not self._results: + self._results = self._results_getter() + + return iter(self._results) + + def __getattr__(self, item): + if not self._results: + self._results = self._results_getter() + + return getattr(self._results, item) + + def __call__(self, *args, **kwargs): + return self._relation(*args, **kwargs) diff --git a/eloquent/orm/relations/has_many.py b/eloquent/orm/relations/has_many.py new file mode 100644 index 00000000..44cb1052 --- /dev/null +++ b/eloquent/orm/relations/has_many.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +from .has_one_or_many import HasOneOrMany + + +class HasMany(HasOneOrMany): + + def get_results(self): + """ + Get the results of the relationship. + """ + return self._query.get() + + def init_relation(self, models, relation): + """ + Initialize the relation on a set of models. + + :type models: list + :type relation: str + """ + for model in models: + model.set_relation(relation, self._related.new_collection()) + + return models + + def match(self, models, results, relation): + """ + Match the eagerly loaded results to their parents. + + :type models: list + :type results: Collection + :type relation: str + """ + return self.match_many(models, results, relation) diff --git a/eloquent/orm/relations/has_many_through.py b/eloquent/orm/relations/has_many_through.py new file mode 100644 index 00000000..d89b36b3 --- /dev/null +++ b/eloquent/orm/relations/has_many_through.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- + +from ...query.expression import QueryExpression +from .relation import Relation + + +class HasManyThrough(Relation): + + def __init__(self, query, far_parent, parent, first_key, second_key): + """ + :param query: A Builder instance + :type query: Builder + + :param far_parent: The far parent model + :type far_parent: Model + + :param parent: The parent model + :type parent: Model + + :type first_key: str + :type second_key: str + """ + self._first_key = first_key + self._second_key = second_key + self._far_parent = far_parent + + super(HasManyThrough, self).__init__(query, parent) + + def add_constraints(self): + """ + Set the base constraints on the relation query. + + :rtype: None + """ + parent_table = self._parent.get_table() + + self._set_join() + + if self._constraints: + self._query.where('%s.%s' % (parent_table, self._first_key), '=', self._far_parent.get_key()) + + def get_relation_count_query(self, query, parent): + """ + Add the constraints for a relationship count query. + + :type query: Builder + :type parent: Builder + + :rtype: Builder + """ + parent_table = self._parent.get_table() + + self._set_join(query) + + query.select(QueryExpression('COUNT(*)')) + + key = self.wrap('%s.%s' % (parent_table, self._first_key)) + + return query.where(self.get_has_compare_key(), '=', QueryExpression(key)) + + def _set_join(self, query=None): + """ + Set the join clause for the query. + """ + if not query: + query = self._query + + foreign_key = '%s.%s' % (self._related.get_table(), self._second_key) + + query.join(self._parent.get_table(), self.get_qualified_parent_key_name(), '=', foreign_key) + + def add_eager_constraints(self, models): + """ + Set the constraints for an eager load of the relation. + + :type models: list + """ + table = self._parent.get_table() + + self._query.where_in('%s.%s' % (table, self._first_key), self.get_keys(models)) + + def init_relation(self, models, relation): + """ + Initialize the relation on a set of models. + + :type models: list + :type relation: str + """ + for model in models: + model.set_relation(relation, self._related.new_collection()) + + return models + + def match(self, models, results, relation): + """ + Match the eagerly loaded results to their parents. + + :type models: list + :type results: Collection + :type relation: str + """ + dictionary = self._build_dictionary(results) + + for model in models: + key = model.get_key() + + if key in dictionary: + value = self._related.new_collection(dictionary[key]) + + model.set_relation(relation, value) + + return models + + def _build_dictionary(self, results): + """ + Build model dictionary keyed by the relation's foreign key. + + :param results: The results + :type results: Collection + + :rtype: dict + """ + foreign = self._first_key + + dictionary = {} + + for result in results: + key = getattr(result, foreign) + if key not in dictionary: + dictionary[key] = [] + + dictionary[key].append(result) + + return dictionary + + def get_results(self): + """ + Get the results of the relationship. + """ + return self.get() + + def get(self, columns=None): + """ + Execute the query as a "select" statement. + + :type columns: list + + :rtype: eloquent.Collection + """ + if columns is None: + columns = ['*'] + + select = self._get_select_columns(columns) + + models = self._query.add_select(select).get_models() + + if len(models) > 0: + models = self._query.eager_load_relations(models) + + return self._related.new_collection(models) + + def _get_select_columns(self, columns=None): + """ + Set the select clause for the relation query. + + :param columns: The columns + :type columns: list + + :rtype: list + """ + if columns == ['*'] or columns is None: + columns = ['%s.*' % self._related.get_table()] + + return columns + ['%s.%s' % (self._parent.get_table(), self._first_key)] + + def get_has_compare_key(self): + return self._far_parent.get_qualified_key_name() diff --git a/eloquent/orm/relations/has_one.py b/eloquent/orm/relations/has_one.py new file mode 100644 index 00000000..0c84aa8d --- /dev/null +++ b/eloquent/orm/relations/has_one.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +from .has_one_or_many import HasOneOrMany + + +class HasOne(HasOneOrMany): + + def get_results(self): + """ + Get the results of the relationship. + """ + return self._query.first() + + def init_relation(self, models, relation): + """ + Initialize the relation on a set of models. + + :type models: list + :type relation: str + """ + for model in models: + model.set_relation(relation, None) + + return models + + def match(self, models, results, relation): + """ + Match the eagerly loaded results to their parents. + + :type models: list + :type results: Collection + :type relation: str + """ + return self.match_one(models, results, relation) diff --git a/eloquent/orm/relations/has_one_or_many.py b/eloquent/orm/relations/has_one_or_many.py new file mode 100644 index 00000000..e18e09de --- /dev/null +++ b/eloquent/orm/relations/has_one_or_many.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- + +from ..collection import Collection +from .relation import Relation + + +class HasOneOrMany(Relation): + + def __init__(self, query, parent, foreign_key, local_key): + """ + :type query: eloquent.orm.Builder + + :param parent: The parent model + :type parent: Model + + :param foreign_key: The foreign key of the parent model + :type foreign_key: str + + :param local_key: The local key of the parent model + :type local_key: str + """ + self._local_key = local_key + self._foreign_key = foreign_key + + super(HasOneOrMany, self).__init__(query, parent) + + def add_constraints(self): + """ + Set the base constraints of the relation query + """ + if self._constraints: + self._query.where(self._foreign_key, '=', self.get_parent_key()) + + def add_eager_constraints(self, models): + """ + Set the constraints for an eager load of the relation. + + :type models: list + """ + return self._query.where_in(self._foreign_key, self.get_keys(models, self._local_key)) + + def match_one(self, models, results, relation): + """ + Match the eargerly loaded resuls to their single parents. + + :param models: The parents + :type models: list + + :param results: The results collection + :type results: Collection + + :param relation: The relation + :type relation: str + + :rtype: list + """ + return self._match_one_or_many(models, results, relation, 'one') + + def match_many(self, models, results, relation): + """ + Match the eargerly loaded resuls to their single parents. + + :param models: The parents + :type models: list + + :param results: The results collection + :type results: Collection + + :param relation: The relation + :type relation: str + + :rtype: list + """ + return self._match_one_or_many(models, results, relation, 'many') + + def _match_one_or_many(self, models, results, relation, type): + """ + Match the eargerly loaded resuls to their single parents. + + :param models: The parents + :type models: list + + :param results: The results collection + :type results: Collection + + :param relation: The relation + :type relation: str + + :param type: The match type + :type type: str + + :rtype: list + """ + dictionary = self._build_dictionary(results) + + for model in models: + key = model.get_attribute(self._local_key) + + if key in dictionary: + value = self._get_relation_value(dictionary, key, type) + + model.set_relation(relation, value) + + return models + + def _get_relation_value(self, dictionary, key, type): + """ + Get the value of the relationship by one or many type. + + :type dictionary: dict + :type key: str + :type type: str + """ + value = dictionary[key] + + if type == 'one': + return value[0] + + return self._related.new_collection(value) + + def _build_dictionary(self, results): + """ + Build model dictionary keyed by the relation's foreign key. + + :param results: The results + :type results: Collection + + :rtype: dict + """ + dictionary = {} + + foreign = self.get_plain_foreign_key() + + for result in results: + key = getattr(result, foreign) + if key not in dictionary: + dictionary[key] = [] + + dictionary[key].append(result) + + return dictionary + + def save(self, model): + """ + Attach a model instance to the parent models. + + :param model: The model instance to attach + :type model: Model + + :rtype: Model + """ + model.set_attribute(self.get_plain_foreign_key(), self.get_parent_key()) + + if model.save(): + return model + + return False + + def save_many(self, models): + """ + Attach a list of models to the parent instance. + + :param models: The models to attach + :type models: list of Model + + :rtype: list + """ + return map(self.save, models) + + def find_or_new(self, id, columns=None): + """ + Find a model by its primary key or return new instance of the related model. + + :param id: The primary key + :type id: mixed + + :param columns: The columns to retrieve + :type columns: list + + :rtype: Collection or Model + """ + if columns is None: + columns = ['*'] + + instance = self.find(id, columns) + + if instance is None: + instance = self._related.new_instance() + instance.set_attribute(self.get_plain_foreign_key(), self.get_parent_key()) + + return instance + + def first_or_new(self, _attributes=None, **attributes): + """ + Get the first related model record matching the attributes or instantiate it. + + :param attributes: The attributes + :type attributes: dict + + :rtype: Model + """ + if _attributes is not None: + attributes.update(_attributes) + + instance = self.where(attributes).first() + + if instance is None: + instance = self._related.new_instance() + instance.set_attribute(self.get_plain_foreign_key(), self.get_parent_key()) + + return instance + + def first_or_create(self, _attributes=None, **attributes): + """ + Get the first related record matching the attributes or create it. + + :param attributes: The attributes + :type attributes: dict + + :rtype: Model + """ + if _attributes is not None: + attributes.update(_attributes) + + instance = self.where(attributes).first() + + if instance is None: + instance = self.create(**attributes) + + return instance + + def update_or_create(self, attributes, values=None): + """ + Create or update a related record matching the attributes, and fill it with values. + + :param attributes: The attributes + :type attributes: dict + + :param values: The values + :type values: dict + + :rtype: Model + """ + instance = self.first_or_new(**attributes) + + instance.fill(values) + + instance.save() + + return instance + + def create(self, _attributes=None, **attributes): + """ + Create a new instance of the related model. + + :param attributes: The attributes + :type attributes: dict + + :rtype: Model + """ + if _attributes is not None: + attributes.update(_attributes) + + instance = self._related.new_instance(attributes) + + instance.set_attribute(self.get_plain_foreign_key(), self.get_parent_key()) + + instance.save() + + return instance + + def create_many(self, records): + """ + Create a list of new instances of the related model. + + :param records: instances attributes + :type records: list + + :rtype: list + """ + instances = [] + + for record in records: + instances.append(self.create(**record)) + + return instances + + def update(self, _attributes=None, **attributes): + """ + Perform an update on all the related models. + + :param attributes: The attributes + :type attributes: dict + + :rtype: int + """ + if _attributes is not None: + attributes.update(_attributes) + + if self._related.uses_timestamps(): + attributes[self.get_related_updated_at()] = self._related.fresh_timestamp() + + return self._query.update(attributes) + + def get_has_compare_key(self): + return self.get_foreign_key() + + def get_foreign_key(self): + return self._foreign_key + + def get_plain_foreign_key(self): + segments = self.get_foreign_key().split('.') + + return segments[-1] + + def get_parent_key(self): + return self._parent.get_attribute(self._local_key) + + def get_qualified_parent_key_name(self): + return '%s.%s' % (self._parent.get_table(), self._local_key) diff --git a/eloquent/orm/relations/pivot.py b/eloquent/orm/relations/pivot.py new file mode 100644 index 00000000..91046dd3 --- /dev/null +++ b/eloquent/orm/relations/pivot.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- + +from ..model import Model + + +class Pivot(Model): + + __guarded__ = [] + + def __init__(self, parent, attributes, table, exists=False): + """ + :param parent: The parent model + :type parent: Model + + :param attributes: The pivot attributes + :type attributes: dict + + :param table: the pivot table + :type table: str + + :param exists: Whether the pivot exists or not + :type exists: bool + """ + if attributes is None: + attributes = {} + + super(Pivot, self).__init__() + + self.set_raw_attributes(attributes, True) + + self.set_table(table) + + self.set_connection(parent.get_connection_name()) + + self.__parent = parent + + self.set_exists(exists) + + self.__timestamps__ = self.has_timestamps_attributes() + + def _set_keys_for_save_query(self, query): + """ + Set the keys for a save update query. + + :param query: A Builder instance + :type query: eloquent.orm.Builder + + :return: The Builder instance + :rtype: eloquent.orm.Builder + """ + query.where(self.__foreign_key, self.get_attribute(self.__foreign_key)) + + return query.where(self.__other_key, self.get_attribute(self.__other_key)) + + def delete(self): + """ + Delete the pivot model record from the database. + + :rtype: int + """ + return self._get_delete_query().delete() + + def _get_delete_query(self): + """ + Get the query builder for a delete operation on the pivot. + + :rtype: eloquent.orm.Builder + """ + foreign = self.get_attribute(self.__foreign_key) + + query = self.new_query().where(self.__foreign_key, foreign) + + return query.where(self.__other_key, self.get_attribute(self.__other_key)) + + def has_timestamps_attributes(self): + """ + Determine if the pivot has timestamps attributes. + + :rtype: bool + """ + return self.get_created_at_column() in self.get_attributes() + + def get_foreign_key(self): + return self.__foreign_key + + def get_other_key(self): + return self.__other_key + + def set_pivot_keys(self, foreign_key, other_key): + """ + Set the key names for the pivot model instance + """ + self.__foreign_key = foreign_key + self.__other_key = other_key + + return self + + def get_created_at_column(self): + return self.__parent.get_created_at_column() + + def get_updated_at_column(self): + return self.__parent.get_updated_at_column() diff --git a/eloquent/orm/relations/relation.py b/eloquent/orm/relations/relation.py new file mode 100644 index 00000000..d2532e72 --- /dev/null +++ b/eloquent/orm/relations/relation.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- + +from ...query.expression import QueryExpression +from ..collection import Collection + + +class Relation(object): + + _constraints = True + + def __init__(self, query, parent): + """ + :param query: A Builder instance + :type query: orm.eloquent.Builder + + :param parent: The parent model + :type parent: Model + """ + self._query = query + self._parent = parent + self._related = query.get_model() + + self.add_constraints() + + def add_constraints(self): + """ + Set the base constraints on the relation query. + + :rtype: None + """ + raise NotImplementedError + + def add_eager_constraints(self, models): + """ + Set the constraints for an eager load of the relation. + + :type models: list + """ + raise NotImplementedError + + def init_relation(self, models, relation): + """ + Initialize the relation on a set of models. + + :type models: list + :type relation: str + """ + raise NotImplementedError + + def match(self, models, results, relation): + """ + Match the eagerly loaded results to their parents. + + :type models: list + :type results: Collection + :type relation: str + """ + raise NotImplementedError + + def get_results(self): + """ + Get the results of the relationship. + """ + raise NotImplementedError + + def get_eager(self): + """ + Get the relationship for eager loading. + + :rtype: Collection + """ + return self.get() + + def touch(self): + """ + Touch all of the related models for the relationship. + """ + column = self.get_related().get_updated_at_column() + + self.raw_update({column: self.get_related().fresh_timestamp()}) + + def raw_update(self, attributes=None): + """ + Run a raw update against the base query. + + :type attributes: dict + + :rtype: int + """ + if attributes is None: + attributes = {} + + return self._query.update(attributes) + + def get_relation_count_query(self, query, parent): + """ + Add the constraints for a relationship count query. + + :type query: Builder + :type parent: Builder + + :rtype: Builder + """ + query.select(QueryExpression('COUNT(*)')) + + key = self.wrap(self.get_qualified_parent_key_name()) + + return query.where(self.get_has_compare_key(), '=', QueryExpression(key)) + + @classmethod + def no_constraints(cls, callback): + """ + Runs a callback with constraints disabled on the relation. + """ + cls._constraints = False + + results = callback() + + cls._constraints = True + + return results + + def get_keys(self, models, key=None): + """ + Get all the primary keys for an array of models. + + :type models: list + :type key: str + + :rtype: list + """ + return list(set(map(lambda value: value.get_attribute(key) if key else value.get_key(), models))) + + def get_query(self): + return self._query + + def get_base_query(self): + return self._query.get_query() + + def merge_query(self, query): + self._query.merge_wheres(query.wheres, query.get_query().get_raw_bindings()['where']) + + def get_parent(self): + return self._parent + + def get_qualified_parent_key_name(self): + return self._parent.get_qualified_key_name() + + def get_related(self): + return self._related + + def created_at(self): + """ + Get the name of the "created at" column. + + :rtype: str + """ + return self._parent.get_created_at_column() + + def updated_at(self): + """ + Get the name of the "updated at" column. + + :rtype: str + """ + return self._parent.get_updated_at_column() + + def get_related_updated_at(self): + """ + Get the name of the related model's "updated at" column. + + :rtype: str + """ + return self._related.get_updated_at_column() + + def wrap(self, value): + """ + Wrap the given calue with the parent's query grammar. + + :rtype: str + """ + return self._parent.new_query().get_query().get_grammar().wrap(value) + + def __dynamic(self, method): + attribute = getattr(self._query, method) + + def call(*args, **kwargs): + result = attribute(*args, **kwargs) + + return result + + if not callable(attribute): + return attribute + + return call + + def __getattr__(self, item): + return self.__dynamic(item) + + def __call__(self, *args, **kwargs): + self._query = self._related.new_query() + self.add_constraints() + + return self diff --git a/eloquent/query/__init__.py b/eloquent/query/__init__.py index 633f8661..d5849401 100644 --- a/eloquent/query/__init__.py +++ b/eloquent/query/__init__.py @@ -1,2 +1,3 @@ # -*- coding: utf-8 -*- +from .builder import QueryBuilder diff --git a/eloquent/query/builder.py b/eloquent/query/builder.py index 6564182b..bc1e0b62 100644 --- a/eloquent/query/builder.py +++ b/eloquent/query/builder.py @@ -342,9 +342,9 @@ def where(self, column, operator=Null(), value=None, boolean='and'): # and can add them each as a where clause. We will maintain the boolean we # received when the method was called and pass it into the nested where. if isinstance(column, dict): - def nested(query): - for key, value_ in column.items(): - query.where(key, '=', value) + nested = self.new_query() + for key, value in column.items(): + nested.where(key, '=', value) return self.where_nested(nested, boolean) @@ -1024,6 +1024,26 @@ def _run_select(self): not self._use_write_connection ) + def chunk(self, count): + """ + Chunk the results of the query + + :param count: The chunk size + :type count: int + + :return: The current chunk + :rtype: list + """ + page = 1 + results = self.for_page(page, count).get() + + while len(results) > 0: + yield results + + page += 1 + + results = self.for_page(page, count).get() + def lists(self, column, key=None): """ Get a list with the values of a given column diff --git a/eloquent/query/processors/__init__.py b/eloquent/query/processors/__init__.py index 633f8661..a0556a9d 100644 --- a/eloquent/query/processors/__init__.py +++ b/eloquent/query/processors/__init__.py @@ -1,2 +1,6 @@ # -*- coding: utf-8 -*- +from .processor import QueryProcessor +from .mysql_processor import MySqlQueryProcessor +from .postgres_processor import PostgresQueryProcessor +from .sqlite_processor import SQLiteQueryProcessor diff --git a/eloquent/support/collection.py b/eloquent/support/collection.py index 1ed54911..45d80917 100644 --- a/eloquent/support/collection.py +++ b/eloquent/support/collection.py @@ -103,7 +103,27 @@ def diff(self, items): """ pass + def first(self, default=None): + """ + Get the first item of the collection. + + :param default: The default value + :type default: mixed + """ + if len(self._items) > 0: + return self._items[0] + else: + return default + + def lists(self, value, key=None): + """ + Get a list with the values of a given key + :rtype: list + """ + results = map(lambda x: x[value], self._items) + + return list(results) def _get_items(self, items): if isinstance(items, Collection): @@ -114,3 +134,17 @@ def _get_items(self, items): items = items.to_dict() return items + + def to_dict(self): + return list(map(lambda value: value.to_dict() if hasattr(value, 'to_dict') else value, + self._items)) + + def __len__(self): + return len(self._items) + + def __iter__(self): + for item in self._items: + yield item + + def __getitem__(self, item): + return self._items[item] diff --git a/requirements.txt b/requirements.txt index e69de29b..5a086679 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,4 @@ +simplejson +arrow +inflection +six diff --git a/setup.py b/setup.py index 0225a9b0..a855ad9b 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages -__version__ = '0.1' +__version__ = '0.2' setup( name='eloquent', @@ -15,8 +15,8 @@ url='https://github.com/sdispater/eloquent', download_url='https://github.com/sdispater/eloquent/archive/%s.tar.gz' % __version__, packages=find_packages(), - install_requires=[], - tests_require=['pytest', 'mock'], + install_requires=['simplejson', 'arrow', 'inflection', 'six'], + tests_require=['pytest', 'mock', 'flexmock'], test_suite='nose.collector', classifiers=[ 'Intended Audience :: Developers', diff --git a/tests-requirements.txt b/tests-requirements.txt index 12136494..d72d9896 100644 --- a/tests-requirements.txt +++ b/tests-requirements.txt @@ -1,2 +1,3 @@ pytest pytest-mock +flexmock diff --git a/tests/orm/__init__.py b/tests/orm/__init__.py new file mode 100644 index 00000000..633f8661 --- /dev/null +++ b/tests/orm/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- + diff --git a/tests/orm/relations/__init__.py b/tests/orm/relations/__init__.py new file mode 100644 index 00000000..633f8661 --- /dev/null +++ b/tests/orm/relations/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- + diff --git a/tests/orm/relations/test_belongs_to.py b/tests/orm/relations/test_belongs_to.py new file mode 100644 index 00000000..5340203f --- /dev/null +++ b/tests/orm/relations/test_belongs_to.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- + + +import arrow +from flexmock import flexmock, flexmock_teardown +from ... import EloquentTestCase, mock + +from eloquent.query.builder import QueryBuilder +from eloquent.query.grammars import QueryGrammar +from eloquent.query.expression import QueryExpression +from eloquent.orm.builder import Builder +from eloquent.orm.model import Model +from eloquent.orm.relations import BelongsTo +from eloquent.orm.collection import Collection + + +class OrmBelongsToTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_update_retrieve_model_and_updates(self): + relation = self._get_relation() + mock = flexmock(Model()) + mock.should_receive('fill').once().with_args({'foo': 'bar'}).and_return(mock) + mock.should_receive('save').once().and_return(True) + relation.get_query().should_receive('first').once().and_return(mock) + + self.assertTrue(relation.update({'foo': 'bar'})) + + def test_relation_is_properly_initialized(self): + relation = self._get_relation() + model = flexmock(Model()) + model.should_receive('set_relation').once().with_args('foo', None) + models = relation.init_relation([model], 'foo') + + self.assertEqual([model], models) + + def test_eager_constraints_are_properly_added(self): + relation = self._get_relation() + relation.get_query().get_query().should_receive('where_in').once()\ + .with_args('relation.id', ['foreign.value', 'foreign.value.two']) + + model1 = OrmBelongsToModelStub() + model2 = OrmBelongsToModelStub() + model2.foreign_key = 'foreign.value' + model3 = AnotherOrmBelongsToModelStub() + model3.foreign_key = 'foreign.value.two' + models = [model1, model2, model3] + + relation.add_eager_constraints(models) + + def test_models_are_properly_matched_to_parents(self): + relation = self._get_relation() + + result1 = flexmock() + result1.should_receive('get_attribute').with_args('id').and_return(1) + result2 = flexmock() + result2.should_receive('get_attribute').with_args('id').and_return(2) + + model1 = OrmBelongsToModelStub() + model1.foreign_key = 1 + model2 = OrmBelongsToModelStub() + model2.foreign_key = 2 + + models = relation.match([model1, model2], Collection([result1, result2]), 'foo') + + self.assertEqual(1, models[0].foo.get_attribute('id')) + self.assertEqual(2, models[1].foo.get_attribute('id')) + + def test_associate_sets_foreign_key_on_model(self): + parent = Model() + parent.foreign_key = 'foreign.value' + parent.get_attribute = mock.MagicMock(return_value='foreign.value') + parent.set_attribute = mock.MagicMock() + parent.set_relation = mock.MagicMock() + relation = self._get_relation(parent) + associate = flexmock(Model()) + associate.should_receive('get_attribute').once().with_args('id').and_return(1) + + relation.associate(associate) + + parent.get_attribute.assert_has_calls([ + mock.call('foreign_key'), + mock.call('foreign_key') + ]) + parent.set_attribute.assert_has_calls([ + mock.call('foreign_key', 1) + ]) + parent.set_relation.assert_called_once_with('relation', associate) + + def _get_relation(self, parent=None): + flexmock(Builder) + query = flexmock(QueryBuilder(None, QueryGrammar(), None)) + builder = Builder(query) + builder.should_receive('where').with_args('relation.id', '=', 'foreign.value') + related = flexmock(Model()) + related.should_receive('get_key_name').and_return('id') + related.should_receive('get_table').and_return('relation') + builder.should_receive('get_model').and_return(related) + if parent is None: + parent = OrmBelongsToModelStub() + parent.foreign_key = 'foreign.value' + + return BelongsTo(builder, parent, 'foreign_key', 'id', 'relation') + + +class OrmBelongsToModelStub(Model): + + foreign_key = 'foreign.value' + + +class AnotherOrmBelongsToModelStub(Model): + + foreign_key = 'foreign.value.two' diff --git a/tests/orm/relations/test_belongs_to_many.py b/tests/orm/relations/test_belongs_to_many.py new file mode 100644 index 00000000..f50b51a8 --- /dev/null +++ b/tests/orm/relations/test_belongs_to_many.py @@ -0,0 +1,562 @@ +# -*- coding: utf-8 -*- + + +import arrow +from flexmock import flexmock, flexmock_teardown +from ... import EloquentTestCase +from ...utils import MockConnection + +from eloquent.query.builder import QueryBuilder +from eloquent.query.grammars import QueryGrammar +from eloquent.query.processors import QueryProcessor +from eloquent.query.expression import QueryExpression +from eloquent.orm.builder import Builder +from eloquent.orm.model import Model +from eloquent.orm.relations import BelongsToMany +from eloquent.orm.relations.pivot import Pivot +from eloquent.orm.collection import Collection + + +class OrmBelongsToTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_models_are_properly_hydrated(self): + model1 = OrmBelongsToManyModelStub() + model1.fill(name='john', pivot_user_id=1, pivot_role_id=2) + model2 = OrmBelongsToManyModelStub() + model2.fill(name='jane', pivot_user_id=3, pivot_role_id=4) + models = [model1, model2] + + base_builder = flexmock(Builder(QueryBuilder(MockConnection().prepare_mock(), + QueryGrammar(), QueryProcessor()))) + + relation = self._get_relation() + relation.get_parent().should_receive('get_connection_name').and_return('foo.connection') + relation.get_query().get_query().should_receive('add_select').once()\ + .with_args(*['roles.*', 'user_role.user_id AS pivot_user_id', 'user_role.role_id AS pivot_role_id'])\ + .and_return(relation.get_query()) + relation.get_query().should_receive('get_models').once().and_return(models) + relation.get_query().should_receive('eager_load_relations').once().with_args(models).and_return(models) + relation.get_related().should_receive('new_collection').replace_with(lambda l: Collection(l)) + relation.get_query().should_receive('get_query').once().and_return(base_builder) + + results = relation.get() + + self.assertIsInstance(results, Collection) + + # Make sure the foreign keys were set on the pivot models + self.assertEqual('user_id', results[0].pivot.get_foreign_key()) + self.assertEqual('role_id', results[0].pivot.get_other_key()) + + self.assertEqual('john', results[0].name) + self.assertEqual(1, results[0].pivot.user_id) + self.assertEqual(2, results[0].pivot.role_id) + self.assertEqual('foo.connection', results[0].pivot.get_connection_name()) + + self.assertEqual('jane', results[1].name) + self.assertEqual(3, results[1].pivot.user_id) + self.assertEqual(4, results[1].pivot.role_id) + self.assertEqual('foo.connection', results[1].pivot.get_connection_name()) + + self.assertEqual('user_role', results[0].pivot.get_table()) + self.assertTrue(results[0].pivot.exists) + + def test_timestamps_can_be_retrieved_properly(self): + model1 = OrmBelongsToManyModelStub() + model1.fill(name='john', pivot_user_id=1, pivot_role_id=2) + model2 = OrmBelongsToManyModelStub() + model2.fill(name='jane', pivot_user_id=3, pivot_role_id=4) + models = [model1, model2] + + base_builder = flexmock(Builder(QueryBuilder(MockConnection().prepare_mock(), + QueryGrammar(), QueryProcessor()))) + + relation = self._get_relation().with_timestamps() + relation.get_parent().should_receive('get_connection_name').and_return('foo.connection') + relation.get_query().get_query().should_receive('add_select').once()\ + .with_args( + 'roles.*', + 'user_role.user_id AS pivot_user_id', + 'user_role.role_id AS pivot_role_id', + 'user_role.created_at AS pivot_created_at', + 'user_role.updated_at AS pivot_updated_at' + )\ + .and_return(relation.get_query()) + relation.get_query().should_receive('get_models').once().and_return(models) + relation.get_query().should_receive('eager_load_relations').once().with_args(models).and_return(models) + relation.get_related().should_receive('new_collection').replace_with(lambda l: Collection(l)) + relation.get_query().should_receive('get_query').once().and_return(base_builder) + + results = relation.get() + + def test_models_are_properly_matched_to_parents(self): + relation = self._get_relation() + + result1 = OrmBelongsToManyModelPivotStub() + result1.pivot.user_id = 1 + + result2 = OrmBelongsToManyModelPivotStub() + result2.pivot.user_id = 2 + + result3 = OrmBelongsToManyModelPivotStub() + result3.pivot.user_id = 2 + + model1 = OrmBelongsToManyModelStub() + model1.id = 1 + model2 = OrmBelongsToManyModelStub() + model2.id = 2 + model3 = OrmBelongsToManyModelStub() + model3.id = 3 + + relation.get_related().should_receive('new_collection').replace_with(lambda l: Collection(l)) + models = relation.match([model1, model2, model3], Collection([result1, result2, result3]), 'foo') + + self.assertEqual(1, models[0].foo[0].pivot.user_id) + self.assertEqual(1, len(models[0].foo)) + self.assertEqual(2, models[1].foo[0].pivot.user_id) + self.assertEqual(2, models[1].foo[1].pivot.user_id) + self.assertEqual(2, len(models[1].foo)) + self.assertFalse(hasattr(models[2], 'foo')) + + def test_relation_is_properly_initialized(self): + relation = self._get_relation() + relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l or [])) + model = flexmock(Model()) + model.should_receive('set_relation').once().with_args('foo', Collection) + models = relation.init_relation([model], 'foo') + + self.assertEqual([model], models) + + def test_eager_constraints_are_properly_added(self): + relation = self._get_relation() + relation.get_query().get_query().should_receive('where_in').once().with_args('user_role.user_id', [1, 2]) + model1 = OrmBelongsToManyModelStub() + model1.id = 1 + model2 = OrmBelongsToManyModelStub() + model2.id = 2 + + relation.add_eager_constraints([model1, model2]) + + def test_attach_inserts_pivot_table_record(self): + flexmock(BelongsToMany, touch_if_touching=lambda: True) + relation = self._get_relation() + query = flexmock() + query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive('insert').once().with_args([{'user_id': 1, 'role_id': 2, 'foo': 'bar'}]).and_return(True) + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + relation.should_receive('touch_if_touching').once() + + relation.attach(2, {'foo': 'bar'}) + + def test_attach_multiple_inserts_pivot_table_record(self): + flexmock(BelongsToMany, touch_if_touching=lambda: True) + relation = self._get_relation() + query = flexmock() + query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive('insert').once().with_args( + [ + {'user_id': 1, 'role_id': 2, 'foo': 'bar'}, + {'user_id': 1, 'role_id': 3, 'bar': 'baz', 'foo': 'bar'} + ] + ).and_return(True) + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + relation.should_receive('touch_if_touching').once() + + relation.attach([2, {3: {'bar': 'baz'}}], {'foo': 'bar'}) + + def test_attach_inserts_pivot_table_records_with_timestamps_when_ncessary(self): + flexmock(BelongsToMany, touch_if_touching=lambda: True) + relation = self._get_relation().with_timestamps() + query = flexmock() + query.should_receive('from_').once().with_args('user_role').and_return(query) + now = arrow.get().naive + query.should_receive('insert').once().with_args( + [ + {'user_id': 1, 'role_id': 2, 'foo': 'bar', 'created_at': now, 'updated_at': now} + ] + ).and_return(True) + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + relation.get_parent().should_receive('fresh_timestamp').once().and_return(now) + relation.should_receive('touch_if_touching').once() + + relation.attach(2, {'foo': 'bar'}) + + def test_attach_inserts_pivot_table_records_with_a_created_at_timestamp(self): + flexmock(BelongsToMany, touch_if_touching=lambda: True) + relation = self._get_relation().with_pivot('created_at') + query = flexmock() + query.should_receive('from_').once().with_args('user_role').and_return(query) + now = arrow.get().naive + query.should_receive('insert').once().with_args( + [ + {'user_id': 1, 'role_id': 2, 'foo': 'bar', 'created_at': now} + ] + ).and_return(True) + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + relation.get_parent().should_receive('fresh_timestamp').once().and_return(now) + relation.should_receive('touch_if_touching').once() + + relation.attach(2, {'foo': 'bar'}) + + def test_attach_inserts_pivot_table_records_with_an_updated_at_timestamp(self): + flexmock(BelongsToMany, touch_if_touching=lambda: True) + relation = self._get_relation().with_pivot('updated_at') + query = flexmock() + query.should_receive('from_').once().with_args('user_role').and_return(query) + now = arrow.get().naive + query.should_receive('insert').once().with_args( + [ + {'user_id': 1, 'role_id': 2, 'foo': 'bar', 'updated_at': now} + ] + ).and_return(True) + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + relation.get_parent().should_receive('fresh_timestamp').once().and_return(now) + relation.should_receive('touch_if_touching').once() + + relation.attach(2, {'foo': 'bar'}) + + def test_detach_remove_pivot_table_record(self): + flexmock(BelongsToMany, touch_if_touching=lambda: True) + relation = self._get_relation() + query = flexmock() + query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive('where').once().with_args('user_id', 1).and_return(query) + query.should_receive('where_in').once().with_args('role_id', [1, 2, 3]) + query.should_receive('delete').once().and_return(True) + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + relation.should_receive('touch_if_touching').once() + + self.assertTrue(relation.detach([1, 2, 3])) + + def test_detach_with_single_id_remove_pivot_table_record(self): + flexmock(BelongsToMany, touch_if_touching=lambda: True) + relation = self._get_relation() + query = flexmock() + query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive('where').once().with_args('user_id', 1).and_return(query) + query.should_receive('where_in').once().with_args('role_id', [1]) + query.should_receive('delete').once().and_return(True) + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + relation.should_receive('touch_if_touching').once() + + self.assertTrue(relation.detach(1)) + + def test_detach_clears_all_records_when_no_ids(self): + flexmock(BelongsToMany, touch_if_touching=lambda: True) + relation = self._get_relation() + query = flexmock() + query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive('where').once().with_args('user_id', 1).and_return(query) + query.should_receive('where_in').never() + query.should_receive('delete').once().and_return(True) + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + relation.should_receive('touch_if_touching').once() + + self.assertTrue(relation.detach()) + + def test_create_creates_new_model_and_insert_attachment_record(self): + flexmock(BelongsToMany, attach=lambda: True) + relation = self._get_relation() + model = flexmock() + relation.get_related().should_receive('new_instance').once().and_return(model).with_args({'foo': 'bar'}) + model.should_receive('save').once() + model.should_receive('get_key').and_return('foo') + relation.should_receive('attach').once().with_args('foo', {'bar': 'baz'}, True) + + self.assertEqual(model, relation.create({'foo': 'bar'}, {'bar': 'baz'})) + + def test_find_or_new_finds_model(self): + flexmock(BelongsToMany) + relation = self._get_relation() + model = flexmock() + model.foo = 'bar' + relation.get_query().should_receive('find').with_args('foo', None).and_return(model) + relation.get_related().should_receive('new_instance').never() + + self.assertEqual('bar', relation.find_or_new('foo').foo) + + def test_find_or_new_returns_new_model(self): + flexmock(BelongsToMany) + relation = self._get_relation() + model = flexmock() + model.foo = 'bar' + relation.get_query().should_receive('find').with_args('foo', None).and_return(None) + relation.get_related().should_receive('new_instance').once().and_return(model) + + self.assertEqual('bar', relation.find_or_new('foo').foo) + + def test_first_or_new_finds_first_model(self): + flexmock(BelongsToMany) + relation = self._get_relation() + model = flexmock() + model.foo = 'bar' + relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query()) + relation.get_query().should_receive('first').once().and_return(model) + relation.get_related().should_receive('new_instance').never() + + self.assertEqual('bar', relation.first_or_new({'foo': 'bar'}).foo) + + def test_first_or_new_returns_new_model(self): + flexmock(BelongsToMany) + relation = self._get_relation() + model = flexmock() + model.foo = 'bar' + relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query()) + relation.get_query().should_receive('first').once().and_return(None) + relation.get_related().should_receive('new_instance').once().and_return(model) + + self.assertEqual('bar', relation.first_or_new({'foo': 'bar'}).foo) + + def test_first_or_create_finds_first_model(self): + flexmock(BelongsToMany) + relation = self._get_relation() + model = flexmock() + model.foo = 'bar' + relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query()) + relation.get_query().should_receive('first').once().and_return(model) + relation.should_receive('create').never() + + self.assertEqual('bar', relation.first_or_create({'foo': 'bar'}).foo) + + def test_first_or_create_returns_new_model(self): + flexmock(BelongsToMany) + relation = self._get_relation() + model = flexmock() + model.foo = 'bar' + relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query()) + relation.get_query().should_receive('first').once().and_return(None) + relation.should_receive('create').once().with_args({'foo': 'bar'}, {}, True).and_return(model) + + self.assertEqual('bar', relation.first_or_create({'foo': 'bar'}).foo) + + def test_update_or_create_finds_first_mode_and_updates(self): + flexmock(BelongsToMany) + relation = self._get_relation() + model = flexmock() + model.foo = 'bar' + relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query()) + relation.get_query().should_receive('first').once().and_return(model) + model.should_receive('fill').once() + model.should_receive('save').once() + relation.should_receive('create').never() + + self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}).foo) + + def test_update_or_create_returns_new_model(self): + flexmock(BelongsToMany) + relation = self._get_relation() + model = flexmock() + model.foo = 'bar' + relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query()) + relation.get_query().should_receive('first').once().and_return(None) + relation.should_receive('create').once().with_args({'bar': 'baz'}, None, True).and_return(model) + + self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'bar': 'baz'}).foo) + + def test_sync_syncs_intermediate_table_with_given_list(self): + for list_ in [[2, 3, 4], ['2', '3', '4']]: + flexmock(BelongsToMany) + relation = self._get_relation() + query = flexmock() + query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive('where').once().with_args('user_id', 1).and_return(query) + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + query.should_receive('lists').once().with_args('role_id').and_return([1, list_[0], list_[1]]) + relation.should_receive('attach').once().with_args(list_[2], {}, False) + relation.should_receive('detach').once().with_args([1]) + relation.get_related().should_receive('touches').and_return(False) + relation.get_parent().should_receive('touches').and_return(False) + + self.assertEqual( + { + 'attached': [list_[2]], + 'detached': [1], + 'updated': [] + }, + relation.sync(list_) + ) + + def test_sync_syncs_intermediate_table_with_given_list_and_attributes(self): + flexmock(BelongsToMany) + relation = self._get_relation() + query = flexmock() + query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive('where').once().with_args('user_id', 1).and_return(query) + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + query.should_receive('lists').once().with_args('role_id').and_return([1, 2, 3]) + relation.should_receive('attach').once().with_args(4, {'foo': 'bar'}, False) + relation.should_receive('update_existing_pivot').once().with_args(3, {'bar': 'baz'}, False).and_return(True) + relation.should_receive('detach').once().with_args([1]) + relation.should_receive('touch_if_touching').once() + relation.get_related().should_receive('touches').and_return(False) + relation.get_parent().should_receive('touches').and_return(False) + + self.assertEqual( + { + 'attached': [4], + 'detached': [1], + 'updated': [3] + }, + relation.sync([2, {3: {'bar': 'baz'}}, {4: {'foo': 'bar'}}], ) + ) + + def test_sync_does_not_return_values_that_were_not_updated(self): + flexmock(BelongsToMany) + relation = self._get_relation() + query = flexmock() + query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive('where').once().with_args('user_id', 1).and_return(query) + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + query.should_receive('lists').once().with_args('role_id').and_return([1, 2, 3]) + relation.should_receive('attach').once().with_args(4, {'foo': 'bar'}, False) + relation.should_receive('update_existing_pivot').once().with_args(3, {'bar': 'baz'}, False).and_return(False) + relation.should_receive('detach').once().with_args([1]) + relation.should_receive('touch_if_touching').once() + relation.get_related().should_receive('touches').and_return(False) + relation.get_parent().should_receive('touches').and_return(False) + + self.assertEqual( + { + 'attached': [4], + 'detached': [1], + 'updated': [] + }, + relation.sync([2, {3: {'bar': 'baz'}}, {4: {'foo': 'bar'}}], ) + ) + + def test_touch_method_syncs_timestamps(self): + relation = self._get_relation() + relation.get_related().should_receive('get_updated_at_column').and_return('updated_at') + now = arrow.get().naive + relation.get_related().should_receive('fresh_timestamp').and_return(now) + relation.get_related().should_receive('get_qualified_key_name').and_return('table.id') + relation.get_query().get_query().should_receive('select').once().with_args('table.id')\ + .and_return(relation.get_query().get_query()) + relation.get_query().should_receive('lists').once().and_return([1, 2, 3]) + query = flexmock() + relation.get_related().should_receive('new_query').once().and_return(query) + query.should_receive('where_in').once().with_args('id', [1, 2, 3]).and_return(query) + query.should_receive('update').once().with_args({'updated_at': now}) + + relation.touch() + + def test_touch_if_touching(self): + flexmock(BelongsToMany) + relation = self._get_relation() + relation.should_receive('_touching_parent').once().and_return(True) + relation.get_parent().should_receive('touch').once() + relation.get_parent().should_receive('touches').once().with_args('relation_name').and_return(True) + relation.should_receive('touch').once() + + relation.touch_if_touching() + + def test_sync_method_converts_collection_to_list_of_keys(self): + flexmock(BelongsToMany) + relation = self._get_relation() + query = flexmock() + query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive('where').once().with_args('user_id', 1).and_return(query) + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + query.should_receive('lists').once().with_args('role_id').and_return([1, 2, 3]) + + collection = flexmock(Collection()) + collection.should_receive('model_keys').once().and_return([1, 2, 3]) + relation.should_receive('_format_sync_list').with_args([1, 2, 3]).and_return({1: {}, 2: {}, 3: {}}) + + relation.sync(collection) + + def test_where_pivot_params_used_for_new_queries(self): + flexmock(BelongsToMany) + relation = self._get_relation() + + relation.get_query().should_receive('where').once().and_return(relation) + + query = flexmock() + mock_query_builder = flexmock() + relation.get_query().should_receive('get_query').and_return(mock_query_builder) + mock_query_builder.should_receive('new_query').once().and_return(query) + + query.should_receive('from_').once().with_args('user_role').and_return(query) + + query.should_receive('where').once().with_args('user_id', 1).and_return(query) + + query.should_receive('where').once().with_args('foo', '=', 'bar', 'and').and_return(query) + + query.should_receive('lists').once().with_args('role_id').and_return([1, 2, 3]) + relation.should_receive('_format_sync_list').with_args([1, 2, 3]).and_return({1: {}, 2: {}, 3: {}}) + + relation = relation.where_pivot('foo', '=', 'bar') + relation.sync([1, 2, 3]) + + def _get_relation(self): + builder, parent = self._get_relation_arguments()[:2] + + return BelongsToMany(builder, parent, 'user_role', 'user_id', 'role_id', 'relation_name') + + def _get_relation_arguments(self): + parent = flexmock(Model()) + parent.should_receive('get_key').and_return(1) + parent.should_receive('get_created_at_column').and_return('created_at') + parent.should_receive('get_updated_at_column').and_return('updated_at') + + query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + flexmock(Builder) + builder = Builder(query) + builder.should_receive('get_query').and_return(query) + related = flexmock(Model()) + builder.set_model(related) + builder.should_receive('get_model').and_return(related) + + related.should_receive('get_key_name').and_return('id') + related.should_receive('get_table').and_return('roles') + related.should_receive('new_pivot').replace_with(lambda *args: Pivot(*args)) + + builder.get_query().should_receive('join').once().with_args('user_role', 'roles.id', '=', 'user_role.role_id') + builder.should_receive('where').once().with_args('user_role.user_id', '=', 1) + + return builder, parent, 'user_role', 'user_id', 'role_id', 'relation_id' + + +class OrmBelongsToManyModelStub(Model): + + __guarded__ = [] + + +class OrmBelongsToManyModelPivotStub(Model): + + __guarded__ = [] + + def __init__(self): + super(OrmBelongsToManyModelPivotStub, self).__init__() + + self.pivot = OrmBelongsToManyPivotStub() + + +class OrmBelongsToManyPivotStub(object): + pass diff --git a/tests/orm/relations/test_has_many.py b/tests/orm/relations/test_has_many.py new file mode 100644 index 00000000..a61ca0fb --- /dev/null +++ b/tests/orm/relations/test_has_many.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- + + +import arrow +from flexmock import flexmock, flexmock_teardown +from ... import EloquentTestCase + +from eloquent.query.builder import QueryBuilder +from eloquent.query.grammars import QueryGrammar +from eloquent.query.expression import QueryExpression +from eloquent.orm.builder import Builder +from eloquent.orm.model import Model +from eloquent.orm.relations import HasMany +from eloquent.orm.collection import Collection + + +class OrmHasManyTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_create_properly_creates_new_model(self): + relation = self._get_relation() + created = flexmock(Model(), save=lambda: True, set_attribute=lambda: None) + created.should_receive('save').once().and_return(True) + relation.get_related().should_receive('new_instance').once().with_args({'name': 'john'}).and_return(created) + created.should_receive('set_attribute').with_args('foreign_key', 1) + + self.assertEqual(created, relation.create(name='john')) + + def test_find_or_new_finds_model(self): + relation = self._get_relation() + model = flexmock() + model.foo = 'bar' + relation.get_query().should_receive('find').once().with_args('foo', ['*']).and_return(model) + model.should_receive('set_attribute').never() + + self.assertEqual('bar', relation.find_or_new('foo').foo) + + def test_find_or_new_returns_new_model_with_foreign_key_set(self): + relation = self._get_relation() + relation.get_query().should_receive('find').once().with_args('foo', ['*']).and_return(None) + model = flexmock() + model.foo = 'bar' + relation.get_related().should_receive('new_instance').once().with_args().and_return(model) + model.should_receive('set_attribute').once().with_args('foreign_key', 1) + + self.assertEqual('bar', relation.find_or_new('foo').foo) + + def test_first_or_new_finds_first_model(self): + relation = self._get_relation() + relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) + model = flexmock() + model.foo = 'bar' + relation.get_query().should_receive('first').once().with_args().and_return(model) + model.should_receive('set_attribute').never() + + self.assertEqual('bar', relation.first_or_new(foo='bar').foo) + + def test_first_or_new_returns_new_model_with_foreign_key_set(self): + relation = self._get_relation() + relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) + relation.get_query().should_receive('first').once().with_args().and_return(None) + + model = flexmock() + model.foo = 'bar' + relation.get_related().should_receive('new_instance').once().with_args().and_return(model) + model.should_receive('set_attribute').once().with_args('foreign_key', 1) + + self.assertEqual('bar', relation.first_or_new(foo='bar').foo) + + def test_first_or_create_finds_first_model(self): + relation = self._get_relation() + relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) + model = flexmock() + model.foo = 'bar' + relation.get_query().should_receive('first').once().with_args().and_return(model) + model.should_receive('set_attribute').never() + + self.assertEqual('bar', relation.first_or_create(foo='bar').foo) + + def test_first_or_create_returns_new_model_with_foreign_key_set(self): + relation = self._get_relation() + relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) + relation.get_query().should_receive('first').once().with_args().and_return(None) + + model = flexmock() + model.foo = 'bar' + relation.get_related().should_receive('new_instance').once().with_args({'foo': 'bar'}).and_return(model) + model.should_receive('save').once().and_return(True) + model.should_receive('set_attribute').once().with_args('foreign_key', 1) + + self.assertEqual('bar', relation.first_or_create(foo='bar').foo) + + def test_update_or_create_finds_first_model_and_updates(self): + relation = self._get_relation() + relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) + + model = flexmock() + model.foo = 'bar' + relation.get_query().should_receive('first').once().with_args().and_return(model) + relation.get_related().should_receive('new_instance').never() + model.should_receive('fill').once().with_args({'foo': 'baz'}) + model.should_receive('save').once() + + self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'foo': 'baz'}).foo) + + def test_update_or_create_creates_new_model_with_foreign_key_set(self): + relation = self._get_relation() + relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) + + relation.get_query().should_receive('first').once().with_args().and_return(None) + + model = flexmock() + model.foo = 'bar' + relation.get_related().should_receive('new_instance').once().and_return(model) + model.should_receive('fill').once().with_args({'foo': 'baz'}) + model.should_receive('save').once() + model.should_receive('set_attribute').once().with_args('foreign_key', 1) + + self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'foo': 'baz'}).foo) + + def test_update_updates_models_with_timestamps(self): + relation = self._get_relation() + relation.get_related().should_receive('uses_timestamps').once().and_return(True) + now = arrow.get() + relation.get_related().should_receive('fresh_timestamp').once().and_return(now) + relation.get_query().should_receive('update').once().with_args({'foo': 'bar', 'updated_at': now}).and_return('results') + + self.assertEqual('results', relation.update(foo='bar')) + + def test_relation_is_properly_initialized(self): + relation = self._get_relation() + model = flexmock(Model()) + model.should_receive('set_relation').once().with_args('foo', Collection) + models = relation.init_relation([model], 'foo') + + self.assertEqual([model], models) + + def test_eager_constraints_are_properly_added(self): + relation = self._get_relation() + relation.get_query().get_query().should_receive('where_in').once().with_args('table.foreign_key', [1, 2]) + + model1 = OrmHasOneModelStub() + model1.id = 1 + model2 = OrmHasOneModelStub() + model2.id = 2 + + relation.add_eager_constraints([model1, model2]) + + def test_models_are_properly_matched_to_parents(self): + relation = self._get_relation() + + result1 = OrmHasOneModelStub() + result1.foreign_key = 1 + result2 = OrmHasOneModelStub() + result2.foreign_key = 2 + result3 = OrmHasOneModelStub() + result3.foreign_key = 2 + + model1 = OrmHasOneModelStub() + model1.id = 1 + model2 = OrmHasOneModelStub() + model2.id = 2 + model3 = OrmHasOneModelStub() + model3.id = 3 + + relation.get_related().should_receive('new_collection').replace_with(lambda l: Collection(l)) + + models = relation.match([model1, model2, model3], Collection([result1, result2, result3]), 'foo') + + self.assertEqual(1, models[0].foo[0].foreign_key) + self.assertEqual(1, len(models[0].foo)) + self.assertEqual(2, models[1].foo[0].foreign_key) + self.assertEqual(2, models[1].foo[1].foreign_key) + self.assertEqual(2, len(models[1].foo)) + self.assertFalse(hasattr(models[2], 'foo')) + + def test_relation_count_query_can_be_built(self): + relation = self._get_relation() + query = flexmock(QueryBuilder(None, QueryGrammar(), None)) + builder = Builder(query) + builder.get_query().should_receive('select').once() + relation.get_parent().should_receive('get_table').and_return('table') + builder.should_receive('where').once().with_args('table.foreign_key', '=', QueryExpression) + parent_query = flexmock(QueryBuilder(None, None, None)) + relation.get_query().should_receive('get_query').and_return(parent_query) + grammar = flexmock() + parent_query.should_receive('get_grammar').once().and_return(grammar) + grammar.should_receive('wrap').once().with_args('table.id') + + relation.get_relation_count_query(builder, builder) + + def _get_relation(self): + flexmock(Builder) + query = flexmock(QueryBuilder(None, QueryGrammar(), None)) + builder = Builder(query) + builder.should_receive('where').with_args('table.foreign_key', '=', 1) + related = flexmock(Model()) + builder.should_receive('get_model').and_return(related) + parent = flexmock(Model()) + parent.should_receive('get_attribute').with_args('id').and_return(1) + parent.should_receive('get_created_at_column').and_return('created_at') + parent.should_receive('get_updated_at_column').and_return('updated_at') + parent.should_receive('new_query').and_return(builder) + + return HasMany(builder, parent, 'table.foreign_key', 'id') + + +class OrmHasOneModelStub(Model): + + pass diff --git a/tests/orm/relations/test_has_many_through.py b/tests/orm/relations/test_has_many_through.py new file mode 100644 index 00000000..40b8f049 --- /dev/null +++ b/tests/orm/relations/test_has_many_through.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- + +import arrow +from flexmock import flexmock, flexmock_teardown +from ... import EloquentTestCase +from ...utils import MockConnection + +from eloquent.query.builder import QueryBuilder +from eloquent.query.grammars import QueryGrammar +from eloquent.query.processors import QueryProcessor +from eloquent.query.expression import QueryExpression +from eloquent.orm.builder import Builder +from eloquent.orm.model import Model +from eloquent.orm.relations import HasManyThrough +from eloquent.orm.collection import Collection + + +class OrmHasManyThroughTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_relation_is_properly_initialized(self): + relation = self._get_relation() + model = flexmock(Model()) + relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l or [])) + model.should_receive('set_relation').once().with_args('foo', Collection) + models = relation.init_relation([model], 'foo') + + self.assertEqual([model], models) + + def test_eager_constraints_are_properly_added(self): + relation = self._get_relation() + relation.get_query().get_query().should_receive('where_in').once().with_args('users.country_id', [1, 2]) + model1 = OrmHasManyThroughModelStub() + model1.id = 1 + model2 = OrmHasManyThroughModelStub() + model2.id = 2 + relation.add_eager_constraints([model1, model2]) + + def test_models_are_properly_matched_to_parents(self): + relation = self._get_relation() + + result1 = OrmHasManyThroughModelStub() + result1.country_id = 1 + result2 = OrmHasManyThroughModelStub() + result2.country_id = 2 + result3 = OrmHasManyThroughModelStub() + result3.country_id = 2 + + model1 = OrmHasManyThroughModelStub() + model1.id = 1 + model2 = OrmHasManyThroughModelStub() + model2.id = 2 + model3 = OrmHasManyThroughModelStub() + model3.id = 3 + + relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l or [])) + models = relation.match([model1, model2, model3], Collection([result1, result2, result3]), 'foo') + + self.assertEqual(1, models[0].foo[0].country_id) + self.assertEqual(1, len(models[0].foo)) + self.assertEqual(2, models[1].foo[0].country_id) + self.assertEqual(2, models[1].foo[1].country_id) + self.assertEqual(2, len(models[1].foo)) + self.assertFalse(hasattr(models[2], 'foo')) + + def _get_relation(self): + flexmock(Builder) + query = flexmock(QueryBuilder(None, QueryGrammar(), None)) + builder = Builder(query) + builder.get_query().should_receive('join').once().with_args('users', 'users.id', '=', 'posts.user_id') + builder.should_receive('where').with_args('users.country_id', '=', 1) + country = flexmock(Model()) + country.should_receive('get_key').and_return(1) + country.should_receive('get_foreign_key').and_return('country_id') + user = flexmock(Model()) + user.should_receive('get_table').and_return('users') + user.should_receive('get_qualified_key_name').and_return('users.id') + post = flexmock(Model()) + post.should_receive('get_table').and_return('posts') + builder.should_receive('get_model').and_return(post) + + user.should_receive('get_key').and_return(1) + user.should_receive('get_created_at_column').and_return('created_at') + user.should_receive('get_updated_at_column').and_return('updated_at') + + parent = flexmock(Model()) + parent.should_receive('get_attribute').with_args('id').and_return(1) + parent.should_receive('get_created_at_column').and_return('created_at') + parent.should_receive('get_updated_at_column').and_return('updated_at') + parent.should_receive('new_query').and_return(builder) + + return HasManyThrough(builder, country, user, 'country_id', 'user_id') + + +class OrmHasManyThroughModelStub(Model): + + pass diff --git a/tests/orm/relations/test_has_one.py b/tests/orm/relations/test_has_one.py new file mode 100644 index 00000000..c0c64849 --- /dev/null +++ b/tests/orm/relations/test_has_one.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + + +import arrow +from flexmock import flexmock, flexmock_teardown +from ... import EloquentTestCase + +from eloquent.query.builder import QueryBuilder +from eloquent.query.grammars import QueryGrammar +from eloquent.query.expression import QueryExpression +from eloquent.orm.builder import Builder +from eloquent.orm.model import Model +from eloquent.orm.relations import HasOne +from eloquent.orm.collection import Collection + + +class OrmHasOneTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_save_method_set_foreign_key_on_model(self): + relation = self._get_relation() + mock_model = flexmock(Model(), save=lambda: True) + mock_model.should_receive('save').once().and_return(True) + result = relation.save(mock_model) + + attributes = result.get_attributes() + self.assertEqual(1, attributes['foreign_key']) + + def test_create_properly_creates_new_model(self): + relation = self._get_relation() + created = flexmock(Model(), save=lambda: True, set_attribute=lambda: None) + created.should_receive('save').once().and_return(True) + relation.get_related().should_receive('new_instance').once().with_args({'name': 'john'}).and_return(created) + created.should_receive('set_attribute').with_args('foreign_key', 1) + + self.assertEqual(created, relation.create(name='john')) + + def test_update_updates_models_with_timestamps(self): + relation = self._get_relation() + relation.get_related().should_receive('uses_timestamps').once().and_return(True) + now = arrow.get() + relation.get_related().should_receive('fresh_timestamp').once().and_return(now) + relation.get_query().should_receive('update').once().with_args({'foo': 'bar', 'updated_at': now}).and_return('results') + + self.assertEqual('results', relation.update(foo='bar')) + + def test_relation_is_properly_initialized(self): + relation = self._get_relation() + model = flexmock(Model()) + model.should_receive('set_relation').once().with_args('foo', None) + models = relation.init_relation([model], 'foo') + + self.assertEqual([model], models) + + def test_eager_constraints_are_properly_added(self): + relation = self._get_relation() + relation.get_query().get_query().should_receive('where_in').once().with_args('table.foreign_key', [1, 2]) + + model1 = OrmHasOneModelStub() + model1.id = 1 + model2 = OrmHasOneModelStub() + model2.id = 2 + + relation.add_eager_constraints([model1, model2]) + + def test_models_are_properly_matched_to_parents(self): + relation = self._get_relation() + + result1 = OrmHasOneModelStub() + result1.foreign_key = 1 + result2 = OrmHasOneModelStub() + result2.foreign_key = 2 + + model1 = OrmHasOneModelStub() + model1.id = 1 + model2 = OrmHasOneModelStub() + model2.id = 2 + model3 = OrmHasOneModelStub() + model3.id = 3 + + models = relation.match([model1, model2, model3], Collection([result1, result2]), 'foo') + + self.assertEqual(1, models[0].foo.foreign_key) + self.assertEqual(2, models[1].foo.foreign_key) + self.assertFalse(hasattr(models[2], 'foo')) + + def test_relation_count_query_can_be_built(self): + relation = self._get_relation() + query = flexmock(QueryBuilder(None, QueryGrammar(), None)) + builder = Builder(query) + builder.get_query().should_receive('select').once() + relation.get_parent().should_receive('get_table').and_return('table') + builder.should_receive('where').once().with_args('table.foreign_key', '=', QueryExpression) + parent_query = flexmock(QueryBuilder(None, None, None)) + relation.get_query().should_receive('get_query').and_return(parent_query) + grammar = flexmock() + parent_query.should_receive('get_grammar').once().and_return(grammar) + grammar.should_receive('wrap').once().with_args('table.id') + + relation.get_relation_count_query(builder, builder) + + def _get_relation(self): + flexmock(Builder) + query = flexmock(QueryBuilder(None, QueryGrammar(), None)) + builder = Builder(query) + builder.should_receive('where').with_args('table.foreign_key', '=', 1) + related = flexmock(Model()) + builder.should_receive('get_model').and_return(related) + parent = flexmock(Model()) + parent.should_receive('get_attribute').with_args('id').and_return(1) + parent.should_receive('get_created_at_column').and_return('created_at') + parent.should_receive('get_updated_at_column').and_return('updated_at') + parent.should_receive('new_query').and_return(builder) + + return HasOne(builder, parent, 'table.foreign_key', 'id') + + +class OrmHasOneModelStub(Model): + + pass diff --git a/tests/orm/relations/test_relation.py b/tests/orm/relations/test_relation.py new file mode 100644 index 00000000..9e278f09 --- /dev/null +++ b/tests/orm/relations/test_relation.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- + + +import arrow +from flexmock import flexmock, flexmock_teardown +from ... import EloquentTestCase + +from eloquent.query.builder import QueryBuilder +from eloquent.orm.builder import Builder +from eloquent.orm.model import Model +from eloquent.orm.relations import HasOne + + +class OrmRelationTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_set_relation_fail(self): + parent = OrmRelationResetModelStub() + relation = OrmRelationResetModelStub() + parent.set_relation('test', relation) + parent.set_relation('foo', 'bar') + self.assertFalse('foo' in parent.to_dict()) + + def test_touch_method_updates_related_timestamps(self): + builder = flexmock(Builder, get_model=None, where=None) + parent = Model() + parent = flexmock(parent) + parent.should_receive('get_attribute').with_args('id').and_return(1) + related = Model() + related = flexmock(related) + builder.should_receive('get_model').and_return(related) + builder.should_receive('where') + relation = HasOne(Builder(QueryBuilder(None, None, None)), parent, 'foreign_key', 'id') + related.should_receive('get_table').and_return('table') + related.should_receive('get_updated_at_column').and_return('updated_at') + now = arrow.get() + related.should_receive('fresh_timestamp').and_return(now) + builder.should_receive('update').once().with_args({'updated_at': now}) + + relation.touch() + + +class OrmRelationResetModelStub(Model): + + def get_query(self): + return self.new_query().get_query() diff --git a/tests/orm/test_builder.py b/tests/orm/test_builder.py new file mode 100644 index 00000000..542ea05f --- /dev/null +++ b/tests/orm/test_builder.py @@ -0,0 +1,347 @@ +# -*- coding: utf-8 -*- + +from flexmock import flexmock, flexmock_teardown +from .. import EloquentTestCase, mock +from ..utils import MockModel, MockQueryBuilder, MockConnection, MockProcessor + +from eloquent.query.grammars.grammar import QueryGrammar +from eloquent.orm.builder import Builder +from eloquent.orm.model import Model +from eloquent.exceptions.orm import ModelNotFound +from eloquent.orm.collection import Collection + + +class BuilderTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_find_method(self): + builder = Builder(self.get_mock_query_builder()) + builder.set_model(self.get_mock_model()) + builder.get_query().where = mock.MagicMock() + builder.first = mock.MagicMock(return_value='baz') + + result = builder.find('bar', ['column']) + + builder.get_query().where.assert_called_once_with( + 'foo_table.foo', '=', 'bar' + ) + self.assertEqual('baz', result) + + def test_find_or_new_model_found(self): + model = self.get_mock_model() + model.find_or_new = mock.MagicMock(return_value='baz') + + builder = Builder(self.get_mock_query_builder()) + builder.set_model(model) + builder.get_query().where = mock.MagicMock() + builder.first = mock.MagicMock(return_value='baz') + + expected = model.find_or_new('bar', ['column']) + result = builder.find('bar', ['column']) + + builder.get_query().where.assert_called_once_with( + 'foo_table.foo', '=', 'bar' + ) + self.assertEqual(expected, result) + + def test_find_or_new_model_not_found(self): + model = self.get_mock_model() + model.find_or_new = mock.MagicMock(return_value=self.get_mock_model()) + + builder = Builder(self.get_mock_query_builder()) + builder.set_model(model) + builder.get_query().where = mock.MagicMock() + builder.first = mock.MagicMock(return_value=None) + + result = model.find_or_new('bar', ['column']) + find_result = builder.find('bar', ['column']) + + builder.get_query().where.assert_called_once_with( + 'foo_table.foo', '=', 'bar' + ) + self.assertIsNone(find_result) + self.assertIsInstance(result, Model) + + def test_find_or_fail_raises_model_not_found_exception(self): + model = self.get_mock_model() + + builder = Builder(self.get_mock_query_builder()) + builder.set_model(model) + builder.get_query().where = mock.MagicMock() + builder.first = mock.MagicMock(return_value=None) + + self.assertRaises( + ModelNotFound, + builder.find_or_fail, + 'bar', + ['column'] + ) + + builder.get_query().where.assert_called_once_with( + 'foo_table.foo', '=', 'bar' + ) + + builder.first.assert_called_once_with( + ['column'] + ) + + def test_find_or_fail_with_many_raises_model_not_found_exception(self): + model = self.get_mock_model() + + builder = Builder(self.get_mock_query_builder()) + builder.set_model(model) + builder.get_query().where_in = mock.MagicMock() + builder.get = mock.MagicMock(return_value=Collection([1])) + + self.assertRaises( + ModelNotFound, + builder.find_or_fail, + [1, 2], + ['column'] + ) + + builder.get_query().where_in.assert_called_once_with( + 'foo_table.foo', [1, 2] + ) + + builder.get.assert_called_once_with( + ['column'] + ) + + def test_first_or_fail_raises_model_not_found_exception(self): + model = self.get_mock_model() + + builder = Builder(self.get_mock_query_builder()) + builder.set_model(model) + builder.first = mock.MagicMock(return_value=None) + + self.assertRaises( + ModelNotFound, + builder.first_or_fail, + ['column'] + ) + + builder.first.assert_called_once_with( + ['column'] + ) + + def test_find_with_many(self): + model = self.get_mock_model() + + builder = Builder(self.get_mock_query_builder()) + builder.set_model(model) + builder.get_query().where_in = mock.MagicMock() + builder.get = mock.MagicMock(return_value='baz') + + result = builder.find([1, 2], ['column']) + self.assertEqual('baz', result) + + builder.get_query().where_in.assert_called_once_with( + 'foo_table.foo', [1, 2] + ) + + builder.get.assert_called_once_with( + ['column'] + ) + + def test_first(self): + model = self.get_mock_model() + + builder = Builder(self.get_mock_query_builder()) + builder.set_model(model) + builder.take = mock.MagicMock(return_value=builder) + builder.get = mock.MagicMock(return_value=Collection(['bar'])) + + result = builder.first() + self.assertEqual('bar', result) + + builder.take.assert_called_once_with( + 1 + ) + + builder.get.assert_called_once_with( + ['*'] + ) + + def test_get_loads_models_and_hydrates_eager_relations(self): + flexmock(Builder) + builder = Builder(self.get_mock_query_builder()) + builder.should_receive('get_models').with_args(['foo']).and_return(['bar']) + builder.should_receive('eager_load_relations').with_args(['bar']).and_return(['bar', 'baz']) + builder.set_model(self.get_mock_model()) + builder.get_model().new_collection = mock.MagicMock(return_value=Collection(['bar', 'baz'])) + + results = builder.get(['foo']) + self.assertEqual(['bar', 'baz'], results.all()) + + builder.get_model().new_collection.assert_called_with(['bar', 'baz']) + + def test_get_does_not_eager_relations_when_no_results_are_returned(self): + flexmock(Builder) + builder = Builder(self.get_mock_query_builder()) + builder.should_receive('get_models').with_args(['foo']).and_return(['bar']) + builder.should_receive('eager_load_relations').with_args(['bar']).and_return([]) + builder.set_model(self.get_mock_model()) + builder.get_model().new_collection = mock.MagicMock(return_value=Collection([])) + + results = builder.get(['foo']) + self.assertEqual([], results.all()) + + builder.get_model().new_collection.assert_called_with([]) + + def test_pluck_with_model_found(self): + builder = Builder(self.get_mock_query_builder()) + + model = {'name': 'foo'} + builder.first = mock.MagicMock(return_value=model) + + self.assertEqual('foo', builder.pluck('name')) + + builder.first.assert_called_once_with( + ['name'] + ) + + def test_pluck_with_model_not_found(self): + builder = Builder(self.get_mock_query_builder()) + + builder.first = mock.MagicMock(return_value=None) + + self.assertIsNone(builder.pluck('name')) + + def test_chunk(self): + builder = Builder(self.get_mock_query_builder()) + results = [['foo1', 'foo2'], ['foo3'], []] + builder.for_page = mock.MagicMock(return_value=builder) + builder.get = mock.MagicMock(side_effect=results) + + i = 0 + for result in builder.chunk(2): + self.assertEqual(result, results[i]) + + i += 1 + + builder.for_page.assert_has_calls([ + mock.call(1, 2), + mock.call(2, 2), + mock.call(3, 2) + ]) + + # TODO: lists with get mutators + + def test_lists_without_model_getters(self): + builder = self.get_builder() + builder.get_query().lists = mock.MagicMock(return_value=['bar', 'baz']) + builder.set_model(self.get_mock_model()) + builder.get_model().has_get_mutator = mock.MagicMock(return_value=False) + + result = builder.lists('name') + self.assertEqual(['bar', 'baz'], result) + + builder.get_query().lists.assert_called_once_with('name', '') + + def test_get_models_hydrates_models(self): + builder = Builder(self.get_mock_query_builder()) + records = [{ + 'name': 'john', 'age': 26 + }, { + 'name': 'jane', 'age': 28 + }] + + builder.get_query().get = mock.MagicMock(return_value=records) + model = self.get_mock_model() + builder.set_model(model) + model.get_connection_name = mock.MagicMock(return_value='foo_connection') + model.hydrate = mock.MagicMock(return_value=Collection(['hydrated'])) + models = builder.get_models(['foo']) + + self.assertEqual(models, ['hydrated']) + + model.get_table.assert_called_once_with() + model.get_connection_name.assert_called_once_with() + model.hydrate.assert_called_once_with( + records, 'foo_connection' + ) + + # TODO: eager_load_relations loads top level relationship + + # TODO: relationship eager load process + + # TODO: get relation properly set nested relationship + + # TODO: get relation properly set nested relationship with similar names + + # TODO: eager load parsing sets proper relationships + + def test_query_passthru(self): + builder = self.get_builder() + builder.get_query().foobar = mock.MagicMock(return_value='foo') + + self.assertIsInstance(builder.foobar(), Builder) + self.assertEqual(builder.foobar(), builder) + + builder = self.get_builder() + builder.get_query().insert = mock.MagicMock(return_value='foo') + + self.assertEqual('foo', builder.insert(['bar'])) + + builder.get_query().insert.assert_called_once_with(['bar']) + + # TODO: test query scopes + + def test_simple_where(self): + builder = self.get_builder() + builder.get_query().where = mock.MagicMock() + result = builder.where('foo', '=', 'bar') + + self.assertEqual(builder, result) + + builder.get_query().where.assert_called_once_with('foo', '=', 'bar', 'and') + + def test_nested_where(self): + nested_query = self.get_builder() + nested_raw_query = self.get_mock_query_builder() + nested_query.get_query = mock.MagicMock(return_value=nested_raw_query) + model = self.get_mock_model() + builder = self.get_builder() + builder.set_model(model) + builder.get_query().add_nested_where_query = mock.MagicMock() + + result = builder.where(nested_query) + self.assertEqual(builder, result) + + builder.get_query().add_nested_where_query.assert_called_once_with(nested_raw_query, 'and') + + # TODO: nested query with scopes + + def test_delete_override(self): + builder = self.get_builder() + + builder.on_delete(lambda builder_: {'foo': builder_}) + + self.assertEqual({'foo': builder}, builder.delete()) + + # TODO: has nested + + # TODO: has nested with constraints + + def get_builder(self): + return Builder(self.get_mock_query_builder()) + + def get_mock_model(self): + model = MockModel().prepare_mock() + + return model + + def get_mock_query_builder(self): + connection = MockConnection().prepare_mock() + processor = MockProcessor().prepare_mock() + + builder = MockQueryBuilder( + connection, + QueryGrammar(), + processor + ).prepare_mock() + + return builder diff --git a/tests/orm/test_model.py b/tests/orm/test_model.py new file mode 100644 index 00000000..0e11d08d --- /dev/null +++ b/tests/orm/test_model.py @@ -0,0 +1,855 @@ +# -*- coding: utf-8 -*- + +import simplejson as json +import hashlib +import time +import datetime +from arrow import Arrow +from flexmock import flexmock, flexmock_teardown +from .. import EloquentTestCase, mock +from ..utils import MockModel, MockQueryBuilder, MockConnection, MockProcessor + +from eloquent.query.builder import QueryBuilder +from eloquent.query.grammars import QueryGrammar +from eloquent.query.processors import QueryProcessor +from eloquent.orm.builder import Builder +from eloquent.orm.model import Model +from eloquent.exceptions.orm import ModelNotFound, MassAssignmentError +from eloquent.orm.collection import Collection +from eloquent.connections import Connection +from eloquent import DatabaseManager +from eloquent.utils import basestring + + +class OrmModelTestCase(EloquentTestCase): + + def tearDown(self): + flexmock_teardown() + + def test_attributes_manipulation(self): + model = OrmModelStub() + model.name = 'foo' + self.assertEqual('foo', model.name) + del model.name + self.assertFalse(hasattr(model, 'name')) + + # TODO: mutators + + def test_dirty_attributes(self): + model = OrmModelStub(foo='1', bar=2, baz=3) + model.foo = 1 + model.bar = 20 + model.baz = 30 + + self.assertTrue(model.is_dirty()) + self.assertTrue(model.is_dirty('foo')) + self.assertTrue(model.is_dirty('bar')) + self.assertTrue(model.is_dirty('baz')) + self.assertTrue(model.is_dirty('foo', 'bar', 'baz')) + + # TODO: test calculated attributes + + def test_new_instance_returns_instance_wit_attributes_set(self): + model = OrmModelStub() + instance = model.new_instance({'name': 'john'}) + self.assertIsInstance(instance, OrmModelStub) + self.assertEqual('john', instance.name) + + def test_hydrate_creates_collection_of_models(self): + data = [ + {'name': 'john'}, + {'name': 'jane'} + ] + collection = OrmModelStub.hydrate(data, 'foo_connection') + + self.assertIsInstance(collection, Collection) + self.assertEqual(2, len(collection)) + self.assertIsInstance(collection[0], OrmModelStub) + self.assertIsInstance(collection[1], OrmModelStub) + self.assertEqual(collection[0].get_attributes(), collection[0].get_original()) + self.assertEqual(collection[1].get_attributes(), collection[1].get_original()) + self.assertEqual('john', collection[0].name) + self.assertEqual('jane', collection[1].name) + self.assertEqual('foo_connection', collection[0].get_connection_name()) + self.assertEqual('foo_connection', collection[1].get_connection_name()) + + def test_hydrate_raw_makes_raw_query(self): + model = OrmModelHydrateRawStub() + connection = MockConnection().prepare_mock() + connection.select.return_value = [] + model.get_connection = mock.MagicMock(return_value=connection) + + def _set_connection(name): + model.__connection__ = name + + return model + + OrmModelHydrateRawStub.set_connection = mock.MagicMock(side_effect=_set_connection) + collection = OrmModelHydrateRawStub.hydrate_raw('SELECT ?', ['foo']) + self.assertEqual('hydrated', collection) + connection.select.assert_called_once_with( + 'SELECT ?', ['foo'] + ) + + def test_create_saves_new_model(self): + model = OrmModelSaveStub.create(name='john') + self.assertTrue(model.get_saved()) + self.assertEqual('john', model.name) + + def test_find_method_calls_query_builder_correctly(self): + result = OrmModelFindStub.find(1) + + self.assertEqual('foo', result) + + def test_find_use_write_connection(self): + OrmModelFindWithWriteConnectionStub.on_write_connection().find(1) + + def test_find_with_list_calls_query_builder_correctly(self): + result = OrmModelFindManyStub.find([1, 2]) + + self.assertEqual('foo', result) + + def test_destroy_method_calls_query_builder_correctly(self): + OrmModelDestroyStub.destroy(1, 2, 3) + + def test_with_calls_query_builder_correctly(self): + result = OrmModelWithStub.with_('foo', 'bar') + self.assertEqual('foo', result) + + def test_update_process(self): + query = flexmock(Builder) + query.should_receive('where').once().with_args('id', 1) + query.should_receive('update').once().with_args({'name': 'john'}) + + model = OrmModelStub() + model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + model._update_timestamps = mock.MagicMock() + + # TODO: events + + model.id = 1 + model.foo = 'bar' + model.sync_original() + model.name = 'john' + model.set_exists(True) + self.assertTrue(model.save()) + + model.new_query.assert_called_once_with() + model._update_timestamps.assert_called_once_with() + + def test_update_process_does_not_override_timestamps(self): + query = flexmock(Builder) + query.should_receive('where').once().with_args('id', 1) + query.should_receive('update').once().with_args({'created_at': 'foo', 'updated_at': 'bar'}) + + model = OrmModelStub() + model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + model._update_timestamps = mock.MagicMock() + + # TODO: events + + model.id = 1 + model.sync_original() + model.created_at = 'foo' + model.updated_at = 'bar' + model.set_exists(True) + self.assertTrue(model.save()) + + model.new_query.assert_called_once_with() + self.assertTrue(model._update_timestamps.called) + + # TODO: update cancelled if updating event return false + + def test_update_process_without_timestamps(self): + query = flexmock(Builder) + query.should_receive('where').once().with_args('id', 1) + query.should_receive('update').once().with_args({'name': 'john'}) + + model = OrmModelStub() + model.__timestamps__ = False + model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + model._update_timestamps = mock.MagicMock() + + # TODO: events + + model.id = 1 + model.sync_original() + model.name = 'john' + model.set_exists(True) + self.assertTrue(model.save()) + + model.new_query.assert_called_once_with() + self.assertFalse(model._update_timestamps.called) + + def test_update_process_uses_old_primary_key(self): + query = flexmock(Builder) + query.should_receive('where').once().with_args('id', 1) + query.should_receive('update').once().with_args({'id': 2, 'name': 'john'}) + + model = OrmModelStub() + model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + model._update_timestamps = mock.MagicMock() + + # TODO: events + + model.id = 1 + model.sync_original() + model.id = 2 + model.name = 'john' + model.set_exists(True) + self.assertTrue(model.save()) + + model.new_query.assert_called_once_with() + self.assertTrue(model._update_timestamps.called) + + def test_timestamps_are_returned_as_objects(self): + model = Model() + model.set_raw_attributes({ + 'created_at': '2015-03-24', + 'updated_at': '2015-03-24' + }) + + self.assertIsInstance(model.created_at, Arrow) + self.assertIsInstance(model.updated_at, Arrow) + + def test_timestamps_are_returned_as_objects_from_timestamps_and_datetime(self): + model = Model() + model.set_raw_attributes({ + 'created_at': datetime.datetime.utcnow(), + 'updated_at': time.time() + }) + + self.assertIsInstance(model.created_at, Arrow) + self.assertIsInstance(model.updated_at, Arrow) + + def test_timestamps_are_returned_as_objects_on_create(self): + model = Model() + model.unguard() + + timestamps = { + 'created_at': datetime.datetime.now(), + 'updated_at': datetime.datetime.now() + } + + instance = model.new_instance(timestamps) + + self.assertIsInstance(instance.created_at, Arrow) + self.assertIsInstance(instance.updated_at, Arrow) + + model.reguard() + + def test_timestamps_return_none_if_set_to_none(self): + model = Model() + model.unguard() + + timestamps = { + 'created_at': datetime.datetime.now(), + 'updated_at': datetime.datetime.now() + } + + instance = model.new_instance(timestamps) + instance.created_at = None + + self.assertIsNone(instance.created_at) + + model.reguard() + + def test_insert_process(self): + query = flexmock(Builder) + + model = OrmModelStub() + query_builder = flexmock(QueryBuilder) + query_builder.should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) + model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + model._update_timestamps = mock.MagicMock() + + # TODO: events + + model.name = 'john' + model.set_exists(False) + self.assertTrue(model.save()) + self.assertEqual(1, model.id) + self.assertTrue(model.exists) + self.assertTrue(model._update_timestamps.called) + + model = OrmModelStub() + query_builder.should_receive('insert').once().with_args({'name': 'john'}) + model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + model._update_timestamps = mock.MagicMock() + model.set_incrementing(False) + + # TODO: events + + model.name = 'john' + model.set_exists(False) + self.assertTrue(model.save()) + self.assertFalse(hasattr(model, 'id')) + self.assertTrue(model.exists) + self.assertTrue(model._update_timestamps.called) + + # TODO: insert cancelled if creating event return false + + def test_delete_properly_deletes_model(self): + query = flexmock(Builder) + model = OrmModelStub() + builder = Builder(QueryBuilder(None, None, None)) + query.should_receive('where').once().with_args('id', 1).and_return(builder) + query.should_receive('delete').once() + model.new_query = mock.MagicMock(return_value=builder) + model._touch_owners = mock.MagicMock() + + model.set_exists(True) + model.id = 1 + model.delete() + + self.assertTrue(model._touch_owners.called) + + def test_push_no_relations(self): + flexmock(Builder) + model = flexmock(Model()) + query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + builder = Builder(query) + builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) + model.should_receive('new_query').once().and_return(builder) + model.should_receive('_update_timestamps').once() + + model.name = 'john' + model.set_exists(False) + + self.assertTrue(model.push()) + self.assertEqual(1, model.id) + self.assertTrue(model.exists) + + def test_push_empty_one_relation(self): + flexmock(Builder) + model = flexmock(Model()) + query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + builder = Builder(query) + builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) + model.should_receive('new_query').once().and_return(builder) + model.should_receive('_update_timestamps').once() + + model.name = 'john' + model.set_exists(False) + model.set_relation('relation_one', None) + + self.assertTrue(model.push()) + self.assertEqual(1, model.id) + self.assertTrue(model.exists) + self.assertIsNone(model.relation_one) + + def test_push_one_relation(self): + flexmock(Builder) + related1 = flexmock(Model()) + query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + builder = Builder(query) + builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'related1'}, 'id').and_return(2) + related1.should_receive('new_query').once().and_return(builder) + related1.should_receive('_update_timestamps').once() + + related1.name = 'related1' + related1.set_exists(False) + + model = flexmock(Model()) + model.should_receive('resolve_connection').and_return(MockConnection().prepare_mock()) + query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + builder = Builder(query) + builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) + model.should_receive('new_query').once().and_return(builder) + model.should_receive('_update_timestamps').once() + + model.name = 'john' + model.set_exists(False) + model.set_relation('relation_one', related1) + + self.assertTrue(model.push()) + self.assertEqual(1, model.id) + self.assertTrue(model.exists) + self.assertEqual(2, model.relation_one.id) + self.assertTrue(model.relation_one.exists) + self.assertEqual(2, related1.id) + self.assertTrue(related1.exists) + + def test_push_empty_many_relation(self): + flexmock(Builder) + model = flexmock(Model()) + query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + builder = Builder(query) + builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) + model.should_receive('new_query').once().and_return(builder) + model.should_receive('_update_timestamps').once() + + model.name = 'john' + model.set_exists(False) + model.set_relation('relation_many', Collection([])) + + self.assertTrue(model.push()) + self.assertEqual(1, model.id) + self.assertTrue(model.exists) + self.assertEqual(0, len(model.relation_many)) + + def test_push_many_relation(self): + flexmock(Builder) + related1 = flexmock(Model()) + query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + builder = Builder(query) + builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'related1'}, 'id').and_return(2) + related1.should_receive('new_query').once().and_return(builder) + related1.should_receive('_update_timestamps').once() + + related1.name = 'related1' + related1.set_exists(False) + + flexmock(Builder) + related2 = flexmock(Model()) + query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + builder = Builder(query) + builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'related2'}, 'id').and_return(3) + related2.should_receive('new_query').once().and_return(builder) + related2.should_receive('_update_timestamps').once() + + related2.name = 'related2' + related2.set_exists(False) + + model = flexmock(Model()) + model.should_receive('resolve_connection').and_return(MockConnection().prepare_mock()) + query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + builder = Builder(query) + builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) + model.should_receive('new_query').once().and_return(builder) + model.should_receive('_update_timestamps').once() + + model.name = 'john' + model.set_exists(False) + model.set_relation('relation_many', Collection([related1, related2])) + + self.assertTrue(model.push()) + self.assertEqual(1, model.id) + self.assertTrue(model.exists) + self.assertEqual(2, len(model.relation_many)) + self.assertEqual([2, 3], model.relation_many.lists('id')) + + def test_new_query_returns_eloquent_query_builder(self): + conn = flexmock(Connection) + grammar = flexmock(QueryGrammar) + processor = flexmock(QueryProcessor) + conn.should_receive('get_query_grammar').and_return(grammar) + conn.should_receive('get_post_processor').and_return(processor) + resolver = flexmock(DatabaseManager) + resolver.should_receive('connection').and_return(Connection(None)) + OrmModelStub.set_connection_resolver(DatabaseManager({})) + + model = OrmModelStub() + builder = model.new_query() + self.assertIsInstance(builder, Builder) + + def test_get_and_set_table(self): + model = OrmModelStub() + self.assertEqual('stub', model.get_table()) + model.set_table('foo') + self.assertEqual('foo', model.get_table()) + + def test_get_key_returns_primary_key_value(self): + model = OrmModelStub() + model.id = 1 + self.assertEqual(1, model.get_key()) + self.assertEqual('id', model.get_key_name()) + + def test_connection_management(self): + resolver = flexmock(DatabaseManager) + resolver.should_receive('connection').once().with_args('foo').and_return('bar') + + OrmModelStub.set_connection_resolver(DatabaseManager({})) + model = OrmModelStub() + model.set_connection('foo') + + self.assertEqual('bar', model.get_connection()) + + def test_to_dict(self): + model = OrmModelStub() + model.name = 'foo' + model.age = None + model.password = 'password1' + model.set_hidden(['password']) + + # TODO: relations + + d = model.to_dict() + + self.assertIsInstance(d, dict) + self.assertEqual('foo', d['name']) + self.assertIsNone(d['age']) + + # TODO: relations + + def test_to_dict_includes_default_formatted_timestamps(self): + model = Model() + model.set_raw_attributes({ + 'created_at': '2015-03-24', + 'updated_at': '2015-03-25' + }) + + d = model.to_dict() + + self.assertEqual('2015-03-24T00:00:00+00:00', d['created_at']) + self.assertEqual('2015-03-25T00:00:00+00:00', d['updated_at']) + + def test_to_dict_includes_custom_formatted_timestamps(self): + class Stub(Model): + + def get_date_format(self): + return 'DD-MM-YY' + + model = Stub() + model.set_raw_attributes({ + 'created_at': '2015-03-24', + 'updated_at': '2015-03-25' + }) + + d = model.to_dict() + + self.assertEqual('24-03-15', d['created_at']) + self.assertEqual('25-03-15', d['updated_at']) + + def test_visible_creates_dict_whitelist(self): + model = OrmModelStub() + model.set_visible(['name']) + model.name = 'John' + model.age = 28 + d = model.to_dict() + + self.assertEqual({'name': 'John'}, d) + + # TODO: hidden also hides relationship + + # TODO: to_dict uses mutators + + def test_hidden_are_ignored_when_visible(self): + model = OrmModelStub(name='john', age=28, id='foo') + model.set_visible(['name', 'id']) + model.set_hidden(['name', 'age']) + d = model.to_dict() + + self.assertIn('name', d) + self.assertIn('id', d) + self.assertNotIn('age', d) + + def test_fillable(self): + model = OrmModelStub() + model.fillable(['name', 'age']) + model.fill(name='foo', age=28) + self.assertEqual('foo', model.name) + self.assertEqual(28, model.age) + + def test_unguard_allows_anything(self): + model = OrmModelStub() + model.unguard() + model.guard(['*']) + model.fill(name='foo', age=28) + self.assertEqual('foo', model.name) + self.assertEqual(28, model.age) + model.reguard() + + def test_underscore_properties_are_not_filled(self): + model = OrmModelStub() + model.fill(_foo='bar') + self.assertEqual({}, model.get_attributes()) + + def test_guarded(self): + model = OrmModelStub() + model.guard(['name', 'age']) + model.fill(name='foo', age='bar', foo='bar') + self.assertFalse(hasattr(model, 'name')) + self.assertFalse(hasattr(model, 'age')) + self.assertEqual('bar', model.foo) + + def test_fillable_overrides_guarded(self): + model = OrmModelStub() + model.guard(['name', 'age']) + model.fillable(['age', 'foo']) + model.fill(name='foo', age='bar', foo='bar') + self.assertFalse(hasattr(model, 'name')) + self.assertEqual('bar', model.age) + self.assertEqual('bar', model.foo) + + def test_global_guarded(self): + model = OrmModelStub() + model.guard(['*']) + self.assertRaises( + MassAssignmentError, + model.fill, + name='foo', age='bar', foo='bar' + ) + + # TODO: test relations + + def test_models_assumes_their_name(self): + model = OrmModelNoTableStub() + + self.assertEqual('orm_model_no_table_stubs', model.get_table()) + + # TODO: mutators cache + + def test_clone_model_makes_a_fresh_copy(self): + model = OrmModelStub() + model.id = 1 + model.set_exists(True) + model.first = 'john' + model.last = 'doe' + model.created_at = model.fresh_timestamp() + model.updated_at = model.fresh_timestamp() + # TODO: relation + + clone = model.replicate() + + self.assertFalse(hasattr(clone, 'id')) + self.assertFalse(clone.exists) + self.assertEqual('john', clone.first) + self.assertEqual('doe', clone.last) + self.assertFalse(hasattr(clone, 'created_at')) + self.assertFalse(hasattr(clone, 'updated_at')) + # TODO: relation + + clone.first = 'jane' + + self.assertEqual('john', model.first) + self.assertEqual('jane', clone.first) + + def test_get_attribute_raise_attribute_error(self): + model = OrmModelStub() + + try: + relation = model.incorrect_relation + self.fail('AttributeError not raised') + except AttributeError: + pass + + def test_increment(self): + query = flexmock() + model_mock = flexmock(OrmModelStub, new_query=lambda: query) + model = OrmModelStub() + model.set_exists(True) + model.id = 1 + model.sync_original_attribute('id') + model.foo = 2 + + model_mock.should_receive('new_query').and_return(query) + query.should_receive('where').and_return(query) + query.should_receive('increment') + + model.public_increment('foo') + + self.assertEqual(3, model.foo) + self.assertFalse(model.is_dirty()) + + # TODO: relationship touch_owners is propagated + + # TODO: relationship touch_owners is not propagated if no relationship result + + def test_timestamps_are_not_update_with_timestamps_false_save_option(self): + query = flexmock(Builder) + query.should_receive('where').once().with_args('id', 1) + query.should_receive('update').once().with_args({'name': 'john'}) + + model = OrmModelStub() + model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + + model.id = 1 + model.sync_original() + model.name = 'john' + model.set_exists(True) + self.assertTrue(model.save({'timestamps': False})) + self.assertFalse(hasattr(model, 'updated_at')) + + model.new_query.assert_called_once_with() + + def test_casts(self): + model = OrmModelCastingStub() + model.first = '3' + model.second = '4.0' + model.third = 2.5 + model.fourth = 1 + model.fifth = 0 + model.sixth = {'foo': 'bar'} + model.seventh = ['foo', 'bar'] + model.eighth = {'foo': 'bar'} + + self.assertIsInstance(model.first, int) + self.assertIsInstance(model.second, float) + self.assertIsInstance(model.third, basestring) + self.assertIsInstance(model.fourth, bool) + self.assertIsInstance(model.fifth, bool) + self.assertIsInstance(model.sixth, dict) + self.assertIsInstance(model.seventh, list) + self.assertIsInstance(model.eighth, dict) + self.assertTrue(model.fourth) + self.assertFalse(model.fifth) + self.assertEqual({'foo': 'bar'}, model.sixth) + self.assertEqual({'foo': 'bar'}, model.eighth) + self.assertEqual(['foo', 'bar'], model.seventh) + + d = model.to_dict() + + self.assertIsInstance(d['first'], int) + self.assertIsInstance(d['second'], float) + self.assertIsInstance(d['third'], basestring) + self.assertIsInstance(d['fourth'], bool) + self.assertIsInstance(d['fifth'], bool) + self.assertIsInstance(d['sixth'], dict) + self.assertIsInstance(d['seventh'], list) + self.assertIsInstance(d['eighth'], dict) + self.assertTrue(d['fourth']) + self.assertFalse(d['fifth']) + self.assertEqual({'foo': 'bar'}, d['sixth']) + self.assertEqual({'foo': 'bar'}, d['eighth']) + self.assertEqual(['foo', 'bar'], d['seventh']) + + def test_casts_preserve_null(self): + model = OrmModelCastingStub() + model.first = None + model.second = None + model.third = None + model.fourth = None + model.fifth = None + model.sixth = None + model.seventh = None + model.eighth = None + + self.assertIsNone(model.first) + self.assertIsNone(model.second) + self.assertIsNone(model.third) + self.assertIsNone(model.fourth) + self.assertIsNone(model.fifth) + self.assertIsNone(model.sixth) + self.assertIsNone(model.seventh) + self.assertIsNone(model.eighth) + + d = model.to_dict() + + self.assertIsNone(d['first']) + self.assertIsNone(d['second']) + self.assertIsNone(d['third']) + self.assertIsNone(d['fourth']) + self.assertIsNone(d['fifth']) + self.assertIsNone(d['sixth']) + self.assertIsNone(d['seventh']) + self.assertIsNone(d['eighth']) + + +class OrmModelStub(Model): + + __table__ = 'stub' + + __guarded__ = [] + + def get_list_items_attribute(self, value): + return json.loads(value) + + def set_list_items_attribute(self, value): + self._attributes['list_items'] = json.dumps(value) + + def get_password_attribute(self, _): + return '******' + + def set_password_attribute(self, value): + self._attributes['password_hash'] = hashlib.md5(value).hexdigest() + + def public_increment(self, column, amount=1): + return self._increment(column, amount) + + def get_dates(self): + return [] + + +class OrmModelHydrateRawStub(Model): + + @classmethod + def hydrate(cls, items, connection=None): + return 'hydrated' + + +class OrmModelWithStub(Model): + + def new_query(self): + mock = flexmock(Builder(None)) + mock.should_receive('with_').once().with_args('foo', 'bar').and_return('foo') + + return mock + + +class OrmModelSaveStub(Model): + + __table__ = 'save_stub' + + __guarded__ = [] + + __saved = False + + def save(self, options=None): + self.__saved = True + + def set_incrementing(self, value): + self.__incrementing__ = value + + def get_saved(self): + return self.__saved + + +class OrmModelFindStub(Model): + + def new_query(self): + flexmock(Builder).should_receive('find').once().with_args(1, ['*']).and_return('foo') + + return Builder(None) + + +class OrmModelFindWithWriteConnectionStub(Model): + + def new_query(self): + mock = flexmock(Builder) + mock_query = flexmock(QueryBuilder) + mock_query.should_receive('use_write_connection').once().and_return(flexmock) + mock.should_receive('find').once().with_args(1).and_return('foo') + + return Builder(QueryBuilder(None, None, None)) + + +class OrmModelFindManyStub(Model): + + def new_query(self): + mock = flexmock(Builder) + mock.should_receive('find').once().with_args([1, 2], ['*']).and_return('foo') + + return Builder(QueryBuilder(None, None, None)) + + +class OrmModelDestroyStub(Model): + + def new_query(self): + mock = flexmock(Builder) + model = flexmock() + mock_query = flexmock(QueryBuilder) + mock_query.should_receive('where_in').once().with_args('id', [1, 2, 3]).and_return(flexmock) + mock.should_receive('get').once().and_return([model]) + model.should_receive('delete').once() + + return Builder(QueryBuilder(None, None, None)) + + +class OrmModelNoTableStub(Model): + + pass + + +class OrmModelCastingStub(Model): + + __casts__ = { + 'first': 'int', + 'second': 'float', + 'third': 'str', + 'fourth': 'bool', + 'fifth': 'boolean', + 'sixth': 'dict', + 'seventh': 'list', + 'eighth': 'json' + } diff --git a/tests/query/test_query_builder.py b/tests/query/test_query_builder.py index 7f88a0d6..9c331143 100644 --- a/tests/query/test_query_builder.py +++ b/tests/query/test_query_builder.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +import re + from .. import EloquentTestCase from .. import mock @@ -1513,6 +1515,49 @@ def test_sub_select(self): self.assertEqual(expected_sql, builder.to_sql()) self.assertEqual(expected_bindings, builder.get_bindings()) + def test_chunk(self): + builder = self.get_builder() + results = [ + {'foo': 'bar'}, + {'foo': 'baz'}, + {'foo': 'bam'}, + {'foo': 'boom'} + ] + + def select(query, bindings, _): + index = int(re.search('OFFSET (\d+)', query).group(1)) + limit = int(re.search('LIMIT (\d+)', query).group(1)) + + if index >= len(results): + return [] + + return results[index:index + limit] + + builder.get_connection().select.side_effect = select + + builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_) + + i = 0 + for users in builder.from_('users').chunk(1): + self.assertEqual(users[0], results[i]) + + i += 1 + + builder = self.get_builder() + results = [ + {'foo': 'bar'}, + {'foo': 'baz'}, + {'foo': 'bam'}, + {'foo': 'boom'} + ] + + builder.get_connection().select.side_effect = select + + builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_) + + for users in builder.from_('users').chunk(2): + self.assertEqual(2, len(users)) + def get_mysql_builder(self): grammar = MySqlQueryGrammar() processor = MockProcessor().prepare_mock() diff --git a/tests/utils.py b/tests/utils.py index 68d3194a..64d47175 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,8 @@ from eloquent.query.processors.processor import QueryProcessor from eloquent.database_manager import DatabaseManager from eloquent.connectors.connection_factory import ConnectionFactory +from eloquent.query.builder import QueryBuilder +from eloquent.orm.model import Model class MockConnection(ConnectionInterface): @@ -48,3 +50,21 @@ def prepare_mock(self): self.make = mock.MagicMock(return_value=MockConnection().prepare_mock()) return self + + +class MockQueryBuilder(QueryBuilder): + + def prepare_mock(self): + self.from__ = 'foo_table' + + return self + + +class MockModel(Model): + + def prepare_mock(self): + self.get_key_name = mock.MagicMock(return_value='foo') + self.get_table = mock.MagicMock(return_value='foo_table') + self.get_qualified_key_name = mock.MagicMock(return_value='foo_table.foo') + + return self