diff --git a/.gitignore b/.gitignore
index 486a089..51dbe93 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,4 +5,5 @@ venv/
MANIFEST
*.pyc
dist/
-*.egg-info
\ No newline at end of file
+*.egg-info
+_build
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..f288702
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
diff --git a/MANIFEST.in b/MANIFEST.in
index f057b7b..d8678ee 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,8 +1,8 @@
include requirements.txt
include README.md
-include vot/stack/*.yaml
+recursive-include vot/stack/ *.yaml
include vot/dataset/*.png
include vot/dataset/*.jpg
include vot/document/*.css
include vot/document/*.js
-include vot/document/*.tex
\ No newline at end of file
+include vot/document/*.tex
diff --git a/README.md b/README.md
index c3e2a63..b732bed 100644
--- a/README.md
+++ b/README.md
@@ -2,21 +2,32 @@
The VOT evaluation toolkit
==========================
-This repository contains the official evaluation toolkit for the [Visual Object Tracking (VOT) challenge](http://votchallenge.net/). This is the official version of the toolkit, implemented in Python 3 language. If you are looking for the old Matlab version, you can find an archived repository [here](https://github.com/vicoslab/toolkit-legacy).
+[![PyPI package version](https://badge.fury.io/py/vot-toolkit.svg)](https://badge.fury.io/py/vot-toolkit)
-For more detailed informations consult the documentation available in the source or a compiled version of the documentation [here](http://www.votchallenge.net/howto/). You can also subscribe to the VOT [mailing list](https://service.ait.ac.at/mailman/listinfo/votchallenge) to receive news about challenges and important software updates or join our [support form](https://groups.google.com/forum/?hl=en#!forum/votchallenge-help) to ask questions.
+This repository contains the official evaluation toolkit for the [Visual Object Tracking (VOT) challenge](http://votchallenge.net/). This is the official version of the toolkit, implemented in Python 3 language. If you are looking for the old Matlab version, you can find an archived repository [here](https://github.com/votchallenge/toolkit-legacy).
+
+For more detailed informations consult the documentation available in the source or a compiled version of the documentation [here](http://www.votchallenge.net/howto/). You can also subscribe to the VOT [mailing list](https://liste.arnes.si/mailman3/lists/votchallenge.lists.arnes.si/) to receive news about challenges and important software updates or join our [support form](https://groups.google.com/forum/?hl=en#!forum/votchallenge-help) to ask questions.
Developers
----------
-* Luka Čehovin Zajc, University of Ljubljana (lead developer)
-* Alan Lukežič, University of Ljubljana
+The VOT toolkit is developed and maintained by [Luka Čehovin Zajc](https://vicos.si/lukacu) with the help of the VOT innitiative members and the VOT community.
+
+Contributors:
+
+* [Luka Čehovin Zajc](https://vicos.si/lukacu), University of Ljubljana
+* [Alan Lukežič](https://vicos.si/people/alan_lukezic/), University of Ljubljana
* Yan Song, Tampere University
+Acknowledgements
+----------------
+
+The development of this package was supported by Slovenian research agency (ARRS) projects Z2-1866 and J2-316.
+
License
-------
-Copyright (C) 2021 Luka Čehovin Zajc and the [VOT Challenge innitiative](http://votchallenge.net/).
+Copyright (C) 2023 Luka Čehovin Zajc and the [VOT Challenge innitiative](http://votchallenge.net/).
This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
@@ -27,4 +38,4 @@ You should have received a copy of the GNU General Public License along with thi
Enquiries, Question and Comments
--------------------------------
-If you have any further enquiries, question, or comments, please refer to the contact information link on the [VOT homepage](http://votchallenge.net/). If you would like to file a bug report or a feature request, use the [Github issue tracker](https://github.com/vicoslab/toolkit/issues). **The issue tracker is for toolkit issues only**, if you have a problem with tracker integration or any other questions, please use our [support forum](https://groups.google.com/forum/?hl=en#!forum/votchallenge-help).
+If you have any further enquiries, question, or comments, please refer to the contact information link on the [VOT homepage](http://votchallenge.net/). If you would like to file a bug report or a feature request, use the [Github issue tracker](https://github.com/votchallenge/toolkit/issues). **The issue tracker is for toolkit issues only**, if you have a problem with tracker integration or any other questions, please use our [support forum](https://groups.google.com/forum/?hl=en#!forum/votchallenge-help).
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 0000000..676b72b
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,196 @@
+# Makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line.
+SPHINXOPTS =
+SPHINXBUILD = sphinx-build
+PAPER =
+BUILDDIR = _build
+
+# Internal variables.
+PAPEROPT_a4 = -D latex_paper_size=a4
+PAPEROPT_letter = -D latex_paper_size=letter
+ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
+# the i18n builder cannot share the environment and doctrees with the others
+I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
+
+.PHONY: help
+help:
+ @echo "Please use \`make ' where is one of"
+ @echo " html to make standalone HTML files"
+ @echo " dirhtml to make HTML files named index.html in directories"
+ @echo " singlehtml to make a single large HTML file"
+ @echo " pickle to make pickle files"
+ @echo " json to make JSON files"
+ @echo " htmlhelp to make HTML files and a HTML help project"
+ @echo " qthelp to make HTML files and a qthelp project"
+ @echo " applehelp to make an Apple Help Book"
+ @echo " devhelp to make HTML files and a Devhelp project"
+ @echo " epub to make an epub"
+ @echo " epub3 to make an epub3"
+ @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
+ @echo " latexpdf to make LaTeX files and run them through pdflatex"
+ @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
+ @echo " text to make text files"
+ @echo " man to make manual pages"
+ @echo " texinfo to make Texinfo files"
+ @echo " info to make Texinfo files and run them through makeinfo"
+ @echo " gettext to make PO message catalogs"
+ @echo " changes to make an overview of all changed/added/deprecated items"
+ @echo " xml to make Docutils-native XML files"
+ @echo " pseudoxml to make pseudoxml-XML files for display purposes"
+ @echo " linkcheck to check all external links for integrity"
+ @echo " doctest to run all doctests embedded in the documentation (if enabled)"
+ @echo " coverage to run coverage check of the documentation (if enabled)"
+ @echo " dummy to check syntax errors of document sources"
+
+.PHONY: clean
+clean:
+ rm -rf $(BUILDDIR)/*
+
+.PHONY: html
+html:
+ $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
+ @echo
+ @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
+
+.PHONY: dirhtml
+dirhtml:
+ $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
+ @echo
+ @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
+
+.PHONY: singlehtml
+singlehtml:
+ $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
+ @echo
+ @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
+
+.PHONY: pickle
+pickle:
+ $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
+ @echo
+ @echo "Build finished; now you can process the pickle files."
+
+.PHONY: json
+json:
+ $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
+ @echo
+ @echo "Build finished; now you can process the JSON files."
+
+.PHONY: htmlhelp
+htmlhelp:
+ $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
+ @echo
+ @echo "Build finished; now you can run HTML Help Workshop with the" \
+ ".hhp project file in $(BUILDDIR)/htmlhelp."
+
+.PHONY: epub
+epub:
+ $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
+ @echo
+ @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
+
+.PHONY: epub3
+epub3:
+ $(SPHINXBUILD) -b epub3 $(ALLSPHINXOPTS) $(BUILDDIR)/epub3
+ @echo
+ @echo "Build finished. The epub3 file is in $(BUILDDIR)/epub3."
+
+.PHONY: latex
+latex:
+ $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
+ @echo
+ @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
+ @echo "Run \`make' in that directory to run these through (pdf)latex" \
+ "(use \`make latexpdf' here to do that automatically)."
+
+.PHONY: latexpdf
+latexpdf:
+ $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
+ @echo "Running LaTeX files through pdflatex..."
+ $(MAKE) -C $(BUILDDIR)/latex all-pdf
+ @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
+
+.PHONY: latexpdfja
+latexpdfja:
+ $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
+ @echo "Running LaTeX files through platex and dvipdfmx..."
+ $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
+ @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
+
+.PHONY: text
+text:
+ $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
+ @echo
+ @echo "Build finished. The text files are in $(BUILDDIR)/text."
+
+.PHONY: man
+man:
+ $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
+ @echo
+ @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
+
+.PHONY: texinfo
+texinfo:
+ $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
+ @echo
+ @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
+ @echo "Run \`make' in that directory to run these through makeinfo" \
+ "(use \`make info' here to do that automatically)."
+
+.PHONY: info
+info:
+ $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
+ @echo "Running Texinfo files through makeinfo..."
+ make -C $(BUILDDIR)/texinfo info
+ @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
+
+.PHONY: gettext
+gettext:
+ $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
+ @echo
+ @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
+
+.PHONY: changes
+changes:
+ $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
+ @echo
+ @echo "The overview file is in $(BUILDDIR)/changes."
+
+.PHONY: linkcheck
+linkcheck:
+ $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
+ @echo
+ @echo "Link check complete; look for any errors in the above output " \
+ "or in $(BUILDDIR)/linkcheck/output.txt."
+
+.PHONY: doctest
+doctest:
+ $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
+ @echo "Testing of doctests in the sources finished, look at the " \
+ "results in $(BUILDDIR)/doctest/output.txt."
+
+.PHONY: coverage
+coverage:
+ $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage
+ @echo "Testing of coverage in the sources finished, look at the " \
+ "results in $(BUILDDIR)/coverage/python.txt."
+
+.PHONY: xml
+xml:
+ $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
+ @echo
+ @echo "Build finished. The XML files are in $(BUILDDIR)/xml."
+
+.PHONY: pseudoxml
+pseudoxml:
+ $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
+ @echo
+ @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
+
+.PHONY: dummy
+dummy:
+ $(SPHINXBUILD) -b dummy $(ALLSPHINXOPTS) $(BUILDDIR)/dummy
+ @echo
+ @echo "Build finished. Dummy builder generates no files."
diff --git a/docs/api.rst b/docs/api.rst
new file mode 100644
index 0000000..61f3434
--- /dev/null
+++ b/docs/api.rst
@@ -0,0 +1,22 @@
+Documentation
+=============
+
+The API section contains the generated documentation of individual structures and functions from source code docstrings.
+
+.. toctree::
+ :maxdepth: 2
+
+ api/analysis
+ api/dataset
+ api/document
+ api/experiment
+ api/region
+ api/stack
+ api/tracker
+ api/utilities
+
+Core utilities
+--------------
+
+.. automodule:: vot
+ :members:
\ No newline at end of file
diff --git a/docs/api/analysis.rst b/docs/api/analysis.rst
new file mode 100644
index 0000000..c29896a
--- /dev/null
+++ b/docs/api/analysis.rst
@@ -0,0 +1,41 @@
+Analysis module
+===============
+
+The analysis module contains classes that implement various performance analysis methodologies. It also contains a parallel runtime with caching capabilities that
+enables efficient execution of large-scale evaluations.
+
+.. automodule:: vot.analysis
+ :members:
+
+.. automodule:: vot.analysis.parallel
+ :members:
+
+Accuracy analysis
+-----------------
+
+.. automodule:: vot.analysis.accuracy
+ :members:
+
+Failure analysis
+----------------
+
+.. automodule:: vot.analysis.failure
+ :members:
+
+Long-term measures
+------------------
+
+.. automodule:: vot.analysis.longterm
+ :members:
+
+Multi-start measures
+--------------------
+
+.. automodule:: vot.analysis.multistart
+ :members:
+
+Supervision analysis
+--------------------
+
+.. automodule:: vot.analysis.supervision
+ :members:
\ No newline at end of file
diff --git a/docs/api/dataset.rst b/docs/api/dataset.rst
new file mode 100644
index 0000000..bd34f02
--- /dev/null
+++ b/docs/api/dataset.rst
@@ -0,0 +1,28 @@
+Dataset module
+==============
+
+.. automodule:: vot.dataset
+ :members:
+
+.. automodule:: vot.datase.common
+ :members:
+
+Extended dataset support
+------------------------
+
+Many datasets are supported by the toolkit using special adapters.
+
+### OTB
+
+.. automodule:: vot.dataset.otb
+ :members:
+
+### GOT10k
+
+.. automodule:: vot.dataset.got10k
+ :members:
+
+### TrackingNet
+
+.. automodule:: vot.dataset.trackingnet
+ :members:
diff --git a/docs/api/document.rst b/docs/api/document.rst
new file mode 100644
index 0000000..09b1b2d
--- /dev/null
+++ b/docs/api/document.rst
@@ -0,0 +1,21 @@
+Document module
+============
+
+.. automodule:: vot.document
+ :members:
+
+.. automodule:: vot.document.common
+ :members:
+
+HTML document generation
+------------------------
+
+.. automodule:: vot.document
+ :members:
+
+LaTeX document generation
+-------------------------
+
+.. automodule:: vot.document.latex
+ :members:
+
diff --git a/docs/api/experiment.rst b/docs/api/experiment.rst
new file mode 100644
index 0000000..25b3353
--- /dev/null
+++ b/docs/api/experiment.rst
@@ -0,0 +1,5 @@
+Experiment module
+================
+
+.. automodule:: vot.experiment
+ :members:
\ No newline at end of file
diff --git a/docs/api/region.rst b/docs/api/region.rst
new file mode 100644
index 0000000..cf8e48b
--- /dev/null
+++ b/docs/api/region.rst
@@ -0,0 +1,23 @@
+Region module
+============
+
+.. automodule:: vot.region
+ :members:
+
+Shapes
+------
+
+.. automodule:: vot.region.shapes
+ :members:
+
+Raster utilities
+----------------
+
+.. automodule:: vot.region.raster
+ :members:
+
+IO functions
+------------
+
+.. automodule:: vot.region.io
+ :members:
\ No newline at end of file
diff --git a/docs/api/stack.rst b/docs/api/stack.rst
new file mode 100644
index 0000000..6d360d7
--- /dev/null
+++ b/docs/api/stack.rst
@@ -0,0 +1,5 @@
+Stack module
+============
+
+.. automodule:: vot.stack
+ :members:
\ No newline at end of file
diff --git a/docs/api/tracker.rst b/docs/api/tracker.rst
new file mode 100644
index 0000000..c84ac22
--- /dev/null
+++ b/docs/api/tracker.rst
@@ -0,0 +1,18 @@
+Tracker module
+==============
+
+.. automodule:: vot.tracker
+ :members:
+
+TraX protocol module
+--------------------
+
+.. automodule:: vot.tracker.trax
+ :members:
+
+Results module
+--------------
+
+.. automodule:: vot.tracker.results
+ :members:
+
diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst
new file mode 100644
index 0000000..c1a2157
--- /dev/null
+++ b/docs/api/utilities.rst
@@ -0,0 +1,5 @@
+Utilities module
+===============
+
+.. automodule:: vot.utilities
+ :members:
\ No newline at end of file
diff --git a/docs/conf.py b/docs/conf.py
new file mode 100644
index 0000000..dbc9f00
--- /dev/null
+++ b/docs/conf.py
@@ -0,0 +1,136 @@
+# -*- coding: utf-8 -*-
+#
+import os
+import sys
+sys.path.insert(0, os.path.abspath('../python'))
+
+# -- General configuration ------------------------------------------------
+
+
+extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon']
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+from recommonmark.parser import CommonMarkParser
+
+source_parsers = {
+ '.md': CommonMarkParser,
+}
+
+source_suffix = ['.rst', '.md']
+
+master_doc = 'index'
+
+# General information about the project.
+project = u'VOT Toolkit'
+copyright = u'2022, Luka Cehovin Zajc'
+author = u'Luka Cehovin Zajc'
+
+try:
+ import sys
+ import os
+
+ __version__ = "0.0.0"
+
+ exec(open(os.path.join(os.path.dirname(__file__), '..', 'vot', 'version.py')).read())
+
+ version = __version__
+except:
+ version = 'unknown'
+
+# The full version, including alpha/beta/rc tags.
+release = version
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = None
+
+# There are two options for replacing |today|: either, you set today to some
+# non-false value, then it is used:
+#
+# today = ''
+#
+# Else, today_fmt is used as the format for a strftime call.
+#
+# today_fmt = '%B %d, %Y'
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This patterns also effect to html_static_path and html_extra_path
+exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+
+# The reST default role (used for this markup: `text`) to use for all
+# documents.
+#
+# default_role = None
+
+# If true, '()' will be appended to :func: etc. cross-reference text.
+#
+# add_function_parentheses = True
+
+# If true, the current module name will be prepended to all description
+# unit titles (such as .. function::).
+#
+# add_module_names = True
+
+# If true, sectionauthor and moduleauthor directives will be shown in the
+# output. They are ignored by default.
+#
+# show_authors = False
+
+# The name of the Pygments (syntax highlighting) style to use.
+pygments_style = 'sphinx'
+
+# A list of ignored prefixes for module index sorting.
+# modindex_common_prefix = []
+
+# If true, keep warnings as "system message" paragraphs in the built documents.
+# keep_warnings = False
+
+# If true, `todo` and `todoList` produce output, else they produce nothing.
+todo_include_todos = False
+
+
+# -- Options for HTML output ----------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'sphinx_rtd_theme'
+
+html_static_path = ['_static']
+
+htmlhelp_basename = 'vottoolkitdoc'
+
+# -- Options for LaTeX output ---------------------------------------------
+
+latex_documents = [
+ (master_doc, 'vot-toolkit.tex', u'VOT Toolkit Documentation',
+ u'Luka Cehovin Zajc', 'manual'),
+]
+
+man_pages = [
+ (master_doc, 'vot-toolkit', u'VOT Toolkit Documentation',
+ [author], 1)
+]
+
+# If true, show URL addresses after external links.
+#
+# man_show_urls = False
+
+
+# -- Options for Texinfo output -------------------------------------------
+
+# Grouping the document tree into Texinfo files. List of tuples
+# (source start file, target name, title, author,
+# dir menu entry, description, category)
+texinfo_documents = [
+ (master_doc, 'VOT Toolkit', u'VOT Toolkit Documentation',
+ author, 'VOT Toolkit', 'The official VOT Challenge evaluation toolkit',
+ 'Miscellaneous'),
+]
+
diff --git a/docs/index.rst b/docs/index.rst
new file mode 100644
index 0000000..58b2683
--- /dev/null
+++ b/docs/index.rst
@@ -0,0 +1,24 @@
+Welcome to the VOT Toolkit documentation
+========================================
+
+The VOT toolkit is the official evaluation tool for the [Visual Object Tracking (VOT) challenge](http://votchallenge.net/).
+It is written in Python 3 language. The toolkit is designed to be easy to use and to have broad support for various trackers,
+datasets and evaluation measures.
+
+Contributions and development
+-----------------------------
+
+The VOT toolkit is developed by the VOT Committee, primarily by Luka Čehovin Zajc and the tracking community as an open-source project (GPLv3 license).
+
+Contributions to the VOT toolkit are welcome, the preferred way to do it is by submitting an issue or a pull requests on `GitHub `_.
+
+Index
+-----
+
+.. toctree::
+ :maxdepth: 1
+
+ overview
+ tutorials
+ api
+
diff --git a/docs/overview.rst b/docs/overview.rst
new file mode 100644
index 0000000..ba24c58
--- /dev/null
+++ b/docs/overview.rst
@@ -0,0 +1,41 @@
+Overview
+========
+
+The toolkit is designed as a modular framework with several modules that address different aspects of the performance evaluation problem.
+
+Key concepts
+------------
+
+Key concepts that are used throughout the toolkit are:
+
+* **Dataset** - a collection of sequences that is used for performance evaluation. A dataset is a collection of **sequences**.
+* **Sequence** - a sequence of frames with correspoding ground truth annotations for one or more objects. A sequence is a collection of **frames**.
+* **Tracker** - a tracker is an algorithm that takes frames from a sequence as input (one by one) and produces a set of **trajectories** as output.
+* **Experiment** - an experiment is a method that applies a tracker to a given sequence in a specific way.
+* **Analysis** - an analysis is a set of **measures** that are used to evaluate the performance of a tracker (compare predicted trajectories to groundtruth).
+* **Stack** - a stack is a collection of **experiments** and **analyses** that are performed on a given dataset.
+* **Workspace** - a workspace is a collection of experiments and analyses that are performed on a given dataset.
+
+Tracker support
+---------------
+
+The toolkit supports various ways of interacting with a tracking methods. Primary manner (at the only supported at the moment) is using the TraX protocol.
+The toolkit provides a wrapper for the TraX protocol that allows to use any tracker that supports the protocol.
+
+Dataset support
+---------------
+
+The toolkit is capable of using any dataset that is provided in the toolkit format.
+The toolkit format is a simple directory structure that contains a set of sequences. Each sequence is a directory that contains a set of frames and a groundtruth file.
+The groundtruth file is a text file that contains one line per frame. Each line contains the bounding box of the object in the frame in the format `x,y,w,h`. The toolkit format is used by the toolkit itself and by the VOT challenges.
+
+
+Performance methodology support
+-------------------------------
+
+Various performance measures and visualzatons are implemented, most of them were used in VOT challenges.
+
+ * **Accuracy** - the accuracy measure is the overlap between the predicted and groundtruth bounding boxes. The overlap is measured using the intersection over union (IoU) measure.
+ * **Robustness** - the robustness measure is the number of failures of the tracker. A failure is defined as the overlap between the predicted and groundtruth bounding boxes being less than 0.5.
+ * **Expected Average Overlap** - the expected average overlap (EAO) is a measure that combines accuracy and robustness into a single measure. The EAO is computed as the area under the accuracy-robustness curve.
+ * **Expected Overlap** - the expected overlap (EO) is a measure that combines accuracy and robustness into a single measure. The EO is computed as the area under the accuracy-robustness curve.
diff --git a/docs/tutorials.rst b/docs/tutorials.rst
new file mode 100644
index 0000000..2e38175
--- /dev/null
+++ b/docs/tutorials.rst
@@ -0,0 +1,15 @@
+Tutorials
+=========
+
+The main purpose of the toolkit is to facilitate tracker evaluation for VOT challenges and benchmarks. But there are many other
+ways that the toolkit can be used and extended. The following tutorials are provided to help you get started with the toolkit.
+
+.. toctree::
+ :maxdepth: 1
+
+ tutorial_introduction
+ tutorial_integration
+ tutorial_evaluation
+ tutorial_dataset
+ tutorial_stack
+ tutorial_jupyter
diff --git a/requirements.txt b/requirements.txt
index a509fff..6e58b59 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-vot-trax>=3.0.2
+vot-trax>=4.0.1
tqdm>=4.37
numpy>=1.16
opencv-python>=4.0
@@ -16,4 +16,5 @@ dominate>=2.5
cachetools>=4.1
bidict>=0.19
phx-class-registry>=3.0
-attributee>=0.1.3
\ No newline at end of file
+attributee>=0.1.8
+lazy-object-proxy>=1.9
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 5163392..e2fdda1 100644
--- a/setup.py
+++ b/setup.py
@@ -35,7 +35,7 @@
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
],
- python_requires='>=3.6',
+ python_requires='>=3.7',
entry_points={
'console_scripts': ['vot=vot.utilities.cli:main'],
},
diff --git a/vot/__init__.py b/vot/__init__.py
index ea9adbc..2db5dfc 100644
--- a/vot/__init__.py
+++ b/vot/__init__.py
@@ -1,3 +1,4 @@
+""" Some basic functions and classes used by the toolkit. """
import os
import logging
@@ -44,7 +45,7 @@ def check_updates() -> bool:
try:
get_logger().debug("Checking for new version")
- response = requests.get(version_url, timeout=2)
+ response = requests.get(version_url, timeout=5, allow_redirects=True)
except Exception as e:
get_logger().debug("Unable to retrieve version information %s", e)
return False, None
@@ -62,12 +63,44 @@ def check_updates() -> bool:
else:
return False, None
+from attributee import Attributee, Integer, Boolean
+
+class GlobalConfiguration(Attributee):
+ """Global configuration object for the toolkit. It is used to store global configuration options.
+ """
+
+ debug_mode = Boolean(default=False, description="Enables debug mode for the toolkit.")
+ sequence_cache_size = Integer(default=1000, description="Maximum number of sequences to keep in cache.")
+ results_binary = Boolean(default=True, description="Enables binary results format.")
+ mask_optimize_read = Boolean(default=True, description="Enables mask optimization when reading masks.")
+
+ def __init__(self):
+ """Initializes the global configuration object. It reads the configuration from environment variables.
+
+ Raises:
+ ValueError: When an invalid value is provided for an attribute.
+ """
+ kwargs = {}
+ for k in self.attributes():
+ envname = "VOT_{}".format(k.upper())
+ if envname in os.environ:
+ kwargs[k] = os.environ[envname]
+ super().__init__(**kwargs)
+ _logger.debug("Global configuration: %s", self)
+
+ def __repr__(self):
+ """Returns a string representation of the global configuration object."""
+ return "debug_mode={} sequence_cache_size={} results_binary={} mask_optimize_read={}".format(
+ self.debug_mode, self.sequence_cache_size, self.results_binary, self.mask_optimize_read
+ )
+
+config = GlobalConfiguration()
+
def check_debug() -> bool:
"""Checks if debug is enabled for the toolkit via an environment variable.
Returns:
bool: True if debug is enabled, False otherwise
"""
- var = os.environ.get("VOT_TOOLKIT_DEBUG", "false").lower()
- return var in ["true", "1"]
+ return config.debug_mode
diff --git a/vot/__main__.py b/vot/__main__.py
index ab9bbe2..9db15b3 100644
--- a/vot/__main__.py
+++ b/vot/__main__.py
@@ -1,4 +1,4 @@
-
+""" This module is a shortcut for the CLI interface so that it can be run as a "vot" module. """
# Just a shortcut for the CLI interface so that it can be run as a "vot" module.
diff --git a/vot/analysis/__init__.py b/vot/analysis/__init__.py
index 83baddc..ffec339 100644
--- a/vot/analysis/__init__.py
+++ b/vot/analysis/__init__.py
@@ -1,14 +1,11 @@
-import logging
-import functools
-import threading
+"""This module contains classes and functions for analysis of tracker performance. The analysis is performed on the results of an experiment."""
+
from collections import namedtuple
-from enum import Enum, Flag, auto
-from typing import List, Optional, Tuple, Dict, Any, Set, Union, NamedTuple
+from enum import Enum, auto
+from typing import List, Optional, Tuple, Any
from abc import ABC, abstractmethod
-from concurrent.futures import Executor
import importlib
-from cachetools import Cache
from class_registry import ClassRegistry
from attributee import Attributee, String
@@ -26,8 +23,12 @@
class MissingResultsException(ToolkitException):
"""Exception class that denotes missing results during analysis
"""
- pass
-
+ def __init__(self, *args: object) -> None:
+ """Constructor"""
+ if not args:
+ args = ["Missing results"]
+ super().__init__(*args)
+
class Sorting(Enum):
"""Sorting direction enumeration class
"""
@@ -67,19 +68,24 @@ def __init__(self, name: str, abbreviation: Optional[str] = None, description: O
@property
def name(self) -> str:
+ """Name of the result, used in reports"""
return self._name
@property
def abbreviation(self) -> str:
+ """Abbreviation, if empty, then name is used. Can be used to define a shorter text representation."""
return self._abbreviation
@property
def description(self) -> str:
+ """Description of the result, used in reports"""
return self._description
class Label(Result):
+ """Label describes a single categorical output of an analysis. Can have a set of possible values."""
def __init__(self, *args, **kwargs):
+ """Constructor."""
super().__init__(*args, **kwargs)
class Measure(Result):
@@ -89,6 +95,19 @@ class Measure(Result):
def __init__(self, name: str, abbreviation: Optional[str] = None, minimal: Optional[float] = None, \
maximal: Optional[float] = None, direction: Optional[Sorting] = Sorting.UNSORTABLE):
+ """Constructor for Measure class.
+
+ Arguments:
+ name {str} -- Name of the measure, used in reports
+
+ Keyword Arguments:
+ abbreviation {Optional[str]} -- Abbreviation, if empty, then name is used.
+ Can be used to define a shorter text representation. (default: {None})
+ minimal {Optional[float]} -- Minimal value of the measure. If None, then the measure is not bounded from below. (default: {None})
+ maximal {Optional[float]} -- Maximal value of the measure. If None, then the measure is not bounded from above. (default: {None})
+ direction {Optional[Sorting]} -- Direction of sorting. If Sorting.UNSORTABLE, then the measure is not sortable. (default: {Sorting.UNSORTABLE})
+
+ """
super().__init__(name, abbreviation)
self._minimal = minimal
@@ -97,14 +116,17 @@ def __init__(self, name: str, abbreviation: Optional[str] = None, minimal: Optio
@property
def minimal(self) -> float:
+ """Minimal value of the measure. If None, then the measure is not bounded from below."""
return self._minimal
@property
def maximal(self) -> float:
+ """Maximal value of the measure. If None, then the measure is not bounded from above."""
return self._maximal
@property
def direction(self) -> Sorting:
+ """Direction of sorting. If Sorting.UNSORTABLE, then the measure is not sortable."""
return self._direction
class Drawable(Result):
@@ -124,11 +146,30 @@ def __init__(self, name: str, abbreviation: Optional[str] = None, trait: Optiona
@property
def trait(self):
+ """Trait of the data, used for specification"""
return self._trait
class Multidimensional(Drawable):
+ """Base class for multidimensional results. This class is used to describe results that can be visualized in a scatter plot."""
+
def __init__(self, name: str, dimensions: int, abbreviation: Optional[str] = None, minimal: Optional[Tuple[float]] = None, \
maximal: Optional[Tuple[float]] = None, labels: Optional[Tuple[str]] = None, trait: Optional[str] = None):
+ """Constructor for Multidimensional class.
+
+ Arguments:
+ name {str} -- Name of the measure, used in reports
+ dimensions {int} -- Number of dimensions of the result
+
+ Keyword Arguments:
+ abbreviation {Optional[str]} -- Abbreviation, if empty, then name is used.
+ Can be used to define a shorter text representation. (default: {None})
+ minimal {Optional[Tuple[float]]} -- Minimal value of the measure. If None, then the measure is not bounded from below. (default: {None})
+ maximal {Optional[Tuple[float]]} -- Maximal value of the measure. If None, then the measure is not bounded from above. (default: {None})
+ labels {Optional[Tuple[str]]} -- Labels for each dimension. (default: {None})
+ trait {Optional[str]} -- Trait of the data, used for specification . Defaults to None.
+
+ """
+
assert(dimensions > 1)
super().__init__(name, abbreviation, trait)
self._dimensions = dimensions
@@ -138,15 +179,19 @@ def __init__(self, name: str, dimensions: int, abbreviation: Optional[str] = Non
@property
def dimensions(self):
+ """Number of dimensions of the result"""
return self._dimensions
def minimal(self, i):
+ """Minimal value of the i-th dimension. If None, then the measure is not bounded from below."""
return self._minimal[i]
def maximal(self, i):
+ """Maximal value of the i-th dimension. If None, then the measure is not bounded from above."""
return self._maximal[i]
def label(self, i):
+ """Label for the i-th dimension."""
return self._labels[i]
class Point(Multidimensional):
@@ -160,6 +205,20 @@ class Plot(Drawable):
def __init__(self, name: str, abbreviation: Optional[str] = None, wrt: str = "frames", minimal: Optional[float] = None, \
maximal: Optional[float] = None, trait: Optional[str] = None):
+ """Constructor for Plot class.
+
+ Arguments:
+ name {str} -- Name of the measure, used in reports
+
+ Keyword Arguments:
+ abbreviation {Optional[str]} -- Abbreviation, if empty, then name is used.
+ Can be used to define a shorter text representation. (default: {None})
+ wrt {str} -- Unit of the independent variable. (default: {"frames"})
+ minimal {Optional[float]} -- Minimal value of the measure. If None, then the measure is not bounded from below. (default: {None})
+ maximal {Optional[float]} -- Maximal value of the measure. If None, then the measure is not bounded from above. (default: {None})
+ trait {Optional[str]} -- Trait of the data, used for specification . Defaults to None.
+
+ """
super().__init__(name, abbreviation, trait)
self._wrt = wrt
self._minimal = minimal
@@ -167,15 +226,17 @@ def __init__(self, name: str, abbreviation: Optional[str] = None, wrt: str = "fr
@property
def minimal(self):
+ """Minimal value of the measure. If None, then the measure is not bounded from below."""
return self._minimal
@property
def maximal(self):
+ """Maximal value of the measure. If None, then the measure is not bounded from above."""
return self._maximal
-
@property
def wrt(self):
+ """Unit of the independent variable."""
return self._wrt
class Curve(Multidimensional):
@@ -183,42 +244,73 @@ class Curve(Multidimensional):
"""
class Analysis(Attributee):
+ """Base class for all analysis classes. Analysis is a class that descibes computation of one or more performance metrics for a given experiment."""
- name = String(default=None)
+ name = String(default=None, description="Name of the analysis")
def __init__(self, **kwargs):
+ """Constructor for Analysis class.
+
+ Keyword Arguments:
+ name {str} -- Name of the analysis (default: {None})
+ """
super().__init__(**kwargs)
self._identifier_cache = None
def compatible(self, experiment: Experiment):
+ """Checks if the analysis is compatible with the experiment type."""
raise NotImplementedError()
@property
def title(self) -> str:
+ """Returns the title of the analysis. If name is not set, then the default title is returned."""
+
+ if self.name is None:
+ return self._title_default
+ else:
+ return self.name
+
+ @property
+ def _title_default(self) -> str:
+ """Returns the default title of the analysis. This is used when name is not set."""
raise NotImplementedError()
def dependencies(self) -> List["Analysis"]:
+ """Returns a list of dependencies of the analysis. This is used to determine the order of execution of the analysis."""
return []
@property
def identifier(self) -> str:
+ """Returns a unique identifier of the analysis. This is used to determine if the analysis has been already computed."""
+
if not self._identifier_cache is None:
return self._identifier_cache
params = self.dump()
del params["name"]
+
confighash = arg_hash(**params)
self._identifier_cache = class_fullname(self) + "@" + confighash
-
+
return self._identifier_cache
def describe(self) -> Tuple["Result"]:
- """Returns a tuple of descriptions of results
- """
+ """Returns a tuple of descriptions of results of the analysis."""
raise NotImplementedError()
def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid:
+ """Computes the analysis for the given experiment, trackers and sequences. The dependencies are the results of the dependnent analyses.
+ The result is a grid with the results of the analysis. The grid is indexed by trackers and sequences. The axes are described by the axes() method.
+
+ Args:
+ experiment (Experiment): Experiment to compute the analysis for.
+ trackers (List[Tracker]): List of trackers to compute the analysis for.
+ sequences (List[Sequence]): List of sequences to compute the analysis for.
+ dependencies (List[Grid]): List of dependencies of the analysis.
+
+ Returns: Grid with the results of the analysis.
+ """
raise NotImplementedError()
@property
@@ -227,15 +319,18 @@ def axes(self) -> Axes:
raise NotImplementedError()
def commit(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence]):
+ """Commits the analysis for execution on default processor."""
return AnalysisProcessor.commit_default(self, experiment, trackers, sequences)
def run(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence]):
+ """Runs the analysis on default processor."""
return AnalysisProcessor.run_default(self, experiment, trackers, sequences)
class SeparableAnalysis(Analysis):
"""Analysis that is separable with respect to trackers and/or sequences, each part can be processed in parallel
as a separate job. The separation is determined by the result of the axes() method: Axes.BOTH means separation
- in tracker-sequence pairs, Axes.TRACKER means separation according to
+ in tracker-sequence pairs, Axes.TRACKER means separation according to trackers and Axes.SEQUENCE means separation
+ according to sequences.
"""
SeparablePart = namedtuple("SeparablePart", ["trackers", "sequences", "tid", "sid"])
@@ -252,15 +347,13 @@ def subcompute(self, experiment: Experiment, tracker, sequence, dependencies: Li
note that each dependency is processed using select function to only contain
information relevant for the current part of the analysis
- Raises:
- NotImplementedError: [description]
-
Returns:
- Tuple[Any]: [description]
+ Tuple[Any]: Tuple of results of the analysis
"""
raise NotImplementedError()
def __init__(self, **kwargs):
+ """Initializes the analysis. The axes semantic description is checked to be compatible with the dependencies."""
super().__init__(**kwargs)
# All dependencies should be mappable to individual parts. If parts contain
@@ -270,6 +363,15 @@ def __init__(self, **kwargs):
assert all([dependency.axes != Axes.BOTH for dependency in self.dependencies()])
def separate(self, trackers: List[Tracker], sequences: List[Sequence]) -> List["SeparablePart"]:
+ """Separates the analysis into parts that can be processed separately.
+
+ Args:
+ trackers (List[Tracker]): List of trackers to compute the analysis for.
+ sequences (List[Sequence]): List of sequences to compute the analysis for.
+
+ Returns: List of parts of the analysis.
+
+ """
if self.axes == Axes.BOTH:
parts = []
for i, tracker in enumerate(trackers):
@@ -288,6 +390,17 @@ def separate(self, trackers: List[Tracker], sequences: List[Sequence]) -> List["
return parts
def join(self, trackers: List[Tracker], sequences: List[Sequence], results: List[Tuple[Any]]):
+ """Joins the results of the analysis into a single grid. The results are indexed by trackers and sequences.
+
+ Args:
+ trackers (List[Tracker]): List of trackers to compute the analysis for.
+ sequences (List[Sequence]): List of sequences to compute the analysis for.
+ results (List[Tuple[Any]]): List of results of the analysis.
+
+ Returns:
+ Grid: Grid with the results of the analysis.
+ """
+
if self.axes == Axes.BOTH:
transformed_results = Grid(len(trackers), len(sequences))
k = 0
@@ -356,7 +469,7 @@ def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: Li
return Grid.scalar(self.subcompute(experiment, trackers[0], sequences[0], dependencies))
elif self.axes == Axes.TRACKERS and len(trackers) == 1:
return Grid.scalar(self.subcompute(experiment, trackers[0], sequences, dependencies))
- elif self.axes == Axes.BOTH and len(sequences) == 1:
+ elif self.axes == Axes.SEQUENCES and len(sequences) == 1:
return Grid.scalar(self.subcompute(experiment, trackers, sequences[0], dependencies))
else:
parts = self.separate(trackers, sequences)
@@ -370,11 +483,14 @@ def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: Li
@property
def axes(self) -> Axes:
+ """Returns the axes of the analysis. This is used to determine how the analysis is split into parts."""
return Axes.BOTH
class SequenceAggregator(Analysis): # pylint: disable=W0223
+ """Base class for sequence aggregators. Sequence aggregators take the results of a tracker and aggregate them over sequences."""
def __init__(self, **kwargs):
+ """Base constructor."""
super().__init__(**kwargs)
# We only support one dependency in aggregator ...
assert len(self.dependencies()) == 1
@@ -383,9 +499,27 @@ def __init__(self, **kwargs):
@abstractmethod
def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]:
+ """Aggregate the results of the analysis over sequences for a single tracker.
+
+ Args:
+ tracker (Tracker): Tracker to aggregate the results for.
+ sequences (List[Sequence]): List of sequences to aggregate the results for.
+ results (Grid): Results of the analysis for the tracker and sequences.
+
+ """
raise NotImplementedError()
def compute(self, _: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid:
+ """Compute the analysis for a list of trackers and sequences.
+
+ Args:
+ trackers (List[Tracker]): List of trackers to compute the analysis for.
+ sequences (List[Sequence]): List of sequences to compute the analysis for.
+ dependencies (List[Grid]): List of dependencies, should be one grid with results of the dependency analysis.
+
+ Returns:
+ Grid: Grid with the results of the analysis.
+ """
results = dependencies[0]
transformed_results = Grid(len(trackers), 1)
@@ -396,6 +530,7 @@ def compute(self, _: Experiment, trackers: List[Tracker], sequences: List[Sequen
@property
def axes(self) -> Axes:
+ """The analysis is separable in trackers."""
return Axes.TRACKERS
class TrackerSeparableAnalysis(SeparableAnalysis):
@@ -404,10 +539,12 @@ class TrackerSeparableAnalysis(SeparableAnalysis):
@abstractmethod
def subcompute(self, experiment: Experiment, tracker: Tracker, sequences: List[Sequence], dependencies: List[Grid]) -> Tuple[Any]:
+ """Compute the analysis for a single tracker."""
raise NotImplementedError()
@property
def axes(self) -> Axes:
+ """The analysis is separable in trackers."""
return Axes.TRACKERS
class SequenceSeparableAnalysis(SeparableAnalysis):
@@ -416,18 +553,20 @@ class SequenceSeparableAnalysis(SeparableAnalysis):
@abstractmethod
def subcompute(self, experiment: Experiment, trackers: List[Tracker], sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
+ """Compute the analysis for a single sequence."""
raise NotImplementedError
@property
def axes(self) -> Axes:
+ """The analysis is separable in sequences."""
return Axes.SEQUENCES
def is_special(region: Region, code=None) -> bool:
+ """Check if the region is special (not a shape) and optionally if it has a specific code."""
if code is None:
return region.type == RegionType.SPECIAL
return region.type == RegionType.SPECIAL and region.code == code
-from ._processor import process_stack_analyses, AnalysisProcessor, AnalysisError
-
-for module in ["vot.analysis.multistart", "vot.analysis.supervised", "vot.analysis.basic", "vot.analysis.tpr"]:
- importlib.import_module(module)
+from .processor import process_stack_analyses, AnalysisProcessor, AnalysisError
+for module in [".multistart", ".supervised", ".accuracy", ".failures", ".longterm"]:
+ importlib.import_module(module, package="vot.analysis")
diff --git a/vot/analysis/accuracy.py b/vot/analysis/accuracy.py
new file mode 100644
index 0000000..27b9e45
--- /dev/null
+++ b/vot/analysis/accuracy.py
@@ -0,0 +1,291 @@
+"""Accuracy analysis. Computes average overlap between predicted and groundtruth regions."""
+
+from typing import List, Tuple, Any
+
+import numpy as np
+
+from attributee import Boolean, Integer, Include, Float
+
+from vot.analysis import (Measure,
+ MissingResultsException,
+ SequenceAggregator, Sorting,
+ is_special, SeparableAnalysis,
+ analysis_registry, Curve)
+from vot.dataset import Sequence
+from vot.experiment import Experiment
+from vot.experiment.multirun import (MultiRunExperiment)
+from vot.region import Region, calculate_overlaps
+from vot.tracker import Tracker, Trajectory
+from vot.utilities.data import Grid
+
+def gather_overlaps(trajectory: List[Region], groundtruth: List[Region], burnin: int = 10,
+ ignore_unknown: bool = True, ignore_invisible: bool = False, bounds = None, threshold: float = None) -> np.ndarray:
+ """Gather overlaps between trajectory and groundtruth regions.
+
+ Args:
+ trajectory (List[Region]): List of regions predicted by the tracker.
+ groundtruth (List[Region]): List of groundtruth regions.
+ burnin (int, optional): Number of frames to skip at the beginning of the sequence. Defaults to 10.
+ ignore_unknown (bool, optional): Ignore unknown regions in the groundtruth. Defaults to True.
+ ignore_invisible (bool, optional): Ignore invisible regions in the groundtruth. Defaults to False.
+ bounds ([type], optional): Bounds of the sequence. Defaults to None.
+ threshold (float, optional): Minimum overlap to consider. Defaults to None.
+
+ Returns:
+ np.ndarray: List of overlaps."""
+
+ overlaps = np.array(calculate_overlaps(trajectory, groundtruth, bounds))
+ mask = np.ones(len(overlaps), dtype=bool)
+
+ if threshold is None: threshold = -1
+
+ for i, (region_tr, region_gt) in enumerate(zip(trajectory, groundtruth)):
+ # Skip if groundtruth is unknown
+ if is_special(region_gt, Sequence.UNKNOWN):
+ mask[i] = False
+ elif ignore_invisible and region_gt.is_empty():
+ mask[i] = False
+ # Skip if predicted is unknown
+ elif is_special(region_tr, Trajectory.UNKNOWN) and ignore_unknown:
+ mask[i] = False
+ # Skip if predicted is initialization frame
+ elif is_special(region_tr, Trajectory.INITIALIZATION):
+ for j in range(i, min(len(trajectory), i + burnin)):
+ mask[j] = False
+ elif is_special(region_tr, Trajectory.FAILURE):
+ mask[i] = False
+ elif overlaps[i] <= threshold:
+ mask[i] = False
+
+ return overlaps[mask]
+
+@analysis_registry.register("accuracy")
+class SequenceAccuracy(SeparableAnalysis):
+ """Sequence accuracy analysis. Computes average overlap between predicted and groundtruth regions."""
+
+ burnin = Integer(default=10, val_min=0, description="Number of frames to skip after the initialization.")
+ ignore_unknown = Boolean(default=True, description="Ignore unknown regions in the groundtruth.")
+ ignore_invisible = Boolean(default=False, description="Ignore invisible regions in the groundtruth.")
+ bounded = Boolean(default=True, description="Consider only the bounded region of the sequence.")
+ threshold = Float(default=None, val_min=0, val_max=1, description="Minimum overlap to consider.")
+
+ def compatible(self, experiment: Experiment):
+ """Check if the experiment is compatible with the analysis."""
+ return isinstance(experiment, MultiRunExperiment)
+
+ @property
+ def _title_default(self):
+ """Default title of the analysis."""
+ return "Sequence accurarcy"
+
+ def describe(self):
+ """Describe the analysis."""
+ return Measure(self.title, "", 0, 1, Sorting.DESCENDING),
+
+ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
+ """Compute the analysis for a single sequence.
+
+ Args:
+ experiment (Experiment): Experiment.
+ tracker (Tracker): Tracker.
+ sequence (Sequence): Sequence.
+ dependencies (List[Grid]): List of dependencies.
+
+ Returns:
+ Tuple[Any]: Tuple of results.
+ """
+ assert isinstance(experiment, MultiRunExperiment)
+
+ objects = sequence.objects()
+ objects_accuracy = 0
+ bounds = (sequence.size) if self.bounded else None
+
+ for object in objects:
+ trajectories = experiment.gather(tracker, sequence, objects=[object])
+ if len(trajectories) == 0:
+ raise MissingResultsException()
+
+ cummulative = 0
+
+ for trajectory in trajectories:
+ overlaps = gather_overlaps(trajectory.regions(), sequence.object(object), self.burnin,
+ ignore_unknown=self.ignore_unknown, ignore_invisible=self.ignore_invisible, bounds=bounds, threshold=self.threshold)
+ if overlaps.size > 0:
+ cummulative += np.mean(overlaps)
+
+ objects_accuracy += cummulative / len(trajectories)
+
+ return objects_accuracy / len(objects),
+
+@analysis_registry.register("average_accuracy")
+class AverageAccuracy(SequenceAggregator):
+ """Average accuracy analysis. Computes average overlap between predicted and groundtruth regions."""
+
+ analysis = Include(SequenceAccuracy, description="Sequence accuracy analysis.")
+ weighted = Boolean(default=True, description="Weight accuracy by the number of frames.")
+
+ def compatible(self, experiment: Experiment):
+ """Check if the experiment is compatible with the analysis. This analysis requires a multirun experiment."""
+ return isinstance(experiment, MultiRunExperiment)
+
+ @property
+ def _title_default(self):
+ """Default title of the analysis."""
+ return "Accurarcy"
+
+ def dependencies(self):
+ """List of dependencies."""
+ return self.analysis,
+
+ def describe(self):
+ """Describe the analysis."""
+ return Measure(self.title, "", 0, 1, Sorting.DESCENDING),
+
+ def aggregate(self, _: Tracker, sequences: List[Sequence], results: Grid):
+ """Aggregate the results of the analysis.
+
+ Args:
+ tracker (Tracker): Tracker.
+ sequences (List[Sequence]): List of sequences.
+ results (Grid): Grid of results.
+
+ Returns:
+ Tuple[Any]: Tuple of results.
+ """
+
+ accuracy = 0
+ frames = 0
+
+ for i, sequence in enumerate(sequences):
+ if results[i, 0] is None:
+ continue
+
+ if self.weighted:
+ accuracy += results[i, 0][0] * len(sequence)
+ frames += len(sequence)
+ else:
+ accuracy += results[i, 0][0]
+ frames += 1
+
+ return accuracy / frames,
+
+@analysis_registry.register("success_plot")
+class SuccessPlot(SeparableAnalysis):
+ """Success plot analysis. Computes the success plot of the tracker."""
+
+ ignore_unknown = Boolean(default=True, description="Ignore unknown regions in the groundtruth.")
+ ignore_invisible = Boolean(default=False, description="Ignore invisible regions in the groundtruth.")
+ burnin = Integer(default=0, val_min=0, description="Number of frames to skip after the initialization.")
+ bounded = Boolean(default=True, description="Consider only the bounded region of the sequence.")
+ threshold = Float(default=None, val_min=0, val_max=1, description="Minimum overlap to consider.")
+ resolution = Integer(default=100, val_min=2, description="Number of points in the plot.")
+
+ def compatible(self, experiment: Experiment):
+ """Check if the experiment is compatible with the analysis. This analysis is only compatible with multi-run experiments."""
+ return isinstance(experiment, MultiRunExperiment)
+
+ @property
+ def _title_default(self):
+ """Default title of the analysis."""
+ return "Sequence success plot"
+
+ def describe(self):
+ """Describe the analysis."""
+ return Curve("Plot", 2, "S", minimal=(0, 0), maximal=(1, 1), labels=("Threshold", "Success"), trait="success"),
+
+ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
+ """Compute the analysis for a single sequence.
+
+ Args:
+ experiment (Experiment): Experiment.
+ tracker (Tracker): Tracker.
+ sequence (Sequence): Sequence.
+ dependencies (List[Grid]): List of dependencies.
+
+ Returns:
+ Tuple[Any]: Tuple of results.
+ """
+
+ assert isinstance(experiment, MultiRunExperiment)
+
+ objects = sequence.objects()
+ bounds = (sequence.size) if self.bounded else None
+
+ axis_x = np.linspace(0, 1, self.resolution)
+ axis_y = np.zeros_like(axis_x)
+
+ for object in objects:
+ trajectories = experiment.gather(tracker, sequence, objects=[object])
+ if len(trajectories) == 0:
+ raise MissingResultsException()
+
+ object_y = np.zeros_like(axis_x)
+
+ for trajectory in trajectories:
+ overlaps = gather_overlaps(trajectory.regions(), sequence.object(object), burnin=self.burnin, ignore_unknown=self.ignore_unknown, ignore_invisible=self.ignore_invisible, bounds=bounds, threshold=self.threshold)
+
+ for i, threshold in enumerate(axis_x):
+ if threshold == 1:
+ # Nicer handling of the edge case
+ object_y[i] += np.sum(overlaps >= threshold) / len(overlaps)
+ else:
+ object_y[i] += np.sum(overlaps > threshold) / len(overlaps)
+
+ axis_y += object_y / len(trajectories)
+
+ axis_y /= len(objects)
+
+ return [(x, y) for x, y in zip(axis_x, axis_y)],
+
+
+@analysis_registry.register("average_success_plot")
+class AverageSuccessPlot(SequenceAggregator):
+ """Average success plot analysis. Computes the average success plot of the tracker."""
+
+ resolution = Integer(default=100, val_min=2)
+ analysis = Include(SuccessPlot)
+
+ def dependencies(self):
+ """List of dependencies."""
+ return self.analysis,
+
+ def compatible(self, experiment: Experiment):
+ """Check if the experiment is compatible with the analysis. This analysis is only compatible with multi-run experiments."""
+ return isinstance(experiment, MultiRunExperiment)
+
+ @property
+ def _title_default(self):
+ """Default title of the analysis."""
+ return "Success plot"
+
+ def describe(self):
+ """Describe the analysis."""
+ return Curve("Plot", 2, "S", minimal=(0, 0), maximal=(1, 1), labels=("Threshold", "Success"), trait="success"),
+
+ def aggregate(self, _: Tracker, sequences: List[Sequence], results: Grid):
+ """Aggregate the results of the analysis.
+
+ Args:
+ tracker (Tracker): Tracker.
+ sequences (List[Sequence]): List of sequences.
+ results (Grid): Grid of results.
+
+ Returns:
+ Tuple[Any]: Tuple of results.
+ """
+
+ axis_x = np.linspace(0, 1, self.resolution)
+ axis_y = np.zeros_like(axis_x)
+
+ for i, _ in enumerate(sequences):
+ if results[i, 0] is None:
+ continue
+
+ curve = results[i, 0][0]
+
+ for j, (_, y) in enumerate(curve):
+ axis_y[j] += y
+
+ axis_y /= len(sequences)
+
+ return [(x, y) for x, y in zip(axis_x, axis_y)],
diff --git a/vot/analysis/basic.py b/vot/analysis/basic.py
deleted file mode 100644
index de38c4b..0000000
--- a/vot/analysis/basic.py
+++ /dev/null
@@ -1,184 +0,0 @@
-from typing import List, Tuple, Any
-
-import numpy as np
-
-from attributee import Boolean, Integer, Include
-
-from vot.analysis import (Measure,
- MissingResultsException,
- SequenceAggregator, Sorting,
- is_special, SeparableAnalysis,
- analysis_registry)
-from vot.dataset import Sequence
-from vot.experiment import Experiment
-from vot.experiment.multirun import (MultiRunExperiment, SupervisedExperiment)
-from vot.region import Region, Special, calculate_overlaps
-from vot.tracker import Tracker
-from vot.utilities.data import Grid
-
-def compute_accuracy(trajectory: List[Region], sequence: Sequence, burnin: int = 10,
- ignore_unknown: bool = True, bounded: bool = True) -> float:
-
- overlaps = np.array(calculate_overlaps(trajectory, sequence.groundtruth(), (sequence.size) if bounded else None))
- mask = np.ones(len(overlaps), dtype=bool)
-
- for i, region in enumerate(trajectory):
- if is_special(region, Special.UNKNOWN) and ignore_unknown:
- mask[i] = False
- elif is_special(region, Special.INITIALIZATION):
- for j in range(i, min(len(trajectory), i + burnin)):
- mask[j] = False
- elif is_special(region, Special.FAILURE):
- mask[i] = False
-
- if any(mask):
- return np.mean(overlaps[mask]), np.sum(mask)
- else:
- return 0, 0
-
-def compute_eao_partial(overlaps: List, success: List[bool], curve_length: int):
- phi = curve_length * [float(0)]
- active = curve_length * [float(0)]
-
- for o, success in zip(overlaps, success):
-
- o_array = np.array(o)
-
- for j in range(1, curve_length):
-
- if j < len(o):
- phi[j] += np.mean(o_array[1:j+1])
- active[j] += 1
- elif not success:
- phi[j] += np.sum(o_array[1:len(o)]) / (j - 1)
- active[j] += 1
-
- phi = [p / a if a > 0 else 0 for p, a in zip(phi, active)]
- return phi, active
-
-def count_failures(trajectory: List[Region]) -> Tuple[int, int]:
- return len([region for region in trajectory if is_special(region, Special.FAILURE)]), len(trajectory)
-
-@analysis_registry.register("accuracy")
-class SequenceAccuracy(SeparableAnalysis):
-
- burnin = Integer(default=10, val_min=0)
- ignore_unknown = Boolean(default=True)
- bounded = Boolean(default=True)
-
- def compatible(self, experiment: Experiment):
- return isinstance(experiment, MultiRunExperiment)
-
- @property
- def title(self):
- return "Sequence accurarcy"
-
- def describe(self):
- return Measure("Accuracy", "AUC", 0, 1, Sorting.DESCENDING),
-
- def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
-
- assert isinstance(experiment, MultiRunExperiment)
-
- trajectories = experiment.gather(tracker, sequence)
-
- if len(trajectories) == 0:
- raise MissingResultsException()
-
- cummulative = 0
- for trajectory in trajectories:
- accuracy, _ = compute_accuracy(trajectory.regions(), sequence, self.burnin, self.ignore_unknown, self.bounded)
- cummulative = cummulative + accuracy
-
- return cummulative / len(trajectories),
-
-@analysis_registry.register("average_accuracy")
-class AverageAccuracy(SequenceAggregator):
-
- analysis = Include(SequenceAccuracy)
- weighted = Boolean(default=True)
-
- def compatible(self, experiment: Experiment):
- return isinstance(experiment, MultiRunExperiment)
-
- @property
- def title(self):
- return "Average accurarcy"
-
- def dependencies(self):
- return self.analysis,
-
- def describe(self):
- return Measure("Accuracy", "AUC", 0, 1, Sorting.DESCENDING),
-
- def aggregate(self, _: Tracker, sequences: List[Sequence], results: Grid):
- accuracy = 0
- frames = 0
-
- for i, sequence in enumerate(sequences):
- if results[i, 0] is None:
- continue
-
- if self.weighted:
- accuracy += results[i, 0][0] * len(sequence)
- frames += len(sequence)
- else:
- accuracy += results[i, 0][0]
- frames += 1
-
- return accuracy / frames,
-
-@analysis_registry.register("failures")
-class FailureCount(SeparableAnalysis):
-
- def compatible(self, experiment: Experiment):
- return isinstance(experiment, SupervisedExperiment)
-
- @property
- def title(self):
- return "Number of failures"
-
- def describe(self):
- return Measure("Failures", "F", 0, None, Sorting.ASCENDING),
-
- def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
-
- assert isinstance(experiment, SupervisedExperiment)
-
- trajectories = experiment.gather(tracker, sequence)
-
- if len(trajectories) == 0:
- raise MissingResultsException()
-
- failures = 0
- for trajectory in trajectories:
- failures = failures + count_failures(trajectory.regions())[0]
-
- return failures / len(trajectories), len(trajectories[0])
-
-
-@analysis_registry.register("cumulative_failures")
-class CumulativeFailureCount(SequenceAggregator):
-
- analysis = Include(FailureCount)
-
- def compatible(self, experiment: Experiment):
- return isinstance(experiment, SupervisedExperiment)
-
- def dependencies(self):
- return self.analysis,
-
- @property
- def title(self):
- return "Number of failures"
-
- def describe(self):
- return Measure("Failures", "F", 0, None, Sorting.ASCENDING),
-
- def aggregate(self, _: Tracker, sequences: List[Sequence], results: Grid):
- failures = 0
-
- for a in results:
- failures = failures + a[0]
-
- return failures,
diff --git a/vot/analysis/failures.py b/vot/analysis/failures.py
new file mode 100644
index 0000000..615bdb7
--- /dev/null
+++ b/vot/analysis/failures.py
@@ -0,0 +1,101 @@
+"""This module contains the implementation of the FailureCount analysis. The analysis counts the number of failures in one or more sequences."""
+
+from typing import List, Tuple, Any
+
+from attributee import Include
+
+from vot.analysis import (Measure,
+ MissingResultsException,
+ SequenceAggregator, Sorting,
+ is_special, SeparableAnalysis,
+ analysis_registry)
+from vot.dataset import Sequence
+from vot.experiment import Experiment
+from vot.experiment.multirun import (SupervisedExperiment)
+from vot.region import Region, Special, calculate_overlaps
+from vot.tracker import Tracker
+from vot.utilities.data import Grid
+
+
+def count_failures(trajectory: List[Region]) -> Tuple[int, int]:
+ """Count the number of failures in a trajectory. A failure is defined as a region that overlaps with a Special.FAILURE region."""
+ return len([region for region in trajectory if is_special(region, SupervisedExperiment.FAILURE)]), len(trajectory)
+
+
+@analysis_registry.register("failures")
+class FailureCount(SeparableAnalysis):
+ """Count the number of failures in a sequence. A failure is defined as a region that overlaps with a Special.FAILURE region."""
+
+ def compatible(self, experiment: Experiment):
+ """Check if the experiment is compatible with the analysis."""
+ return isinstance(experiment, SupervisedExperiment)
+
+ @property
+ def _title_default(self):
+ """Default title for the analysis."""
+ return "Number of failures"
+
+ def describe(self):
+ """Describe the analysis."""
+ return Measure("Failures", "F", 0, None, Sorting.ASCENDING),
+
+ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
+ """Compute the analysis for a single sequence."""
+
+ assert isinstance(experiment, SupervisedExperiment)
+
+ objects = sequence.objects()
+ objects_failures = 0
+
+ for object in objects:
+ trajectories = experiment.gather(tracker, sequence, objects=[object])
+ if len(trajectories) == 0:
+ raise MissingResultsException()
+
+ failures = 0
+ for trajectory in trajectories:
+ failures = failures + count_failures(trajectory.regions())[0]
+ objects_failures += failures / len(trajectories)
+
+ return objects_failures / len(objects), len(sequence)
+
+@analysis_registry.register("cumulative_failures")
+class CumulativeFailureCount(SequenceAggregator):
+ """Count the number of failures in a sequence. A failure is defined as a region that overlaps with a Special.FAILURE region."""
+
+ analysis = Include(FailureCount)
+
+ def compatible(self, experiment: Experiment):
+ """Check if the experiment is compatible with the analysis."""
+ return isinstance(experiment, SupervisedExperiment)
+
+ def dependencies(self):
+ """Return the dependencies of the analysis."""
+ return self.analysis,
+
+ @property
+ def _title_default(self):
+ """Default title for the analysis."""
+ return "Number of failures"
+
+ def describe(self):
+ """Describe the analysis."""
+ return Measure("Failures", "F", 0, None, Sorting.ASCENDING),
+
+ def aggregate(self, _: Tracker, sequences: List[Sequence], results: Grid):
+ """Aggregate the analysis for a list of sequences. The aggregation is done by summing the number of failures for each sequence.
+
+ Args:
+ sequences (List[Sequence]): The list of sequences to aggregate.
+ results (Grid): The results of the analysis for each sequence.
+
+ Returns:
+ Tuple[Any]: The aggregated analysis.
+ """
+
+ failures = 0
+
+ for a in results:
+ failures = failures + a[0]
+
+ return failures,
diff --git a/vot/analysis/longterm.py b/vot/analysis/longterm.py
new file mode 100644
index 0000000..a64bc40
--- /dev/null
+++ b/vot/analysis/longterm.py
@@ -0,0 +1,677 @@
+"""This module contains the implementation of the long term tracking performance measures."""
+import math
+import numpy as np
+from typing import List, Iterable, Tuple, Any
+import itertools
+
+from attributee import Float, Integer, Boolean, Include
+
+from vot.tracker import Tracker
+from vot.dataset import Sequence
+from vot.region import Region, RegionType, calculate_overlaps
+from vot.experiment import Experiment
+from vot.experiment.multirun import UnsupervisedExperiment, MultiRunExperiment
+from vot.analysis import SequenceAggregator, Analysis, SeparableAnalysis, \
+ MissingResultsException, Measure, Sorting, Curve, Plot, SequenceAggregator, \
+ Axes, analysis_registry, Point, is_special, Analysis
+from vot.utilities.data import Grid
+
+def determine_thresholds(scores: Iterable[float], resolution: int) -> List[float]:
+ """Determine thresholds for a given set of scores and a resolution.
+ The thresholds are determined by sorting the scores and selecting the thresholds that divide the sorted scores into equal sized bins.
+
+ Args:
+ scores (Iterable[float]): Scores to determine thresholds for.
+ resolution (int): Number of thresholds to determine.
+
+ Returns:
+ List[float]: List of thresholds.
+ """
+ scores = [score for score in scores if not math.isnan(score)] #and not score is None]
+ scores = sorted(scores, reverse=True)
+
+ if len(scores) > resolution - 2:
+ delta = math.floor(len(scores) / (resolution - 2))
+ idxs = np.round(np.linspace(delta, len(scores) - delta, num=resolution - 2)).astype(np.int)
+ thresholds = [scores[idx] for idx in idxs]
+ else:
+ thresholds = scores
+
+ thresholds.insert(0, math.inf)
+ thresholds.insert(len(thresholds), -math.inf)
+
+ return thresholds
+
+def compute_tpr_curves(trajectory: List[Region], confidence: List[float], sequence: Sequence, thresholds: List[float],
+ ignore_unknown: bool = True, bounded: bool = True):
+ """Compute the TPR curves for a given trajectory and confidence scores.
+
+ Args:
+ trajectory (List[Region]): Trajectory to compute the TPR curves for.
+ confidence (List[float]): Confidence scores for the trajectory.
+ sequence (Sequence): Sequence to compute the TPR curves for.
+ thresholds (List[float]): Thresholds to compute the TPR curves for.
+ ignore_unknown (bool, optional): Ignore unknown regions. Defaults to True.
+ bounded (bool, optional): Bounded evaluation. Defaults to True.
+
+ Returns:
+ List[float], List[float]: TPR curves for the given thresholds.
+ """
+
+ overlaps = np.array(calculate_overlaps(trajectory, sequence.groundtruth(), (sequence.size) if bounded else None))
+ confidence = np.array(confidence)
+
+ n_visible = len([region for region in sequence.groundtruth() if region.type is not RegionType.SPECIAL])
+
+ precision = len(thresholds) * [float(0)]
+ recall = len(thresholds) * [float(0)]
+
+ for i, threshold in enumerate(thresholds):
+
+ subset = confidence >= threshold
+
+ if np.sum(subset) == 0:
+ precision[i] = 1
+ recall[i] = 0
+ else:
+ precision[i] = np.mean(overlaps[subset])
+ recall[i] = np.sum(overlaps[subset]) / n_visible
+
+ return precision, recall
+
+class _ConfidenceScores(SeparableAnalysis):
+ """Computes the confidence scores for a tracker for given sequences. This is internal analysis and should not be used directly."""
+
+ @property
+ def _title_default(self):
+ """Title of the analysis."""
+ return "Aggregate confidence scores"
+
+ def describe(self):
+ """Describes the analysis."""
+ return None,
+
+ def compatible(self, experiment: Experiment):
+ """Checks if the experiment is compatible with the analysis. """
+ return isinstance(experiment, UnsupervisedExperiment)
+
+ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
+ """Computes the confidence scores for a tracker for given sequences.
+
+ Args:
+ experiment (Experiment): Experiment to compute the confidence scores for.
+ tracker (Tracker): Tracker to compute the confidence scores for.
+ sequence (Sequence): Sequence to compute the confidence scores for.
+ dependencies (List[Grid]): Dependencies of the analysis.
+
+ Returns:
+ Tuple[Any]: Confidence scores for the given sequence.
+ """
+
+ scores_all = []
+ trajectories = experiment.gather(tracker, sequence)
+
+ if len(trajectories) == 0:
+ raise MissingResultsException("Missing results for sequence {}".format(sequence.name))
+
+ for trajectory in trajectories:
+ confidence = [trajectory.properties(i).get('confidence', 0) for i in range(len(trajectory))]
+ scores_all.extend(confidence)
+
+ return scores_all,
+
+class _Thresholds(SequenceAggregator):
+ """Computes the thresholds for a tracker for given sequences. This is internal analysis and should not be used directly."""
+
+ resolution = Integer(default=100)
+
+ @property
+ def _title_default(self):
+ """Title of the analysis."""
+ return "Thresholds for tracking precision/recall"
+
+ def describe(self):
+ """Describes the analysis."""
+ return None,
+
+ def compatible(self, experiment: Experiment):
+ """Checks if the experiment is compatible with the analysis. """
+ return isinstance(experiment, UnsupervisedExperiment)
+
+ def dependencies(self):
+ """Dependencies of the analysis."""
+ return _ConfidenceScores(),
+
+ def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]:
+ """Computes the thresholds for a tracker for given sequences.
+
+ Args:
+ tracker (Tracker): Tracker to compute the thresholds for.
+ sequences (List[Sequence]): Sequences to compute the thresholds for.
+ results (Grid): Results of the dependencies.
+
+ Returns:
+ Tuple[Any]: Thresholds for the given sequences."""
+
+ thresholds = determine_thresholds(itertools.chain(*[result[0] for result in results]), self.resolution),
+
+ return thresholds,
+
+@analysis_registry.register("pr_curves")
+class PrecisionRecallCurves(SeparableAnalysis):
+ """ Computes the precision/recall curves for a tracker for given sequences. """
+
+ thresholds = Include(_Thresholds)
+ ignore_unknown = Boolean(default=True, description="Ignore unknown regions")
+ bounded = Boolean(default=True, description="Bounded evaluation")
+
+ @property
+ def _title_default(self):
+ """Title of the analysis."""
+ return "Tracking precision/recall"
+
+ def describe(self):
+ """Describes the analysis."""
+ return Curve("Precision Recall curve", dimensions=2, abbreviation="PR", minimal=(0, 0), maximal=(1, 1), labels=("Recall", "Precision")), None
+
+ def compatible(self, experiment: Experiment):
+ """Checks if the experiment is compatible with the analysis."""
+ return isinstance(experiment, UnsupervisedExperiment)
+
+ def dependencies(self):
+ """Dependencies of the analysis."""
+ return self.thresholds,
+
+ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
+ """Computes the precision/recall curves for a tracker for given sequences.
+
+ Args:
+ experiment (Experiment): Experiment to compute the precision/recall curves for.
+ tracker (Tracker): Tracker to compute the precision/recall curves for.
+ sequence (Sequence): Sequence to compute the precision/recall curves for.
+ dependencies (List[Grid]): Dependencies of the analysis.
+
+ Returns:
+ Tuple[Any]: Precision/recall curves for the given sequence.
+ """
+
+ thresholds = dependencies[0, 0][0][0] # dependencies[0][0, 0]
+
+ trajectories = experiment.gather(tracker, sequence)
+
+ if len(trajectories) == 0:
+ raise MissingResultsException()
+
+ precision = len(thresholds) * [float(0)]
+ recall = len(thresholds) * [float(0)]
+ for trajectory in trajectories:
+ confidence = [trajectory.properties(i).get('confidence', 0) for i in range(len(trajectory))]
+ pr, re = compute_tpr_curves(trajectory.regions(), confidence, sequence, thresholds, self.ignore_unknown, self.bounded)
+ for i in range(len(thresholds)):
+ precision[i] += pr[i]
+ recall[i] += re[i]
+
+# return [(re / len(trajectories), pr / len(trajectories)) for pr, re in zip(precision, recall)], thresholds
+ return [(pr / len(trajectories), re / len(trajectories)) for pr, re in zip(precision, recall)], thresholds
+
+@analysis_registry.register("pr_curve")
+class PrecisionRecallCurve(SequenceAggregator):
+ """ Computes the average precision/recall curve for a tracker. """
+
+ curves = Include(PrecisionRecallCurves)
+
+ @property
+ def _title_default(self):
+ """Title of the analysis."""
+ return "Tracking precision/recall average curve"
+
+ def describe(self):
+ """Describes the analysis."""
+ return self.curves.describe()
+
+ def compatible(self, experiment: Experiment):
+ """Checks if the experiment is compatible with the analysis. This analysis is compatible with unsupervised experiments."""
+ return isinstance(experiment, UnsupervisedExperiment)
+
+ def dependencies(self):
+ """Dependencies of the analysis."""
+ return self.curves,
+
+ def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]:
+ """Computes the average precision/recall curve for a tracker.
+
+ Args:
+ tracker (Tracker): Tracker to compute the average precision/recall curve for.
+ sequences (List[Sequence]): Sequences to compute the average precision/recall curve for.
+ results (Grid): Results of the dependencies.
+
+ Returns:
+ Tuple[Any]: Average precision/recall curve for the given sequences.
+ """
+
+ curve = None
+ thresholds = None
+
+ for partial, thresholds in results:
+ if curve is None:
+ curve = partial
+ continue
+
+ curve = [(pr1 + pr2, re1 + re2) for (pr1, re1), (pr2, re2) in zip(curve, partial)]
+
+ curve = [(re / len(results), pr / len(results)) for pr, re in curve]
+
+ return curve, thresholds
+
+
+@analysis_registry.register("f_curve")
+class FScoreCurve(Analysis):
+ """ Computes the F-score curve for a tracker. """
+
+ beta = Float(default=1, description="Beta value for the F-score")
+ prcurve = Include(PrecisionRecallCurve)
+
+ @property
+ def _title_default(self):
+ """Title of the analysis."""
+ return "Tracking precision/recall"
+
+ def describe(self):
+ """Describes the analysis."""
+ return Plot("Tracking F-score curve", "F", wrt="normalized threshold", minimal=0, maximal=1), None
+
+ def compatible(self, experiment: Experiment):
+ """Checks if the experiment is compatible with the analysis. This analysis is compatible with unsupervised experiments."""
+ return isinstance(experiment, UnsupervisedExperiment)
+
+ def dependencies(self):
+ """Dependencies of the analysis."""
+ return self.prcurve,
+
+ def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid:
+ """Computes the F-score curve for a tracker.
+
+ Args:
+ experiment (Experiment): Experiment to compute the F-score curve for.
+ trackers (List[Tracker]): Trackers to compute the F-score curve for.
+ sequences (List[Sequence]): Sequences to compute the F-score curve for.
+ dependencies (List[Grid]): Dependencies of the analysis.
+
+ Returns:
+ Grid: F-score curve for the given sequences.
+ """
+
+ processed_results = Grid(len(trackers), 1)
+
+ for i, result in enumerate(dependencies[0]):
+ beta2 = (self.beta * self.beta)
+ f_curve = [((1 + beta2) * pr_ * re_) / (beta2 * pr_ + re_) for pr_, re_ in result[0]]
+
+ processed_results[i, 0] = (f_curve, result[0][1])
+
+ return processed_results
+
+ @property
+ def axes(self):
+ """Axes of the analysis."""
+ return Axes.TRACKERS
+
+@analysis_registry.register("average_tpr")
+class PrecisionRecall(Analysis):
+ """ Computes the average precision/recall for a tracker. """
+
+ prcurve = Include(PrecisionRecallCurve)
+ fcurve = Include(FScoreCurve)
+
+ @property
+ def _title_default(self):
+ """Title of the analysis."""
+ return "Tracking precision/recall"
+
+ def describe(self):
+ """Describes the analysis."""
+ return Measure("Precision", "Pr", minimal=0, maximal=1, direction=Sorting.DESCENDING), \
+ Measure("Recall", "Re", minimal=0, maximal=1, direction=Sorting.DESCENDING), \
+ Measure("F Score", "F", minimal=0, maximal=1, direction=Sorting.DESCENDING)
+
+ def compatible(self, experiment: Experiment):
+ """Checks if the experiment is compatible with the analysis. This analysis is compatible with unsupervised experiments."""
+ return isinstance(experiment, UnsupervisedExperiment)
+
+ def dependencies(self):
+ """Dependencies of the analysis."""
+ return self.prcurve, self.fcurve
+
+ def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid:
+ """Computes the average precision/recall for a tracker.
+
+ Args:
+ experiment (Experiment): Experiment to compute the average precision/recall for.
+ trackers (List[Tracker]): Trackers to compute the average precision/recall for.
+ sequences (List[Sequence]): Sequences to compute the average precision/recall for.
+ dependencies (List[Grid]): Dependencies of the analysis.
+
+ Returns:
+ Grid: Average precision/recall for the given sequences.
+ """
+
+ f_curves = dependencies[1]
+ pr_curves = dependencies[0]
+
+ joined = Grid(len(trackers), 1)
+
+ for i, (f_curve, pr_curve) in enumerate(zip(f_curves, pr_curves)):
+ # get optimal F-score and Pr and Re at this threshold
+ f_score = max(f_curve[0])
+ best_i = f_curve[0].index(f_score)
+ re_score = pr_curve[0][best_i][0]
+ pr_score = pr_curve[0][best_i][1]
+ joined[i, 0] = (pr_score, re_score, f_score)
+
+ return joined
+
+ @property
+ def axes(self):
+ """Axes of the analysis."""
+ return Axes.TRACKERS
+
+
+def count_frames(trajectory: List[Region], groundtruth: List[Region], bounds = None, threshold: float = 0) -> float:
+ """Counts the number of frames where the tracker is correct, fails, misses, hallucinates or notices an object.
+
+ Args:
+ trajectory (List[Region]): Trajectory of the tracker.
+ groundtruth (List[Region]): Groundtruth trajectory.
+ bounds (Optional[Region]): Bounds of the sequence.
+ threshold (float): Threshold for the overlap.
+
+ Returns:
+ float: Number of frames where the tracker is correct, fails, misses, hallucinates or notices an object.
+ """
+
+ overlaps = np.array(calculate_overlaps(trajectory, groundtruth, bounds))
+ if threshold is None: threshold = -1
+
+ # Tracking, Failure, Miss, Halucination, Notice
+ T, F, M, H, N = 0, 0, 0, 0, 0
+
+ for i, (region_tr, region_gt) in enumerate(zip(trajectory, groundtruth)):
+ if (is_special(region_gt, Sequence.UNKNOWN)):
+ continue
+ if region_gt.is_empty():
+ if region_tr.is_empty():
+ N += 1
+ else:
+ H += 1
+ else:
+ if overlaps[i] > threshold:
+ T += 1
+ else:
+ if region_tr.is_empty():
+ M += 1
+ else:
+ F += 1
+
+ return T, F, M, H, N
+
+class CountFrames(SeparableAnalysis):
+ """Counts the number of frames where the tracker is correct, fails, misses, hallucinates or notices an object."""
+
+ threshold = Float(default=0.0, val_min=0, val_max=1)
+ bounded = Boolean(default=True)
+
+ def compatible(self, experiment: Experiment):
+ """Checks if the experiment is compatible with the analysis. This analysis is compatible with multi-run experiments."""
+ return isinstance(experiment, MultiRunExperiment)
+
+ def describe(self):
+ """Describes the analysis."""
+ return None,
+
+ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
+ """Computes the number of frames where the tracker is correct, fails, misses, hallucinates or notices an object."""
+
+ assert isinstance(experiment, MultiRunExperiment)
+
+ objects = sequence.objects()
+ distribution = []
+ bounds = (sequence.size) if self.bounded else None
+
+ for object in objects:
+ trajectories = experiment.gather(tracker, sequence, objects=[object])
+ if len(trajectories) == 0:
+ raise MissingResultsException()
+
+ CN, CF, CM, CH, CT = 0, 0, 0, 0, 0
+
+ for trajectory in trajectories:
+ T, F, M, H, N = count_frames(trajectory.regions(), sequence.object(object), bounds=bounds)
+ CN += N
+ CF += F
+ CM += M
+ CH += H
+ CT += T
+ CN /= len(trajectories)
+ CF /= len(trajectories)
+ CM /= len(trajectories)
+ CH /= len(trajectories)
+ CT /= len(trajectories)
+
+ distribution.append((CT, CF, CM, CH, CN))
+
+ return distribution,
+
+
+@analysis_registry.register("quality_auxiliary")
+class QualityAuxiliary(SeparableAnalysis):
+ """Computes the non-reported error, drift-rate error and absence-detection quality."""
+
+ threshold = Float(default=0.0, val_min=0, val_max=1)
+ bounded = Boolean(default=True)
+ absence_threshold = Integer(default=10, val_min=0)
+
+ def compatible(self, experiment: Experiment):
+ """Checks if the experiment is compatible with the analysis. This analysis is compatible with multi-run experiments."""
+ return isinstance(experiment, MultiRunExperiment)
+
+ @property
+ def _title_default(self):
+ """Default title of the analysis."""
+ return "Quality Auxiliary"
+
+ def describe(self):
+ """Describes the analysis."""
+ return Measure("Non-reported Error", "NRE", 0, 1, Sorting.DESCENDING), \
+ Measure("Drift-rate Error", "DRE", 0, 1, Sorting.DESCENDING), \
+ Measure("Absence-detection Quality", "ADQ", 0, 1, Sorting.DESCENDING),
+
+ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
+ """Computes the non-reported error, drift-rate error and absence-detection quality.
+
+ Args:
+ experiment (Experiment): Experiment.
+ tracker (Tracker): Tracker.
+ sequence (Sequence): Sequence.
+ dependencies (List[Grid]): Dependencies.
+
+ Returns:
+ Tuple[Any]: Non-reported error, drift-rate error and absence-detection quality.
+
+ """
+
+ assert isinstance(experiment, MultiRunExperiment)
+
+ not_reported_error = 0
+ drift_rate_error = 0
+ absence_detection = 0
+
+ objects = sequence.objects()
+ bounds = (sequence.size) if self.bounded else None
+
+ absence_valid = 0
+
+ for object in objects:
+ trajectories = experiment.gather(tracker, sequence, objects=[object])
+ if len(trajectories) == 0:
+ raise MissingResultsException()
+
+ CN, CF, CM, CH, CT = 0, 0, 0, 0, 0
+
+ for trajectory in trajectories:
+ T, F, M, H, N = count_frames(trajectory.regions(), sequence.object(object), bounds=bounds)
+ CN += N
+ CF += F
+ CM += M
+ CH += H
+ CT += T
+ CN /= len(trajectories)
+ CF /= len(trajectories)
+ CM /= len(trajectories)
+ CH /= len(trajectories)
+ CT /= len(trajectories)
+
+ not_reported_error += CM / (CT + CF + CM)
+ drift_rate_error += CF / (CT + CF + CM)
+
+ if CN + CH > self.absence_threshold:
+ absence_detection += CN / (CN + CH)
+ absence_valid += 1
+
+ if absence_valid > 0:
+ absence_detection /= absence_valid
+ else:
+ absence_detection = None
+
+ return not_reported_error / len(objects), drift_rate_error / len(objects), absence_detection,
+
+
+@analysis_registry.register("average_quality_auxiliary")
+class AverageQualityAuxiliary(SequenceAggregator):
+ """Computes the average non-reported error, drift-rate error and absence-detection quality."""
+
+ analysis = Include(QualityAuxiliary)
+
+ @property
+ def _title_default(self):
+ """Default title of the analysis."""
+ return "Quality Auxiliary"
+
+ def dependencies(self):
+ """Returns the dependencies of the analysis."""
+ return self.analysis,
+
+ def describe(self):
+ """Describes the analysis."""
+ return Measure("Non-reported Error", "NRE", 0, 1, Sorting.DESCENDING), \
+ Measure("Drift-rate Error", "DRE", 0, 1, Sorting.DESCENDING), \
+ Measure("Absence-detection Quality", "ADQ", 0, 1, Sorting.DESCENDING),
+
+ def compatible(self, experiment: Experiment):
+ """Checks if the experiment is compatible with the analysis. This analysis is compatible with multi-run experiments."""
+ return isinstance(experiment, MultiRunExperiment)
+
+ def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid):
+ """Aggregates the non-reported error, drift-rate error and absence-detection quality.
+
+ Args:
+ tracker (Tracker): Tracker.
+ sequences (List[Sequence]): Sequences.
+ results (Grid): Results.
+
+ Returns:
+ Tuple[Any]: Non-reported error, drift-rate error and absence-detection quality.
+ """
+
+ not_reported_error = 0
+ drift_rate_error = 0
+ absence_detection = 0
+ absence_count = 0
+
+ for nre, dre, ad in results:
+ not_reported_error += nre
+ drift_rate_error += dre
+ if ad is not None:
+ absence_count += 1
+ absence_detection += ad
+
+ if absence_count > 0:
+ absence_detection /= absence_count
+
+ return not_reported_error / len(sequences), drift_rate_error / len(sequences), absence_detection
+
+from vot.analysis import SequenceAggregator
+from vot.analysis.accuracy import SequenceAccuracy
+
+@analysis_registry.register("longterm_ar")
+class AccuracyRobustness(Analysis):
+ """Longterm multi-object accuracy-robustness measure. """
+
+ threshold = Float(default=0.0, val_min=0, val_max=1)
+ bounded = Boolean(default=True)
+ counts = Include(CountFrames)
+
+ def dependencies(self) -> List[Analysis]:
+ """Returns the dependencies of the analysis."""
+ return self.counts, SequenceAccuracy(burnin=0, threshold=self.threshold, bounded=self.bounded, ignore_invisible=True, ignore_unknown=False)
+
+ def compatible(self, experiment: Experiment):
+ """Checks if the experiment is compatible with the analysis. This analysis is compatible with multi-run experiments."""
+ return isinstance(experiment, MultiRunExperiment)
+
+ @property
+ def _title_default(self):
+ """Default title of the analysis."""
+ return "Accuracy-robustness"
+
+ def describe(self):
+ """Describes the analysis."""
+ return Measure("Accuracy", "A", minimal=0, maximal=1, direction=Sorting.DESCENDING), \
+ Measure("Robustness", "R", minimal=0, direction=Sorting.DESCENDING), \
+ Point("AR plot", dimensions=2, abbreviation="AR", minimal=(0, 0), \
+ maximal=(1, 1), labels=("Robustness", "Accuracy"), trait="ar")
+
+ def compute(self, _: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid:
+ """Aggregate results from multiple sequences into a single value.
+
+ Args:
+ experiment (Experiment): Experiment.
+ trackers (List[Tracker]): Trackers.
+ sequences (List[Sequence]): Sequences.
+ dependencies (List[Grid]): Dependencies.
+
+ Returns:
+ Grid: Aggregated results.
+ """
+
+ frame_counts = dependencies[0]
+ accuracy_analysis = dependencies[1]
+
+ results = Grid(len(trackers), 1)
+
+ for j, _ in enumerate(trackers):
+ accuracy = 0
+ robustness = 0
+ count = 0
+
+ for i, _ in enumerate(sequences):
+ if accuracy_analysis[j, i] is None:
+ continue
+
+ accuracy += accuracy_analysis[j, i][0]
+
+ frame_counts_sequence = frame_counts[j, i][0]
+
+ objects = len(frame_counts_sequence)
+ for o in range(objects):
+ robustness += (1/objects) * frame_counts_sequence[o][0] / (frame_counts_sequence[o][0] + frame_counts_sequence[o][1] + frame_counts_sequence[o][2])
+
+ count += 1
+
+ results[j, 0] = (accuracy / count, robustness / count, (robustness / count, accuracy / count))
+
+ return results
+
+ @property
+ def axes(self) -> Axes:
+ """Returns the axes of the analysis."""
+ return Axes.TRACKERS
\ No newline at end of file
diff --git a/vot/analysis/multistart.py b/vot/analysis/multistart.py
index e03e5ff..21a1aeb 100644
--- a/vot/analysis/multistart.py
+++ b/vot/analysis/multistart.py
@@ -1,4 +1,5 @@
-import math
+"""This module contains the implementation of the accuracy-robustness analysis and EAO analysis for the multistart experiment."""
+
from typing import List, Tuple, Any
import numpy as np
@@ -16,6 +17,16 @@
from vot.utilities.data import Grid
def compute_eao_partial(overlaps: List, success: List[bool], curve_length: int):
+ """Compute the EAO curve for a single sequence. The curve is computed as the average overlap at each frame.
+
+ Args:
+ overlaps (List): List of overlaps for each frame.
+ success (List[bool]): List of success flags for each frame.
+ curve_length (int): Length of the curve.
+
+ Returns:
+ List[float]: EAO curve.
+ """
phi = curve_length * [float(0)]
active = curve_length * [float(0)]
@@ -38,6 +49,7 @@ def compute_eao_partial(overlaps: List, success: List[bool], curve_length: int):
@analysis_registry.register("multistart_ar")
class AccuracyRobustness(SeparableAnalysis):
+ """This analysis computes the accuracy-robustness curve for the multistart experiment."""
burnin = Integer(default=10, val_min=0)
grace = Integer(default=10, val_min=0)
@@ -45,10 +57,12 @@ class AccuracyRobustness(SeparableAnalysis):
threshold = Float(default=0.1, val_min=0, val_max=1)
@property
- def title(self):
+ def _title_default(self):
+ """Title of the analysis."""
return "AR Analysis"
def describe(self):
+ """Return the description of the analysis."""
return Measure("Accuracy", "A", minimal=0, maximal=1, direction=Sorting.DESCENDING), \
Measure("Robustness", "R", minimal=0, direction=Sorting.DESCENDING), \
Point("AR plot", dimensions=2, abbreviation="AR",
@@ -56,9 +70,21 @@ def describe(self):
None, None
def compatible(self, experiment: Experiment):
+ """Check if the experiment is compatible with the analysis. The experiment must be a multistart experiment."""
return isinstance(experiment, MultiStartExperiment)
def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
+ """Compute the accuracy-robustness for each sequence.
+
+ Args:
+ experiment (Experiment): Experiment.
+ tracker (Tracker): Tracker.
+ sequence (Sequence): Sequence.
+ dependencies (List[Grid]): List of dependencies.
+
+ Returns:
+ Tuple[Any]: Accuracy, robustness, AR curve, robustness, length of the sequence.
+ """
results = experiment.results(tracker, sequence)
@@ -79,7 +105,7 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc
if reverse:
proxy = FrameMapSequence(sequence, list(reversed(range(0, i + 1))))
else:
- proxy = FrameMapSequence(sequence, list(range(i, sequence.length)))
+ proxy = FrameMapSequence(sequence, list(range(i, len(sequence))))
trajectory = Trajectory.read(results, name)
@@ -107,17 +133,21 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc
@analysis_registry.register("multistart_average_ar")
class AverageAccuracyRobustness(SequenceAggregator):
+ """This analysis computes the average accuracy-robustness curve for the multistart experiment."""
analysis = Include(AccuracyRobustness)
@property
- def title(self):
+ def _title_default(self):
+ """Title of the analysis."""
return "AR Analysis"
def dependencies(self):
+ """Return the dependencies of the analysis."""
return self.analysis,
def describe(self):
+ """Return the description of the analysis."""
return Measure("Accuracy", "A", minimal=0, maximal=1, direction=Sorting.DESCENDING), \
Measure("Robustness", "R", minimal=0, direction=Sorting.DESCENDING), \
Point("AR plot", dimensions=2, abbreviation="AR",
@@ -125,9 +155,20 @@ def describe(self):
None, None
def compatible(self, experiment: Experiment):
+ """Check if the experiment is compatible with the analysis. The experiment must be a multistart experiment."""
return isinstance(experiment, MultiStartExperiment)
def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid):
+ """Aggregate the results of the analysis.
+
+ Args:
+ tracker (Tracker): Tracker.
+ sequences (List[Sequence]): List of sequences.
+ results (Grid): Grid of results.
+
+ Returns:
+ Tuple[Any]: Aggregated results.
+ """
total_accuracy = 0
total_robustness = 0
weight_accuracy = 0
@@ -145,6 +186,7 @@ def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid):
@analysis_registry.register("multistart_fragments")
class MultiStartFragments(SeparableAnalysis):
+ """This analysis computes the accuracy-robustness curve for the multistart experiment."""
burnin = Integer(default=10, val_min=0)
grace = Integer(default=10, val_min=0)
@@ -152,16 +194,29 @@ class MultiStartFragments(SeparableAnalysis):
threshold = Float(default=0.1, val_min=0, val_max=1)
@property
- def title(self):
+ def _title_default(self):
+ """Title of the analysis."""
return "Fragment Analysis"
def describe(self):
+ """Return the description of the analysis."""
return Curve("Success", 2, "Sc", minimal=(0, 0), maximal=(1,1), trait="points"), Curve("Accuracy", 2, "Ac", minimal=(0, 0), maximal=(1,1), trait="points")
def compatible(self, experiment: Experiment):
+ """Check if the experiment is compatible with the analysis. The experiment must be a multistart experiment."""
return isinstance(experiment, MultiStartExperiment)
def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
+ """Compute the analysis for a single sequence. The sequence must contain at least one anchor.
+
+ Args:
+ experiment (Experiment): Experiment.
+ tracker (Tracker): Tracker.
+ sequence (Sequence): Sequence.
+ dependencies (List[Grid]): List of dependencies.
+
+ Returns:
+ Tuple[Any]: Results of the analysis."""
results = experiment.results(tracker, sequence)
@@ -182,7 +237,7 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc
if reverse:
proxy = FrameMapSequence(sequence, list(reversed(range(0, i + 1))))
else:
- proxy = FrameMapSequence(sequence, list(range(i, sequence.length)))
+ proxy = FrameMapSequence(sequence, list(range(i, len(sequence))))
trajectory = Trajectory.read(results, name)
@@ -208,6 +263,7 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc
# TODO: remove high
@analysis_registry.register("multistart_eao_curves")
class EAOCurves(SeparableAnalysis):
+ """This analysis computes the expected average overlap curve for the multistart experiment."""
burnin = Integer(default=10, val_min=0)
grace = Integer(default=10, val_min=0)
@@ -217,17 +273,31 @@ class EAOCurves(SeparableAnalysis):
high = Integer()
@property
- def title(self):
+ def _title_default(self):
+ """Title of the analysis."""
return "EAO Curve"
def describe(self):
+ """Return the description of the analysis."""
return Plot("Expected average overlap", "EAO", minimal=0, maximal=1, wrt="frames", trait="eao"),
def compatible(self, experiment: Experiment):
+ """Check if the experiment is compatible with the analysis. The experiment must be a multistart experiment."""
return isinstance(experiment, MultiStartExperiment)
def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
-
+ """Compute the analysis for a single sequence. The sequence must contain at least one anchor.
+
+ Args:
+ experiment (Experiment): Experiment.
+ tracker (Tracker): Tracker.
+ sequence (Sequence): Sequence.
+ dependencies (List[Grid]): List of dependencies.
+
+ Returns:
+ Tuple[Any]: Results of the analysis.
+ """
+
results = experiment.results(tracker, sequence)
forward, backward = find_anchors(sequence, experiment.anchor)
@@ -247,7 +317,7 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc
if reverse:
proxy = FrameMapSequence(sequence, list(reversed(range(0, i + 1))))
else:
- proxy = FrameMapSequence(sequence, list(range(i, sequence.length)))
+ proxy = FrameMapSequence(sequence, list(range(i, len(sequence))))
trajectory = Trajectory.read(results, name)
@@ -279,23 +349,38 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc
#TODO: remove high
@analysis_registry.register("multistart_eao_curve")
class EAOCurve(SequenceAggregator):
+ """This analysis computes the expected average overlap curve for the multistart experiment. It is an aggregator of the curves for individual sequences."""
curves = Include(EAOCurves)
@property
- def title(self):
+ def _title_default(self):
+ """Title of the analysis."""
return "EAO Curve"
def describe(self):
+ """Return the description of the analysis."""
return Plot("Expected average overlap", "EAO", minimal=0, maximal=1, wrt="frames", trait="eao"),
def compatible(self, experiment: Experiment):
+ """Check if the experiment is compatible with the analysis. The experiment must be a multistart experiment."""
return isinstance(experiment, MultiStartExperiment)
def dependencies(self):
+ """Return the dependencies of the analysis."""
return self.curves,
def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]:
+ """Aggregate the results of the analysis for multiple sequences. The sequences must contain at least one anchor.
+
+ Args:
+ tracker (Tracker): Tracker.
+ sequences (List[Sequence]): List of sequences.
+ results (Grid): Grid of results.
+
+ Returns:
+ Tuple[Any]: Results of the analysis.
+ """
eao_curve = self.curves.high * [float(0)]
eao_weights = self.curves.high * [float(0)]
@@ -309,29 +394,46 @@ def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid)
@analysis_registry.register("multistart_eao_score")
class EAOScore(Analysis):
+ """This analysis computes the expected average overlap score for the multistart experiment. It does this by computing the EAO curve and then integrating it."""
low = Integer()
high = Integer()
eaocurve = Include(EAOCurve)
@property
- def title(self):
+ def _title_default(self):
+ """Title of the analysis."""
return "EAO analysis"
def describe(self):
+ """Return the description of the analysis."""
return Measure("Expected average overlap", "EAO", minimal=0, maximal=1, direction=Sorting.DESCENDING),
def compatible(self, experiment: Experiment):
+ """Check if the experiment is compatible with the analysis. The experiment must be a multistart experiment."""
return isinstance(experiment, MultiStartExperiment)
def dependencies(self):
+ """Return the dependencies of the analysis."""
return self.eaocurve,
def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid:
+ """Compute the analysis for multiple sequences. The sequences must contain at least one anchor.
+
+ Args:
+ experiment (Experiment): Experiment.
+ trackers (List[Tracker]): List of trackers.
+ sequences (List[Sequence]): List of sequences.
+ dependencies (List[Grid]): List of dependencies.
+
+ Returns:
+ Grid: Grid of results.
+ """
return dependencies[0].foreach(lambda x, i, j: (float(np.mean(x[0][self.low:self.high + 1])), ) )
@property
def axes(self):
+ """Return the axes of the analysis."""
return Axes.TRACKERS
diff --git a/vot/analysis/_processor.py b/vot/analysis/processor.py
similarity index 66%
rename from vot/analysis/_processor.py
rename to vot/analysis/processor.py
index e7794e0..1015a45 100644
--- a/vot/analysis/_processor.py
+++ b/vot/analysis/processor.py
@@ -1,7 +1,13 @@
+"""This module contains the implementation of the analysis processor. The processor is responsible for executing the analysis tasks in parallel and caching the results."""
import logging
+import sys
import threading
-from collections import Iterable, OrderedDict, namedtuple
+from collections import OrderedDict, namedtuple
+if sys.version_info >= (3, 3):
+ from collections.abc import Iterable
+else:
+ from collections import Iterable
from functools import partial
from typing import List, Union, Mapping, Tuple, Any
from concurrent.futures import Executor, Future, ThreadPoolExecutor
@@ -22,7 +28,9 @@
logger = logging.getLogger("vot")
def hashkey(analysis: Analysis, *args):
+ """Compute a hash key for the analysis and its arguments. The key is used for caching the results."""
def transform(arg):
+ """Transform an argument into a hashable object."""
if isinstance(arg, Sequence):
return arg.name
if isinstance(arg, Tracker):
@@ -37,13 +45,24 @@ def transform(arg):
return (analysis.identifier, *[transform(arg) for arg in args])
def unwrap(arg):
+ """Unwrap a single element list."""
+
if isinstance(arg, list) and len(arg) == 1:
return arg[0]
else:
return arg
class AnalysisError(ToolkitException):
+ """An exception that is raised when an analysis fails."""
+
def __init__(self, cause, task=None):
+ """Creates an analysis error.
+
+ Args:
+ cause (Exception): The cause of the error.
+ task (AnalysisTask, optional): The task that caused the error. Defaults to None.
+
+ """
self._tasks = []
self._cause = cause
super().__init__(cause, task)
@@ -51,12 +70,15 @@ def __init__(self, cause, task=None):
@property
def task(self):
+ """The task that caused the error."""
return self._tasks[-1]
def __str__(self):
+ """String representation of the error."""
return "Error during analysis {}".format(self.task)
def print(self, logoutput):
+ """Print the error to the log output."""
logoutput.error(str(self))
if len(self._tasks) > 1:
for task in reversed(self._tasks[:-1]):
@@ -65,6 +87,7 @@ def print(self, logoutput):
@property
def root_cause(self):
+ """The root cause of the error."""
cause = self._cause
if cause is None:
return None
@@ -94,6 +117,8 @@ def __init__(self, strict=True):
self._thread.start()
def submit(self, fn, *args, **kwargs):
+ """Submits a task to the executor."""
+
promise = Future()
with self._lock:
self._queue.put(DebugExecutor.Task(fn, args, kwargs, promise))
@@ -102,6 +127,7 @@ def submit(self, fn, *args, **kwargs):
return promise
def _run(self):
+ """The main loop of the executor."""
while True:
@@ -130,7 +156,7 @@ def _run(self):
except TypeError as e:
- logger.debug("Task %s call resulted in error: %s", task.fn, e)
+ logger.info("Task %s call resulted in error: %s", task.fn, e)
error = e
@@ -138,7 +164,11 @@ def _run(self):
error = e
- logger.debug("Task %s resulted in exception: %s", task.fn, e)
+ logger.info("Task %s resulted in exception: %s", task.fn, e)
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.exception(e)
+
+ logger.exception(e)
if error is not None:
task.promise.set_exception(error)
@@ -149,6 +179,7 @@ def _run(self):
break
def _clear(self):
+ """Clears the queue."""
with self._lock:
while True:
@@ -160,6 +191,12 @@ def _clear(self):
break
def shutdown(self, wait=True):
+ """Shuts down the executor. If wait is True, the method blocks until all tasks are completed.
+
+ Args:
+ wait (bool, optional): Wait for all tasks to complete. Defaults to True.
+ """
+
self._alive = False
self._clear()
if wait:
@@ -169,8 +206,14 @@ def shutdown(self, wait=True):
self._thread.join()
class ExecutorWrapper(object):
+ """A wrapper for an executor that allows to submit tasks with dependencies."""
def __init__(self, executor: Executor):
+ """Creates an executor wrapper.
+
+ Args:
+ executor (Executor): The executor to wrap.
+ """
self._lock = RLock()
self._executor = executor
self._pending = OrderedDict()
@@ -178,9 +221,20 @@ def __init__(self, executor: Executor):
@property
def total(self):
+ """The total number of tasks submitted to the executor."""
return self._total
def submit(self, fn, *futures: Tuple[Future], mapping=None) -> Future:
+ """Submits a task to the executor. The task will be executed when all futures are completed.
+
+ Args:
+ fn (Callable): The task to execute.
+ futures (Tuple[Future]): The futures that must be completed before the task is executed.
+ mapping (Dict[Future, Any], optional): A mapping of futures to values. Defaults to None.
+
+ Returns:
+ Future: A future that will be completed when the task is completed.
+ """
with self._lock:
@@ -200,7 +254,7 @@ def submit(self, fn, *futures: Tuple[Future], mapping=None) -> Future:
return proxy
def _ready_callback(self, fn, mapping, proxy: Future, future: Future):
- """ Internally handles completion of dependencies
+ """ Internally handles completion of dependencies. Submits the task to the executor.
"""
with self._lock:
@@ -229,9 +283,10 @@ def _ready_callback(self, fn, mapping, proxy: Future, future: Future):
def _done_callback(self, proxy: Future, future: Future):
""" Internally handles completion of executor future, copies result to proxy
+
Args:
- fn (function): [description]
- future (Future): [description]
+ proxy (Future): Proxy future
+ future (Future): Executor future
"""
if future.cancelled():
@@ -246,6 +301,10 @@ def _done_callback(self, proxy: Future, future: Future):
def _proxy_done(self, future: Future):
""" Internally handles events for proxy futures, this means handling cancellation.
+
+ Args:
+ future (Future): Proxy future
+
"""
with self._lock:
@@ -261,8 +320,15 @@ def _proxy_done(self, future: Future):
dependency.cancel()
class FuturesAggregator(Future):
+ """A future that aggregates results from other futures."""
def __init__(self, *futures: Tuple[Future]):
+ """Initializes the aggregator.
+
+ Args:
+ *futures (Tuple[Future]): The futures to aggregate.
+ """
+
super().__init__()
self._lock = RLock()
self._results = [None] * len(futures)
@@ -275,6 +341,8 @@ def __init__(self, *futures: Tuple[Future]):
self.set_result([])
def _on_result(self, i, future):
+ """Handles completion of a dependency future."""
+
with self._lock:
if self.done():
return
@@ -288,6 +356,8 @@ def _on_result(self, i, future):
self.set_result(self._results)
def _on_done(self, future):
+ """Handles completion of the future."""
+
with self._lock:
try:
self.set_result(future.result())
@@ -295,6 +365,8 @@ def _on_done(self, future):
self.set_exception(e)
def cancel(self):
+ """Cancels the future and all dependencies."""
+
with self._lock:
for promise in self._tasks:
promise.cancel()
@@ -302,9 +374,19 @@ def cancel(self):
class AnalysisTask(object):
+ """A task that computes an analysis."""
def __init__(self, analysis: Analysis, experiment: Experiment,
trackers: List[Tracker], sequences: List[Sequence]):
+ """Initializes a new instance of the AnalysisTask class.
+
+ Args:
+ analysis (Analysis): The analysis to compute.
+ experiment (Experiment): The experiment to compute the analysis for.
+ trackers (List[Tracker]): The trackers to compute the analysis for.
+ sequences (List[Sequence]): The sequences to compute the analysis for.
+ """
+
self._analysis = analysis
self._trackers = trackers
self._experiment = experiment
@@ -312,6 +394,15 @@ def __init__(self, analysis: Analysis, experiment: Experiment,
self._key = hashkey(analysis, experiment, trackers, sequences)
def __call__(self, dependencies: List[Grid] = None):
+ """Computes the analysis.
+
+ Args:
+ dependencies (List[Grid], optional): The dependencies to use. Defaults to None.
+
+ Returns:
+ Grid: The computed analysis.
+ """
+
try:
if dependencies is None:
dependencies = []
@@ -320,9 +411,18 @@ def __call__(self, dependencies: List[Grid] = None):
raise AnalysisError(cause=e, task=self._key)
class AnalysisPartTask(object):
+ """A task that computes a part of a separable analysis."""
def __init__(self, analysis: SeparableAnalysis, experiment: Experiment,
trackers: List[Tracker], sequences: List[Sequence]):
+ """Initializes a new instance of the AnalysisPartTask class.
+
+ Args:
+ analysis (SeparableAnalysis): The analysis to compute.
+ experiment (Experiment): The experiment to compute the analysis for.
+ trackers (List[Tracker]): The trackers to compute the analysis for.
+ sequences (List[Sequence]): The sequences to compute the analysis for.
+ """
self._analysis = analysis
self._trackers = trackers
self._experiment = experiment
@@ -330,6 +430,14 @@ def __init__(self, analysis: SeparableAnalysis, experiment: Experiment,
self._key = hashkey(analysis, experiment, unwrap(trackers), unwrap(sequences))
def __call__(self, dependencies: List[Grid] = None):
+ """Computes the analysis.
+
+ Args:
+ dependencies (List[Grid], optional): The dependencies to use. Defaults to None.
+
+ Returns:
+ Grid: The computed analysis.
+ """
try:
if dependencies is None:
dependencies = []
@@ -338,9 +446,19 @@ def __call__(self, dependencies: List[Grid] = None):
raise AnalysisError(cause=e, task=self._key)
class AnalysisJoinTask(object):
+ """A task that joins the results of a separable analysis."""
def __init__(self, analysis: SeparableAnalysis, experiment: Experiment,
trackers: List[Tracker], sequences: List[Sequence]):
+
+ """Initializes a new instance of the AnalysisJoinTask class.
+
+ Args:
+ analysis (Analysis): The analysis to join.
+ experiment (Experiment): The experiment to join the analysis for.
+ trackers (List[Tracker]): The trackers to join the analysis for.
+ sequences (List[Sequence]): The sequences to join the analysis for.
+ """
self._analysis = analysis
self._trackers = trackers
self._experiment = experiment
@@ -348,29 +466,55 @@ def __init__(self, analysis: SeparableAnalysis, experiment: Experiment,
self._key = hashkey(analysis, experiment, trackers, sequences)
def __call__(self, results: List[Grid]):
+ """Joins the results of the analysis.
+
+ Args:
+ results (List[Grid]): The results to join.
+
+ Returns:
+ Grid: The joined analysis.
+ """
+
try:
return self._analysis.join(self._trackers, self._sequences, results)
except BaseException as e:
raise AnalysisError(cause=e, task=self._key)
class AnalysisFuture(Future):
+ """A future that represents the result of an analysis."""
def __init__(self, key):
+ """Initializes a new instance of the AnalysisFuture class.
+
+ Args:
+ key (str): The key of the analysis.
+ """
+
super().__init__()
self._key = key
@property
def key(self):
+ """Gets the key of the analysis."""
return self._key
def __repr__(self) -> str:
+ """Gets a string representation of the future."""
return "".format(self._key)
class AnalysisProcessor(object):
+ """A processor that computes analyses."""
_context = threading.local()
def __init__(self, executor: Executor = None, cache: Cache = None):
+ """Initializes a new instance of the AnalysisProcessor class.
+
+ Args:
+ executor (Executor, optional): The executor to use for computations. Defaults to None.
+ cache (Cache, optional): The cache to use for computations. Defaults to None.
+
+ """
if executor is None:
executor = ThreadPoolExecutor(1)
@@ -383,6 +527,17 @@ def __init__(self, executor: Executor = None, cache: Cache = None):
def commit(self, analysis: Analysis, experiment: Experiment,
trackers: Union[Tracker, List[Tracker]], sequences: Union[Sequence, List[Sequence]]) -> Future:
+ """Commits an analysis for computation. If the analysis is already being computed, the existing future is returned.
+
+ Args:
+ analysis (Analysis): The analysis to commit.
+ experiment (Experiment): The experiment to commit the analysis for.
+ trackers (Union[Tracker, List[Tracker]]): The trackers to commit the analysis for.
+ sequences (Union[Sequence, List[Sequence]]): The sequences to commit the analysis for.
+
+ Returns:
+ Future: A future that represents the result of the analysis.
+ """
key = hashkey(analysis, experiment, trackers, sequences)
@@ -401,6 +556,7 @@ def commit(self, analysis: Analysis, experiment: Experiment,
if isinstance(analysis, SeparableAnalysis):
def select_dependencies(analysis: SeparableAnalysis, tracker: int, sequence: int, *dependencies):
+ """Selects the dependencies for a part of a separable analysis."""
return [analysis.select(meta, data, tracker, sequence) for meta, data in zip(analysis.dependencies(), dependencies)]
promise = AnalysisFuture(key)
@@ -440,7 +596,16 @@ def select_dependencies(analysis: SeparableAnalysis, tracker: int, sequence: int
return promise
- def _exists(self, key):
+ def _exists(self, key: str):
+ """Checks if an analysis is already being computed.
+
+ Args:
+ key (str): The key of the analysis to check.
+
+ Returns:
+ AnalysisFuture: The future that represents the analysis if it is already being computed, None otherwise.
+ """
+
if self._cache is not None and key in self._cache:
promise = AnalysisFuture(key)
promise.set_result(self._cache[key])
@@ -454,6 +619,12 @@ def _exists(self, key):
return None
def _future_done(self, future: Future):
+ """Handles the completion of a future.
+
+ Args:
+ future (Future): The future that completed.
+
+ """
with self._lock:
@@ -489,6 +660,16 @@ def _future_done(self, future: Future):
def _promise_cancelled(self, future: Future):
+ """Handles the cancellation of a promise. If the promise is the last promise for a computation, the computation is cancelled.
+
+ Args:
+ future (Future): The promise that was cancelled.
+
+ Returns:
+ bool: True if the promise was the last promise for a computation, False otherwise.
+
+ """
+
if not future.cancelled():
return
@@ -508,20 +689,27 @@ def _promise_cancelled(self, future: Future):
@property
def pending(self):
+ """The number of pending analyses."""
+
with self._lock:
return len(self._pending)
@property
def total(self):
+ """The total number of analyses."""
+
with self._lock:
return self._executor.total
def cancel(self):
+ """Cancels all pending analyses."""
+
with self._lock:
for _, future in list(self._pending.items()):
future.cancel()
def wait(self):
+ """Waits for all pending analyses to complete. If no analyses are pending, this method returns immediately."""
if self.total == 0:
return
@@ -542,6 +730,11 @@ def wait(self):
progress.close()
def __enter__(self):
+ """Sets this analysis processor as the default for the current thread.
+
+ Returns:
+ AnalysisProcessor: This analysis processor.
+ """
processor = getattr(AnalysisProcessor._context, 'analysis_processor', None)
@@ -557,6 +750,8 @@ def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
+ """Clears the default analysis processor for the current thread."""
+
processor = getattr(AnalysisProcessor._context, 'analysis_processor', None)
if processor == self:
@@ -565,6 +760,11 @@ def __exit__(self, exc_type, exc_value, traceback):
@staticmethod
def default():
+ """Returns the default analysis processor for the current thread.
+
+ Returns:
+ AnalysisProcessor: The default analysis processor for the current thread.
+ """
processor = getattr(AnalysisProcessor._context, 'analysis_processor', None)
@@ -581,11 +781,23 @@ def default():
@staticmethod
def commit_default(analysis: Analysis, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence]):
+ """Commits an analysis to the default analysis processor. This method is thread-safe. If the analysis is already being computed, this method returns immediately."""
processor = AnalysisProcessor.default()
return processor.commit(analysis, experiment, trackers, sequences)
def run(self, analysis: Analysis, experiment: Experiment,
trackers: Union[Tracker, List[Tracker]], sequences: Union[Sequence, List[Sequence]]) -> Grid:
+ """Runs an analysis on a set of trackers and sequences. This method is thread-safe. If the analysis is already being computed, this method returns immediately.
+
+ Args:
+ analysis (Analysis): The analysis to run.
+ experiment (Experiment): The experiment to run the analysis on.
+ trackers (Union[Tracker, List[Tracker]]): The trackers to run the analysis on.
+ sequences (Union[Sequence, List[Sequence]]): The sequences to run the analysis on.
+
+ Returns:
+ Grid: The results of the analysis.
+ """
assert self.pending == 0
@@ -597,23 +809,49 @@ def run(self, analysis: Analysis, experiment: Experiment,
@staticmethod
def run_default(analysis: Analysis, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence]):
+ """Runs an analysis on a set of trackers and sequences. This method is thread-safe. If the analysis is already being computed, this method returns immediately.
+
+ Args:
+ analysis (Analysis): The analysis to run.
+ experiment (Experiment): The experiment to run the analysis on.
+ trackers (List[Tracker]): The trackers to run the analysis on.
+ sequences (List[Sequence]): The sequences to run the analysis on.
+
+ Returns:
+ Grid: The results of the analysis."""
+
processor = AnalysisProcessor.default()
return processor.run(analysis, experiment, trackers, sequences)
def process_stack_analyses(workspace: "Workspace", trackers: List[Tracker]):
+ """Process all analyses in the workspace stack. This function is used by the command line interface to run all the analyses provided in a stack.
+
+ Args:
+ workspace (Workspace): The workspace to process.
+ trackers (List[Tracker]): The trackers to run the analyses on.
+
+ """
processor = AnalysisProcessor.default()
results = dict()
condition = Condition()
+ errors = []
def insert_result(container: dict, key):
+ """Creates a callback that inserts the result of a computation into a container. The container is a dictionary that maps analyses to their results.
+
+ Args:
+ container (dict): The container to insert the result into.
+ key (Analysis): The analysis to insert the result for.
+ """
def insert(future: Future):
+ """Inserts the result of a computation into a container."""
try:
container[key] = future.result()
except AnalysisError as e:
- e.print(logger)
+ errors.append(e)
except Exception as e:
logger.exception(e)
with condition:
@@ -628,7 +866,7 @@ def insert(future: Future):
results[experiment] = experiment_results
- sequences = [experiment.transform(sequence) for sequence in workspace.dataset]
+ sequences = experiment.transform(workspace.dataset)
for analysis in experiment.analyses:
@@ -665,4 +903,12 @@ def insert(future: Future):
logger.info("Analysis interrupted by user, aborting.")
return None
+ if len(errors) > 0:
+ logger.info("Errors occured during analysis, incomplete.")
+ for e in errors:
+ logger.info("Failed task {}: {}".format(e.task, e.root_cause))
+ if logger.isEnabledFor(logging.DEBUG):
+ e.print(logger)
+ return None
+
return results
\ No newline at end of file
diff --git a/vot/analysis/supervised.py b/vot/analysis/supervised.py
index 0077d45..93f89b6 100644
--- a/vot/analysis/supervised.py
+++ b/vot/analysis/supervised.py
@@ -11,28 +11,40 @@
from vot.tracker import Tracker, Trajectory
from vot.dataset import Sequence
-from vot.dataset.proxy import FrameMapSequence
from vot.experiment import Experiment
from vot.experiment.multirun import SupervisedExperiment
-from vot.experiment.multistart import MultiStartExperiment, find_anchors
-from vot.region import Region, Special, calculate_overlaps
+from vot.region import Region, calculate_overlaps
from vot.analysis import MissingResultsException, Measure, Point, is_special, Plot, Analysis, \
Sorting, SeparableAnalysis, SequenceAggregator, analysis_registry, TrackerSeparableAnalysis, Axes
from vot.utilities.data import Grid
def compute_accuracy(trajectory: List[Region], sequence: Sequence, burnin: int = 10,
ignore_unknown: bool = True, bounded: bool = True) -> float:
+ """ Computes accuracy of a tracker on a given sequence. Accuracy is defined as mean overlap of the tracker
+ region with the groundtruth region. The overlap is computed only for frames where the tracker is not in
+ initialization or failure state. The overlap is computed only for frames after the burnin period.
+
+ Args:
+ trajectory (List[Region]): Tracker trajectory.
+ sequence (Sequence): Sequence to compute accuracy on.
+ burnin (int, optional): Burnin period. Defaults to 10.
+ ignore_unknown (bool, optional): Ignore unknown regions. Defaults to True.
+ bounded (bool, optional): Consider only first N frames. Defaults to True.
+
+ Returns:
+ float: Accuracy.
+ """
overlaps = np.array(calculate_overlaps(trajectory, sequence.groundtruth(), (sequence.size) if bounded else None))
mask = np.ones(len(overlaps), dtype=bool)
for i, region in enumerate(trajectory):
- if is_special(region, Special.UNKNOWN) and ignore_unknown:
+ if is_special(region, Trajectory.UNKNOWN) and ignore_unknown:
mask[i] = False
- elif is_special(region, Special.INITIALIZATION):
+ elif is_special(region, Trajectory.INITIALIZATION):
for j in range(i, min(len(trajectory), i + burnin)):
mask[j] = False
- elif is_special(region, Special.FAILURE):
+ elif is_special(region, Trajectory.FAILURE):
mask[i] = False
if any(mask):
@@ -41,14 +53,17 @@ def compute_accuracy(trajectory: List[Region], sequence: Sequence, burnin: int =
return 0, 0
def count_failures(trajectory: List[Region]) -> Tuple[int, int]:
- return len([region for region in trajectory if is_special(region, Special.FAILURE)]), len(trajectory)
+ """Counts number of failures in a trajectory. Failure is defined as a frame where the tracker is in failure state."""
+ return len([region for region in trajectory if is_special(region, Trajectory.FAILURE)]), len(trajectory)
def locate_failures_inits(trajectory: List[Region]) -> Tuple[int, int]:
- return [i for i, region in enumerate(trajectory) if is_special(region, Special.FAILURE)], \
- [i for i, region in enumerate(trajectory) if is_special(region, Special.INITIALIZATION)]
+ """Locates failures and initializations in a trajectory. Failure is defined as a frame where the tracker is in failure state."""
+ return [i for i, region in enumerate(trajectory) if is_special(region, Trajectory.FAILURE)], \
+ [i for i, region in enumerate(trajectory) if is_special(region, Trajectory.INITIALIZATION)]
def compute_eao_curve(overlaps: List, weights: List[float], success: List[bool]):
+ """Computes EAO curve from a list of overlaps, weights and success flags."""
max_length = max([len(el) for el in overlaps])
total_runs = len(overlaps)
@@ -74,6 +89,11 @@ def compute_eao_curve(overlaps: List, weights: List[float], success: List[bool])
@analysis_registry.register("supervised_ar")
class AccuracyRobustness(SeparableAnalysis):
+ """Accuracy-Robustness analysis. Computes accuracy and robustness of a tracker on a given sequence.
+ Accuracy is defined as mean overlap of the tracker region with the groundtruth region. The overlap is computed only for frames where the tracker is not in
+ initialization or failure state. The overlap is computed only for frames after the burnin period.
+ Robustness is defined as a number of failures divided by the total number of frames.
+ """
sensitivity = Float(default=30, val_min=1)
burnin = Integer(default=10, val_min=0)
@@ -81,10 +101,12 @@ class AccuracyRobustness(SeparableAnalysis):
bounded = Boolean(default=True)
@property
- def title(self):
+ def _title_default(self):
+ """Returns title of the analysis."""
return "AR analysis"
def describe(self):
+ """Returns description of the analysis."""
return Measure("Accuracy", "A", minimal=0, maximal=1, direction=Sorting.DESCENDING), \
Measure("Robustness", "R", minimal=0, direction=Sorting.ASCENDING), \
Point("AR plot", dimensions=2, abbreviation="AR", minimal=(0, 0), \
@@ -92,9 +114,21 @@ def describe(self):
None
def compatible(self, experiment: Experiment):
+ """Returns True if the analysis is compatible with the experiment. Only SupervisedExperiment is compatible."""
return isinstance(experiment, SupervisedExperiment)
def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
+ """Computes accuracy and robustness of a tracker on a given sequence.
+
+ Args:
+ experiment (Experiment): Experiment.
+ tracker (Tracker): Tracker.
+ sequence (Sequence): Sequence.
+ dependencies (List[Grid]): Dependencies.
+
+ Returns:
+ Tuple[Any]: Accuracy, robustness, AR, number of frames.
+ """
trajectories = experiment.gather(tracker, sequence)
if len(trajectories) == 0:
@@ -112,17 +146,27 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc
@analysis_registry.register("supervised_average_ar")
class AverageAccuracyRobustness(SequenceAggregator):
+ """Average accuracy-robustness analysis. Computes average accuracy and robustness of a tracker on a given sequence.
+
+ Accuracy is defined as mean overlap of the tracker region with the groundtruth region. The overlap is computed only for frames where the tracker is not in
+ initialization or failure state. The overlap is computed only for frames after the burnin period.
+ Robustness is defined as a number of failures divided by the total number of frames.
+ The analysis is computed as an average of accuracy and robustness over all sequences.
+ """
analysis = Include(AccuracyRobustness)
@property
- def title(self):
+ def _title_default(self):
+ """Returns title of the analysis."""
return "AR Analysis"
def dependencies(self):
+ """Returns dependencies of the analysis."""
return self.analysis,
def describe(self):
+ """Returns description of the analysis."""
return Measure("Accuracy", "A", minimal=0, maximal=1, direction=Sorting.DESCENDING), \
Measure("Robustness", "R", minimal=0, direction=Sorting.ASCENDING), \
Point("AR plot", dimensions=2, abbreviation="AR", minimal=(0, 0), \
@@ -130,9 +174,21 @@ def describe(self):
None
def compatible(self, experiment: Experiment):
+ """Returns True if the analysis is compatible with the experiment. Only SupervisedExperiment is compatible."""
return isinstance(experiment, SupervisedExperiment)
def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid):
+ """Aggregates results of the analysis.
+
+ Args:
+ tracker (Tracker): Tracker.
+ sequences (List[Sequence]): List of sequences.
+ results (Grid): Results of the analysis.
+
+ Returns:
+ Tuple[Any]: Accuracy, robustness, AR, number of frames.
+ """
+
failures = 0
accuracy = 0
weight_total = 0
@@ -148,21 +204,40 @@ def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid):
@analysis_registry.register("supervised_eao_curve")
class EAOCurve(TrackerSeparableAnalysis):
+ """Expected Average Overlap curve analysis. Computes expected average overlap of a tracker on a given sequence.
+ The overlap is computed only for frames where the tracker is not in initialization or failure state.
+ The overlap is computed only for frames after the burnin period.
+ The analysis is computed as an average of accuracy and robustness over all sequences.
+ """
burnin = Integer(default=10, val_min=0)
bounded = Boolean(default=True)
@property
- def title(self):
+ def _title_default(self):
+ """Returns title of the analysis."""
return "EAO Curve"
def describe(self):
+ """Returns description of the analysis."""
return Plot("Expected Average Overlap", "EAO", minimal=0, maximal=1, trait="eao"),
def compatible(self, experiment: Experiment):
+ """Returns True if the analysis is compatible with the experiment. Only SupervisedExperiment is compatible."""
return isinstance(experiment, SupervisedExperiment)
def subcompute(self, experiment: Experiment, tracker: Tracker, sequences: List[Sequence], dependencies: List[Grid]) -> Tuple[Any]:
+ """Computes expected average overlap of a tracker on a given sequence.
+
+ Args:
+ experiment (Experiment): Experiment.
+ tracker (Tracker): Tracker.
+ sequences (List[Sequence]): List of sequences.
+ dependencies (List[Grid]): Dependencies.
+
+ Returns:
+ Tuple[Any]: Expected average overlap.
+ """
overlaps_all = []
weights_all = []
@@ -203,27 +278,45 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequences: List[S
@analysis_registry.register("supervised_eao_score")
class EAOScore(Analysis):
+ """Expected Average Overlap score analysis. The analysis is computed as an average of EAO scores over multiple sequences.
+ """
eaocurve = Include(EAOCurve)
low = Integer()
high = Integer()
@property
- def title(self):
+ def _title_default(self):
+ """Returns title of the analysis."""
return "EAO analysis"
def describe(self):
+ """Returns description of the analysis."""
return Measure("Expected average overlap", "EAO", 0, 1, Sorting.DESCENDING),
def compatible(self, experiment: Experiment):
+ """Returns True if the analysis is compatible with the experiment. Only SupervisedExperiment is compatible."""
return isinstance(experiment, SupervisedExperiment)
def dependencies(self):
+ """Returns dependencies of the analysis."""
return self.eaocurve,
def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid:
+ """Computes expected average overlap of a tracker on a given sequence.
+
+ Args:
+ experiment (Experiment): Experiment.
+ trackers (List[Tracker]): List of trackers.
+ sequences (List[Sequence]): List of sequences.
+ dependencies (List[Grid]): Dependencies.
+
+ Returns:
+ Grid: Expected average overlap.
+ """
return dependencies[0].foreach(lambda x, i, j: (float(np.mean(x[0][self.low:self.high + 1])), ) )
@property
def axes(self):
+ """Returns axes of the analysis."""
return Axes.TRACKERS
\ No newline at end of file
diff --git a/vot/analysis/tpr.py b/vot/analysis/tpr.py
deleted file mode 100644
index d53d0d6..0000000
--- a/vot/analysis/tpr.py
+++ /dev/null
@@ -1,261 +0,0 @@
-import math
-import numpy as np
-from typing import List, Iterable, Tuple, Any
-import itertools
-
-from attributee import Float, Integer, Boolean, Include
-
-from vot.tracker import Tracker
-from vot.dataset import Sequence
-from vot.region import Region, RegionType, calculate_overlaps
-from vot.experiment import Experiment
-from vot.experiment.multirun import UnsupervisedExperiment
-from vot.analysis import SequenceAggregator, Analysis, SeparableAnalysis, \
- MissingResultsException, Measure, Sorting, Curve, Plot, SequenceAggregator, \
- Axes, analysis_registry
-from vot.utilities.data import Grid
-
-def determine_thresholds(scores: Iterable[float], resolution: int) -> List[float]:
- scores = [score for score in scores if not math.isnan(score)] #and not score is None]
- scores = sorted(scores, reverse=True)
-
- if len(scores) > resolution - 2:
- delta = math.floor(len(scores) / (resolution - 2))
- idxs = np.round(np.linspace(delta, len(scores) - delta, num=resolution - 2)).astype(np.int)
- thresholds = [scores[idx] for idx in idxs]
- else:
- thresholds = scores
-
- thresholds.insert(0, math.inf)
- thresholds.insert(len(thresholds), -math.inf)
-
- return thresholds
-
-def compute_tpr_curves(trajectory: List[Region], confidence: List[float], sequence: Sequence, thresholds: List[float],
- ignore_unknown: bool = True, bounded: bool = True):
-
- overlaps = np.array(calculate_overlaps(trajectory, sequence.groundtruth(), (sequence.size) if bounded else None))
- confidence = np.array(confidence)
-
- n_visible = len([region for region in sequence.groundtruth() if region.type is not RegionType.SPECIAL])
-
-
-
- precision = len(thresholds) * [float(0)]
- recall = len(thresholds) * [float(0)]
-
- for i, threshold in enumerate(thresholds):
-
- subset = confidence >= threshold
-
- if np.sum(subset) == 0:
- precision[i] = 1
- recall[i] = 0
- else:
- precision[i] = np.mean(overlaps[subset])
- recall[i] = np.sum(overlaps[subset]) / n_visible
-
- return precision, recall
-
-class _ConfidenceScores(SeparableAnalysis):
-
- @property
- def title(self):
- return "Aggregate confidence scores"
-
- def describe(self):
- return None,
-
- def compatible(self, experiment: Experiment):
- return isinstance(experiment, UnsupervisedExperiment)
-
- def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
-
- scores_all = []
- trajectories = experiment.gather(tracker, sequence)
-
- if len(trajectories) == 0:
- raise MissingResultsException("Missing results for sequence {}".format(sequence.name))
-
- for trajectory in trajectories:
- confidence = [trajectory.properties(i).get('confidence', 0) for i in range(len(trajectory))]
- scores_all.extend(confidence)
-
- return scores_all,
-
-
-class _Thresholds(SequenceAggregator):
-
- resolution = Integer(default=100)
-
- @property
- def title(self):
- return "Thresholds for tracking precision/recall"
-
- def describe(self):
- return None,
-
- def compatible(self, experiment: Experiment):
- return isinstance(experiment, UnsupervisedExperiment)
-
- def dependencies(self):
- return _ConfidenceScores(),
-
- def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]:
-
- thresholds = determine_thresholds(itertools.chain(*[result[0] for result in results]), self.resolution),
-
- return thresholds,
-
-@analysis_registry.register("pr_curves")
-class PrecisionRecallCurves(SeparableAnalysis):
-
- thresholds = Include(_Thresholds)
- ignore_unknown = Boolean(default=True)
- bounded = Boolean(default=True)
-
- @property
- def title(self):
- return "Tracking precision/recall"
-
- def describe(self):
- return Curve("Precision Recall curve", dimensions=2, abbreviation="PR", minimal=(0, 0), maximal=(1, 1), labels=("Recall", "Precision")), None
-
- def compatible(self, experiment: Experiment):
- return isinstance(experiment, UnsupervisedExperiment)
-
- def dependencies(self):
- return self.thresholds,
-
- def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]:
-
- thresholds = dependencies[0, 0][0][0] # dependencies[0][0, 0]
-
- trajectories = experiment.gather(tracker, sequence)
-
- if len(trajectories) == 0:
- raise MissingResultsException()
-
- precision = len(thresholds) * [float(0)]
- recall = len(thresholds) * [float(0)]
- for trajectory in trajectories:
- confidence = [trajectory.properties(i).get('confidence', 0) for i in range(len(trajectory))]
- pr, re = compute_tpr_curves(trajectory.regions(), confidence, sequence, thresholds, self.ignore_unknown, self.bounded)
- for i in range(len(thresholds)):
- precision[i] += pr[i]
- recall[i] += re[i]
-
-# return [(re / len(trajectories), pr / len(trajectories)) for pr, re in zip(precision, recall)], thresholds
- return [(pr / len(trajectories), re / len(trajectories)) for pr, re in zip(precision, recall)], thresholds
-
-@analysis_registry.register("pr_curve")
-class PrecisionRecallCurve(SequenceAggregator):
-
- curves = Include(PrecisionRecallCurves)
-
- @property
- def title(self):
- return "Tracking precision/recall average curve"
-
- def describe(self):
- return self.curves.describe()
-
- def compatible(self, experiment: Experiment):
- return isinstance(experiment, UnsupervisedExperiment)
-
- def dependencies(self):
- return self.curves,
-
- # def collapse(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]:
- def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]:
-
- curve = None
- thresholds = None
-
- for partial, thresholds in results:
- if curve is None:
- curve = partial
- continue
-
- curve = [(pr1 + pr2, re1 + re2) for (pr1, re1), (pr2, re2) in zip(curve, partial)]
-
- curve = [(re / len(results), pr / len(results)) for pr, re in curve]
-
- return curve, thresholds
-
-
-@analysis_registry.register("f_curve")
-class FScoreCurve(Analysis):
-
- beta = Float(default=1)
- prcurve = Include(PrecisionRecallCurve)
-
- @property
- def title(self):
- return "Tracking precision/recall"
-
- def describe(self):
- return Plot("Tracking F-score curve", "F", wrt="normalized threshold", minimal=0, maximal=1), None
-
- def compatible(self, experiment: Experiment):
- return isinstance(experiment, UnsupervisedExperiment)
-
- def dependencies(self):
- return self.prcurve,
-
- def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid:
- processed_results = Grid(len(trackers), 1)
-
- for i, result in enumerate(dependencies[0]):
- beta2 = (self.beta * self.beta)
- f_curve = [((1 + beta2) * pr_ * re_) / (beta2 * pr_ + re_) for pr_, re_ in result[0]]
-
- processed_results[i, 0] = (f_curve, result[0][1])
-
- return processed_results
-
- @property
- def axes(self):
- return Axes.TRACKERS
-
-@analysis_registry.register("average_tpr")
-class PrecisionRecall(Analysis):
-
- prcurve = Include(PrecisionRecallCurve)
- fcurve = Include(FScoreCurve)
-
- @property
- def title(self):
- return "Tracking precision/recall"
-
- def describe(self):
- return Measure("Precision", "Pr", minimal=0, maximal=1, direction=Sorting.DESCENDING), \
- Measure("Recall", "Re", minimal=0, maximal=1, direction=Sorting.DESCENDING), \
- Measure("F Score", "F", minimal=0, maximal=1, direction=Sorting.DESCENDING)
-
- def compatible(self, experiment: Experiment):
- return isinstance(experiment, UnsupervisedExperiment)
-
- def dependencies(self):
- return self.prcurve, self.fcurve
-
- def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid:
-
- f_curves = dependencies[1]
- pr_curves = dependencies[0]
-
- joined = Grid(len(trackers), 1)
-
- for i, (f_curve, pr_curve) in enumerate(zip(f_curves, pr_curves)):
- # get optimal F-score and Pr and Re at this threshold
- f_score = max(f_curve[0])
- best_i = f_curve[0].index(f_score)
- re_score = pr_curve[0][best_i][0]
- pr_score = pr_curve[0][best_i][1]
- joined[i, 0] = (pr_score, re_score, f_score)
-
- return joined
-
- @property
- def axes(self):
- return Axes.TRACKERS
diff --git a/vot/dataset/__init__.py b/vot/dataset/__init__.py
index 78b9f2b..960dfc5 100644
--- a/vot/dataset/__init__.py
+++ b/vot/dataset/__init__.py
@@ -1,98 +1,242 @@
+"""Dataset module provides an interface for accessing the datasets and sequences. It also provides a set of utility functions for downloading and extracting datasets."""
+
import os
-import json
-import glob
+import logging
+from numbers import Number
+from collections import namedtuple
from abc import abstractmethod, ABC
+from typing import List, Mapping, Optional, Set, Tuple, Iterator
+
+from class_registry import ClassRegistry
from PIL.Image import Image
import numpy as np
+from cachetools import cached, LRUCache
+
+from vot.region import Region
from vot import ToolkitException
-from vot.utilities import read_properties
-from vot.region import parse
import cv2
+logger = logging.getLogger("vot")
+
+dataset_downloader = ClassRegistry("vot_downloader")
+sequence_indexer = ClassRegistry("vot_indexer")
+sequence_reader = ClassRegistry("vot_sequence")
+
class DatasetException(ToolkitException):
+ """Dataset and sequence related exceptions
+ """
pass
class Channel(ABC):
+ """Abstract representation of individual image channel, a sequence of images with
+ uniform dimensions.
+ """
def __init__(self):
+ """ Base constructor for channel"""
pass
- @property
- @abstractmethod
- def length(self):
- pass
+ def __len__(self) -> int:
+ """Returns the length of channel
+
+ Returns:
+ int: Length of channel
+ """
+ raise NotImplementedError()
@abstractmethod
- def frame(self, index):
+ def frame(self, index: int) -> "Frame":
+ """Returns frame object for the given index
+
+ Args:
+ index (int): Index of the frame
+
+ Returns:
+ Frame: Frame object
+ """
pass
@abstractmethod
- def filename(self, index):
+ def filename(self, index: int) -> str:
+ """Returns filename for the given index of the channel sequence
+
+ Args:
+ index (int): Index of the frame
+
+ Returns:
+ str: Filename of the frame
+ """
pass
@property
@abstractmethod
- def size(self):
+ def size(self) -> int:
+ """Returns the size of the channel in bytes"""
pass
class Frame(object):
+ """Frame object represents a single frame in the sequence. It provides access to the
+ image data, groundtruth, tags and values as a wrapper around the sequence object."""
def __init__(self, sequence, index):
+ """Base constructor for frame object
+
+ Args:
+ sequence (Sequence): Sequence object
+ index (int): Index of the frame
+
+ Returns:
+ Frame: Frame object
+ """
self._sequence = sequence
self._index = index
@property
def index(self) -> int:
+ """Returns the index of the frame
+
+ Returns:
+ int: Index of the frame
+ """
return self._index
@property
def sequence(self) -> 'Sequence':
+ """Returns the sequence object of the frame object
+
+ Returns:
+ Sequence: Sequence object
+ """
return self._sequence
def channels(self):
+ """Returns the list of channels in the sequence
+
+ Returns:
+ List[str]: List of channels
+ """
return self._sequence.channels()
- def channel(self, channel=None):
+ def channel(self, channel: Optional[str] = None):
+ """Returns the channel object for the given channel name
+
+ Args:
+ channel (Optional[str], optional): Name of the channel. Defaults to None.
+ """
channelobj = self._sequence.channel(channel)
if channelobj is None:
return None
return channelobj.frame(self._index)
- def filename(self, channel=None):
+ def filename(self, channel: Optional[str] = None):
+ """Returns the filename for the given channel name and frame index
+
+ Args:
+ channel (Optional[str], optional): Name of the channel. Defaults to None.
+
+ Returns:
+ str: Filename of the frame
+ """
channelobj = self._sequence.channel(channel)
if channelobj is None:
return None
return channelobj.filename(self._index)
- def image(self, channel=None):
+ def image(self, channel: Optional[str] = None) -> np.ndarray:
+ """Returns the image for the given channel name and frame index
+
+ Args:
+ channel (Optional[str], optional): Name of the channel. Defaults to None.
+
+ Returns:
+ np.ndarray: Image object
+ """
channelobj = self._sequence.channel(channel)
if channelobj is None:
return None
return channelobj.frame(self._index)
- def groundtruth(self):
+ def objects(self) -> List[str]:
+ """Returns the list of objects in the frame
+
+ Returns:
+ List[str]: List of object ids
+ """
+ objects = {}
+ for o in self._sequence.objects():
+ region = self._sequence.object(o, self._index)
+ if region is not None:
+ objects[o] = region
+ return objects
+
+ def object(self, id: str) -> Region:
+ """Returns the object region for the given object id and frame index
+
+ Args:
+ id (str): Id of the object
+
+ Returns:
+ Region: Object region
+ """
+ return self._sequence.object(id, self._index)
+
+ def groundtruth(self) -> Region:
+ """Returns the groundtruth region for the frame
+
+ Returns:
+ Region: Groundtruth region
+
+ Raises:
+ DatasetException: If groundtruth is not available
+ """
return self._sequence.groundtruth(self._index)
- def tags(self, index = None):
+ def tags(self) -> List[str]:
+ """Returns the tags for the frame
+
+ Returns:
+ List[str]: List of tags
+ """
return self._sequence.tags(self._index)
- def values(self, index=None):
+ def values(self) -> Mapping[str, float]:
+ """Returns the values for the frame
+
+ Returns:
+ Mapping[str, float]: Mapping of values
+ """
return self._sequence.values(self._index)
class SequenceIterator(object):
-
- def __init__(self, sequence):
+ """Sequence iterator provides an iterator interface for the sequence object"""
+
+ def __init__(self, sequence: "Sequence"):
+ """Base constructor for sequence iterator
+
+ Args:
+ sequence (Sequence): Sequence object
+ """
self._position = 0
self._sequence = sequence
def __iter__(self):
+ """Returns the iterator object
+
+ Returns:
+ SequenceIterator: Sequence iterator object
+ """
return self
- def __next__(self):
+ def __next__(self) -> Frame:
+ """Returns the next frame object in the sequence iterator
+
+ Returns:
+ Frame: Frame object
+ """
if self._position >= len(self._sequence):
raise StopIteration()
index = self._position
@@ -100,8 +244,11 @@ def __next__(self):
return Frame(self._sequence, index)
class InMemoryChannel(Channel):
+ """In-memory channel represents a sequence of images with uniform dimensions. It is
+ used to represent a sequence of images in memory."""
def __init__(self):
+ """Base constructor for in-memory channel"""
super().__init__()
self._images = []
self._width = 0
@@ -109,6 +256,11 @@ def __init__(self):
self._depth = 0
def append(self, image):
+ """Appends an image to the channel
+
+ Args:
+ image (np.ndarray): Image object
+ """
if isinstance(image, Image):
image = np.asarray(image)
@@ -133,40 +285,104 @@ def append(self, image):
self._images.append(image)
@property
- def length(self):
+ def length(self) -> int:
+ """Returns the length of the sequence channel in number of frames
+
+ Returns:
+ int: Length of the sequence channel
+ """
return len(self._images)
def frame(self, index):
- if index < 0 or index >= self.length:
+ """Returns the frame object for the given index in the sequence channel
+
+ Args:
+ index (int): Index of the frame
+
+ Returns:
+ Frame: Frame object
+ """
+ if index < 0 or index >= len(self):
return None
return self._images[index]
@property
def size(self):
+ """Returns the size of the channel in the format (width, height)
+
+ Returns:
+ Tuple[int, int]: Size of the channel
+ """
return self._width, self._height
def filename(self, index):
+ """ Thwows an exception as the sequence is available in memory and not in files. """
raise DatasetException("Sequence is available in memory, image files not available")
class PatternFileListChannel(Channel):
-
- def __init__(self, path, start=1, step=1, end=None):
+ """Sequence channel implementation where each frame is stored in a file and all file names
+ follow a specific pattern.
+ """
+
+ def __init__(self, path, start=1, step=1, end=None, check_files=True):
+ """ Creates a new channel object
+
+ Args:
+ path (str): Path to the sequence
+ start (int, optional): First frame index
+ step (int, optional): Step between frames
+ end (int, optional): Last frame index
+ check_files (bool, optional): Check that files exist
+
+ Raises:
+ DatasetException: If the pattern is invalid
+
+ Returns:
+ PatternFileListChannel: Channel object
+ """
super().__init__()
base, pattern = os.path.split(path)
self._base = base
self._pattern = pattern
- self.__scan(pattern, start, step, end)
+ self.__scan(pattern, start, step, end, check_files=check_files)
@property
- def base(self):
+ def base(self) -> str:
+ """Returns the base path of the sequence
+
+ Returns:
+ str: Base path
+ """
return self._base
@property
def pattern(self):
+ """Returns the pattern of the sequence
+
+ Returns:
+ str: Pattern
+ """
return self._pattern
- def __scan(self, pattern, start, step, end):
+ def __scan(self, pattern, start, step, end, check_files=True):
+ """Scans the sequence directory for files matching the pattern and stores the file names in
+ the internal list. The pattern must contain a single %d placeholder for the frame index. The
+ placeholder must be at the end of the pattern. The pattern may contain a file extension. If
+ the pattern contains no file extension, .jpg is assumed. If end frame is specified, the
+ scanning stops when the end frame is reached. If check_files is True, and end frame is set
+ then files are checked to exist.
+
+ Args:
+ pattern (str): Pattern
+ start (int): First frame index
+ step (int): Step between frames
+ end (int): Last frame index
+ check_files (bool, optional): Check that files exist
+
+ Raises:
+ DatasetException: If the pattern is invalid
+ """
extension = os.path.splitext(pattern)[1]
if not extension in {'.jpg', '.png'}:
@@ -177,10 +393,12 @@ def __scan(self, pattern, start, step, end):
fullpattern = os.path.join(self.base, pattern)
+ assert end is not None or check_files
+
while True:
image_file = os.path.join(fullpattern % i)
- if not os.path.isfile(image_file):
+ if check_files and not os.path.isfile(image_file):
break
self._files.append(os.path.basename(image_file))
i = i + step
@@ -191,234 +409,564 @@ def __scan(self, pattern, start, step, end):
if i <= start:
raise DatasetException("Empty sequence, no frames found.")
- im = cv2.imread(self.filename(0))
- self._width = im.shape[1]
- self._height = im.shape[0]
- self._depth = im.shape[2]
+ if os.path.isfile(self.filename(0)):
+ im = cv2.imread(self.filename(0))
+ self._width = im.shape[1]
+ self._height = im.shape[0]
+ self._depth = im.shape[2]
+ else:
+ self._depth = None
+ self._width = None
+ self._height = None
- @property
- def length(self):
+ def __len__(self) -> int:
+ """Returns the number of frames in the sequence
+
+ Returns:
+ int: Number of frames
+ """
return len(self._files)
- def frame(self, index):
- if index < 0 or index >= self.length:
+ def frame(self, index: int) -> np.ndarray:
+ """Returns the frame at the specified index as a numpy array. The image is loaded using
+ OpenCV and converted to RGB color space if necessary.
+
+ Args:
+ index (int): Frame index
+
+ Returns:
+ np.ndarray: Frame
+
+ Raises:
+ DatasetException: If the index is out of bounds
+
+ """
+ if index < 0 or index >= len(self):
return None
bgr = cv2.imread(self.filename(index))
+
+ # Check if the image is grayscale
+ if len(bgr.shape) == 2:
+ return bgr
+
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
@property
- def size(self):
+ def size(self) -> tuple:
+ """Returns the size of the frames in the sequence as a tuple (width, height)
+
+ Returns:
+ tuple: Size of the frames
+ """
return self._width, self._height
@property
- def width(self):
+ def width(self) -> int:
+ """Returns the width of the frames in the sequence
+
+ Returns:
+ int: Width of the frames
+ """
return self._width
@property
- def height(self):
+ def height(self) -> int:
+ """Returns the height of the frames in the sequence
+
+ Returns:
+ int: Height of the frames
+ """
return self._height
- def filename(self, index):
- if index < 0 or index >= self.length:
+ def filename(self, index) -> str:
+ """Returns the filename of the frame at the specified index
+
+ Args:
+ index (int): Frame index
+
+ Returns:
+ str: Filename
+ """
+ if index < 0 or index >= len(self):
return None
return os.path.join(self.base, self._files[index])
-class FrameList(ABC):
+class FrameList(object):
+ """Abstract base for all sequences, just a list of frame objects
+ """
def __iter__(self):
+ """Returns an iterator over the frames in the sequence
+
+ Returns:
+ SequenceIterator: Iterator
+ """
return SequenceIterator(self)
- @abstractmethod
def __len__(self) -> int:
- pass
-
- @abstractmethod
- def frame(self, index) -> Frame:
- pass
+ """Returns the number of frames in the sequence.
+
+ Returns:
+ int: Number of frames
+ """
+ return NotImplementedError()
+
+ def frame(self, index: int) -> Frame:
+ """Returns the frame at the specified index
+
+ Args:
+ index (int): Frame index
+ """
+ return NotImplementedError()
class Sequence(FrameList):
+ """A sequence is a list of frames (multiple channels) and a list of one or more annotated objects. It also contains
+ additional metadata and per-frame information, such as tags and values.
+ """
- def __init__(self, name: str, dataset: "Dataset" = None):
- self._name = name
- self._dataset = dataset
+ UNKNOWN = 0 # object state is unknown in this frame
+ INVISIBLE = 1 # object is not visible in this frame
- def __len__(self) -> int:
- return self.length
+ def __init__(self, name: str):
+ """Creates a new sequence with the specified name"""
+ self._name = name
@property
def name(self) -> str:
+ """Returns the name of the sequence
+
+ Returns:
+ str: Name
+ """
return self._name
@property
def identifier(self) -> str:
+ """Returns the identifier of the sequence. The identifier is a string that uniquely identifies the sequence in
+ the dataset. The identifier is usually the same as the name, but may be different if the name is not unique.
+
+ Returns:
+ str: Identifier
+ """
return self._name
- @property
- def dataset(self):
- return self._dataset
-
@abstractmethod
def metadata(self, name, default=None):
+ """Returns the value of the specified metadata field. If the field does not exist, the default value is returned
+
+ Args:
+ name (str): Name of the metadata field
+ default (object, optional): Default value
+
+ Returns:
+ object: Value of the metadata field
+ """
pass
@abstractmethod
- def channel(self, channel=None):
+ def channel(self, channel=None) -> Channel:
+ """Returns the channel with the specified name or the default channel if no name is specified
+
+ Args:
+ channel (str, optional): Name of the channel
+
+ Returns:
+ Channel: Channel
+ """
pass
@abstractmethod
- def channels(self):
+ def channels(self) -> Set[str]:
+ """Returns the names of all channels in the sequence
+
+ Returns:
+ set: Names of all channels
+ """
pass
@abstractmethod
- def groundtruth(self, index: int):
+ def objects(self) -> Set[str]:
+ """Returns the names of all objects in the sequence
+
+ Returns:
+ set: Names of all objects
+ """
pass
@abstractmethod
- def tags(self, index=None):
+ def object(self, id, index=None):
+ """Returns the object with the specified name or identifier. If the index is specified, the object is returned
+ only if it is visible in the frame at the specified index.
+
+ Args:
+ id (str): Name or identifier of the object
+ index (int, optional): Frame index
+
+ Returns:
+ Region: Object
+ """
pass
@abstractmethod
- def values(self, index=None):
- pass
+ def groundtruth(self, index: int) -> Region:
+ """Returns the ground truth region for the specified frame index or None if no ground truth is available for the
+ frame or the frame index is out of bounds. This is a legacy method for compatibility with single-object datasets
+ and should not be used in new code.
- @property
- @abstractmethod
- def size(self):
- pass
+ Args:
+ index (int): Frame index
- @property
- @abstractmethod
- def length(self):
+ Returns:
+ Region: Ground truth region
+ """
pass
- def describe(self):
- data = dict(length=self.length, size=self.size)
- return data
-
-class Dataset(ABC):
-
- def __init__(self, path):
- self._path = path
-
- def __len__(self):
- return self.length
-
- @property
- def path(self):
- return self._path
-
- @property
@abstractmethod
- def length(self):
+ def tags(self, index=None) -> List[str]:
+ """Returns the tags for the specified frame index or None if no tags are available for the frame or the frame
+ index is out of bounds.
+
+ Args:
+ index (int, optional): Frame index
+
+ Returns:
+ list: List of tags
+ """
pass
@abstractmethod
- def __getitem__(self, key):
+ def values(self, index=None) -> Mapping[str, Number]:
+ """Returns the values for the specified frame index or None if no values are available for the frame or the frame
+ index is out of bounds.
+
+ Args:
+ index (int, optional): Frame index
+
+ Returns:
+ dict: Dictionary of values"""
pass
+ @property
@abstractmethod
- def __contains__(self, key):
- return False
+ def width(self) -> int:
+ """Returns the width of the frames in the sequence in pixels
- @abstractmethod
- def __iter__(self):
+ Returns:
+ int: Width of the frames
+ """
pass
+ @property
@abstractmethod
- def list(self):
- return []
+ def height(self) -> int:
+ """Returns the height of the frames in the sequence in pixels
- def keys(self):
- return self.list()
-
-class BaseSequence(Sequence):
+ Returns:
+ int: Height of the frames
+ """
+ pass
- def __init__(self, name, dataset=None):
- super().__init__(name, dataset)
- self._metadata = self._read_metadata()
- self._data = None
+ @property
+ def size(self) -> Tuple[int, int]:
+ """Returns the size of the frames in the sequence in pixels as a tuple (width, height)
+
+ Returns:
+ tuple: Size of the frames
+ """
+ return self.width, self.height
- @abstractmethod
- def _read_metadata(self):
- raise NotImplementedError
+ def describe(self):
+ """Returns a dictionary with information about the sequence
+
+ Returns:
+ dict: Dictionary with information
+ """
+ return dict(length=len(self), width=self.width, height=self.height)
+
+class Dataset(object):
+ """Base abstract class for a tracking dataset, a list of image sequences addressable by their names and interatable.
+ """
+
+ def __init__(self, sequences: Mapping[str, Sequence]) -> None:
+ """Creates a new dataset with the specified sequences
+
+ Args:
+ sequences (dict): Dictionary of sequences
+ """
+ self._sequences = sequences
- @abstractmethod
- def _read(self):
- raise NotImplementedError
+ def __len__(self) -> int:
+ """Returns the number of sequences in the dataset
+
+ Returns:
+ int: Number of sequences
+ """
+ return len(self._sequences)
+
+ def __getitem__(self, key: str) -> Sequence:
+ """Returns the sequence with the specified name
+
+ Args:
+ key (str): Sequence name
+
+ Returns:
+ Sequence: Sequence
+ """
+ return self._sequences[key]
+
+ def __contains__(self, key: str) -> bool:
+ """Returns true if the dataset contains a sequence with the specified name
+
+ Args:
+ key (str): Sequence name
+
+ Returns:
+ bool: True if the dataset contains the sequence
+ """
+ return key in self._sequences
+
+ def __iter__(self) -> Iterator[Sequence]:
+ """Returns an iterator over the sequences in the dataset
+
+ Returns:
+ DatasetIterator: Iterator
+ """
+ return iter(self._sequences.values())
+
+ def list(self) -> List[str]:
+ """Returns a list of unique sequence names
+
+ Returns:
+ List[str]: List of sequence names
+ """
+ return list(self._sequences.keys())
+
+ def keys(self) -> List[str]:
+ """Returns a list of unique sequence names
+
+ Returns:
+ List[str]: List of sequence names
+ """
+ return list(self._sequences.keys())
+
+SequenceData = namedtuple("SequenceData", ["channels", "objects", "tags", "values", "length"])
+
+from vot import config
+
+@cached(LRUCache(maxsize=config.sequence_cache_size))
+def _cached_loader(sequence):
+ """Loads the sequence data from the sequence object. This function serves as a cache for the sequence data and is only
+ called if the sequence data is not already loaded. The cache is implemented as a LRU cache with a maximum size
+ specified in the configuration."""
+ return sequence._loader(sequence._metadata)
+
+class BasedSequence(Sequence):
+ """This class implements the caching of the sequence data. The sequence data is loaded only when it is needed.
+ """
+
+ def __init__(self, name: str, loader: callable, metadata: dict = None):
+ """Initializes the sequence
+
+ Args:
+ name (str): Sequence name
+ loader (callable): Loader function that takes the metadata as an argument and returns a SequenceData object
+ metadata (dict, optional): Sequence metadata. Defaults to None.
+
+ Raises:
+ ValueError: If the loader is not callable
+ """
+ super().__init__(name)
+ self._loader = loader
+ self._metadata = metadata if metadata is not None else {}
def __preload(self):
- if self._data is None:
- self._data = self._read()
-
- def metadata(self, name, default=None):
+ """Loads the sequence data if needed. This is an internal function that should not be called directly.
+ It calles a cached loader function that is implemented as a LRU cache with a configurable maximum size."""
+ return _cached_loader(self)
+
+ def metadata(self, name: str, default: object=None) -> object:
+ """Returns the metadata value with the specified name.
+
+ Args:
+ name (str): Metadata name
+ default (object, optional): Default value. Defaults to None.
+
+ Returns:
+ object: Metadata value"""
return self._metadata.get(name, default)
- def channels(self):
- self.__preload()
- return self._data[0]
-
- def channel(self, channel=None):
- self.__preload()
+ def channels(self) -> List[str]:
+ """Returns a list of channel names in the sequence
+
+ Returns:
+ List[str]: List of channel names
+ """
+ data = self.__preload()
+ return data.channels.keys()
+
+ def channel(self, channel: str=None) -> Channel:
+ """Returns the channel with the specified name. If the channel name is not specified, the default channel is returned.
+
+ Args:
+ channel (str, optional): Channel name. Defaults to None.
+
+ Returns:
+ Channel: Channel
+ """
+ data = self.__preload()
if channel is None:
channel = self.metadata("channel.default")
- return self._data[0].get(channel, None)
+ return data.channels.get(channel, None)
def frame(self, index):
+ """Returns the frame with the specified index in the sequence as a Frame object
+
+ Args:
+ index (int): Frame index
+
+ Returns:
+ Frame: Frame
+ """
return Frame(self, index)
- def groundtruth(self, index=None):
- self.__preload()
+ def objects(self) -> List[str]:
+ """Returns a list of object ids in the sequence
+
+ Returns:
+ List[str]: List of object ids
+ """
+ data = self.__preload()
+ return data.objects.keys()
+
+ def object(self, id, index=None) -> Region:
+ """Returns the object with the specified id. If the index is specified, the object is returned as a Region object.
+
+ Args:
+ id (str): Object id
+ index (int, optional): Frame index. Defaults to None.
+
+ Returns:
+ Region: Object region
+ """
+ data = self.__preload()
+ obj = data.objects.get(id, None)
if index is None:
- return self._data[1]
- return self._data[1][index]
+ return obj
+ if obj is None:
+ return None
+ return obj[index]
- def tags(self, index=None):
- self.__preload()
+ def groundtruth(self, index=None):
+ """Returns the groundtruth object. If the index is specified, the object is returned as a Region object. If the
+ sequence contains more than one object, an exception is raised.
+
+ Args:
+ index (int, optional): Frame index. Defaults to None.
+
+ Returns:
+ Region: Groundtruth region
+ """
+ data = self.__preload()
+ if len(self.objects()) != 1:
+ raise DatasetException("More than one object in sequence")
+
+ id = next(iter(data.objects))
+ return self.object(id, index)
+
+ def tags(self, index: int = None) -> List[str]:
+ """Returns a list of tags in the sequence. If the index is specified, only the tags that are present in the frame
+ with the specified index are returned.
+
+ Args:
+ index (int, optional): Frame index. Defaults to None.
+
+ Returns:
+ List[str]: List of tags
+ """
+ data = self.__preload()
if index is None:
- return self._data[2].keys()
- return [t for t, sq in self._data[2].items() if sq[index]]
-
- def values(self, index=None):
- self.__preload()
+ return data.tags.keys()
+ return [t for t, sq in data.tags.items() if sq[index]]
+
+ def values(self, index: int = None) -> List[float]:
+ """Returns a list of values in the sequence. If the index is specified, only the values that are present in the frame
+ with the specified index are returned.
+
+ Args:
+ index (int, optional): Frame index. Defaults to None.
+
+ Returns:
+ List[float]: List of values
+ """
+ data = self.__preload()
if index is None:
- return self._data[3].keys()
- return {v: sq[index] for v, sq in self._data[3].items()}
+ return data.values.keys()
+ return {v: sq[index] for v, sq in data.values.items()}
@property
def size(self):
- return self.channel().size
+ """Returns the sequence size as a tuple (width, height)
+
+ Returns:
+ tuple: Sequence size
+
+ """
+ return self.width, self.height
@property
def width(self):
- return self.channel().width
+ """Returns the sequence width"""
+ return self._metadata["width"]
@property
def height(self):
- return self.channel().height
+ """Returns the sequence height"""
+ return self._metadata["height"]
- @property
- def length(self):
- self.__preload()
- return len(self._data[1])
-
-class InMemorySequence(BaseSequence):
+ def __len__(self):
+ """Returns the sequence length in frames
+
+ Returns:
+ int: Sequence length
+ """
+ data = self.__preload()
+ return data.length
+
+class InMemorySequence(Sequence):
+ """ An in-memory sequence that can be used to construct a sequence programmatically and store it do disk.
+ Used mainly for testing and debugging.
+
+ Only single object sequences are supported at the moment.
+ """
def __init__(self, name, channels):
+ """Creates a new in-memory sequence.
+
+ Args:
+ name (str): Sequence name
+ channels (list): List of channel names
+
+ Raises:
+ DatasetException: If images are not provided for all channels
+ """
super().__init__(name, None)
self._channels = {c: InMemoryChannel() for c in channels}
self._tags = {}
self._values = {}
self._groundtruth = []
- def _read_metadata(self):
- return dict()
-
- def _read(self):
- return self._channels, self._groundtruth, self._tags, self._values
-
def append(self, images: dict, region: "Region", tags: list = None, values: dict = None):
+ """Appends a new frame to the sequence. The frame is specified by a dictionary of images, a region and optional
+ tags and values.
+
+ Args:
+ images (dict): Dictionary of images
+ region (Region): Region
+ tags (list, optional): List of tags
+ values (dict, optional): Dictionary of values
+ """
if not set(images.keys()).issuperset(self._channels.keys()):
raise DatasetException("Images not provided for all channels")
@@ -432,7 +980,7 @@ def append(self, images: dict, region: "Region", tags: list = None, values: dict
tags = set(tags)
for tag in tags:
if not tag in self._tags:
- self._tags[tag] = [False] * self.length
+ self._tags[tag] = [False] * len(self)
self._tags[tag].append(True)
for tag in set(self._tags.keys()).difference(tags):
self._tags[tag].append(False)
@@ -441,51 +989,290 @@ def append(self, images: dict, region: "Region", tags: list = None, values: dict
values = dict()
for name, value in values.items():
if not name in self._values:
- self._values[name] = [0] * self.length
+ self._values[name] = [0] * len(self)
self._values[tag].append(value)
for name in set(self._values.keys()).difference(values.keys()):
self._values[name].append(0)
self._groundtruth.append(region)
+ def channel(self, channel : str) -> "Channel":
+ """Returns the specified channel object.
+
+ Args:
+ channel (str): Channel name
+
+ Returns:
+ Channel: Channel object
+ """
+ return self._channels.get(channel, None)
+
+ def channels(self) -> List[str]:
+ """Returns a list of channel names.
+
+ Returns:
+ List[str]: List of channel names
+
+ """
+ return self._channels.keys()
+
+ def frame(self, index : int) -> "Frame":
+ """Returns the specified frame. The frame is returned as a Frame object.
+
+ Args:
+ index (int): Frame index
+
+ Returns:
+ Frame: Frame object
+ """
+ return Frame(self, index)
+
+ def groundtruth(self, index: int = None) -> "Region":
+ """Returns the groundtruth object. If the index is specified, the object is returned as a Region object. If the
+ sequence contains more than one object, an exception is raised. If the index is not specified, the groundtruth
+ object is returned as a Region object. If the sequence contains more than one object, an exception is raised.
+
+ Args:
+ index (int, optional): Frame index. Defaults to None.
+
+ Returns:
+ Region: Groundtruth object
+ """
+ if index is None:
+ return self._groundtruth
+ return self._groundtruth[index]
+
+ def object(self, id: str, index: int = None) -> "Region":
+ """Returns the specified object. If the index is specified, the object is returned as a Region object. If the
+ sequence contains more than one object, an exception is raised. If the index is not specified, the groundtruth
+ object is returned as a Region object. If the sequence contains more than one object, an exception is raised.
+
+ Args:
+ id (str): Object id
+ index (int, optional): Frame index. Defaults to None.
+
+ Returns:
+ Region: Object
+ """
+ if id != "object":
+ return None
+
+ if index is None:
+ return self._groundtruth
+ return self._groundtruth[index]
+
+ def objects(self, index: str = None) -> List[str]:
+ """Returns a list of object ids. If the index is specified, only the objects that are present in the frame with
+ the specified index are returned.
+
+ Since only single object sequences are supported, the only object id that is returned is "object".
+
+ Args:
+ index (int, optional): Frame index. Defaults to None."""
+ return ["object"]
+
+ def tags(self, index=None):
+ """Returns a list of tags in the sequence. If the index is specified, only the tags that are present in the frame
+ with the specified index are returned.
+
+ Args:
+ index (int, optional): Frame index. Defaults to None.
+
+ Returns:
+ List[str]: List of tags
+ """
+ if index is None:
+ return self._tags.keys()
+ return [t for t, sq in self._tags.items() if sq[index]]
+
+ def values(self, index=None):
+ """Returns a list of values in the sequence. If the index is specified, only the values that are present in the
+ frame with the specified index are returned.
+
+ Args:
+ index (int, optional): Frame index. Defaults to None.
+
+ Returns:
+ List[str]: List of values
+ """
+ if index is None:
+ return self._values.keys()
+ return {v: sq[index] for v, sq in self._values.items()}
+
+ def __len__(self):
+ """Returns the sequence length in frames
+
+ Returns:
+ int: Sequence length
+ """
+ return len(self._groundtruth)
+
+ @property
+ def width(self) -> int:
+ """Returns the sequence width
+
+ Returns:
+ int: Sequence width
+ """
+ return self.channel().width
+
+ @property
+ def height(self) -> int:
+ """Returns the sequence height
+
+ Returns:
+ int: Sequence height
+ """
+ return self.channel().height
+
+ @property
+ def size(self) -> tuple:
+ """Returns the sequence size as a tuple (width, height)
+
+ Returns:
+ tuple: Sequence size
+ """
+ return self.channel().size
+
+ @property
+ def channels(self) -> list:
+ """Returns a list of channel names
+
+ Returns:
+ list: List of channel names
+ """
+ return self._channels.keys()
+
+def download_bundle(url: str, path: str = "."):
+ """Downloads a dataset bundle as a ZIP file and decompresses it.
+
+ Args:
+ url (str): Source bundle URL
+ path (str, optional): Destination directory. Defaults to ".".
+
+ Raises:
+ DatasetException: If the bundle cannot be downloaded or is not supported.
+ """
+
+ from vot.utilities.net import download_uncompress, NetworkException
+
+ if not url.endswith(".zip"):
+ raise DatasetException("Unknown bundle format")
+
+ logger.info('Downloading sequence bundle from "%s". This may take a while ...', url)
+
+ try:
+ download_uncompress(url, path)
+ except NetworkException as e:
+ raise DatasetException("Unable do download dataset bundle, Please try to download the bundle manually from {} and uncompress it to {}'".format(url, path))
+ except IOError as e:
+ raise DatasetException("Unable to extract dataset bundle, is the target directory writable and do you have enough space?")
+
+def download_dataset(url: str, path: str):
+ """Downloads a dataset from a given url or an alias.
+
+ Args:
+ url (str): URL to the data bundle or metadata description file
+ path (str): Destination directory
+
+ Raises:
+ DatasetException: If the dataset is not found or a network error occured
+ """
+ from urllib.parse import urlsplit
+
+ try:
+ res = urlsplit(url)
+
+ if res.scheme in ["http", "https"]:
+ if res.path.endswith(".json"):
+ from .common import download_dataset_meta
+ download_dataset_meta(url, path)
+ return
+ else:
+ download_bundle(url, path)
+ return
+
+ raise DatasetException("Unknown dataset domain: {}".format(res.scheme))
+
+ except ValueError:
+
+ if url in dataset_downloader:
+ dataset_downloader[url](path)
+ return
+
+ raise DatasetException("Illegal dataset identifier: {}".format(url))
+
+
+def load_dataset(path: str) -> Dataset:
+ """Loads a dataset from a local directory
+
+ Args:
+ path (str): The path to the local dataset data
+
+ Raises:
+ DatasetException: When a folder does not exist or the format is not recognized.
+
+ Returns:
+ Dataset: Dataset object
+ """
+
+ from collections import OrderedDict
+
+ names = []
+
+ if os.path.isfile(path):
+ with open(os.path.join(path), 'r') as fd:
+ names = fd.readlines()
+
+ if os.path.isdir(path):
+ if os.path.isfile(os.path.join(path, "list.txt")):
+ with open(os.path.join(path, "list.txt"), 'r') as fd:
+ names = fd.readlines()
+
+ if len(names) == 0:
+ raise DatasetException("Dataset directory does not contain a list.txt file")
+
+ sequences = OrderedDict()
+
+ logger.debug("Loading sequences...")
+
+ for name in names:
+ root = os.path.join(path, name.strip())
+ sequences[name.strip()] = load_sequence(root)
+
+ logger.debug("Found %d sequences in dataset" % len(names))
-from .vot import VOTDataset, VOTSequence
-from .got10k import GOT10kSequence, GOT10kDataset
+ return Dataset(sequences)
-def download_dataset(identifier: str, path: str):
+def load_sequence(path: str) -> Sequence:
+ """Loads a sequence from a given path (directory), tries to guess the format of the sequence.
- split = identifier.find(":")
- domain = "vot"
+ Args:
+ path (str): The path to the local sequence data
- if split > 0:
- domain = identifier[0:split].lower()
- identifier = identifier[split+1:]
+ Raises:
+ DatasetException: If an loading error occures, unsupported format or other issues.
- if domain == "vot":
- from .vot import download_dataset
- download_dataset(identifier, path)
- elif domain == "otb":
- from .otb import download_dataset
- download_dataset(path, identifier == "otb50")
- else:
- raise DatasetException("Unknown dataset domain: {}".format(domain))
+ Returns:
+ Sequence: Sequence object
+ """
-def load_dataset(path: str):
+ for _, loader in sequence_reader.items():
+ logger.debug("Trying to load sequence with {}.{}".format(loader.__module__, loader.__name__))
+ sequence = loader(path)
+ if sequence is not None:
+ return sequence
- if not os.path.isdir(path):
- raise DatasetException("Dataset directory does not exist")
+ raise DatasetException("Unable to load sequence, unknown format or unsupported sequence: {}".format(path))
- if VOTDataset.check(path):
- return VOTDataset(path)
- elif GOT10kDataset.check(path):
- return GOT10kDataset(path)
- else:
- raise DatasetException("Unsupported dataset type")
+import importlib
+for module in [".common", ".otb", ".got10k", ".trackingnet"]:
+ importlib.import_module(module, package="vot.dataset")
-def load_sequence(path: str):
- if VOTSequence.check(path):
- return VOTSequence(path)
- elif GOT10kSequence.check(path):
- return GOT10kSequence(path)
- else:
- raise DatasetException("Unsupported sequence type")
\ No newline at end of file
+# Legacy reader is registered last, otherwise it will cause problems
+# TODO: implement explicit ordering of readers
+@sequence_reader.register("legacy")
+def read_legacy_sequence(path: str) -> Sequence:
+ """Wrapper around the legacy sequence reader."""
+ from vot.dataset.common import read_sequence_legacy
+ return read_sequence_legacy(path)
\ No newline at end of file
diff --git a/vot/dataset/common.py b/vot/dataset/common.py
new file mode 100644
index 0000000..439f647
--- /dev/null
+++ b/vot/dataset/common.py
@@ -0,0 +1,323 @@
+"""This module contains functionality for reading sequences from the storage using VOT compatible format."""
+
+import os
+import glob
+import logging
+
+import six
+
+import cv2
+
+from vot.dataset import Dataset, DatasetException, Sequence, BasedSequence, PatternFileListChannel, SequenceData
+from vot.region.io import write_trajectory, read_trajectory
+from vot.region import Special
+from vot.utilities import Progress, localize_path, read_properties, write_properties
+
+logger = logging.getLogger("vot")
+
+def convert_int(value: str) -> int:
+ """Converts the given value to an integer. If the value is not a valid integer, None is returned.
+
+ Args:
+ value (str): The value to convert.
+
+ Returns:
+ int: The converted value or None if the value is not a valid integer.
+ """
+ try:
+ if value is None:
+ return None
+ return int(value)
+ except ValueError:
+ return None
+
+def _load_channel(source, length=None):
+ """Loads a channel from the given source.
+
+ Args:
+ source (str): The source to load the channel from.
+ length (int): The length of the channel. If not specified, the channel is loaded from a pattern file list.
+
+ Returns:
+ Channel: The loaded channel.
+ """
+
+ extension = os.path.splitext(source)[1]
+
+ if extension == '':
+ source = os.path.join(source, '%08d.jpg')
+ return PatternFileListChannel(source, end=length, check_files=length is None)
+
+def _read_data(metadata):
+ """Reads data from the given metadata.
+
+ Args:
+ metadata (dict): The metadata to read data from.
+
+ Returns:
+ dict: The data read from the metadata.
+ """
+
+ channels = {}
+ tags = {}
+ values = {}
+ length = metadata["length"]
+
+ root = metadata["root"]
+
+ for c in ["color", "depth", "ir"]:
+ channel_path = metadata.get("channels.%s" % c, None)
+ if not channel_path is None:
+ channels[c] = _load_channel(os.path.join(root, localize_path(channel_path)), length)
+
+ # Load default channel if no explicit channel data available
+ if len(channels) == 0:
+ channels["color"] = _load_channel(os.path.join(root, "color", "%08d.jpg"), length=length)
+ else:
+ metadata["channel.default"] = next(iter(channels.keys()))
+
+ if metadata.get("width", None) is None or metadata.get("height", None) is None:
+ metadata["width"], metadata["height"] = six.next(six.itervalues(channels)).size
+
+ lengths = [len(t) for t in channels.values()]
+ assert all([x == lengths[0] for x in lengths]), "Sequence channels have different lengths"
+ length = lengths[0]
+
+ objectsfiles = glob.glob(os.path.join(root, 'groundtruth_*.txt'))
+ objects = {}
+ if len(objectsfiles) > 0:
+ for objectfile in objectsfiles:
+ groundtruth = read_trajectory(os.path.join(objectfile))
+ if len(groundtruth) < length: groundtruth += [Special(Sequence.UNKNOWN)] * (length - len(groundtruth))
+ objectid = os.path.basename(objectfile)[12:-4]
+ objects[objectid] = groundtruth
+ else:
+ groundtruth_file = os.path.join(root, metadata.get("groundtruth", "groundtruth.txt"))
+ groundtruth = read_trajectory(groundtruth_file)
+ if len(groundtruth) < length: groundtruth += [Special(Sequence.UNKNOWN)] * (length - len(groundtruth))
+ objects["object"] = groundtruth
+
+ metadata["length"] = length
+
+ tagfiles = glob.glob(os.path.join(root, '*.tag')) + glob.glob(os.path.join(root, '*.label'))
+
+ for tagfile in tagfiles:
+ with open(tagfile, 'r') as filehandle:
+ tagname = os.path.splitext(os.path.basename(tagfile))[0]
+ tag = [line.strip() == "1" for line in filehandle.readlines()]
+ while not len(tag) >= length:
+ tag.append(False)
+ tags[tagname] = tag
+
+ valuefiles = glob.glob(os.path.join(root, '*.value'))
+
+ for valuefile in valuefiles:
+ with open(valuefile, 'r') as filehandle:
+ valuename = os.path.splitext(os.path.basename(valuefile))[0]
+ value = [float(line.strip()) for line in filehandle.readlines()]
+ while not len(value) >= length:
+ value.append(0.0)
+ values[valuename] = value
+
+ for name, tag in tags.items():
+ if not len(tag) == length:
+ tag_tmp = length * [False]
+ tag_tmp[:len(tag)] = tag
+ tag = tag_tmp
+
+ for name, value in values.items():
+ if not len(value) == length:
+ raise DatasetException("Length mismatch for value %s" % name)
+
+ return SequenceData(channels, objects, tags, values, length)
+
+def _read_metadata(path):
+ """Reads metadata from the given path. The metadata is read from the sequence file in the given path.
+
+ Args:
+ path (str): The path to read metadata from.
+
+ Returns:
+ dict: The metadata read from the given path.
+ """
+ metadata = dict(fps=30, format="default")
+ metadata["channel.default"] = "color"
+
+ metadata_file = os.path.join(path, 'sequence')
+ metadata.update(read_properties(metadata_file))
+
+ metadata["height"] = convert_int(metadata.get("height", None))
+ metadata["width"] = convert_int(metadata.get("width", None))
+ metadata["length"] = convert_int(metadata.get("length", None))
+ metadata["fps"] = convert_int(metadata.get("fps", None))
+
+ metadata["root"] = path
+
+ return metadata
+
+from vot.dataset import sequence_reader
+
+@sequence_reader.register("default")
+def read_sequence(path):
+ """Reads a sequence from the given path.
+
+ Args:
+ path (str): The path to read the sequence from.
+
+ Returns:
+ Sequence: The sequence read from the given path.
+ """
+ if not os.path.isfile(os.path.join(path, "sequence")):
+ return None
+
+ return BasedSequence(os.path.basename(path), _read_data, _read_metadata(path))
+
+def read_sequence_legacy(path):
+ """Reads a sequence from the given path.
+
+ Args:
+ path (str): The path to read the sequence from.
+
+ Returns:
+ Sequence: The sequence read from the given path.
+ """
+ if not os.path.isfile(os.path.join(path, "groundtruth.txt")):
+ return None
+
+ metadata = dict(fps=30, format="default")
+ metadata["channel.default"] = "color"
+ metadata["channel.color"] = "%08d.jpg"
+
+ return BasedSequence(os.path.basename(path), _read_data, metadata=metadata)
+
+def download_dataset_meta(url: str, path: str) -> None:
+ """Downloads the metadata of a dataset from a given URL and stores it in the given path.
+
+ Args:
+ url (str): The URL to download the metadata from.
+ path (str): The path to store the metadata in.
+
+ """
+ from vot.utilities.net import download_uncompress, download_json, get_base_url, join_url, NetworkException
+ from vot.utilities import format_size
+
+ meta = download_json(url)
+
+ total_size = 0
+ for sequence in meta["sequences"]:
+ total_size += sequence["annotations"]["uncompressed"]
+ for channel in sequence["channels"].values():
+ total_size += channel["uncompressed"]
+
+ logger.info('Downloading sequence dataset "%s" with %s sequences (total %s).', meta["name"], len(meta["sequences"]), format_size(total_size))
+
+ base_url = get_base_url(url) + "/"
+
+ failed = []
+
+ with Progress("Downloading", len(meta["sequences"])) as progress:
+ for sequence in meta["sequences"]:
+ sequence_directory = os.path.join(path, sequence["name"])
+ os.makedirs(sequence_directory, exist_ok=True)
+
+ if os.path.isfile(os.path.join(sequence_directory, "sequence")):
+ refdata = read_properties(os.path.join(sequence_directory, "sequence"))
+ if "uid" in refdata and refdata["uid"] == sequence["annotations"]["uid"]:
+ logger.info('Sequence "%s" already downloaded.', sequence["name"])
+ progress.relative(1)
+ continue
+
+ data = {'name': sequence["name"], 'fps': sequence["fps"], 'format': 'default'}
+
+ annotations_url = join_url(base_url, sequence["annotations"]["url"])
+
+ data["uid"] = sequence["annotations"]["uid"]
+
+ try:
+ download_uncompress(annotations_url, sequence_directory)
+ except NetworkException as e:
+ logger.exception(e)
+ failed.append(sequence["name"])
+ continue
+ except IOError as e:
+ logger.exception(e)
+ failed.append(sequence["name"])
+ continue
+
+ failure = False
+
+ for cname, channel in sequence["channels"].items():
+ channel_directory = os.path.join(sequence_directory, cname)
+ os.makedirs(channel_directory, exist_ok=True)
+
+ channel_url = join_url(base_url, channel["url"])
+
+ try:
+ download_uncompress(channel_url, channel_directory)
+ except NetworkException as e:
+ logger.exception(e)
+ failed.append(sequence["name"])
+ failure = False
+ except IOError as e:
+ logger.exception(e)
+ failed.append(sequence["name"])
+ failure = False
+
+ if "pattern" in channel:
+ data["channels." + cname] = cname + os.path.sep + channel["pattern"]
+ else:
+ data["channels." + cname] = cname + os.path.sep
+
+ if failure:
+ continue
+
+ write_properties(os.path.join(sequence_directory, 'sequence'), data)
+ progress.relative(1)
+
+ if len(failed) > 0:
+ logger.error('Failed to download %d sequences.', len(failed))
+ logger.error('Failed sequences: %s', ', '.join(failed))
+ else:
+ logger.info('Successfully downloaded all sequences.')
+ with open(os.path.join(path, "list.txt"), "w") as fp:
+ for sequence in meta["sequences"]:
+ fp.write('{}\n'.format(sequence["name"]))
+
+def write_sequence(directory: str, sequence: Sequence):
+ """Writes a sequence to a directory. The sequence is written as a set of images in a directory structure
+ corresponding to the channel names. The sequence metadata is written to a file called sequence in the root
+ directory.
+
+ Args:
+ directory (str): The directory to write the sequence to.
+ sequence (Sequence): The sequence to write.
+ """
+
+ channels = sequence.channels()
+
+ metadata = dict()
+ metadata["channel.default"] = sequence.metadata("channel.default", "color")
+ metadata["fps"] = sequence.metadata("fps", "30")
+
+ for channel in channels:
+ cdir = os.path.join(directory, channel)
+ os.makedirs(cdir, exist_ok=True)
+
+ metadata["channels.%s" % channel] = os.path.join(channel, "%08d.jpg")
+
+ for i in range(len(sequence)):
+ frame = sequence.frame(i).channel(channel)
+ cv2.imwrite(os.path.join(cdir, "%08d.jpg" % (i + 1)), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
+
+ for tag in sequence.tags():
+ data = "\n".join(["1" if tag in sequence.tags(i) else "0" for i in range(len(sequence))])
+ with open(os.path.join(directory, "%s.tag" % tag), "w") as fp:
+ fp.write(data)
+
+ for value in sequence.values():
+ data = "\n".join([ str(sequence.values(i).get(value, "")) for i in range(len(sequence))])
+ with open(os.path.join(directory, "%s.value" % value), "w") as fp:
+ fp.write(data)
+
+ write_trajectory(os.path.join(directory, "groundtruth.txt"), [f.groundtruth() for f in sequence])
+ write_properties(os.path.join(directory, "sequence"), metadata)
diff --git a/vot/dataset/dummy.py b/vot/dataset/dummy.py
index 28fa57d..bb77261 100644
--- a/vot/dataset/dummy.py
+++ b/vot/dataset/dummy.py
@@ -1,73 +1,97 @@
+""" Dummy sequences for testing purposes."""
import os
import math
import tempfile
-from vot.dataset import VOTSequence
-from vot.region import Rectangle, write_file
+from vot.dataset import BasedSequence
+from vot.region import Rectangle
+from vot.region.io import write_trajectory
from vot.utilities import write_properties
from PIL import Image
import numpy as np
-class DummySequence(VOTSequence):
+def _generate(base, length, size, objects):
+ """Generate a new dummy sequence.
+
+ Args:
+ base (str): The base directory for the sequence.
+ length (int): The length of the sequence.
+ size (tuple): The size of the sequence.
+ objects (int): The number of objects in the sequence.
+ """
- def __init__(self, length=100, size=(640, 480)):
- base = os.path.join(tempfile.gettempdir(), "vot_dummy_%d_%d_%d" % (length, size[0], size[1]))
- if not os.path.isdir(base) or not os.path.isfile(os.path.join(base, "groundtruth.txt")):
- DummySequence._generate(base, length, size)
- super().__init__(base, None)
+ background_color = Image.fromarray(np.random.normal(15, 5, (size[1], size[0], 3)).astype(np.uint8))
+ background_depth = Image.fromarray(np.ones((size[1], size[0]), dtype=np.uint8) * 200)
+ background_ir = Image.fromarray(np.zeros((size[1], size[0]), dtype=np.uint8))
- @staticmethod
- def _generate(base, length, size):
+ template = Image.open(os.path.join(os.path.dirname(__file__), "cow.png"))
- background_color = Image.fromarray(np.random.normal(15, 5, (size[1], size[0], 3)).astype(np.uint8))
- background_depth = Image.fromarray(np.ones((size[1], size[0]), dtype=np.uint8) * 200)
- background_ir = Image.fromarray(np.zeros((size[1], size[0]), dtype=np.uint8))
+ dir_color = os.path.join(base, "color")
+ dir_depth = os.path.join(base, "depth")
+ dir_ir = os.path.join(base, "ir")
- template = Image.open(os.path.join(os.path.dirname(__file__), "cow.png"))
+ os.makedirs(dir_color, exist_ok=True)
+ os.makedirs(dir_depth, exist_ok=True)
+ os.makedirs(dir_ir, exist_ok=True)
- dir_color = os.path.join(base, "color")
- dir_depth = os.path.join(base, "depth")
- dir_ir = os.path.join(base, "ir")
+ path_color = os.path.join(dir_color, "%08d.jpg")
+ path_depth = os.path.join(dir_depth, "%08d.png")
+ path_ir = os.path.join(dir_ir, "%08d.png")
- os.makedirs(dir_color, exist_ok=True)
- os.makedirs(dir_depth, exist_ok=True)
- os.makedirs(dir_ir, exist_ok=True)
+ groundtruth = {i : [] for i in range(objects)}
- path_color = os.path.join(dir_color, "%08d.jpg")
- path_depth = os.path.join(dir_depth, "%08d.png")
- path_ir = os.path.join(dir_ir, "%08d.png")
+ center_x = size[0] / 2
+ center_y = size[1] / 2
- groundtruth = []
+ radius = min(center_x - template.size[0], center_y - template.size[1])
- center_x = size[0] / 2
- center_y = size[1] / 2
+ speed = (math.pi * 2) / length
+ offset = (math.pi * 2) / objects
- radius = min(center_x - template.size[0], center_y - template.size[1])
+ for i in range(length):
+ frame_color = background_color.copy()
+ frame_depth = background_depth.copy()
+ frame_ir = background_ir.copy()
- speed = (math.pi * 2) / length
+ for o in range(objects):
- for i in range(length):
- frame_color = background_color.copy()
- frame_depth = background_depth.copy()
- frame_ir = background_ir.copy()
-
- x = int(center_x + math.cos(i * speed) * radius - template.size[0] / 2)
- y = int(center_y + math.sin(i * speed) * radius - template.size[1] / 2)
+ x = int(center_x + math.cos(i * speed + offset * o) * radius - template.size[0] / 2)
+ y = int(center_y + math.sin(i * speed + offset * o) * radius - template.size[1] / 2)
frame_color.paste(template, (x, y), template)
frame_depth.paste(10, (x, y), template)
frame_ir.paste(240, (x, y), template)
- frame_color.save(path_color % (i + 1))
- frame_depth.save(path_depth % (i + 1))
- frame_ir.save(path_ir % (i + 1))
-
- groundtruth.append(Rectangle(x, y, template.size[0], template.size[1]))
-
- write_file(os.path.join(base, "groundtruth.txt"), groundtruth)
- metadata = {"name": "dummy", "fps" : 30, "format" : "dummy",
- "channel.default": "color"}
- write_properties(os.path.join(base, "sequence"), metadata)
-
+ groundtruth[o].append(Rectangle(x, y, template.size[0], template.size[1]))
+
+ frame_color.save(path_color % (i + 1))
+ frame_depth.save(path_depth % (i + 1))
+ frame_ir.save(path_ir % (i + 1))
+
+ if objects == 1:
+ write_trajectory(os.path.join(base, "groundtruth.txt"), groundtruth[0])
+ else:
+ for i, g in groundtruth.items():
+ write_trajectory(os.path.join(base, "groundtruth_%03d.txt" % i), g)
+
+ metadata = {"name": "dummy", "fps" : 30, "format" : "dummy",
+ "channel.default": "color"}
+ write_properties(os.path.join(base, "sequence"), metadata)
+
+def generate_dummy(length=100, size=(640, 480), objects=1):
+ """Create a new dummy sequence.
+
+ Args:
+ length (int, optional): The length of the sequence. Defaults to 100.
+ size (tuple, optional): The size of the sequence. Defaults to (640, 480).
+ objects (int, optional): The number of objects in the sequence. Defaults to 1.
+ """
+ from vot.dataset import load_sequence
+
+ base = os.path.join(tempfile.gettempdir(), "vot_dummy_%d_%d_%d_%d" % (length, size[0], size[1], objects))
+ if not os.path.isdir(base) or not os.path.isfile(os.path.join(base, "sequence")):
+ _generate(base, length, size, objects)
+
+ return load_sequence(base)
diff --git a/vot/dataset/got10k.py b/vot/dataset/got10k.py
index 48f6a7b..e354bd8 100644
--- a/vot/dataset/got10k.py
+++ b/vot/dataset/got10k.py
@@ -1,153 +1,128 @@
+""" GOT-10k dataset adapter module. The format of GOT-10k dataset is very similar to a subset of VOT, so there
+is a lot of code duplication."""
import os
import glob
-import logging
-from collections import OrderedDict
import configparser
import six
-from vot.dataset import Dataset, DatasetException, BaseSequence, PatternFileListChannel
-from vot.region import parse
-from vot.utilities import Progress
+from vot import get_logger
+from vot.dataset import DatasetException, BasedSequence, \
+ PatternFileListChannel, SequenceData, Sequence
+from vot.region import Special
+from vot.region.io import read_trajectory
-logger = logging.getLogger("vot")
+logger = get_logger()
def load_channel(source):
-
+ """ Load channel from the given source.
+
+ Args:
+ source (str): Path to the source. If the source is a directory, it is
+ assumed to be a pattern file list. If the source is a file, it is
+ assumed to be a video file.
+
+ Returns:
+ Channel: Channel object.
+ """
extension = os.path.splitext(source)[1]
if extension == '':
source = os.path.join(source, '%08d.jpg')
return PatternFileListChannel(source)
-class GOT10kSequence(BaseSequence):
-
- def __init__(self, base, name=None, dataset=None):
- self._base = base
- if name is None:
- name = os.path.basename(base)
- super().__init__(name, dataset)
-
- @staticmethod
- def check(path: str):
- return os.path.isfile(os.path.join(path, 'groundtruth.txt')) and not os.path.isfile(os.path.join(path, 'sequence'))
-
- def _read_metadata(self):
- metadata = dict(fps=30, format="default")
-
- if os.path.isfile(os.path.join(self._base, 'meta_info.ini')):
- config = configparser.ConfigParser()
- config.read(os.path.join(self._base, 'meta_info.ini'))
- metadata.update(config["METAINFO"])
- metadata["fps"] = int(metadata["anno_fps"][:-3])
-
- metadata["channel.default"] = "color"
-
- return metadata
-
- def _read(self):
-
- channels = {}
- tags = {}
- values = {}
- groundtruth = []
-
- channels["color"] = load_channel(os.path.join(self._base, "%08d.jpg"))
- self._metadata["channel.default"] = "color"
- self._metadata["width"], self._metadata["height"] = six.next(six.itervalues(channels)).size
-
- groundtruth_file = os.path.join(self._base, self.metadata("groundtruth", "groundtruth.txt"))
-
- with open(groundtruth_file, 'r') as filehandle:
- for region in filehandle.readlines():
- groundtruth.append(parse(region))
-
- self._metadata["length"] = len(groundtruth)
- tagfiles = glob.glob(os.path.join(self._base, '*.label'))
+def _read_data(metadata):
+ """ Read data from the given metadata.
+
+ Args:
+ metadata (dict): Metadata dictionary.
+ """
+ channels = {}
+ tags = {}
+ values = {}
+ groundtruth = []
- for tagfile in tagfiles:
- with open(tagfile, 'r') as filehandle:
- tagname = os.path.splitext(os.path.basename(tagfile))[0]
- tag = [line.strip() == "1" for line in filehandle.readlines()]
- while not len(tag) >= len(groundtruth):
- tag.append(False)
- tags[tagname] = tag
+ base = metadata["root"]
- valuefiles = glob.glob(os.path.join(self._base, '*.value'))
+ channels["color"] = load_channel(os.path.join(base, "%08d.jpg"))
+ metadata["channel.default"] = "color"
+ metadata["width"], metadata["height"] = six.next(six.itervalues(channels)).size
- for valuefile in valuefiles:
- with open(valuefile, 'r') as filehandle:
- valuename = os.path.splitext(os.path.basename(valuefile))[0]
- value = [float(line.strip()) for line in filehandle.readlines()]
- while not len(value) >= len(groundtruth):
- value.append(0.0)
- values[valuename] = value
+ groundtruth_file = os.path.join(base, metadata.get("groundtruth", "groundtruth.txt"))
+ groundtruth = read_trajectory(groundtruth_file)
- for name, channel in channels.items():
- if not channel.length == len(groundtruth):
- raise DatasetException("Length mismatch for channel %s" % name)
+ if len(groundtruth) == 1 and channels["color"].length > 1:
+ # We are dealing with testing dataset, only first frame is available, so we pad the
+ # groundtruth with unknowns. Only unsupervised experiment will work, but it is ok
+ groundtruth.extend([Special(Sequence.UNKNOWN)] * (channels["color"].length - 1))
- for name, tag in tags.items():
- if not len(tag) == len(groundtruth):
- tag_tmp = len(groundtruth) * [False]
- tag_tmp[:len(tag)] = tag
- tag = tag_tmp
+ metadata["length"] = len(groundtruth)
- for name, value in values.items():
- if not len(value) == len(groundtruth):
- raise DatasetException("Length mismatch for value %s" % name)
+ tagfiles = glob.glob(os.path.join(base, '*.label'))
- return channels, groundtruth, tags, values
+ for tagfile in tagfiles:
+ with open(tagfile, 'r') as filehandle:
+ tagname = os.path.splitext(os.path.basename(tagfile))[0]
+ tag = [line.strip() == "1" for line in filehandle.readlines()]
+ while not len(tag) >= len(groundtruth):
+ tag.append(False)
+ tags[tagname] = tag
-class GOT10kDataset(Dataset):
+ valuefiles = glob.glob(os.path.join(base, '*.value'))
- def __init__(self, path, sequence_list="list.txt"):
- super().__init__(path)
+ for valuefile in valuefiles:
+ with open(valuefile, 'r') as filehandle:
+ valuename = os.path.splitext(os.path.basename(valuefile))[0]
+ value = [float(line.strip()) for line in filehandle.readlines()]
+ while not len(value) >= len(groundtruth):
+ value.append(0.0)
+ values[valuename] = value
- if not os.path.isabs(sequence_list):
- sequence_list = os.path.join(path, sequence_list)
+ for name, channel in channels.items():
+ if not channel.length == len(groundtruth):
+ raise DatasetException("Length mismatch for channel %s" % name)
- if not os.path.isfile(sequence_list):
- raise DatasetException("Sequence list does not exist")
+ for name, tag in tags.items():
+ if not len(tag) == len(groundtruth):
+ tag_tmp = len(groundtruth) * [False]
+ tag_tmp[:len(tag)] = tag
+ tag = tag_tmp
- with open(sequence_list, 'r') as handle:
- names = handle.readlines()
+ for name, value in values.items():
+ if not len(value) == len(groundtruth):
+ raise DatasetException("Length mismatch for value %s" % name)
- self._sequences = OrderedDict()
+ objects = {"object" : groundtruth}
- with Progress("Loading dataset", len(names)) as progress:
+ return SequenceData(channels, objects, tags, values, len(groundtruth))
- for name in names:
- self._sequences[name.strip()] = GOT10kSequence(os.path.join(path, name.strip()), dataset=self)
- progress.relative(1)
+from vot.dataset import sequence_reader
- @staticmethod
- def check(path: str):
- if not os.path.isfile(os.path.join(path, 'list.txt')):
- return False
+@sequence_reader.register("GOT-10k")
+def read_sequence(path):
+ """ Read GOT-10k sequence from the given path.
+
+ Args:
+ path (str): Path to the sequence.
+ """
- with open(os.path.join(path, 'list.txt'), 'r') as handle:
- sequence = handle.readline().strip()
- return GOT10kSequence.check(os.path.join(path, sequence))
+ if not (os.path.isfile(os.path.join(path, 'groundtruth.txt')) and os.path.isfile(os.path.join(path, 'meta_info.ini'))):
+ return None
- @property
- def path(self):
- return self._path
+ metadata = dict(fps=30, format="default")
- @property
- def length(self):
- return len(self._sequences)
+ if os.path.isfile(os.path.join(path, 'meta_info.ini')):
+ config = configparser.ConfigParser()
+ config.read(os.path.join(path, 'meta_info.ini'))
+ metadata.update(config["METAINFO"])
+ metadata["fps"] = int(metadata["anno_fps"][:-3])
- def __getitem__(self, key):
- return self._sequences[key]
+ metadata["root"] = path
+ metadata["name"] = os.path.basename(path)
+ metadata["channel.default"] = "color"
- def __contains__(self, key):
- return key in self._sequences
+ return BasedSequence(metadata["name"], _read_data, metadata)
- def __iter__(self):
- return self._sequences.values().__iter__()
- def list(self):
- return list(self._sequences.keys())
diff --git a/vot/dataset/otb.py b/vot/dataset/otb.py
index b4c8afa..bc956e3 100644
--- a/vot/dataset/otb.py
+++ b/vot/dataset/otb.py
@@ -1,15 +1,16 @@
+""" OTB dataset adapter module. OTB is one of the earliest tracking benchmarks. It is a collection of 50/100 sequences
+with ground truth annotations. The dataset is available at http://cvlab.hanyang.ac.kr/tracker_benchmark/datasets.html.
+"""
-from collections import OrderedDict
import os
-import logging
import six
-from vot.dataset import BaseSequence, Dataset, DatasetException, PatternFileListChannel
+from vot import get_logger
+from vot.dataset import BasedSequence, DatasetException, PatternFileListChannel, SequenceData
from vot.utilities import Progress
-from vot.region import parse
-
-logger = logging.getLogger("vot")
+from vot.region.io import parse_region
+logger = get_logger()
_BASE_URL = "http://cvlab.hanyang.ac.kr/tracker_benchmark/seq/"
@@ -17,8 +18,8 @@
"Car1", "Car4", "CarDark", "CarScale", "ClifBar", "Couple", "Crowds", "David", "Deer", "Diving",
"DragonBaby", "Dudek", "Football", "Freeman4", "Girl", "Human3", "Human4", "Human6", "Human9",
"Ironman", "Jump", "Jumping", "Liquor", "Matrix", "MotorRolling", "Panda", "RedTeam", "Shaking",
- "Singer2", "Skating1", "Skating2", "Skiing", "Soccer", "Surfer", "Sylvester", "Tiger2",
- "Trellis", "Walking", "Walking2", "Woman" ]
+ "Singer2", "Skating1", "Skating2_1", "Skating2_2", "Skiing", "Soccer", "Surfer", "Sylvester", "Tiger2",
+ "Trellis", "Walking", "Walking2", "Woman"]
_SEQUENCES = {
"Basketball": {"attributes": ["IV", "OCC", "DEF", "OPR", "BC"]},
@@ -39,7 +40,7 @@
"Crowds": {"attributes": ["IV", "DEF", "BC"]},
"David": {"attributes": ["IV", "SV", "OCC", "DEF", "MB", "IPR", "OPR"], "start": 300, "stop": 770},
"Deer": {"attributes": ["MB", "FM", "IPR", "BC", "LR"]},
- "Diving": {"attributes": ["SV", "DEF", "IPR"]},
+ "Diving": {"attributes": ["SV", "DEF", "IPR"], "stop": 215},
"DragonBaby": {"attributes": ["SV", "OCC", "MB", "FM", "IPR", "OPR", "OV"]},
"Dudek": {"attributes": ["SV", "OCC", "DEF", "FM", "IPR", "OPR", "OV", "BC"]},
"Football": {"attributes": ["OCC", "IPR", "OPR", "BC"]},
@@ -60,7 +61,8 @@
"Shaking": {"attributes": ["IV", "SV", "IPR", "OPR", "BC"]},
"Singer2": {"attributes": ["IV", "DEF", "IPR", "OPR", "BC"]},
"Skating1": {"attributes": ["IV", "SV", "OCC", "DEF", "OPR", "BC"]},
- "Skating2": {"objects": 2, "attributes": ["SV", "OCC", "DEF", "FM", "OPR"]},
+ "Skating2_1": {"attributes": ["SV", "OCC", "DEF", "FM", "OPR"], "base": "Skating2", "groundtruth" : "groundtruth_rect.1.txt"},
+ "Skating2_2": {"attributes": ["SV", "OCC", "DEF", "FM", "OPR"], "base": "Skating2", "groundtruth" : "groundtruth_rect.2.txt"},
"Skiing": {"attributes": ["IV", "SV", "DEF", "IPR", "OPR"]},
"Soccer": {"attributes": ["IV", "SV", "OCC", "MB", "FM", "IPR", "OPR", "BC"]},
"Surfer": {"attributes": ["SV", "FM", "IPR", "OPR", "LR"]},
@@ -72,10 +74,10 @@
"Woman": {"attributes": ["IV", "SV", "OCC", "DEF", "MB", "FM", "OPR"]},
# OTB-100 sequences
"Bird2": {"attributes": ["OCC", "DEF", "FM", "IPR", "OPR"]},
- "BlurCar1": {"attributes": ["MB", "FM"]},
- "BlurCar3": {"attributes": ["MB", "FM"]},
- "BlurCar4": {"attributes": ["MB", "FM"]},
- "Board": {"attributes": ["SV", "MB", "FM", "OPR", "OV", "BC"]},
+ "BlurCar1": {"attributes": ["MB", "FM"], "start": 247},
+ "BlurCar3": {"attributes": ["MB", "FM"], "start": 3},
+ "BlurCar4": {"attributes": ["MB", "FM"], "start": 18},
+ "Board": {"attributes": ["SV", "MB", "FM", "OPR", "OV", "BC"], "zeros": 5},
"Bolt2": {"attributes": ["DEF", "BC"]},
"Boy": {"attributes": ["SV", "MB", "FM", "IPR", "OPR"]},
"Car2": {"attributes": ["IV", "SV", "MB", "FM", "BC"]},
@@ -103,8 +105,8 @@
"Human5": {"attributes": ["SV", "OCC", "DEF"]},
"Human7": {"attributes": ["IV", "SV", "OCC", "DEF", "MB", "FM"]},
"Human8": {"attributes": ["IV", "SV", "DEF"]},
- "Jogging1": {"attributes": ["OCC", "DEF", "OPR"], "base": "Jogging"},
- "Jogging2": {"attributes": ["OCC", "DEF", "OPR"], "base": "Jogging"},
+ "Jogging1": {"attributes": ["OCC", "DEF", "OPR"], "base": "Jogging", "groundtruth" : "groundtruth_rect.1.txt"},
+ "Jogging2": {"attributes": ["OCC", "DEF", "OPR"], "base": "Jogging", "groundtruth" : "groundtruth_rect.2.txt"},
"KiteSurf": {"attributes": ["IV", "OCC", "IPR", "OPR"]},
"Lemming": {"attributes": ["IV", "SV", "OCC", "FM", "OPR", "OV"]},
"Man": {"attributes": ["IV"]},
@@ -123,107 +125,96 @@
"Vase": {"attributes": ["SV", "FM", "IPR"]},
}
-class OTBSequence(BaseSequence):
-
- def __init__(self, root, name=None, dataset=None):
-
- metadata = _SEQUENCES[self.name]
- self._base = os.path.join(root, metadata.get("base", name))
-
- super().__init__(name, dataset)
-
- @staticmethod
- def check(path: str):
- return os.path.isfile(os.path.join(path, 'groundtruth_rect.txt'))
-
- def _read_metadata(self):
-
- metadata = _SEQUENCES[self.name]
-
- return {"attributes": metadata["attributes"]}
-
- def _read(self):
-
- channels = {}
- groundtruth = []
-
- metadata = _SEQUENCES[self.name]
-
- channels["color"] = PatternFileListChannel(os.path.join(self._base, "img", "%04d.jpg"),
- start=metadata.get("start", 1), end=metadata.get("end", None))
-
- self._metadata["channel.default"] = "color"
- self._metadata["width"], self._metadata["height"] = six.next(six.itervalues(channels)).size
+def _load_sequence(metadata):
+ """Load a sequence from the OTB dataset.
+
+ Args:
+ metadata (dict): Sequence metadata.
+ """
- groundtruth_file = os.path.join(self._base, "groundtruth_rect.txt")
+ channels = {}
+ groundtruth = []
- with open(groundtruth_file, 'r') as filehandle:
- for region in filehandle.readlines():
- groundtruth.append(parse(region))
+ attributes = metadata.get("attributes", {})
- self._metadata["length"] = len(groundtruth)
+ channels["color"] = PatternFileListChannel(os.path.join(metadata["path"], "img", "%%0%dd.jpg" % attributes.get("zeros", 4)),
+ start=attributes.get("start", 1), end=attributes.get("stop", None))
- if not channels["color"].length == len(groundtruth):
- raise DatasetException("Length mismatch between groundtruth and images")
+ metadata["channel.default"] = "color"
+ metadata["width"], metadata["height"] = six.next(six.itervalues(channels)).size
- return channels, groundtruth, {}, {}
+ groundtruth_file = os.path.join(metadata["path"], attributes.get("groundtruth", "groundtruth_rect.txt"))
-class GOT10kDataset(Dataset):
+ with open(groundtruth_file, 'r') as filehandle:
+ for region in filehandle.readlines():
+ groundtruth.append(parse_region(region.replace("\t", ",").replace(" ", ",")))
- def __init__(self, path, otb50: bool = False):
- super().__init__(path)
+ metadata["length"] = len(groundtruth)
- dataset = _SEQUENCES
+ if not len(channels["color"]) == len(groundtruth):
+ raise DatasetException("Length mismatch between groundtruth and images %d != %d" % (len(channels["color"]), len(groundtruth)))
+
+ objects = {"object" : groundtruth}
- if otb50:
- dataset = {k: v for k, v in dataset.items() if k in _OTB50_SUBSET}
+ return SequenceData(channels, objects, {}, {}, len(groundtruth))
- self._sequences = OrderedDict()
+from vot.dataset import sequence_reader
- with Progress("Loading dataset", len(dataset)) as progress:
+@sequence_reader.register("otb")
+def read_sequence(path: str):
+ """Reads a sequence from OTB dataset. The sequence is identified by the name of the folder and the
+ groundtruth_rect.txt file is expected to be present in the folder.
+
+ Args:
+ path (str): Path to the sequence folder.
+
+ Returns:
+ Sequence: The sequence object.
+ """
- for name in sorted(list(dataset.keys())):
- self._sequences[name.strip()] = OTBSequence(path, name, dataset=self)
- progress.relative(1)
+ if not os.path.isfile(os.path.join(path, 'groundtruth_rect.txt')):
+ return None
- @staticmethod
- def check(path: str):
- if os.path.isfile(os.path.join(path, 'list.txt')):
- return False
+ name = os.path.basename(path)
- for sequence in _OTB50_SUBSET:
- return OTBSequence.check(os.path.join(path, sequence))
+ if name not in _SEQUENCES:
+ return None
- @property
- def path(self):
- return self._path
+ metadata = {"attributes": _SEQUENCES[name], "path": path}
+ return BasedSequence(name.strip(), _load_sequence, metadata)
- @property
- def length(self):
- return len(self._sequences)
-
- def __getitem__(self, key):
- return self._sequences[key]
-
- def __contains__(self, key):
- return key in self._sequences
-
- def __iter__(self):
- return self._sequences.values().__iter__()
-
- def list(self):
- return list(self._sequences.keys())
+from vot.dataset import dataset_downloader
+@dataset_downloader.register("otb50")
+def download_otb50(path: str):
+ """Downloads OTB50 dataset to the given path.
+
+ Args:
+ path (str): Path to the dataset folder.
+ """
+ dataset = _SEQUENCES
+ dataset = {k: v for k, v in dataset.items() if k in _OTB50_SUBSET}
+ _download_dataset(path, dataset)
+
+@dataset_downloader.register("otb100")
+def download_otb100(path: str):
+ """Downloads OTB100 dataset to the given path.
+
+ Args:
+ path (str): Path to the dataset folder.
+ """
+ dataset = _SEQUENCES
+ _download_dataset(path, dataset)
-def download_dataset(path: str, otb50: bool = False):
+def _download_dataset(path: str, dataset: dict):
+ """Downloads the given dataset to the given path.
+
+ Args:
+ path (str): Path to the dataset folder.
+ """
from vot.utilities.net import download_uncompress, join_url, NetworkException
- dataset = _SEQUENCES
-
- if otb50:
- dataset = {k: v for k, v in dataset.items() if k in _OTB50_SUBSET}
-
with Progress("Downloading", len(dataset)) as progress:
for name, metadata in dataset.items():
name = metadata.get("base", name)
@@ -237,6 +228,7 @@ def download_dataset(path: str, otb50: bool = False):
progress.relative(1)
-
-if __name__ == "__main__":
- download_dataset("")
\ No newline at end of file
+ # Write sequence list to a list.txt file
+ with open(os.path.join(path, "list.txt"), 'w') as filehandle:
+ for name in dataset.keys():
+ filehandle.write("%s\n" % name)
\ No newline at end of file
diff --git a/vot/dataset/proxy.py b/vot/dataset/proxy.py
index 0dde363..c1bc1d1 100644
--- a/vot/dataset/proxy.py
+++ b/vot/dataset/proxy.py
@@ -1,46 +1,209 @@
+""" Proxy sequence classes that allow to modify the behaviour of a sequence without changing the underlying data."""
+from typing import List, Set, Tuple
-from typing import List
+from vot.region import Region
from vot.dataset import Channel, Sequence, Frame
+class ProxySequence(Sequence):
+ """A proxy sequence base that forwards requests to undelying source sequence. Meant as a base class.
+ """
+
+ def __init__(self, source: Sequence, name: str = None):
+ """Creates a proxy sequence.
+
+ Args:
+ source (Sequence): Source sequence object
+ """
+ if name is None:
+ name = source.name
+ super().__init__(name)
+ self._source = source
+
+ def __len__(self):
+ """Returns the length of the sequence. Forwards the request to the source sequence.
+
+ Returns:
+ int: Length of the sequence.
+ """
+ return len(self)
+
+ def frame(self, index: int) -> Frame:
+ """Returns a frame object for the given index. Forwards the request to the source sequence.
+
+ Args:
+ index (int): Index of the frame.
+
+ Returns:
+ Frame: Frame object."""
+ return Frame(self, index)
+
+ def metadata(self, name, default=None):
+ """Returns a metadata value for the given name. Forwards the request to the source sequence.
+
+ Args:
+ name (str): Name of the metadata.
+ default (object, optional): Default value to return if the metadata is not found. Defaults to None.
+
+ Returns:
+ object: Metadata value.
+ """
+ return self._source.metadata(name, default)
+
+ def channel(self, channel=None):
+ """Returns a channel object for the given name. Forwards the request to the source sequence.
+
+ Args:
+ channel (str, optional): Name of the channel. Defaults to None.
+
+ Returns:
+ Channel: Channel object.
+ """
+ return self._source.channel(channel)
+
+ def channels(self):
+ """Returns a list of channel names. Forwards the request to the source sequence.
+
+ Returns:
+ list: List of channel names.
+ """
+ return self._source.channels()
+
+ def objects(self):
+ """Returns a list of object ids. Forwards the request to the source sequence.
+
+ Returns:
+ list: List of object ids.
+ """
+ return self._source.objects()
+
+ def object(self, id, index=None):
+ """Returns an object for the given id. Forwards the request to the source sequence.
+
+ Args:
+ id (str): Id of the object.
+ index (int, optional): Index of the frame. Defaults to None.
+
+ Returns:
+ Object: Object object.
+ """
+ return self._source.object(id, index)
+
+ def groundtruth(self, index: int = None) -> List[Region]:
+ """Returns a list of groundtruth regions for the given index. Forwards the request to the source sequence.
+
+ Args:
+ index (int, optional): Index of the frame. Defaults to None.
+
+ Returns:
+ list: List of groundtruth regions.
+ """
+ return self._source.groundtruth(index)
+
+ def tags(self, index=None):
+ """Returns a list of tags for the given index. Forwards the request to the source sequence.
+
+ Args:
+ index (int, optional): Index of the frame. Defaults to None.
+
+ Returns:
+ list: List of tags.
+ """
+ return self._source.tags(index)
+
+ def values(self, index=None):
+ """Returns a list of values for the given index. Forwards the request to the source sequence.
+
+ Args:
+ index (int, optional): Index of the frame. Defaults to None.
+ """
+ return self._source.values(index)
+
+ @property
+ def size(self) -> Tuple[int, int]:
+ """Returns the size of the sequence. Forwards the request to the source sequence.
+
+ Returns:
+ Tuple[int, int]: Size of the sequence.
+ """
+ return self._source.size
+
+
class FrameMapChannel(Channel):
+ """A proxy channel that maps frames from a source channel in another order."""
def __init__(self, source: Channel, frame_map: List[int]):
+ """Creates a frame mapping proxy channel.
+
+ Args:
+ source (Channel): Source channel object
+ frame_map (List[int]): A list of frame indices in the source channel that will form the proxy. The list is filtered
+ so that all indices that are out of bounds are removed.
+
+ """
super().__init__()
self._source = source
self._map = frame_map
- @property
- def length(self):
+ def __len__(self):
+ """Returns the length of the channel."""
return len(self._map)
def frame(self, index):
+ """Returns a frame object for the given index.
+
+ Args:
+ index (int): Index of the frame.
+
+ Returns:
+ Frame: Frame object.
+ """
return self._source.frame(self._map[index])
def filename(self, index):
+ """Returns the filename of the frame for the given index. Index is mapped according to the frame map before the request is forwarded to the source channel.
+
+ Args:
+ index (int): Index of the frame.
+
+ Returns:
+ str: Filename of the frame.
+ """
return self._source.filename(self._map[index])
@property
def size(self):
+ """Returns the size of the channel.
+
+ Returns:
+ Tuple[int, int]: Size of the channel.
+ """
return self._source.size
-class FrameMapSequence(Sequence):
+class FrameMapSequence(ProxySequence):
+ """A proxy sequence that maps frames from a source sequence in another order.
+ """
def __init__(self, source: Sequence, frame_map: List[int]):
- super().__init__(source.name, source.dataset)
- self._source = source
- self._map = [i for i in frame_map if i >= 0 and i < source.length]
+ """Creates a frame mapping proxy sequence.
- def __len__(self):
- return self.length
-
- def frame(self, index):
- return Frame(self, index)
-
- def metadata(self, name, default=None):
- return self._source.metadata(name, default)
+ Args:
+ source (Sequence): Source sequence object
+ frame_map (List[int]): A list of frame indices in the source sequence that will form the proxy. The list is filtered
+ so that all indices that are out of bounds are removed.
+ """
+ super().__init__(source)
+ self._map = [i for i in frame_map if i >= 0 and i < len(source)]
def channel(self, channel=None):
+ """Returns a channel object for the given channel name.
+
+ Args:
+ channel (str): Name of the channel.
+
+ Returns:
+ Channel: Channel object.
+ """
sourcechannel = self._source.channel(channel)
if sourcechannel is None:
@@ -49,9 +212,32 @@ def channel(self, channel=None):
return FrameMapChannel(sourcechannel, self._map)
def channels(self):
+ """Returns a list of channel names.
+
+ Returns:
+ list: List of channel names.
+ """
return self._source.channels()
- def groundtruth(self, index=None):
+ def frame(self, index: int) -> Frame:
+ """Returns a frame object for the given index. Forwards the request to the source sequence with the mapped index.
+
+ Args:
+ index (int): Index of the frame.
+
+ Returns:
+ Frame: Frame object.
+ """
+ return self._source.frame(self._map[index])
+
+ def groundtruth(self, index: int = None) -> List[Region]:
+ """Returns a list of groundtruth regions for the given index. Forwards the request to the source sequence with the mapped index.
+
+ Args:
+ index (int, optional): Index of the frame. Defaults to None.
+
+ Returns:
+ list: List of groundtruth regions."""
if index is None:
groundtruth = [None] * len(self)
for i, m in enumerate(self._map):
@@ -60,21 +246,142 @@ def groundtruth(self, index=None):
else:
return self._source.groundtruth(self._map[index])
+ def object(self, id, index=None):
+ """Returns an object for the given id. Forwards the request to the source sequence with the mapped index.
+
+ Args:
+ id (str): Id of the object.
+ index (int, optional): Index of the frame. Defaults to None.
+
+ Returns:
+ Region: Object region or a list of object regions.
+ """
+ if index is None:
+ groundtruth = [None] * len(self)
+ for i, m in enumerate(self._map):
+ groundtruth[i] = self._source.object(id, m)
+ return groundtruth
+ else:
+ return super().object(id, self._map[index])
+
def tags(self, index=None):
+ """Returns a list of tags for the given index. Forwards the request to the source sequence with the mapped index.
+
+ Args:
+ index (int, optional): Index of the frame. Defaults to None.
+
+ Returns:
+ list: List of tags.
+ """
if index is None:
+ # TODO: this is probably not correct
return self._source.tags()
else:
return self._source.tags(self._map[index])
def values(self, index=None):
+ """Returns a list of values for the given index. Forwards the request to the source sequence with the mapped index.
+
+ Args:
+ index (int, optional): Index of the frame. Defaults to None.
+
+ Returns:
+ list: List of values.
+ """
if index is None:
+ # TODO: this is probably not correct
return self._source.values()
return self._source.values(self._map[index])
- @property
- def size(self):
- return self._source.size
+ def __len__(self) -> int:
+ """Returns the length of the sequence. The length is the same as the length of the frame map.
- @property
- def length(self):
+ Returns:
+ int: Length of the sequence.
+ """
return len(self._map)
+
+class ChannelFilterSequence(ProxySequence):
+ """A proxy sequence that only makes specific channels visible.
+ """
+
+ def __init__(self, source: Sequence, channels: Set[str]):
+ """Creates a channel filter proxy sequence.
+
+ Args:
+ source (Sequence): Source sequence object
+ channels (Set[str]): A set of channel names that will be visible in the proxy sequence. The set is filtered
+ so that all channel names that are not in the source sequence are removed.
+ """
+ super().__init__(source)
+ self._filter = [i for i in channels if i in source.channels()]
+
+ def channel(self, channel=None):
+ """Returns a channel object for the given channel name. If the channel is not in the filter, None is returned.
+
+ Args:
+ channel (str): Name of the channel.
+
+ Returns:
+ Channel: Channel object.
+ """
+ if channel not in self._filter:
+ return None
+ return self._source.channel(channel)
+
+ def channels(self):
+ """Returns a list of channel names.
+
+ Returns:
+ list: List of channel names.
+ """
+ return set(self._filter)
+
+class ObjectFilterSequence(ProxySequence):
+ """A proxy sequence that only makes specific object visible.
+ """
+
+ def __init__(self, source: Sequence, id: str, trim: bool=False):
+ """Creates an object filter proxy sequence.
+
+ Args:
+ source (Sequence): Source sequence object
+ id (str): ID of the object that will be visible in the proxy sequence.
+
+ Keyword Args:
+ trim (bool): If true, the sequence will be trimmed to the first and last frame where the object is visible.
+ """
+ super().__init__(source, "%s_%s" % (source.name, id))
+ self._id = id
+ # TODO: implement trim
+ self._trim = trim
+
+ def objects(self):
+ """Returns a dictionary of all objects in the sequence.
+
+ Returns:
+ Dict[str, Object]: Dictionary of all objects in the sequence.
+ """
+ objects = self._source.objects()
+ return {self._id: objects[id]}
+
+ def object(self, id, index=None):
+ """Returns an object for the given id.
+
+ Args:
+ id (str): ID of the object.
+
+ Returns:
+ Region: Object object.
+ """
+ if id != self._id:
+ return None
+ return self._source.object(id, index)
+
+ def groundtruth(self, index: int = None) -> List[Region]:
+ """Returns the groundtruth for the given index.
+
+ Args:
+ index (int): Index of the frame.
+ """
+ return self._source.object(self._id, index)
\ No newline at end of file
diff --git a/vot/dataset/trackingnet.py b/vot/dataset/trackingnet.py
new file mode 100644
index 0000000..25d597c
--- /dev/null
+++ b/vot/dataset/trackingnet.py
@@ -0,0 +1,129 @@
+""" Dataset adapter for the TrackingNet dataset. Note that the dataset is organized a different way than the VOT datasets,
+annotated frames are stored in a separate directory. The dataset also contains train and test splits. The loader
+assumes that only one of the splits is used at a time and that the path is given to this part of the dataset. """
+
+import os
+import glob
+import logging
+from collections import OrderedDict
+
+import six
+
+from vot.dataset import Dataset, DatasetException, \
+ BasedSequence, PatternFileListChannel, SequenceData, \
+ Sequence
+from vot.region import Special
+from vot.region.io import read_trajectory
+from vot.utilities import Progress
+
+logger = logging.getLogger("vot")
+
+def load_channel(source):
+ """ Load channel from the given source.
+
+ Args:
+ source (str): Path to the source. If the source is a directory, it is
+ assumed to be a pattern file list. If the source is a file, it is
+ assumed to be a video file.
+
+ Returns:
+ Channel: Channel object.
+ """
+
+ extension = os.path.splitext(source)[1]
+
+ if extension == '':
+ source = os.path.join(source, '%d.jpg')
+ return PatternFileListChannel(source)
+
+
+def _read_data(metadata):
+ """Internal function for reading data from the given metadata for a TrackingNet sequence.
+
+ Args:
+ metadata (dict): Metadata dictionary.
+
+ Returns:
+ SequenceData: Sequence data object.
+ """
+
+ channels = {}
+ tags = {}
+ values = {}
+ groundtruth = []
+
+ name = metadata["name"]
+ root = metadata["root"]
+
+ channels["color"] = load_channel(os.path.join(root, 'frames', name))
+ metadata["channel.default"] = "color"
+ metadata["width"], metadata["height"] = six.next(six.itervalues(channels)).size
+
+ groundtruth = read_trajectory(root)
+
+ if len(groundtruth) == 1 and channels["color"].length > 1:
+ # We are dealing with testing dataset, only first frame is available, so we pad the
+ # groundtruth with unknowns. Only unsupervised experiment will work, but it is ok
+ groundtruth.extend([Special(Sequence.UNKNOWN)] * (channels["color"].length - 1))
+
+ metadata["length"] = len(groundtruth)
+
+ objects = {"object" : groundtruth}
+
+ return SequenceData(channels, objects, tags, values, len(groundtruth))
+
+from vot.dataset import sequence_reader
+
+sequence_reader.register("trackingnet")
+def read_sequence(path):
+ """ Read sequence from the given path. Different to VOT datasets, the sequence is not
+ a directory, but a file. From the file name the sequence name is extracted and the
+ path to image frames is inferred based on standard TrackingNet directory structure.
+
+ Args:
+ path (str): Path to the sequence groundtruth.
+
+ Returns:
+ Sequence: Sequence object.
+ """
+ if not os.path.isfile(path):
+ return None
+
+ name, ext = os.path.splitext(os.path.basename(path))
+
+ if ext != '.txt':
+ return None
+
+ root = os.path.dirname(os.path.dirname(os.path.dirname(path)))
+
+ if not os.path.isfile(path) and os.path.isdir(os.path.join(root, 'frames', name)):
+ return None
+
+ metadata = dict(fps=30)
+ metadata["channel.default"] = "color"
+ metadata["name"] = name
+ metadata["root"] = root
+
+ return BasedSequence(name, _read_data, metadata)
+
+from vot.dataset import sequence_indexer
+
+sequence_indexer.register("trackingnet")
+def list_sequences(path):
+ """ List sequences in the given path. The path is expected to be the root of the TrackingNet dataset split.
+
+ Args:
+ path (str): Path to the dataset root.
+
+ Returns:
+ list: List of sequences.
+ """
+ for dirname in ["anno", "frames"]:
+ if not os.path.isdir(os.path.join(path, dirname)):
+ return None
+
+ sequences = list(glob.glob(os.path.join(path, "anno", "*.txt")))
+
+ return sequences
+
+
diff --git a/vot/dataset/vot.py b/vot/dataset/vot.py
deleted file mode 100644
index efdc768..0000000
--- a/vot/dataset/vot.py
+++ /dev/null
@@ -1,278 +0,0 @@
-
-import os
-import glob
-import logging
-from collections import OrderedDict
-
-import six
-
-import cv2
-
-from vot.dataset import Dataset, DatasetException, Sequence, BaseSequence, PatternFileListChannel
-from vot.region import parse, write_file
-from vot.utilities import Progress, localize_path, read_properties, write_properties
-
-logger = logging.getLogger("vot")
-
-def load_channel(source):
-
- extension = os.path.splitext(source)[1]
-
- if extension == '':
- source = os.path.join(source, '%08d.jpg')
- return PatternFileListChannel(source)
-
-class VOTSequence(BaseSequence):
-
- def __init__(self, base, name=None, dataset=None):
- self._base = base
- if name is None:
- name = os.path.basename(base)
- super().__init__(name, dataset)
-
- @staticmethod
- def check(path):
- return os.path.isfile(os.path.join(path, 'sequence'))
-
- def _read_metadata(self):
- metadata = dict(fps=30, format="default")
- metadata["channel.default"] = "color"
-
- metadata_file = os.path.join(self._base, 'sequence')
- metadata.update(read_properties(metadata_file))
-
- return metadata
-
- def _read(self):
-
- channels = {}
- tags = {}
- values = {}
- groundtruth = []
-
- for c in ["color", "depth", "ir"]:
- channel_path = self.metadata("channels.%s" % c, None)
- if not channel_path is None:
- channels[c] = load_channel(os.path.join(self._base, localize_path(channel_path)))
-
- # Load default channel if no explicit channel data available
- if len(channels) == 0:
- channels["color"] = load_channel(os.path.join(self._base, "color", "%08d.jpg"))
- else:
- self._metadata["channel.default"] = next(iter(channels.keys()))
-
- self._metadata["width"], self._metadata["height"] = six.next(six.itervalues(channels)).size
-
- groundtruth_file = os.path.join(self._base, self.metadata("groundtruth", "groundtruth.txt"))
-
- with open(groundtruth_file, 'r') as filehandle:
- for region in filehandle.readlines():
- groundtruth.append(parse(region))
-
- self._metadata["length"] = len(groundtruth)
-
- tagfiles = glob.glob(os.path.join(self._base, '*.tag')) + glob.glob(os.path.join(self._base, '*.label'))
-
- for tagfile in tagfiles:
- with open(tagfile, 'r') as filehandle:
- tagname = os.path.splitext(os.path.basename(tagfile))[0]
- tag = [line.strip() == "1" for line in filehandle.readlines()]
- while not len(tag) >= len(groundtruth):
- tag.append(False)
- tags[tagname] = tag
-
- valuefiles = glob.glob(os.path.join(self._base, '*.value'))
-
- for valuefile in valuefiles:
- with open(valuefile, 'r') as filehandle:
- valuename = os.path.splitext(os.path.basename(valuefile))[0]
- value = [float(line.strip()) for line in filehandle.readlines()]
- while not len(value) >= len(groundtruth):
- value.append(0.0)
- values[valuename] = value
-
- for name, channel in channels.items():
- if not channel.length == len(groundtruth):
- raise DatasetException("Length mismatch for channel %s (%d != %d)" % (name, channel.length, len(groundtruth)))
-
- for name, tag in tags.items():
- if not len(tag) == len(groundtruth):
- tag_tmp = len(groundtruth) * [False]
- tag_tmp[:len(tag)] = tag
- tag = tag_tmp
-
- for name, value in values.items():
- if not len(value) == len(groundtruth):
- raise DatasetException("Length mismatch for value %s" % name)
-
- return channels, groundtruth, tags, values
-
-class VOTDataset(Dataset):
-
- def __init__(self, path):
- super().__init__(path)
-
- if not os.path.isfile(os.path.join(path, "list.txt")):
- raise DatasetException("Dataset not available locally")
-
- with open(os.path.join(path, "list.txt"), 'r') as fd:
- names = fd.readlines()
-
- self._sequences = OrderedDict()
-
- with Progress("Loading dataset", len(names)) as progress:
-
- for name in names:
- self._sequences[name.strip()] = VOTSequence(os.path.join(path, name.strip()), dataset=self)
- progress.relative(1)
-
- @staticmethod
- def check(path: str):
- if not os.path.isfile(os.path.join(path, 'list.txt')):
- return False
-
- with open(os.path.join(path, 'list.txt'), 'r') as handle:
- sequence = handle.readline().strip()
- return VOTSequence.check(os.path.join(path, sequence))
-
- @property
- def path(self):
- return self._path
-
- @property
- def length(self):
- return len(self._sequences)
-
- def __getitem__(self, key):
- return self._sequences[key]
-
- def __contains__(self, key):
- return key in self._sequences
-
- def __iter__(self):
- return self._sequences.values().__iter__()
-
- def list(self):
- return list(self._sequences.keys())
-
- @classmethod
- def download(self, url, path="."):
- from vot.utilities.net import download_uncompress, download_json, get_base_url, join_url, NetworkException
-
- if os.path.splitext(url)[1] == '.zip':
- logger.info('Downloading sequence bundle from "%s". This may take a while ...', url)
-
- try:
- download_uncompress(url, path)
- except NetworkException as e:
- raise DatasetException("Unable do download dataset bundle, Please try to download the bundle manually from {} and uncompress it to {}'".format(url, path))
- except IOError as e:
- raise DatasetException("Unable to extract dataset bundle, is the target directory writable and do you have enough space?")
-
- else:
-
- meta = download_json(url)
-
- logger.info('Downloading sequence dataset "%s" with %s sequences.', meta["name"], len(meta["sequences"]))
-
- base_url = get_base_url(url) + "/"
-
- with Progress("Downloading", len(meta["sequences"])) as progress:
- for sequence in meta["sequences"]:
- sequence_directory = os.path.join(path, sequence["name"])
- os.makedirs(sequence_directory, exist_ok=True)
-
- data = {'name': sequence["name"], 'fps': sequence["fps"], 'format': 'default'}
-
- annotations_url = join_url(base_url, sequence["annotations"]["url"])
-
- try:
- download_uncompress(annotations_url, sequence_directory)
- except NetworkException as e:
- raise DatasetException("Unable do download annotations bundle")
- except IOError as e:
- raise DatasetException("Unable to extract annotations bundle, is the target directory writable and do you have enough space?")
-
- for cname, channel in sequence["channels"].items():
- channel_directory = os.path.join(sequence_directory, cname)
- os.makedirs(channel_directory, exist_ok=True)
-
- channel_url = join_url(base_url, channel["url"])
-
- try:
- download_uncompress(channel_url, channel_directory)
- except NetworkException as e:
- raise DatasetException("Unable do download channel bundle")
- except IOError as e:
- raise DatasetException("Unable to extract channel bundle, is the target directory writable and do you have enough space?")
-
- if "pattern" in channel:
- data["channels." + cname] = cname + os.path.sep + channel["pattern"]
- else:
- data["channels." + cname] = cname + os.path.sep
-
- write_properties(os.path.join(sequence_directory, 'sequence'), data)
-
- progress.relative(1)
-
- with open(os.path.join(path, "list.txt"), "w") as fp:
- for sequence in meta["sequences"]:
- fp.write('{}\n'.format(sequence["name"]))
-
-def write_sequence(directory: str, sequence: Sequence):
-
- channels = sequence.channels()
-
- metadata = dict()
- metadata["channel.default"] = sequence.metadata("channel.default", "color")
- metadata["fps"] = sequence.metadata("fps", "30")
-
- for channel in channels:
- cdir = os.path.join(directory, channel)
- os.makedirs(cdir, exist_ok=True)
-
- metadata["channels.%s" % channel] = os.path.join(channel, "%08d.jpg")
-
- for i in range(sequence.length):
- frame = sequence.frame(i).channel(channel)
- cv2.imwrite(os.path.join(cdir, "%08d.jpg" % (i + 1)), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
-
- for tag in sequence.tags():
- data = "\n".join(["1" if tag in sequence.tags(i) else "0" for i in range(sequence.length)])
- with open(os.path.join(directory, "%s.tag" % tag), "w") as fp:
- fp.write(data)
-
- for value in sequence.values():
- data = "\n".join([ str(sequence.values(i).get(value, "")) for i in range(sequence.length)])
- with open(os.path.join(directory, "%s.value" % value), "w") as fp:
- fp.write(data)
-
- write_file(os.path.join(directory, "groundtruth.txt"), [f.groundtruth() for f in sequence])
- write_properties(os.path.join(directory, "sequence"), metadata)
-
-
-VOT_DATASETS = {
- "vot2013" : "http://data.votchallenge.net/vot2013/dataset/description.json",
- "vot2014" : "http://data.votchallenge.net/vot2014/dataset/description.json",
- "vot2015" : "http://data.votchallenge.net/vot2015/dataset/description.json",
- "vot-tir2015" : "http://www.cvl.isy.liu.se/research/datasets/ltir/version1.0/ltir_v1_0_8bit.zip",
- "vot2016" : "http://data.votchallenge.net/vot2016/main/description.json",
- "vot-tir2016" : "http://data.votchallenge.net/vot2016/vot-tir2016.zip",
- "vot2017" : "http://data.votchallenge.net/vot2017/main/description.json",
- "vot-st2018" : "http://data.votchallenge.net/vot2018/main/description.json",
- "vot-lt2018" : "http://data.votchallenge.net/vot2018/longterm/description.json",
- "vot-st2019" : "http://data.votchallenge.net/vot2019/main/description.json",
- "vot-lt2019" : "http://data.votchallenge.net/vot2019/longterm/description.json",
- "vot-rgbd2019" : "http://data.votchallenge.net/vot2019/rgbd/description.json",
- "vot-rgbt2019" : "http://data.votchallenge.net/vot2019/rgbtir/meta/description.json",
- "vot-st2020" : "https://data.votchallenge.net/vot2020/shortterm/description.json",
- "vot-rgbt2020" : "http://data.votchallenge.net/vot2020/rgbtir/meta/description.json",
- "vot-st2021": "https://data.votchallenge.net/vot2021/shortterm/description.json",
- "test" : "http://data.votchallenge.net/toolkit/test.zip",
- "segmentation" : "http://box.vicos.si/tracking/vot20_test_dataset.zip"
-}
-
-def download_dataset(name, path="."):
- if not name in VOT_DATASETS:
- raise ValueError("Unknown dataset")
- VOTDataset.download(VOT_DATASETS[name], path)
\ No newline at end of file
diff --git a/vot/document/__init__.py b/vot/document/__init__.py
index 3ddbd52..893976f 100644
--- a/vot/document/__init__.py
+++ b/vot/document/__init__.py
@@ -1,15 +1,14 @@
+""" This module contains classes for generating reports and visualizations. """
-import os
import typing
from abc import ABC, abstractmethod
import json
-import math
import inspect
import threading
-import logging
-import tempfile
import datetime
import collections
+import collections.abc
+import sys
from asyncio import wait
from asyncio.futures import wrap_future
@@ -24,18 +23,28 @@
from attributee import Attributee, Object, Nested, String, Callable, Integer, List
from vot import __version__ as version
-from vot import check_debug
+from vot import get_logger
from vot.dataset import Sequence
from vot.tracker import Tracker
from vot.analysis import Axes
-from vot.experiment import Experiment, analysis_resolver
from vot.utilities import class_fullname
from vot.utilities.data import Grid
class Plot(object):
+ """ Base class for all plots. """
def __init__(self, identifier: str, xlabel: str, ylabel: str,
xlimits: typing.Tuple[float, float], ylimits: typing.Tuple[float, float], trait = None):
+ """ Initializes the plot.
+
+ Args:
+ identifier (str): The identifier of the plot.
+ xlabel (str): The label of the x axis.
+ ylabel (str): The label of the y axis.
+ xlimits (tuple): The limits of the x axis.
+ ylimits (tuple): The limits of the y axis.
+ trait (str): The trait of the plot.
+ """
self._identifier = identifier
@@ -54,25 +63,32 @@ def __init__(self, identifier: str, xlabel: str, ylabel: str,
self._axes.autoscale(False, axis="y")
def __call__(self, key, data):
+ """ Draws the data on the plot."""
self.draw(key, data)
def draw(self, key, data):
+ """ Draws the data on the plot."""
raise NotImplementedError
@property
def axes(self) -> Axes:
+ """ Returns the axes of the plot."""
return self._axes
def save(self, output, fmt):
+ """ Saves the plot to a file."""
self._figure.savefig(output, format=fmt, bbox_inches='tight', transparent=True)
@property
def identifier(self):
+ """ Returns the identifier of the plot."""
return self._identifier
class ScatterPlot(Plot):
+ """ A scatter plot."""
def draw(self, key, data):
+ """ Draws the data on the plot. """
if data is None or len(data) != 2:
return
@@ -81,8 +97,10 @@ def draw(self, key, data):
#handle.set_gid("report_%s_%d" % (self._identifier, style["number"]))
class LinePlot(Plot):
+ """ A line plot."""
def draw(self, key, data):
+ """ Draws the data on the plot."""
if data is None or len(data) < 1:
return
@@ -101,8 +119,10 @@ def draw(self, key, data):
# handle[0].set_gid("report_%s_%d" % (self._identifier, style["number"]))
class ResultsJSONEncoder(json.JSONEncoder):
+ """ JSON encoder for results. """
def default(self, o):
+ """ Default encoder. """
if isinstance(o, Grid):
return list(o)
elif isinstance(o, datetime.date):
@@ -113,12 +133,15 @@ def default(self, o):
return super().default(o)
class ResultsYAMLEncoder(yaml.Dumper):
+ """ YAML encoder for results."""
def represent_tuple(self, data):
+ """ Represents a tuple. """
return self.represent_list(list(data))
def represent_object(self, o):
+ """ Represents an object. """
if isinstance(o, Grid):
return self.represent_list(list(o))
elif isinstance(o, datetime.date):
@@ -136,6 +159,7 @@ def represent_object(self, o):
ResultsYAMLEncoder.add_multi_representer(np.inexact, ResultsYAMLEncoder.represent_float)
def generate_serialized(trackers: typing.List[Tracker], sequences: typing.List[Sequence], results, storage: "Storage", serializer: str):
+ """ Generates a serialized report of the results. """
doc = dict()
doc["toolkit"] = version
@@ -162,6 +186,7 @@ def generate_serialized(trackers: typing.List[Tracker], sequences: typing.List[S
raise RuntimeError("Unknown serializer")
def configure_axes(figure, rect=None, _=None):
+ """ Configures the axes of the plot. """
axes = PlotAxes(figure, rect or [0, 0, 1, 1])
@@ -170,6 +195,7 @@ def configure_axes(figure, rect=None, _=None):
return axes
def configure_figure(traits=None):
+ """ Configures the figure of the plot. """
args = {}
if traits == "ar":
@@ -177,63 +203,94 @@ def configure_figure(traits=None):
elif traits == "eao":
args["figsize"] = (7, 5)
elif traits == "attributes":
- args["figsize"] = (15, 5)
+ args["figsize"] = (10, 5)
return Figure(**args)
class PlotStyle(object):
+ """ A style for a plot."""
def line_style(self, opacity=1):
+ """ Returns the style for a line."""
raise NotImplementedError
def point_style(self):
+ """ Returns the style for a point."""
raise NotImplementedError
class DefaultStyle(PlotStyle):
+ """ The default style for a plot."""
colormap = get_cmap("tab20b")
colorcount = 20
markers = ["o", "v", "<", ">", "^", "8", "*"]
def __init__(self, number):
+ """ Initializes the style.
+
+ Args:
+ number (int): The number of the style.
+ """
super().__init__()
self._number = number
def line_style(self, opacity=1):
+ """ Returns the style for a line.
+
+ Args:
+ opacity (float): The opacity of the line.
+ """
color = DefaultStyle.colormap((self._number % DefaultStyle.colorcount + 1) / DefaultStyle.colorcount)
if opacity < 1:
color = colors.to_rgba(color, opacity)
return dict(linewidth=1, c=color)
def point_style(self):
+ """ Returns the style for a point.
+
+ Args:
+ color (str): The color of the point.
+ opacity (float): The opacity of the line.
+ """
color = DefaultStyle.colormap((self._number % DefaultStyle.colorcount + 1) / DefaultStyle.colorcount)
marker = DefaultStyle.markers[self._number % len(DefaultStyle.markers)]
return dict(marker=marker, c=[color])
class Legend(object):
+ """ A legend for a plot."""
def __init__(self, style_factory=DefaultStyle):
+ """ Initializes the legend.
+
+ Args:
+ style_factory (PlotStyleFactory): The style factory.
+ """
self._mapping = collections.OrderedDict()
self._counter = 0
self._style_factory = style_factory
def _number(self, key):
+ """ Returns the number for a key."""
if not key in self._mapping:
self._mapping[key] = self._counter
self._counter += 1
return self._mapping[key]
def __getitem__(self, key) -> PlotStyle:
+ """ Returns the style for a key."""
number = self._number(key)
return self._style_factory(number)
def _style(self, number):
+ """ Returns the style for a number."""
raise NotImplementedError
def keys(self):
+ """ Returns the keys of the legend."""
return self._mapping.keys()
def figure(self, key):
+ """ Returns a figure for a key."""
style = self[key]
figure = Figure(figsize=(0.1, 0.1)) # TODO: hardcoded
axes = PlotAxes(figure, [0, 0, 1, 1], yticks=[], xticks=[], frame_on=False)
@@ -245,6 +302,7 @@ def figure(self, key):
return figure
class StyleManager(Attributee):
+ """ A manager for styles. """
plots = Callable(default=DefaultStyle)
axes = Callable(default=configure_axes)
@@ -253,13 +311,16 @@ class StyleManager(Attributee):
_context = threading.local()
def __init__(self, **kwargs):
+ """ Initializes a new instance of the StyleManager class."""
super().__init__(**kwargs)
self._legends = dict()
def __getitem__(self, key) -> PlotStyle:
+ """ Gets the style for the given key."""
return self.plot_style(key)
def legend(self, key) -> Legend:
+ """ Gets the legend for the given key."""
if inspect.isclass(key):
klass = key
else:
@@ -271,18 +332,29 @@ def legend(self, key) -> Legend:
return self._legends[klass]
def plot_style(self, key) -> PlotStyle:
+ """ Gets the plot style for the given key."""
return self.legend(key)[key]
def make_axes(self, figure, rect=None, trait=None) -> Axes:
+ """ Makes the axes for the given figure."""
return self.axes(figure, rect, trait)
def make_figure(self, trait=None) -> typing.Tuple[Figure, Axes]:
+ """ Makes the figure for the given trait.
+
+ Args:
+ trait: The trait for which to make the figure.
+
+ Returns:
+ A tuple containing the figure and the axes.
+ """
figure = self.figure(trait)
axes = self.make_axes(figure, trait=trait)
return figure, axes
def __enter__(self):
+ """Enters the context of the style manager."""
manager = getattr(StyleManager._context, 'style_manager', None)
@@ -294,6 +366,7 @@ def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
+ """Exits the context of the style manager."""
manager = getattr(StyleManager._context, 'style_manager', None)
if manager == self:
@@ -301,21 +374,36 @@ def __exit__(self, exc_type, exc_value, traceback):
@staticmethod
def default() -> "StyleManager":
+ """ Gets the default style manager."""
manager = getattr(StyleManager._context, 'style_manager', None)
if manager is None:
+ get_logger().info("Creating new style manager")
manager = StyleManager()
StyleManager._context.style_manager = manager
return manager
class TrackerSorter(Attributee):
+ """ A sorter for trackers. """
experiment = String(default=None)
analysis = String(default=None)
result = Integer(val_min=0, default=0)
def __call__(self, experiments, trackers, sequences):
+ """ Sorts the trackers.
+
+ Arguments:
+ experiments (list of Experiment): The experiments.
+ trackers (list of Tracker): The trackers.
+ sequences (list of Sequence): The sequences.
+
+ Returns:
+ A list of indices of the trackers in the sorted order.
+ """
+ from vot.analysis import AnalysisError
+
if self.experiment is None or self.analysis is None:
return range(len(trackers))
@@ -329,8 +417,12 @@ def __call__(self, experiments, trackers, sequences):
if analysis is None:
raise RuntimeError("Analysis not found")
- future = analysis.commit(experiment, trackers, sequences)
- result = future.result()
+ try:
+
+ future = analysis.commit(experiment, trackers, sequences)
+ result = future.result()
+ except AnalysisError as e:
+ raise RuntimeError("Unable to sort trackers", e)
scores = [x[self.result] for x in result]
indices = [i[0] for i in sorted(enumerate(scores), reverse=True, key=lambda x: x[1])]
@@ -338,12 +430,17 @@ def __call__(self, experiments, trackers, sequences):
return indices
class Generator(Attributee):
+ """ A generator for reports."""
async def generate(self, experiments, trackers, sequences):
raise NotImplementedError
async def process(self, analyses, experiment, trackers, sequences):
- if not isinstance(analyses, collections.Iterable):
+ if sys.version_info >= (3, 3):
+ _Iterable = collections.abc.Iterable
+ else:
+ _Iterable = collections.Iterable
+ if not isinstance(analyses, _Iterable):
analyses = [analyses]
futures = []
@@ -359,6 +456,7 @@ async def process(self, analyses, experiment, trackers, sequences):
return (future.result() for future in futures)
class ReportConfiguration(Attributee):
+ """ A configuration for reports."""
style = Nested(StyleManager)
sort = Nested(TrackerSorter)
@@ -366,6 +464,17 @@ class ReportConfiguration(Attributee):
# TODO: replace this with report generator and separate json/yaml dump
def generate_document(format: str, config: ReportConfiguration, trackers: typing.List[Tracker], sequences: typing.List[Sequence], results, storage: "Storage"):
+ """ Generates a report document.
+
+ Args:
+ format: The format of the report.
+ config: The configuration of the report.
+ trackers: The trackers to include in the report.
+ sequences: The sequences to include in the report.
+ results: The results to include in the report.
+ storage: The storage to use for the report.
+
+ """
from .html import generate_html_document
from .latex import generate_latex_document
@@ -376,13 +485,12 @@ def generate_document(format: str, config: ReportConfiguration, trackers: typing
generate_serialized(trackers, sequences, results, storage, "yaml")
else:
order = config.sort(results.keys(), trackers, sequences)
- trackers = [trackers[i] for i in order]
with config.style:
if format == "html":
generate_html_document(trackers, sequences, results, storage)
elif format == "latex":
- generate_latex_document(trackers, sequences, results, storage, False)
+ generate_latex_document(trackers, sequences, results, storage, False, order=order)
elif format == "pdf":
- generate_latex_document(trackers, sequences, results, storage, True)
+ generate_latex_document(trackers, sequences, results, storage, True, order=order)
diff --git a/vot/document/common.py b/vot/document/common.py
index 3b748a9..8a645e0 100644
--- a/vot/document/common.py
+++ b/vot/document/common.py
@@ -1,3 +1,4 @@
+"""Common functions for document generation."""
import os
import math
@@ -5,11 +6,23 @@
from vot.analysis import Measure, Point, Plot, Curve, Sorting, Axes
def read_resource(name):
+ """Reads a resource file from the package directory. The file is read as a string."""
path = os.path.join(os.path.dirname(__file__), name)
with open(path, "r") as filehandle:
return filehandle.read()
+def per_tracker(a):
+ """Returns true if the analysis is per-tracker."""
+ return a.axes == Axes.TRACKERS
+
def extract_measures_table(trackers, results):
+ """Extracts a table of measures from the results. The table is a list of lists, where each list is a column.
+ The first column is the tracker name, the second column is the measure name, and the rest of the columns are the values for each tracker.
+
+ Args:
+ trackers (list): List of trackers.
+ results (dict): Dictionary of results.
+ """
table_header = [[], [], []]
table_data = dict()
column_order = []
@@ -19,7 +32,7 @@ def extract_measures_table(trackers, results):
descriptions = analysis.describe()
# Ignore all non per-tracker results
- if analysis.axes != Axes.TRACKERS:
+ if not per_tracker(analysis):
continue
for i, description in enumerate(descriptions):
@@ -31,6 +44,9 @@ def extract_measures_table(trackers, results):
table_header[2].append(description)
column_order.append(description.direction)
+ if aresults is None:
+ continue
+
for tracker, values in zip(trackers, aresults):
if not tracker in table_data:
table_data[tracker] = list()
@@ -63,9 +79,8 @@ def extract_measures_table(trackers, results):
return table_header, table_data, table_order
-
-
-def extract_plots(trackers, results):
+def extract_plots(trackers, results, order=None):
+ """Extracts a list of plots from the results. The list is a list of tuples, where each tuple is a pair of strings and a plot."""
plots = dict()
j = 0
@@ -75,7 +90,7 @@ def extract_plots(trackers, results):
descriptions = analysis.describe()
# Ignore all non per-tracker results
- if analysis.axes != Axes.TRACKERS:
+ if not per_tracker(analysis):
continue
for i, description in enumerate(descriptions):
@@ -103,17 +118,20 @@ def extract_plots(trackers, results):
else:
continue
- for tracker, values in zip(trackers, aresults):
+ for t in order if order is not None else range(len(trackers)):
+ tracker = trackers[t]
+ values = aresults[t, 0]
data = values[i] if not values is None else None
plot(tracker, data)
- experiment_plots.append((description.name, plot))
+ experiment_plots.append((analysis.title + " - " + description.name, plot))
plots[experiment] = experiment_plots
return plots
def format_value(data):
+ """Formats a value for display."""
if data is None:
return "N/A"
if isinstance(data, str):
@@ -125,6 +143,7 @@ def format_value(data):
return str(data)
def merge_repeats(objects):
+ """Merges repeated objects in a list into a list of tuples (object, count)."""
if not objects:
return []
diff --git a/vot/document/html.py b/vot/document/html.py
index c27953c..21ef299 100644
--- a/vot/document/html.py
+++ b/vot/document/html.py
@@ -1,4 +1,4 @@
-
+"""HTML report generation. This module is used to generate HTML reports from the results of the experiments."""
import os
import io
import logging
@@ -20,12 +20,14 @@
ORDER_CLASSES = {1: "first", 2: "second", 3: "third"}
def insert_cell(value, order):
+ """Inserts a cell into the data table."""
attrs = dict(data_sort_value=order, data_value=value)
if order in ORDER_CLASSES:
attrs["cls"] = ORDER_CLASSES[order]
td(format_value(value), **attrs)
def table_cell(value):
+ """Returns a cell for the data table."""
if isinstance(value, str):
return value
elif isinstance(value, Tracker):
@@ -35,6 +37,7 @@ def table_cell(value):
return format_value(value)
def grid_table(data: Grid, rows: List[str], columns: List[str]):
+ """Generates a table from a grid object."""
assert data.dimensions == 2
assert data.size(0) == len(rows) and data.size(1) == len(columns)
@@ -57,24 +60,36 @@ def grid_table(data: Grid, rows: List[str], columns: List[str]):
return element
def generate_html_document(trackers: List[Tracker], sequences: List[Sequence], results, storage: Storage):
+ """Generates an HTML document from the results of the experiments.
+
+ Args:
+ trackers (list): List of trackers.
+ sequences (list): List of sequences.
+ results (dict): Dictionary of results.
+ storage (Storage): Storage object.
+ """
def insert_figure(figure):
+ """Inserts a matplotlib figure into the document."""
buffer = io.StringIO()
figure.save(buffer, "SVG")
raw(buffer.getvalue())
def insert_mplfigure(figure):
+ """Inserts a matplotlib figure into the document."""
buffer = io.StringIO()
figure.savefig(buffer, format="SVG", bbox_inches='tight', pad_inches=0.01, dpi=200)
raw(buffer.getvalue())
def add_style(name, linked=False):
+ """Adds a style to the document."""
if linked:
link(rel='stylesheet', href='file://' + os.path.join(os.path.dirname(__file__), name))
else:
style(read_resource(name))
def add_script(name, linked=False):
+ """Adds a script to the document."""
if linked:
script(type='text/javascript', src='file://' + os.path.join(os.path.dirname(__file__), name))
else:
diff --git a/vot/document/latex.py b/vot/document/latex.py
index 83c8fb4..e768031 100644
--- a/vot/document/latex.py
+++ b/vot/document/latex.py
@@ -1,16 +1,15 @@
-
+"""This module contains functions for generating LaTeX documents with results."""
import io
-import logging
import tempfile
import datetime
from typing import List
from pylatex.base_classes import Container
from pylatex.package import Package
-from pylatex import Document, Section, Command, LongTable, MultiColumn, MultiRow, Figure, UnsafeCommand
+from pylatex import Document, Section, Command, LongTable, MultiColumn, Figure, UnsafeCommand
from pylatex.utils import NoEscape
-from vot import toolkit_version
+from vot import toolkit_version, get_logger
from vot.tracker import Tracker
from vot.dataset import Sequence
from vot.workspace import Storage
@@ -20,25 +19,31 @@
TRACKER_GROUP = "default"
class Chunk(Container):
+ """A container that does not add a newline after the content."""
def dumps(self):
+ """Returns the LaTeX representation of the container."""
return self.dumps_content()
def strip_comments(src, wrapper=True):
+ """Strips comments from a LaTeX source file."""
return "\n".join([line for line in src.split("\n") if not line.startswith("%") and (wrapper or not line.startswith(r"\makeat"))])
def insert_figure(figure):
+ """Inserts a figure into a LaTeX document."""
buffer = io.StringIO()
figure.save(buffer, "PGF")
return NoEscape(strip_comments(buffer.getvalue()))
def insert_mplfigure(figure, wrapper=True):
+ """Inserts a matplotlib figure into a LaTeX document."""
buffer = io.StringIO()
figure.savefig(buffer, format="PGF", bbox_inches='tight', pad_inches=0.01)
return NoEscape(strip_comments(buffer.getvalue(), wrapper))
def generate_symbols(container, trackers):
+ """Generates a LaTeX command for each tracker. The command is named after the tracker reference and contains the tracker symbol."""
legend = StyleManager.default().legend(Tracker)
@@ -50,20 +55,37 @@ def generate_symbols(container, trackers):
container.append(Command("makeatother"))
-def generate_latex_document(trackers: List[Tracker], sequences: List[Sequence], results, storage: Storage, build=False, multipart=True):
+def generate_latex_document(trackers: List[Tracker], sequences: List[Sequence], results, storage: Storage, build=False, multipart=True, order=None) -> str:
+ """Generates a LaTeX document with the results. The document is returned as a string. If build is True, the document is compiled and the PDF is returned.
+
+ Args:
+
+ trackers (list): List of trackers.
+ sequences (list): List of sequences.
+ results (dict): Dictionary of results.
+ storage (Storage): Storage object.
+ build (bool): If True, the document is compiled and the PDF is returned.
+ multipart (bool): If True, the document is split into multiple files.
+ order (list): List of tracker indices to use for ordering.
+ """
order_marks = {1: "first", 2: "second", 3: "third"}
def format_cell(value, order):
+ """Formats a cell in the data table."""
cell = format_value(value)
if order in order_marks:
cell = Command(order_marks[order], cell)
return cell
- logger = logging.getLogger("vot")
+ logger = get_logger()
table_header, table_data, table_order = extract_measures_table(trackers, results)
- plots = extract_plots(trackers, results)
+
+ if order is not None:
+ ordered_trackers = [trackers[i] for i in order]
+ else:
+ ordered_trackers = trackers
doc = Document(page_numbers=True)
@@ -79,18 +101,19 @@ def format_cell(value, order):
if multipart:
container = Chunk()
- generate_symbols(container, trackers)
+ generate_symbols(container, ordered_trackers)
with storage.write("symbols.tex") as out:
container.dump(out)
doc.preamble.append(Command("input", "symbols.tex"))
else:
- generate_symbols(doc.preamble, trackers)
+ generate_symbols(doc.preamble, ordered_trackers)
doc.preamble.append(Command('title', 'VOT report'))
doc.preamble.append(Command('author', 'Toolkit version ' + toolkit_version()))
doc.preamble.append(Command('date', datetime.datetime.now().isoformat()))
doc.append(NoEscape(r'\maketitle'))
+
if len(table_header[2]) == 0:
logger.debug("No measures found, skipping table")
else:
@@ -107,10 +130,20 @@ def format_cell(value, order):
data_table.end_table_header()
data_table.add_hline()
- for tracker, data in table_data.items():
+ for tracker in ordered_trackers:
+ data = table_data[tracker]
data_table.add_row([UnsafeCommand("Tracker", [tracker.reference, TRACKER_GROUP])] +
[format_cell(x, order[tracker] if not order is None else None) for x, order in zip(data, table_order)])
+ if order is not None:
+ z_order = [0] * len(order)
+ for i, j in enumerate(order):
+ z_order[max(order) - i] = j
+ else:
+ z_order = list(range(len(trackers)))
+
+ plots = extract_plots(trackers, results, z_order)
+
for experiment, experiment_plots in plots.items():
if len(experiment_plots) == 0:
continue
@@ -131,7 +164,7 @@ def format_cell(value, order):
if build:
temp = tempfile.mktemp()
- logger.debug("Generating to tempourary output %s", temp)
+ logger.debug("Generating to temporary output %s", temp)
doc.generate_pdf(temp, clean_tex=True)
storage.copy(temp + ".pdf", "report.pdf")
else:
diff --git a/vot/experiment/__init__.py b/vot/experiment/__init__.py
index 8f4b55d..bb1a770 100644
--- a/vot/experiment/__init__.py
+++ b/vot/experiment/__init__.py
@@ -1,4 +1,5 @@
-
+""" Experiments are the main building blocks of the toolkit. They are used to evaluate trackers on sequences in
+various ways."""
import logging
import typing
@@ -9,33 +10,44 @@
from attributee import Attributee, Object, Integer, Float, Nested, List
-from vot.tracker import RealtimeTrackerRuntime, TrackerException
+from vot import get_logger
+from vot.tracker import TrackerException
from vot.utilities import Progress, to_number, import_class
experiment_registry = ClassRegistry("vot_experiment")
-
transformer_registry = ClassRegistry("vot_transformer")
class RealtimeConfig(Attributee):
+ """Config proxy for real-time experiment.
+ """
grace = Integer(val_min=0, default=0)
fps = Float(val_min=0, default=20)
class NoiseConfig(Attributee):
+ """Config proxy for noise modifiers in experiments."""
# Not implemented yet
placeholder = Integer(default=1)
class InjectConfig(Attributee):
+ """Config proxy for parameter injection in experiments."""
# Not implemented yet
placeholder = Integer(default=1)
def transformer_resolver(typename, context, **kwargs):
+ """Resolve a transformer from a string. If the transformer is not registered, it is imported as a class and
+ instantiated with the provided arguments.
+
+ Args:
+ typename (str): Name of the transformer
+ context (Attributee): Context of the resolver
+
+ Returns:
+ Transformer: Resolved transformer
+ """
from vot.experiment.transformer import Transformer
- if "parent" in context:
- storage = context["parent"].storage.substorage("cache").substorage("transformer")
- else:
- storage = None
+ storage = context.parent.storage.substorage("cache").substorage("transformer")
if typename in transformer_registry:
transformer = transformer_registry.get(typename, cache=storage, **kwargs)
@@ -47,6 +59,17 @@ def transformer_resolver(typename, context, **kwargs):
return transformer_class(cache=storage, **kwargs)
def analysis_resolver(typename, context, **kwargs):
+ """Resolve an analysis from a string. If the analysis is not registered, it is imported as a class and
+ instantiated with the provided arguments.
+
+ Args:
+ typename (str): Name of the analysis
+ context (Attributee): Context of the resolver
+
+ Returns:
+ Analysis: Resolved analysis
+ """
+
from vot.analysis import Analysis, analysis_registry
if typename in analysis_registry:
@@ -57,20 +80,37 @@ def analysis_resolver(typename, context, **kwargs):
assert issubclass(analysis_class, Analysis)
analysis = analysis_class(**kwargs)
- if "parent" in context:
- assert analysis.compatible(context["parent"])
+ assert analysis.compatible(context.parent)
return analysis
class Experiment(Attributee):
+ """Experiment abstract base class. Each experiment is responsible for running a tracker on a sequence and
+ storing results into dedicated storage.
+ """
- realtime = Nested(RealtimeConfig, default=None)
+ UNKNOWN = 0
+ INITIALIZATION = 1
+
+ realtime = Nested(RealtimeConfig, default=None, description="Realtime modifier config")
noise = Nested(NoiseConfig, default=None)
inject = Nested(InjectConfig, default=None)
transformers = List(Object(transformer_resolver), default=[])
analyses = List(Object(analysis_resolver), default=[])
- def __init__(self, _identifier: str, _storage: "LocalStorage", **kwargs):
+ def __init__(self, _identifier: str, _storage: "Storage", **kwargs):
+ """Initialize an experiment.
+
+ Args:
+ _identifier (str): Identifier of the experiment
+ _storage (Storage): Storage to use for storing results
+
+ Keyword Args:
+ **kwargs: Additional arguments
+
+ Raises:
+ ValueError: If the identifier is not valid
+ """
self._identifier = _identifier
self._storage = _storage
super().__init__(**kwargs)
@@ -78,71 +118,229 @@ def __init__(self, _identifier: str, _storage: "LocalStorage", **kwargs):
@property
def identifier(self) -> str:
+ """Identifier of the experiment.
+
+ Returns:
+ str: Identifier of the experiment
+ """
return self._identifier
+ @property
+ def _multiobject(self) -> bool:
+ """Whether the experiment is multi-object or not.
+
+ Returns:
+ bool: Whether the experiment is multi-object or not
+ """
+ # TODO: at some point this may be a property for all experiments
+ return False
+
@property
def storage(self) -> "Storage":
+ """Storage used by the experiment.
+
+ Returns:
+ Storage: Storage used by the experiment
+ """
return self._storage
- def _get_initialization(self, sequence: "Sequence", index: int):
- return sequence.groundtruth(index)
+ def _get_initialization(self, sequence: "Sequence", index: int, id: str = None):
+ """Get initialization for a given sequence, index and object id.
+
+ Args:
+ sequence (Sequence): Sequence to get initialization for
+ index (int): Index of the frame to get initialization for
+ id (str): Object id to get initialization for
+
+ Returns:
+ Initialization: Initialization for the given sequence, index and object id
+
+ Raises:
+ ValueError: If the sequence does not contain the given index or object id
+ """
+ if not self._multiobject and id is None:
+ return sequence.groundtruth(index)
+ else:
+ return sequence.frame(index).object(id)
+
+ def _get_runtime(self, tracker: "Tracker", sequence: "Sequence", multiobject=False):
+ """Get runtime for a given tracker and sequence. Can convert single-object runtimes to multi-object runtimes.
+
+ Args:
+ tracker (Tracker): Tracker to get runtime for
+ sequence (Sequence): Sequence to get runtime for
+ multiobject (bool): Whether the runtime should be multi-object or not
+
+ Returns:
+ TrackerRuntime: Runtime for the given tracker and sequence
+
+ Raises:
+ TrackerException: If the tracker does not support multi-object experiments
+ """
+ from ..tracker import SingleObjectTrackerRuntime, RealtimeTrackerRuntime, MultiObjectTrackerRuntime
+
+ runtime = tracker.runtime()
+
+ if multiobject:
+ if not runtime.multiobject:
+ raise TrackerException("Tracker {} does not support multi-object experiments".format(tracker.identifier))
+ #runtime = MultiObjectTrackerRuntime(runtime)
+ else:
+ runtime = SingleObjectTrackerRuntime(runtime)
- def _get_runtime(self, tracker: "Tracker", sequence: "Sequence"):
if not self.realtime is None:
grace = to_number(self.realtime.grace, min_n=0)
fps = to_number(self.realtime.fps, min_n=0, conversion=float)
interval = 1 / float(sequence.metadata("fps", fps))
- runtime = RealtimeTrackerRuntime(tracker.runtime(), grace, interval)
- else:
- runtime = tracker.runtime()
+ runtime = RealtimeTrackerRuntime(runtime, grace, interval)
+
return runtime
@abstractmethod
def execute(self, tracker: "Tracker", sequence: "Sequence", force: bool = False, callback: typing.Callable = None):
+ """Execute the experiment for a given tracker and sequence.
+
+ Args:
+ tracker (Tracker): Tracker to execute
+ sequence (Sequence): Sequence to execute
+ force (bool): Whether to force execution even if the results are already present
+ callback (typing.Callable): Callback to call after each frame
+
+ Returns:
+ Results: Results for the tracker and sequence
+ """
raise NotImplementedError
@abstractmethod
def scan(self, tracker: "Tracker", sequence: "Sequence"):
+ """ Scan results for a given tracker and sequence.
+
+ Args:
+ tracker (Tracker): Tracker to scan results for
+ sequence (Sequence): Sequence to scan results for
+
+ Returns:
+ Results: Results for the tracker and sequence
+ """
raise NotImplementedError
def results(self, tracker: "Tracker", sequence: "Sequence") -> "Results":
- from vot.tracker import Results
- from vot.workspace import LocalStorage
+ """Get results for a given tracker and sequence.
+
+ Args:
+ tracker (Tracker): Tracker to get results for
+ sequence (Sequence): Sequence to get results for
+
+ Returns:
+ Results: Results for the tracker and sequence
+ """
if tracker.storage is not None:
return tracker.storage.results(tracker, self, sequence)
return self._storage.results(tracker, self, sequence)
def log(self, identifier: str):
+ """Get a log file for the experiment.
+
+ Args:
+ identifier (str): Identifier of the log
+
+ Returns:
+ str: Path to the log file
+ """
return self._storage.substorage("logs").write("{}_{:%Y-%m-%dT%H-%M-%S.%f%z}.log".format(identifier, datetime.now()))
- def transform(self, sequence: "Sequence"):
- for transformer in self.transformers:
- sequence = transformer(sequence)
- return sequence
+ def transform(self, sequences):
+ """Transform a list of sequences using the experiment transformers.
+
+ Args:
+ sequences (typing.List[Sequence]): List of sequences to transform
+
+ Returns:
+ typing.List[Sequence]: List of transformed sequences. The number of sequences may be larger than the input as some transformers may split sequences.
+ """
+ from vot.dataset import Sequence
+ from vot.experiment.transformer import SingleObject
+ if isinstance(sequences, Sequence):
+ sequences = [sequences]
+
+ transformers = list(self.transformers)
+
+ if not self._multiobject:
+ get_logger().debug("Adding single object transformer since experiment is not multi-object")
+ transformers.insert(0, SingleObject(cache=None))
+
+ # Process sequences one transformer at the time. The number of sequences may grow
+ for transformer in transformers:
+ transformed = []
+ for sequence in sequences:
+ get_logger().debug("Transforming sequence {} with transformer {}.{}".format(sequence.identifier, transformer.__class__.__module__, transformer.__class__.__name__))
+ transformed.extend(transformer(sequence))
+ sequences = transformed
+
+ return sequences
from .multirun import UnsupervisedExperiment, SupervisedExperiment
from .multistart import MultiStartExperiment
def run_experiment(experiment: Experiment, tracker: "Tracker", sequences: typing.List["Sequence"], force: bool = False, persist: bool = False):
+ """A helper function that performs a given experiment with a given tracker on a list of sequences.
+
+ Args:
+ experiment (Experiment): The experiment object
+ tracker (Tracker): The tracker object
+ sequences (typing.List[Sequence]): List of sequences.
+ force (bool, optional): Ignore the cached results, rerun all the experiments. Defaults to False.
+ persist (bool, optional): Continue runing even if exceptions were raised. Defaults to False.
+
+ Raises:
+ TrackerException: If the experiment is interrupted
+ """
class EvaluationProgress(object):
+ """A helper class that wraps a progress bar and updates it based on the number of finished sequences."""
def __init__(self, description, total):
+ """Initialize the progress bar.
+
+ Args:
+ description (str): Description of the progress bar
+ total (int): Total number of sequences
+
+ Raises:
+ ValueError: If the total number of sequences is not positive
+ """
self.bar = Progress(description, total)
self._finished = 0
def __call__(self, progress):
+ """Update the progress bar. The progress is a number between 0 and 1.
+
+ Args:
+ progress (float): Progress of the current sequence
+
+ Raises:
+ ValueError: If the progress is not between 0 and 1
+ """
self.bar.absolute(self._finished + min(1, max(0, progress)))
def push(self):
+ """Push the progress bar."""
self._finished = self._finished + 1
self.bar.absolute(self._finished)
+ def close(self):
+ """Close the progress bar."""
+ self.bar.close()
+
logger = logging.getLogger("vot")
+ transformed = []
+ for sequence in sequences:
+ transformed.extend(experiment.transform(sequence))
+ sequences = transformed
+
progress = EvaluationProgress("{}/{}".format(tracker.identifier, experiment.identifier), len(sequences))
for sequence in sequences:
- sequence = experiment.transform(sequence)
try:
experiment.execute(tracker, sequence, force=force, callback=progress)
except TrackerException as te:
@@ -153,6 +351,7 @@ def push(self):
flog.write(te.log)
logger.error("Tracker output written to file: %s", flog.name)
if not persist:
- raise te
+ raise TrackerException("Experiment interrupted", te, tracker=tracker)
progress.push()
+ progress.close()
\ No newline at end of file
diff --git a/vot/experiment/helpers.py b/vot/experiment/helpers.py
new file mode 100644
index 0000000..8488c79
--- /dev/null
+++ b/vot/experiment/helpers.py
@@ -0,0 +1,53 @@
+""" Helper classes for experiments."""
+
+from vot.dataset import Sequence
+from vot.region import RegionType
+
+def _objectstart(sequence: Sequence, id: str):
+ """Returns the first frame where the object appears in the sequence."""
+ trajectory = sequence.object(id)
+ return [x is None or x.type == RegionType.SPECIAL for x in trajectory].index(False)
+
+class MultiObjectHelper(object):
+ """Helper class for multi-object sequences. It provides methods for querying active objects at a given frame."""
+
+ def __init__(self, sequence: Sequence):
+ """Initialize the helper class.
+
+ Args:
+ sequence (Sequence): The sequence to be used.
+ """
+ self._sequence = sequence
+ self._ids = list(sequence.objects())
+ start = [_objectstart(sequence, id) for id in self._ids]
+ self._ids = sorted(zip(start, self._ids), key=lambda x: x[0])
+
+ def new(self, position: int):
+ """Returns a list of objects that appear at the given frame.
+
+ Args:
+ position (int): The frame number.
+
+ Returns:
+ [list]: A list of object ids.
+ """
+ return [x[1] for x in self._ids if x[0] == position]
+
+ def objects(self, position: int):
+ """Returns a list of objects that are active at the given frame.
+
+ Args:
+ position (int): The frame number.
+
+ Returns:
+ [list]: A list of object ids.
+ """
+ return [x[1] for x in self._ids if x[0] <= position]
+
+ def all(self):
+ """Returns a list of all objects in the sequence.
+
+ Returns:
+ [list]: A list of object ids.
+ """
+ return [x[1] for x in self._ids]
\ No newline at end of file
diff --git a/vot/experiment/multirun.py b/vot/experiment/multirun.py
index e9186d4..422eaf0 100644
--- a/vot/experiment/multirun.py
+++ b/vot/experiment/multirun.py
@@ -1,5 +1,5 @@
-#pylint: disable=W0223
-
+"""Multi-run experiments. This module contains the implementation of multi-run experiments.
+ Multi-run experiments are used to run a tracker multiple times on the same sequence. """
from typing import Callable
from vot.dataset import Sequence
@@ -8,100 +8,196 @@
from attributee import Boolean, Integer, Float, List, String
from vot.experiment import Experiment, experiment_registry
-from vot.tracker import Tracker, Trajectory
+from vot.tracker import Tracker, Trajectory, ObjectStatus
class MultiRunExperiment(Experiment):
+ """Base class for multi-run experiments. Multi-run experiments are used to run a tracker multiple times on the same sequence."""
repetitions = Integer(val_min=1, default=1)
early_stop = Boolean(default=True)
def _can_stop(self, tracker: Tracker, sequence: Sequence):
+ """Check whether the experiment can be stopped early.
+
+ Args:
+ tracker (Tracker): The tracker to be checked.
+ sequence (Sequence): The sequence to be checked.
+
+ Returns:
+ bool: True if the experiment can be stopped early, False otherwise.
+ """
if not self.early_stop:
return False
- trajectories = self.gather(tracker, sequence)
- if len(trajectories) < 3:
- return False
+
+ for o in sequence.objects():
- for trajectory in trajectories[1:]:
- if not trajectory.equals(trajectories[0]):
+ trajectories = self.gather(tracker, sequence, objects=[o])
+ if len(trajectories) < 3:
return False
+
+ for trajectory in trajectories[1:]:
+ if not trajectory.equals(trajectories[0]):
+ return False
return True
def scan(self, tracker: Tracker, sequence: Sequence):
+ """Scan the results of the experiment for the given tracker and sequence.
+
+ Args:
+ tracker (Tracker): The tracker to be scanned.
+ sequence (Sequence): The sequence to be scanned.
+
+ Returns:
+ [tuple]: A tuple containing three elements. The first element is a boolean indicating whether the experiment is complete. The second element is a list of files that are present. The third element is the results object.
+ """
results = self.results(tracker, sequence)
files = []
complete = True
+ multiobject = len(sequence.objects()) > 1
+ assert self._multiobject or not multiobject
- for i in range(1, self.repetitions+1):
- name = "%s_%03d" % (sequence.name, i)
- if Trajectory.exists(results, name):
- files.extend(Trajectory.gather(results, name))
- elif self._can_stop(tracker, sequence):
- break
- else:
- complete = False
- break
+ for o in sequence.objects():
+ prefix = sequence.name if not multiobject else "%s_%s" % (sequence.name, o)
+ for i in range(1, self.repetitions+1):
+ name = "%s_%03d" % (prefix, i)
+ if Trajectory.exists(results, name):
+ files.extend(Trajectory.gather(results, name))
+ elif self._can_stop(tracker, sequence):
+ break
+ else:
+ complete = False
+ break
return complete, files, results
- def gather(self, tracker: Tracker, sequence: Sequence):
+ def gather(self, tracker: Tracker, sequence: Sequence, objects = None, pad = False):
+ """Gather trajectories for the given tracker and sequence.
+
+ Args:
+ tracker (Tracker): The tracker to be used.
+ sequence (Sequence): The sequence to be used.
+ objects (list, optional): The list of objects to be gathered. Defaults to None.
+ pad (bool, optional): Whether to pad the list of trajectories with None values. Defaults to False.
+
+ Returns:
+ list: The list of trajectories.
+ """
trajectories = list()
+
+ multiobject = len(sequence.objects()) > 1
+
+ assert self._multiobject or not multiobject
results = self.results(tracker, sequence)
- for i in range(1, self.repetitions+1):
- name = "%s_%03d" % (sequence.name, i)
- if Trajectory.exists(results, name):
- trajectories.append(Trajectory.read(results, name))
+
+ if objects is None:
+ objects = list(sequence.objects())
+
+ for o in objects:
+ prefix = sequence.name if not multiobject else "%s_%s" % (sequence.name, o)
+ for i in range(1, self.repetitions+1):
+ name = "%s_%03d" % (prefix, i)
+ if Trajectory.exists(results, name):
+ trajectories.append(Trajectory.read(results, name))
+ elif pad:
+ trajectories.append(None)
return trajectories
@experiment_registry.register("unsupervised")
class UnsupervisedExperiment(MultiRunExperiment):
+ """Unsupervised experiment. This experiment is used to run a tracker multiple times on the same sequence without any supervision."""
+
+ multiobject = Boolean(default=False)
+
+ @property
+ def _multiobject(self) -> bool:
+ """Whether the experiment is multi-object or not.
+
+ Returns:
+ bool: True if the experiment is multi-object, False otherwise.
+ """
+ return self.multiobject
def execute(self, tracker: Tracker, sequence: Sequence, force: bool = False, callback: Callable = None):
+ """Execute the experiment for the given tracker and sequence.
+
+ Args:
+ tracker (Tracker): The tracker to be used.
+ sequence (Sequence): The sequence to be used.
+ force (bool, optional): Whether to force the execution. Defaults to False.
+ callback (Callable, optional): The callback to be used. Defaults to None.
+ """
+
+ from .helpers import MultiObjectHelper
results = self.results(tracker, sequence)
- with self._get_runtime(tracker, sequence) as runtime:
+ multiobject = len(sequence.objects()) > 1
+ assert self._multiobject or not multiobject
+
+ helper = MultiObjectHelper(sequence)
+
+ def result_name(sequence, o, i):
+ """Get the name of the result file."""
+ return "%s_%s_%03d" % (sequence.name, o, i) if multiobject else "%s_%03d" % (sequence.name, i)
+
+ with self._get_runtime(tracker, sequence, self._multiobject) as runtime:
for i in range(1, self.repetitions+1):
- name = "%s_%03d" % (sequence.name, i)
- if Trajectory.exists(results, name) and not force:
+ trajectories = {}
+
+ for o in helper.all(): trajectories[o] = Trajectory(len(sequence))
+
+ if all([Trajectory.exists(results, result_name(sequence, o, i)) for o in trajectories.keys()]) and not force:
continue
if self._can_stop(tracker, sequence):
return
- trajectory = Trajectory(sequence.length)
+ _, elapsed = runtime.initialize(sequence.frame(0), [ObjectStatus(self._get_initialization(sequence, 0, x), {}) for x in helper.new(0)])
- _, properties, elapsed = runtime.initialize(sequence.frame(0), self._get_initialization(sequence, 0))
+ for x in helper.new(0):
+ trajectories[x].set(0, Special(Trajectory.INITIALIZATION), {"time": elapsed})
- properties["time"] = elapsed
+ for frame in range(1, len(sequence)):
+ state, elapsed = runtime.update(sequence.frame(frame), [ObjectStatus(self._get_initialization(sequence, 0, x), {}) for x in helper.new(frame)])
- trajectory.set(0, Special(Special.INITIALIZATION), properties)
+ if not isinstance(state, list):
+ state = [state]
- for frame in range(1, sequence.length):
- region, properties, elapsed = runtime.update(sequence.frame(frame))
+ for x, object in zip(helper.objects(frame), state):
+ object.properties["time"] = elapsed # TODO: what to do with time stats?
+ trajectories[x].set(frame, object.region, object.properties)
- properties["time"] = elapsed
+ if callback:
+ callback(float(i-1) / self.repetitions + (float(frame) / (self.repetitions * len(sequence))))
- trajectory.set(frame, region, properties)
-
- trajectory.write(results, name)
+ for o, trajectory in trajectories.items():
+ trajectory.write(results, result_name(sequence, o, i))
- if callback:
- callback(i / self.repetitions)
@experiment_registry.register("supervised")
class SupervisedExperiment(MultiRunExperiment):
+ """Supervised experiment. This experiment is used to run a tracker multiple times on the same sequence with supervision (reinitialization in case of failure)."""
+
+ FAILURE = 2
skip_initialize = Integer(val_min=1, default=1)
skip_tags = List(String(), default=[])
failure_overlap = Float(val_min=0, val_max=1, default=0)
def execute(self, tracker: Tracker, sequence: Sequence, force: bool = False, callback: Callable = None):
+ """Execute the experiment for the given tracker and sequence.
+
+ Args:
+ tracker (Tracker): The tracker to be used.
+ sequence (Sequence): The sequence to be used.
+ force (bool, optional): Whether to force the execution. Defaults to False.
+ callback (Callable, optional): The callback to be used. Defaults to None.
+ """
results = self.results(tracker, sequence)
@@ -116,37 +212,35 @@ def execute(self, tracker: Tracker, sequence: Sequence, force: bool = False, cal
if self._can_stop(tracker, sequence):
return
- trajectory = Trajectory(sequence.length)
+ trajectory = Trajectory(len(sequence))
frame = 0
- while frame < sequence.length:
-
- _, properties, elapsed = runtime.initialize(sequence.frame(frame), self._get_initialization(sequence, frame))
+ while frame < len(sequence):
- properties["time"] = elapsed
+ _, elapsed = runtime.initialize(sequence.frame(frame), self._get_initialization(sequence, frame))
- trajectory.set(frame, Special(Special.INITIALIZATION), properties)
+ trajectory.set(frame, Special(Trajectory.INITIALIZATION), {"time" : elapsed})
frame = frame + 1
- while frame < sequence.length:
+ while frame < len(sequence):
- region, properties, elapsed = runtime.update(sequence.frame(frame))
+ object, elapsed = runtime.update(sequence.frame(frame))
- properties["time"] = elapsed
+ object.properties["time"] = elapsed
- if calculate_overlap(region, sequence.groundtruth(frame), sequence.size) <= self.failure_overlap:
- trajectory.set(frame, Special(Special.FAILURE), properties)
+ if calculate_overlap(object.region, sequence.groundtruth(frame), sequence.size) <= self.failure_overlap:
+ trajectory.set(frame, Special(SupervisedExperiment.FAILURE), object.properties)
frame = frame + self.skip_initialize
if self.skip_tags:
- while frame < sequence.length:
+ while frame < len(sequence):
if not [t for t in sequence.tags(frame) if t in self.skip_tags]:
break
frame = frame + 1
break
else:
- trajectory.set(frame, region, properties)
+ trajectory.set(frame, object.region, object.properties)
frame = frame + 1
if callback:
diff --git a/vot/experiment/multistart.py b/vot/experiment/multistart.py
index 04f0e4f..dd1bccd 100644
--- a/vot/experiment/multistart.py
+++ b/vot/experiment/multistart.py
@@ -1,4 +1,6 @@
+""" This module implements the multistart experiment. """
+
from typing import Callable
from vot.dataset import Sequence
@@ -11,9 +13,18 @@
from vot.tracker import Tracker, Trajectory
def find_anchors(sequence: Sequence, anchor="anchor"):
+ """Find anchor frames in the sequence. Anchor frames are frames where the given object is visible and can be used for initialization.
+
+ Args:
+ sequence (Sequence): The sequence to be scanned.
+ anchor (str, optional): The name of the object to be used as an anchor. Defaults to "anchor".
+
+ Returns:
+ [tuple]: A tuple containing two lists of frames. The first list contains forward anchors, the second list contains backward anchors.
+ """
forward = []
backward = []
- for frame in range(sequence.length):
+ for frame in range(len(sequence)):
values = sequence.values(frame)
if anchor in values:
if values[anchor] > 0:
@@ -24,10 +35,22 @@ def find_anchors(sequence: Sequence, anchor="anchor"):
@experiment_registry.register("multistart")
class MultiStartExperiment(Experiment):
+ """The multistart experiment. The experiment works by utilizing anchor frames in the sequence.
+ Anchor frames are frames where the given object is visible and can be used for initialization.
+ The tracker is then initialized in each anchor frame and run until the end of the sequence either forward or backward.
+ """
anchor = String(default="anchor")
- def scan(self, tracker: Tracker, sequence: Sequence):
+ def scan(self, tracker: Tracker, sequence: Sequence) -> tuple:
+ """Scan the results of the experiment for the given tracker and sequence.
+
+ Args:
+ tracker (Tracker): The tracker to be scanned.
+ sequence (Sequence): The sequence to be scanned.
+
+ Returns:
+ [tuple]: A tuple containing three elements. The first element is a boolean indicating whether the experiment is complete. The second element is a list of files that are present. The third element is the results object."""
files = []
complete = True
@@ -48,7 +71,18 @@ def scan(self, tracker: Tracker, sequence: Sequence):
return complete, files, results
- def execute(self, tracker: Tracker, sequence: Sequence, force: bool = False, callback: Callable = None):
+ def execute(self, tracker: Tracker, sequence: Sequence, force: bool = False, callback: Callable = None) -> None:
+ """Execute the experiment for the given tracker and sequence.
+
+ Args:
+ tracker (Tracker): The tracker to be executed.
+ sequence (Sequence): The sequence to be executed.
+ force (bool, optional): Force re-execution of the experiment. Defaults to False.
+ callback (Callable, optional): A callback function that is called after each frame. Defaults to None.
+
+ Raises:
+ RuntimeError: If the sequence does not contain any anchors.
+ """
results = self.results(tracker, sequence)
@@ -71,22 +105,20 @@ def execute(self, tracker: Tracker, sequence: Sequence, force: bool = False, cal
if reverse:
proxy = FrameMapSequence(sequence, list(reversed(range(0, i + 1))))
else:
- proxy = FrameMapSequence(sequence, list(range(i, sequence.length)))
-
- trajectory = Trajectory(proxy.length)
+ proxy = FrameMapSequence(sequence, list(range(i, len(sequence))))
- _, properties, elapsed = runtime.initialize(proxy.frame(0), self._get_initialization(proxy, 0))
+ trajectory = Trajectory(len(proxy))
- properties["time"] = elapsed
+ _, elapsed = runtime.initialize(proxy.frame(0), self._get_initialization(proxy, 0))
- trajectory.set(0, Special(Special.INITIALIZATION), properties)
+ trajectory.set(0, Special(Trajectory.INITIALIZATION), {"time": elapsed})
- for frame in range(1, proxy.length):
- region, properties, elapsed = runtime.update(proxy.frame(frame))
+ for frame in range(1, len(proxy)):
+ object, elapsed = runtime.update(proxy.frame(frame))
- properties["time"] = elapsed
+ object.properties["time"] = elapsed
- trajectory.set(frame, region, properties)
+ trajectory.set(frame, object.region, object.properties)
trajectory.write(results, name)
diff --git a/vot/experiment/transformer.py b/vot/experiment/transformer.py
index 0fe856d..d8a932e 100644
--- a/vot/experiment/transformer.py
+++ b/vot/experiment/transformer.py
@@ -1,36 +1,82 @@
+""" Transformer module for experiments."""
+
import os
from abc import abstractmethod
+from typing import List
from PIL import Image
-from attributee import Attributee, Integer, Float
+from attributee import Attributee, Integer, Float, Boolean
-from vot.dataset import Sequence, VOTSequence, InMemorySequence
+from vot.dataset import Sequence, InMemorySequence
from vot.dataset.proxy import FrameMapSequence
-from vot.dataset.vot import write_sequence
+from vot.dataset.common import write_sequence, read_sequence
from vot.region import RegionType
from vot.utilities import arg_hash
from vot.experiment import transformer_registry
class Transformer(Attributee):
+ """Base class for transformers. Transformers are used to generate new modified sequences from existing ones."""
def __init__(self, cache: "LocalStorage", **kwargs):
+ """Initialize the transformer.
+
+ Args:
+ cache (LocalStorage): The cache to be used for storing generated sequences.
+ """
super().__init__(**kwargs)
self._cache = cache
@abstractmethod
- def __call__(self, sequence: Sequence) -> Sequence:
+ def __call__(self, sequence: Sequence) -> List[Sequence]:
+ """Generate a list of sequences from the given sequence. The generated sequences are stored in the cache if needed.
+
+ Args:
+ sequence (Sequence): The sequence to be transformed.
+
+ Returns:
+ [list]: A list of generated sequences.
+ """
raise NotImplementedError
+@transformer_registry.register("singleobject")
+class SingleObject(Transformer):
+ """Transformer that generates a sequence for each object in the given sequence."""
+
+ trim = Boolean(default=False, description="Trim each generated sequence to a visible subsection for the selected object")
+
+ def __call__(self, sequence: Sequence) -> List[Sequence]:
+ """Generate a list of sequences from the given sequence.
+
+ Args:
+ sequence (Sequence): The sequence to be transformed.
+ """
+ from vot.dataset.proxy import ObjectFilterSequence
+
+ if len(sequence.objects()) == 1:
+ return [sequence]
+
+ return [ObjectFilterSequence(sequence, id, self.trim) for id in sequence.objects()]
+
@transformer_registry.register("redetection")
class Redetection(Transformer):
+ """Transformer that test redetection of the object in the sequence. The object is shown in several frames and then moved to a different location.
+
+ This tranformer can only be used with single-object sequences."""
length = Integer(default=100, val_min=1)
initialization = Integer(default=5, val_min=1)
padding = Float(default=2, val_min=0)
scaling = Float(default=1, val_min=0.1, val_max=10)
- def __call__(self, sequence: Sequence) -> Sequence:
+ def __call__(self, sequence: Sequence) -> List[Sequence]:
+ """Generate a list of sequences from the given sequence.
+
+ Args:
+ sequence (Sequence): The sequence to be transformed.
+ """
+
+ assert len(sequence.objects()) == 1, "Redetection transformer can only be used with single-object sequences."
chache_dir = self._cache.directory(self, arg_hash(sequence.name, **self.dump()))
@@ -61,6 +107,6 @@ def __call__(self, sequence: Sequence) -> Sequence:
write_sequence(chache_dir, generated)
- source = VOTSequence(chache_dir, name=sequence.name)
- mapping = [0] * self.initialization + [1] * (self.length - self.initialization)
- return FrameMapSequence(source, mapping)
+ source = read_sequence(chache_dir)
+ mapping = [0] * self.initialization + [1] * (len(self) - self.initialization)
+ return [FrameMapSequence(source, mapping)]
diff --git a/vot/region/__init__.py b/vot/region/__init__.py
index 32bb161..58524ac 100644
--- a/vot/region/__init__.py
+++ b/vot/region/__init__.py
@@ -1,17 +1,29 @@
+""" This module contains classes for region representation and manipulation. Regions are also used to represent results
+ of trackers as well as groundtruth trajectories. The module also contains functions for calculating overlaps between
+ regions and for converting between different region types."""
+
from abc import abstractmethod, ABC
-from typing import Tuple
from enum import Enum
from vot import ToolkitException
from vot.utilities.draw import DrawHandle
-class RegionException(Exception):
+class RegionException(ToolkitException):
"""General region exception"""
class ConversionException(RegionException):
"""Region conversion exception, the conversion cannot be performed
"""
def __init__(self, *args, source=None):
+ """Constructor
+
+ Args:
+ *args: Arguments for the base exception
+
+ Keyword Arguments:
+ source (Region): Source region (default: {None})
+
+ """
super().__init__(*args)
self._source = source
@@ -25,29 +37,37 @@ class RegionType(Enum):
class Region(ABC):
"""
- Base class for all region containers
-
- :var type: type of the region
+ Base class for all region containers.
"""
def __init__(self):
+ """Base constructor"""
pass
@property
@abstractmethod
def type(self):
+ """Return type of the region
+
+ Returns:
+ RegionType -- Type of the region
+ """
pass
@abstractmethod
def copy(self):
"""Copy region to another object
+
+ Returns:
+ Region -- Copy of the region
"""
@abstractmethod
def convert(self, rtype: RegionType):
"""Convert region to another type. Note that some conversions
degrade information.
- Arguments:
- rtype {RegionType} -- Desired type.
+
+ Args:
+ rtype (RegionType): Target region type to convert to.
"""
@abstractmethod
@@ -57,19 +77,16 @@ def is_empty(self):
class Special(Region):
"""
- Special region
+ Special region, meaning of the code can change depending on the context
:var code: Code value
"""
- UNKNOWN = 0
- INITIALIZATION = 1
- FAILURE = 2
-
def __init__(self, code):
""" Constructor
- :param code: Special code
+ Args:
+ code (int): Code value
"""
super().__init__()
self._code = int(code)
@@ -80,12 +97,26 @@ def __str__(self):
@property
def type(self):
+ """Return type of the region"""
return RegionType.SPECIAL
def copy(self):
+ """Copy region to another object"""
return Special(self._code)
def convert(self, rtype: RegionType):
+ """Convert region to another type. Note that some conversions degrade information.
+
+ Args:
+ rtype (RegionType): Target region type to convert to.
+
+ Raises:
+ ConversionException: Unable to convert special region to another type
+
+ Returns:
+ Region -- Converted region
+ """
+
if rtype == RegionType.SPECIAL:
return self.copy()
else:
@@ -93,19 +124,23 @@ def convert(self, rtype: RegionType):
@property
def code(self):
- """Retiurns special code for this region
+ """Retiurns special code for this region.
Returns:
int -- Type code
"""
return self._code
def draw(self, handle: DrawHandle):
+ """Draw region to the image using the provided handle.
+
+ Args:
+ handle (DrawHandle): Draw handle
+ """
pass
def is_empty(self):
- return False
+ """ Check if region is empty. Special regions are always empty by definition."""
+ return True
-from vot.region.io import read_file, write_file
-from .shapes import Rectangle, Polygon, Mask
-from .io import read_file, write_file, parse
from .raster import calculate_overlap, calculate_overlaps
+from .shapes import Rectangle, Polygon, Mask
\ No newline at end of file
diff --git a/vot/region/io.py b/vot/region/io.py
index 8e5371d..5847da8 100644
--- a/vot/region/io.py
+++ b/vot/region/io.py
@@ -1,16 +1,22 @@
+""" Utilities for reading and writing regions from and to files. """
+
import math
-from typing import Union, TextIO
+from typing import List, Union, TextIO
+import io
import numpy as np
-from numba import jit
+import numba
-from vot.region import Special
+@numba.njit(cache=True)
+def mask_to_rle(m, maxstride=100000000):
+ """ Converts a binary mask to RLE encoding. This is a Numba decorated function that is compiled just-in-time for faster execution.
-@jit(nopython=True)
-def mask_to_rle(m):
- """
- # Input: 2-D numpy array
- # Output: list of numbers (1st number = #0s, 2nd number = #1s, 3rd number = #0s, ...)
+ Args:
+ m (np.ndarray): 2-D binary mask
+ maxstride (int): Maximum number of consecutive 0s or 1s in the RLE encoding. If the number of consecutive 0s or 1s is larger than maxstride, it is split into multiple elements.
+
+ Returns:
+ List[int]: RLE encoding of the mask
"""
# reshape mask to vector
v = m.reshape((m.shape[0] * m.shape[1]))
@@ -29,31 +35,49 @@ def mask_to_rle(m):
# go over all elements and check if two consecutive are the same
for i in range(1, v.size):
if v[i] != v[i - 1]:
- rle.append(i - last_idx)
+ length = i - last_idx
+ # if length is larger than maxstride, split it into multiple elements
+ while length > maxstride:
+ rle.append(maxstride)
+ rle.append(0)
+ length -= maxstride
+ # add remaining length
+ if length > 0:
+ rle.append(length)
last_idx = i
if v.size > 0:
# handle last element of rle
if last_idx < v.size - 1:
# last element is the same as one element before it - add number of these last elements
- rle.append(v.size - last_idx)
+ length = v.size - last_idx
+ while length > maxstride:
+ rle.append(maxstride)
+ rle.append(0)
+ length -= maxstride
+ if length > 0:
+ rle.append(length)
else:
# last element is different than one element before - add 1
rle.append(1)
return rle
-@jit(nopython=True)
+@numba.njit(cache=True)
def rle_to_mask(rle, width, height):
+ """ Converts RLE encoding to a binary mask. This is a Numba decorated function that is compiled just-in-time for faster execution.
+
+ Args:
+ rle (List[int]): RLE encoding of the mask
+ width (int): Width of the mask
+ height (int): Height of the mask
+
+ Returns:
+ np.ndarray: 2-D binary mask
"""
- rle: input rle mask encoding
- each evenly-indexed element represents number of consecutive 0s
- each oddly indexed element represents number of consecutive 1s
- width and height are dimensions of the mask
- output: 2-D binary mask
- """
+
# allocate list of zeros
- v = [0] * (width * height)
+ v = np.zeros(width * height, dtype=np.uint8)
# set id of the last different element to the beginning of the vector
idx_ = 0
@@ -66,7 +90,8 @@ def rle_to_mask(rle, width, height):
# reshape vector into 2-D mask
# return np.reshape(np.array(v, dtype=np.uint8), (height, width)) # numba bug / not supporting np.reshape
- return np.array(v, dtype=np.uint8).reshape((height, width))
+ #return np.array(v, dtype=np.uint8).reshape((height, width))
+ return v.reshape((height, width))
def create_mask_from_string(mask_encoding):
"""
@@ -87,12 +112,13 @@ def create_mask_from_string(mask_encoding):
from vot.region.raster import mask_bounds
def encode_mask(mask):
- """
- mask: input binary mask, type: uint8
- output: full RLE encoding in the format: (x0, y0, w, h), RLE
- first get minimal axis-aligned region which contains all positive pixels
- extract this region from mask and calculate mask RLE within the region
- output position and size of the region, dimensions of the full mask and RLE encoding
+ """ Encode a binary mask to a string in the following format: x0, y0, w, h, RLE.
+
+ Args:
+ mask (np.ndarray): 2-D binary mask
+
+ Returns:
+ str: Encoded mask
"""
# calculate coordinates of the top-left corner and region width and height (minimal region containing all 1s)
x_min, y_min, x_max, y_max = mask_bounds(mask)
@@ -113,17 +139,23 @@ def encode_mask(mask):
return (tl_x, tl_y, region_w, region_h), rle
-def parse(string):
- """
- parse string to the appropriate region format and return region object
+def parse_region(string: str) -> "Region":
+ """Parse input string to the appropriate region format and return Region object
+
+ Args:
+ string (str): comma separated list of values
+
+ Returns:
+ Region: resulting region
"""
+ from vot import config
+ from vot.region import Special
from vot.region.shapes import Rectangle, Polygon, Mask
-
if string[0] == 'm':
# input is a mask - decode it
m_, offset_ = create_mask_from_string(string[1:].split(','))
- return Mask(m_, offset=offset_)
+ return Mask(m_, offset=offset_, optimize=config.mask_optimize_read)
else:
# input is not a mask - check if special, rectangle or polygon
tokens = [float(t) for t in string.split(',')]
@@ -139,30 +171,141 @@ def parse(string):
return Special(0)
else:
return Polygon([(x_, y_) for x_, y_ in zip(tokens[::2], tokens[1::2])])
- print('Unknown region format.')
return None
-def read_file(fp: Union[str, TextIO]):
+def read_trajectory_binary(fp: io.RawIOBase):
+ """Reads a trajectory from a binary file and returns a list of regions.
+
+ Args:
+ fp (io.RawIOBase): File pointer to the binary file
+
+ Returns:
+ list: List of regions
+ """
+ import struct
+ from cachetools import LRUCache, cached
+ from vot.region import Special
+ from vot.region.shapes import Rectangle, Polygon, Mask
+
+ buffer = dict(data=fp.read(), offset = 0)
+
+ @cached(cache=LRUCache(maxsize=32))
+ def calcsize(format):
+ """Calculate size of the struct format"""
+ return struct.calcsize(format)
+
+ def read(format: str):
+ """Read struct from the buffer and update offset"""
+ unpacked = struct.unpack_from(format, buffer["data"], buffer["offset"])
+ buffer["offset"] += calcsize(format)
+ return unpacked
+
+ _, length = read("= union[2] or union[1] >= union[3]:
+ # Two empty regons are considered to be identical
+ return float(1)
+
if not bounds is None:
raster_bounds = (max(0, union[0]), max(0, union[1]), min(bounds[0] - 1, union[2]), min(bounds[1] - 1, union[3]))
else:
raster_bounds = union
if raster_bounds[0] >= raster_bounds[2] or raster_bounds[1] >= raster_bounds[3]:
+ # Regions are not identical, but are outside rasterization bounds.
return float(0)
m1 = _region_raster(a, raster_bounds, at, ao)
@@ -228,15 +333,20 @@ def _calculate_overlap(a: np.ndarray, b: np.ndarray, at: int, bt: int, ao: Optio
from vot.region import Region, RegionException
from vot.region.shapes import Shape, Rectangle, Polygon, Mask
-def calculate_overlap(reg1: Shape, reg2: Shape, bounds: Optional[Tuple[int, int]] = None):
- """
- Inputs: reg1 and reg2 are Region objects (Rectangle, Polygon or Mask)
- bounds: size of the image, format: [width, height]
- function first rasterizes both regions to 2-D binary masks and calculates overlap between them
- """
+Bounds = Tuple[int, int]
- if not isinstance(reg1, Shape) or not isinstance(reg2, Shape):
- return float(0)
+def calculate_overlap(reg1: Shape, reg2: Shape, bounds: Optional[Bounds] = None):
+ """ Calculate the overlap between two regions. The function first rasterizes both regions to 2-D binary masks and calculates overlap between them
+
+ Args:
+ reg1: first region
+ reg2: second region
+ bounds: 2-tuple with the bounds of the image (width, height)
+
+ Returns:
+ float with the overlap between the two regions. Note that overlap is one by definition if both regions are empty.
+
+ """
if isinstance(reg1, Rectangle):
data1 = np.round(reg1._data)
@@ -250,6 +360,10 @@ def calculate_overlap(reg1: Shape, reg2: Shape, bounds: Optional[Tuple[int, int]
data1 = reg1.mask
offset1 = reg1.offset
type1 = _TYPE_MASK
+ else:
+ data1 = np.zeros((1, 1))
+ offset1 = (0, 0)
+ type1 = _TYPE_EMPTY
if isinstance(reg2, Rectangle):
data2 = np.round(reg2._data)
@@ -263,14 +377,26 @@ def calculate_overlap(reg1: Shape, reg2: Shape, bounds: Optional[Tuple[int, int]
data2 = reg2.mask
offset2 = reg2.offset
type2 = _TYPE_MASK
+ else:
+ data2 = np.zeros((1, 1))
+ offset2 = (0, 0)
+ type2 = _TYPE_EMPTY
return _calculate_overlap(data1, data2, type1, type2, offset1, offset2, bounds)
-def calculate_overlaps(first: List[Region], second: List[Region], bounds: Optional[Tuple[int, int]]):
- """
- first and second are lists containing objects of type Region
- bounds is in the format [width, height]
- output: list of per-frame overlaps (floats)
+def calculate_overlaps(first: List[Region], second: List[Region], bounds: Optional[Bounds] = None):
+ """ Calculate the overlap between two lists of regions. The function first rasterizes both regions to 2-D binary masks and calculates overlap between them
+
+ Args:
+ first: first list of regions
+ second: second list of regions
+ bounds: 2-tuple with the bounds of the image (width, height)
+
+ Returns:
+ list of floats with the overlap between the two regions. Note that overlap is one by definition if both regions are empty.
+
+ Raises:
+ RegionException: if the lists are not of the same size
"""
if not len(first) == len(second):
raise RegionException("List not of the same size {} != {}".format(len(first), len(second)))
diff --git a/vot/region/shapes.py b/vot/region/shapes.py
index b49edf3..d067fad 100644
--- a/vot/region/shapes.py
+++ b/vot/region/shapes.py
@@ -1,4 +1,4 @@
-import sys
+""" Module for region shapes. """
from copy import copy
from functools import reduce
@@ -13,43 +13,67 @@
from vot.utilities.draw import DrawHandle
class Shape(Region, ABC):
+ """ Base class for all shape regions. """
@abstractmethod
def draw(self, handle: DrawHandle) -> None:
+ """ Draw the region to the given handle.
+
+ """
pass
@abstractmethod
def resize(self, factor=1) -> "Shape":
+ """ Resize the region by the given factor. """
pass
@abstractmethod
def move(self, dx=0, dy=0) -> "Shape":
+ """ Move the region by the given offset.
+
+ Args:
+ dx (float, optional): X offset. Defaults to 0.
+ dy (float, optional): Y offset. Defaults to 0.
+
+ Returns:
+ Shape: Moved region.
+ """
pass
@abstractmethod
def rasterize(self, bounds: Tuple[int, int, int, int]) -> np.ndarray:
+ """ Rasterize the region to a binary mask.
+
+ Args:
+ bounds (Tuple[int, int, int, int]): Bounds of the mask.
+
+ Returns:
+ np.ndarray: Binary mask.
+ """
pass
@abstractmethod
def bounds(self) -> Tuple[int, int, int, int]:
+ """ Get the bounding box of the region.
+
+ Returns:
+ Tuple[int, int, int, int]: Bounding box (x, y, width, height).
+ """
+
pass
class Rectangle(Shape):
"""
- Rectangle region
-
- :var x: top left x coord of the rectangle region
- :var float y: top left y coord of the rectangle region
- :var float w: width of the rectangle region
- :var float h: height of the rectangle region
+ Rectangle region class for representing rectangular regions.
"""
def __init__(self, x=0, y=0, width=0, height=0):
- """ Constructor
+ """ Constructor for rectangle region.
- :param float x: top left x coord of the rectangle region
- :param float y: top left y coord of the rectangle region
- :param float w: width of the rectangle region
- :param float h: height of the rectangle region
+ Args:
+ x (float, optional): X coordinate of the top left corner. Defaults to 0.
+ y (float, optional): Y coordinate of the top left corner. Defaults to 0.
+ width (float, optional): Width of the rectangle. Defaults to 0.
+ height (float, optional): Height of the rectangle. Defaults to 0.
"""
super().__init__()
self._data = np.array([[x], [y], [width], [height]], dtype=np.float32)
@@ -60,28 +84,42 @@ def __str__(self):
@property
def x(self):
- return self._data[0, 0]
+ """ X coordinate of the top left corner. """
+ return float(self._data[0, 0])
@property
def y(self):
- return self._data[1, 0]
+ """ Y coordinate of the top left corner. """
+ return float(self._data[1, 0])
@property
def width(self):
- return self._data[2, 0]
+ """ Width of the rectangle."""
+ return float(self._data[2, 0])
@property
def height(self):
- return self._data[3, 0]
+ """ Height of the rectangle."""
+ return float(self._data[3, 0])
@property
def type(self):
+ """ Type of the region."""
return RegionType.RECTANGLE
def copy(self):
+ """ Copy region to another object. """
return copy(self)
def convert(self, rtype: RegionType):
+ """ Convert region to another type. Note that some conversions degrade information.
+
+ Args:
+ rtype (RegionType): Desired type.
+
+ Raises:
+ ConversionException: Unable to convert rectangle region to {rtype}
+ """
if rtype == RegionType.RECTANGLE:
return self.copy()
elif rtype == RegionType.POLYGON:
@@ -97,46 +135,86 @@ def convert(self, rtype: RegionType):
raise ConversionException("Unable to convert rectangle region to {}".format(rtype), source=self)
def is_empty(self):
+ """ Check if the region is empty.
+
+ Returns:
+ bool: True if the region is empty, False otherwise.
+ """
if self.width > 0 and self.height > 0:
return False
else:
return True
def draw(self, handle: DrawHandle):
+ """ Draw the region to the given handle.
+
+ Args:
+ handle (DrawHandle): Handle to draw to.
+ """
polygon = [(self.x, self.y), (self.x + self.width, self.y), \
(self.x + self.width, self.y + self.height), \
(self.x, self.y + self.height)]
handle.polygon(polygon)
def resize(self, factor=1):
+ """ Resize the region by the given factor.
+
+ Args:
+ factor (float, optional): Resize factor. Defaults to 1.
+
+ Returns:
+ Rectangle: Resized region.
+ """
return Rectangle(self.x * factor, self.y * factor,
self.width * factor, self.height * factor)
def center(self):
+ """ Get the center of the region.
+
+ Returns:
+ tuple: Center coordinates (x,y).
+ """
return (self.x + self.width / 2, self.y + self.height / 2)
def move(self, dx=0, dy=0):
+ """ Move the region by the given offset.
+
+ Args:
+ dx (float, optional): X offset. Defaults to 0.
+ dy (float, optional): Y offset. Defaults to 0.
+
+ Returns:
+ Rectangle: Moved region.
+ """
return Rectangle(self.x + dx, self.y + dy, self.width, self.height)
def rasterize(self, bounds: Tuple[int, int, int, int]):
+ """ Rasterize the region to a binary mask.
+
+ Args:
+ bounds (tuple): Bounds of the mask (x1,y1,x2,y2).
+ """
from vot.region.raster import rasterize_rectangle
return rasterize_rectangle(self._data, np.array(bounds))
def bounds(self):
+ """ Get the bounding box of the region.
+
+ Returns:
+ tuple: Bounding box (x1,y1,x2,y2).
+ """
return int(round(self.x)), int(round(self.y)), int(round(self.width + self.x)), int(round(self.height + self.y))
class Polygon(Shape):
"""
- Polygon region
-
- :var list points: List of points as tuples [(x1,y1), (x2,y2),...,(xN,yN)]
- :var int count: number of points
+ Polygon region defined by a list of points. The polygon is closed, i.e. the first and last point are connected.
"""
def __init__(self, points):
"""
Constructor
- :param list points: List of points as tuples [(x1,y1), (x2,y2),...,(xN,yN)]
+ Args:
+ points (list): List of points as tuples [(x1,y1), (x2,y2),...,(xN,yN)]
"""
super().__init__()
assert(points)
@@ -150,22 +228,39 @@ def __str__(self):
@property
def type(self):
+ """ Get the region type. """
return RegionType.POLYGON
@property
def size(self):
+ """ Get the number of points. """
return self._points.shape[0] # pylint: disable=E1136
def __getitem__(self, i):
+ """ Get the i-th point."""
return self._points[i, 0], self._points[i, 1]
def points(self):
+ """ Get the list of points.
+
+ Returns:
+ list: List of points as tuples [(x1,y1), (x2,y2),...,(xN,yN)]
+ """
return [self[i] for i in range(self.size)]
def copy(self):
+ """ Create a copy of the polygon. """
return copy(self)
def convert(self, rtype: RegionType):
+ """ Convert the polygon to another region type.
+
+ Args:
+ rtype (RegionType): Target region type.
+
+ Returns:
+ Region: Converted region.
+ """
if rtype == RegionType.POLYGON:
return self.copy()
elif rtype == RegionType.RECTANGLE:
@@ -183,15 +278,43 @@ def convert(self, rtype: RegionType):
raise ConversionException("Unable to convert polygon region to {}".format(rtype), source=self)
def draw(self, handle: DrawHandle):
+ """ Draw the polygon on the given handle.
+
+ Args:
+ handle (DrawHandle): Handle to draw on.
+ """
handle.polygon([(p[0], p[1]) for p in self._points])
def resize(self, factor=1):
+ """ Resize the polygon by a factor.
+
+ Args:
+ factor (float): Resize factor.
+
+ Returns:
+ Polygon: Resized polygon.
+ """
return Polygon([(p[0] * factor, p[1] * factor) for p in self._points])
def move(self, dx=0, dy=0):
+ """ Move the polygon by a given offset.
+
+ Args:
+ dx (float): X offset.
+ dy (float): Y offset.
+
+ Returns:
+ Polygon: Moved polygon.
+ """
return Polygon([(p[0] + dx, p[1] + dy) for p in self._points])
def is_empty(self):
+ """ Check if the polygon is empty.
+
+ Returns:
+ bool: True if the polygon is empty, False otherwise.
+
+ """
top = np.min(self._points[:, 1])
bottom = np.max(self._points[:, 1])
left = np.min(self._points[:, 0])
@@ -199,10 +322,23 @@ def is_empty(self):
return top == bottom or left == right
def rasterize(self, bounds: Tuple[int, int, int, int]):
+ """ Rasterize the polygon into a binary mask.
+
+ Args:
+ bounds (tuple): Bounding box of the mask as (left, top, right, bottom).
+
+ Returns:
+ numpy.ndarray: Binary mask.
+ """
from vot.region.raster import rasterize_polygon
return rasterize_polygon(self._points, bounds)
def bounds(self):
+ """ Get the bounding box of the polygon.
+
+ Returns:
+ tuple: Bounding box as (left, top, right, bottom).
+ """
top = np.min(self._points[:, 1])
bottom = np.max(self._points[:, 1])
left = np.min(self._points[:, 0])
@@ -213,49 +349,74 @@ def bounds(self):
from vot.region.io import mask_to_rle
class Mask(Shape):
- """Mask region
+ """Mask region defined by a binary mask. The mask is defined by a binary image and an offset.
"""
def __init__(self, mask: np.array, offset: Tuple[int, int] = (0, 0), optimize=False):
+ """ Constructor
+
+ Args:
+ mask (numpy.ndarray): Binary mask.
+ offset (tuple): Offset of the mask as (x, y).
+ optimize (bool): Optimize the mask by removing empty rows and columns.
+
+ """
super().__init__()
self._mask = mask.astype(np.uint8)
self._mask[self._mask > 0] = 1
self._offset = offset
if optimize: # optimize is used when mask without an offset is given (e.g. full-image mask)
self._optimize()
-
+
def __str__(self):
+ """ Create string from class """
offset_str = '%d,%d' % self.offset
region_sz_str = '%d,%d' % (self.mask.shape[1], self.mask.shape[0])
rle_str = ','.join([str(el) for el in mask_to_rle(self.mask)])
return 'm%s,%s,%s' % (offset_str, region_sz_str, rle_str)
def _optimize(self):
+ """ Optimize the mask by removing empty rows and columns. If the mask is empty, the mask is set to zero size.
+ Do not call this method directly, it is called from the constructor. """
bounds = mask_bounds(self.mask)
- if bounds[0] is None:
+ if bounds[2] == 0:
# mask is empty
self._mask = np.zeros((0, 0), dtype=np.uint8)
self._offset = (0, 0)
else:
- self._mask = np.copy(self.mask[bounds[1]:bounds[3], bounds[0]:bounds[2]])
+
+ self._mask = np.copy(self.mask[bounds[1]:bounds[3]+1, bounds[0]:bounds[2]+1])
self._offset = (bounds[0] + self.offset[0], bounds[1] + self.offset[1])
@property
def mask(self):
+ """ Get the mask. Note that you should not modify the mask directly. Also make sure to
+ take into account the offset when using the mask."""
return self._mask
@property
def offset(self):
+ """ Get the offset of the mask in pixels."""
return self._offset
@property
def type(self):
+ """ Get the region type."""
return RegionType.MASK
def copy(self):
+ """ Create a copy of the mask."""
return copy(self)
def convert(self, rtype: RegionType):
+ """ Convert the mask to another region type. The mask is converted to a rectangle or polygon by approximating bounding box of the mask.
+
+ Args:
+ rtype (RegionType): Target region type.
+
+ Returns:
+ Shape: Converted region.
+ """
if rtype == RegionType.MASK:
return self.copy()
elif rtype == RegionType.RECTANGLE:
@@ -275,19 +436,43 @@ def convert(self, rtype: RegionType):
raise ConversionException("Unable to convert mask region to {}".format(rtype), source=self)
def draw(self, handle: DrawHandle):
+ """ Draw the mask into an image.
+
+ Args:
+ handle (DrawHandle): Handle to the image.
+ """
handle.mask(self._mask, self.offset)
def rasterize(self, bounds: Tuple[int, int, int, int]):
+ """ Rasterize the mask into a binary mask. The mask is cropped to the given bounds.
+
+ Args:
+ bounds (tuple): Bounding box of the mask as (left, top, right, bottom).
+
+ Returns:
+ numpy.ndarray: Binary mask. The mask is a copy of the original mask.
+ """
from vot.region.raster import copy_mask
return copy_mask(self._mask, self._offset, np.array(bounds))
def is_empty(self):
- if self.mask.shape[1] > 0 and self.mask.shape[0] > 0:
- return False
- else:
- return True
+ """ Check if the mask is empty.
+
+ Returns:
+ bool: True if the mask is empty, False otherwise.
+ """
+ bounds = mask_bounds(self.mask)
+ return bounds[2] == 0 or bounds[3] == 0
def resize(self, factor=1):
+ """ Resize the mask by a given factor. The mask is resized using nearest neighbor interpolation.
+
+ Args:
+ factor (float): Resize factor.
+
+ Returns:
+ Mask: Resized mask.
+ """
offset = (int(self.offset[0] * factor), int(self.offset[1] * factor))
height = max(1, int(self.mask.shape[0] * factor))
@@ -301,8 +486,22 @@ def resize(self, factor=1):
return Mask(mask, offset, False)
def move(self, dx=0, dy=0):
+ """ Move the mask by a given offset.
+
+ Args:
+ dx (int): Horizontal offset.
+ dy (int): Vertical offset.
+
+ Returns:
+ Mask: Moved mask.
+ """
return Mask(self._mask, (self.offset[0] + dx, self.offset[1] + dy))
def bounds(self):
+ """ Get the bounding box of the mask.
+
+ Returns:
+ tuple: Bounding box of the mask as (left, top, right, bottom).
+ """
bounds = mask_bounds(self.mask)
return bounds[0] + self.offset[0], bounds[1] + self.offset[1], bounds[2] + self.offset[0], bounds[3] + self.offset[1]
diff --git a/vot/region/tests.py b/vot/region/tests.py
index 561032a..b1a2ce9 100644
--- a/vot/region/tests.py
+++ b/vot/region/tests.py
@@ -1,4 +1,6 @@
+"""Tests for the region module. """
+
import unittest
import numpy as np
@@ -6,21 +8,85 @@
from vot.region.raster import rasterize_polygon, rasterize_rectangle, copy_mask, calculate_overlap
class TestRasterMethods(unittest.TestCase):
+ """Tests for the raster module."""
def test_rasterize_polygon(self):
+ """Tests if the polygon rasterization works correctly. """
points = np.array([[0, 0], [0, 100], [100, 100], [100, 0]], dtype=np.float32)
np.testing.assert_array_equal(rasterize_polygon(points, (0, 0, 99, 99)), np.ones((100, 100), dtype=np.uint8))
def test_rasterize_rectangle(self):
+ """Tests if the rectangle rasterization works correctly."""
np.testing.assert_array_equal(rasterize_rectangle(np.array([[0], [0], [100], [100]], dtype=np.float32), (0, 0, 99, 99)), np.ones((100, 100), dtype=np.uint8))
def test_copy_mask(self):
+ """Tests if the mask copy works correctly."""
mask = np.ones((100, 100), dtype=np.uint8)
np.testing.assert_array_equal(copy_mask(mask, (0, 0), (0, 0, 99, 99)), np.ones((100, 100), dtype=np.uint8))
def test_calculate_overlap(self):
+ """Tests if the overlap calculation works correctly."""
from vot.region import Rectangle
r1 = Rectangle(0, 0, 100, 100)
self.assertEqual(calculate_overlap(r1, r1), 1)
+ r1 = Rectangle(0, 0, 0, 0)
+ self.assertEqual(calculate_overlap(r1, r1), 1)
+
+ def test_empty_mask(self):
+ """Tests if the empty mask is correctly detected."""
+ from vot.region import Mask
+
+ mask = Mask(np.zeros((100, 100), dtype=np.uint8))
+ self.assertTrue(mask.is_empty())
+
+ mask = Mask(np.ones((100, 100), dtype=np.uint8))
+ self.assertFalse(mask.is_empty())
+
+ def test_binary_format(self):
+ """ Tests if the binary format of a region matched the plain-text one."""
+ import io
+
+ from vot.region import Rectangle, Polygon, Mask
+ from vot.region.io import read_trajectory, write_trajectory
+ from vot.region.raster import calculate_overlaps
+
+ trajectory = [
+ Rectangle(0, 0, 100, 100),
+ Rectangle(0, 10, 100, 100),
+ Rectangle(0, 0, 200, 100),
+ Polygon([[0, 0], [0, 100], [100, 100], [100, 0]]),
+ Mask(np.ones((100, 100), dtype=np.uint8)),
+ Mask(np.zeros((100, 100), dtype=np.uint8)),
+ ]
+
+ binf = io.BytesIO()
+ txtf = io.StringIO()
+
+ write_trajectory(binf, trajectory)
+ write_trajectory(txtf, trajectory)
+
+ binf.seek(0)
+ txtf.seek(0)
+
+ bint = read_trajectory(binf)
+ txtt = read_trajectory(txtf)
+
+ o1 = calculate_overlaps(bint, txtt, None)
+ o2 = calculate_overlaps(bint, trajectory, None)
+
+ self.assertTrue(np.all(np.array(o1) == 1))
+ self.assertTrue(np.all(np.array(o2) == 1))
+
+ def test_rle(self):
+ """ Test if RLE encoding works for limited stride representation."""
+ from vot.region.io import rle_to_mask, mask_to_rle
+ rle = [0, 2, 122103, 9, 260, 19, 256, 21, 256, 22, 254, 24, 252, 26, 251, 27, 250, 28, 249, 28, 250, 28, 249, 28, 249, 29, 249, 30, 247, 33, 245, 33, 244, 34, 244, 37, 241, 39, 239, 41, 237, 41, 236, 43, 235, 45, 234, 47, 233, 47, 231, 48, 230, 48, 230, 11, 7, 29, 231, 9, 9, 29, 230, 8, 11, 28, 230, 7, 12, 28, 230, 7, 13, 27, 231, 5, 14, 27, 233, 2, 16, 26, 253, 23, 255, 22, 256, 20, 258, 19, 259, 17, 3]
+ rle = np.array(rle)
+ m1 = rle_to_mask(np.array(rle, dtype=np.int32), 277, 478)
+
+ r2 = mask_to_rle(m1, maxstride=255)
+ m2 = rle_to_mask(np.array(r2, dtype=np.int32), 277, 478)
+
+ np.testing.assert_array_equal(m1, m2)
\ No newline at end of file
diff --git a/vot/stack/__init__.py b/vot/stack/__init__.py
index 44da321..8aead08 100644
--- a/vot/stack/__init__.py
+++ b/vot/stack/__init__.py
@@ -1,29 +1,34 @@
+"""Stacks are collections of experiments that are grouped together for convenience. Stacks are used to organize experiments and to run them in
+batch mode.
+"""
import os
-import json
-import glob
-import collections
-from typing import List
+from typing import List, Mapping
import yaml
from attributee import Attributee, String, Boolean, Map, Object
from vot.experiment import Experiment, experiment_registry
-from vot.experiment.transformer import Transformer
from vot.utilities import import_class
-from vot.analysis import Analysis
def experiment_resolver(typename, context, **kwargs):
+ """Resolves experiment objects from stack definitions. This function is used by the stack module to resolve experiment objects from stack
+ definitions. It is not intended to be used directly.
- if "key" in context:
- identifier = context["key"]
- else:
- identifier = None
+ Args:
+ typename (str): Name of the experiment class
+ context (Attributee): Context of the experiment
+ kwargs (dict): Additional arguments
+
+ Returns:
+ Experiment: Experiment object
+ """
+ identifier = context.key
storage = None
- if "parent" in context:
- if getattr(context["parent"], "workspace", None) is not None:
- storage = context["parent"].workspace.storage
+
+ if getattr(context.parent, "workspace", None) is not None:
+ storage = context.parent.workspace.storage
if typename in experiment_registry:
experiment = experiment_registry.get(typename, _identifier=identifier, _storage=storage, **kwargs)
@@ -35,14 +40,22 @@ def experiment_resolver(typename, context, **kwargs):
return experiment_class(_identifier=identifier, _storage=storage, **kwargs)
class Stack(Attributee):
+ """Stack class represents a collection of experiments. Stacks are used to organize experiments and to run them in batch mode.
+ """
- title = String()
- dataset = String(default="")
+ title = String(default="Stack")
+ dataset = String(default=None)
url = String(default="")
deprecated = Boolean(default=False)
experiments = Map(Object(experiment_resolver))
def __init__(self, name: str, workspace: "Workspace", **kwargs):
+ """Creates a new stack object.
+
+ Args:
+ name (str): Name of the stack
+ workspace (Workspace): Workspace object
+ """
self._workspace = workspace
self._name = name
@@ -50,22 +63,45 @@ def __init__(self, name: str, workspace: "Workspace", **kwargs):
@property
def workspace(self):
+ """Returns the workspace object for the stack."""
return self._workspace
@property
def name(self):
+ """Returns the name of the stack."""
return self._name
def __iter__(self):
+ """Iterates over experiments in the stack."""
return iter(self.experiments.values())
def __len__(self):
+ """Returns the number of experiments in the stack."""
return len(self.experiments)
def __getitem__(self, identifier):
+ """Returns the experiment with the given identifier.
+
+ Args:
+ identifier (str): Identifier of the experiment
+
+ Returns:
+ Experiment: Experiment object
+
+ """
return self.experiments[identifier]
-def resolve_stack(name, *directories):
+def resolve_stack(name: str, *directories: List[str]) -> str:
+ """Searches for stack file in the given directories and returns its absolute path. If given an absolute path as input
+ it simply returns it.
+
+ Args:
+ name (str): Name of the stack
+ directories (List[str]): Directories that will be used
+
+ Returns:
+ str: Absolute path to stack file
+ """
if os.path.isabs(name):
return name if os.path.isfile(name) else None
for directory in directories:
@@ -77,11 +113,24 @@ def resolve_stack(name, *directories):
return full
return None
-def list_integrated_stacks():
+def list_integrated_stacks() -> Mapping[str, str]:
+ """List stacks that come with the toolkit
+
+ Returns:
+ Map[str, str]: A mapping of stack ids and stack title pairs
+ """
+
+ from pathlib import Path
+
stacks = {}
- for stack_file in glob.glob(os.path.join(os.path.dirname(__file__), "*.yaml")):
- with open(stack_file, 'r') as fp:
+ root = Path(os.path.join(os.path.dirname(__file__)))
+
+ for stack_path in root.rglob("*.yaml"):
+ with open(stack_path, 'r') as fp:
stack_metadata = yaml.load(fp, Loader=yaml.BaseLoader)
- stacks[os.path.splitext(os.path.basename(stack_file))[0]] = stack_metadata.get("title", "")
+ if stack_metadata is None:
+ continue
+ key = str(stack_path.relative_to(root).with_name(os.path.splitext(stack_path.name)[0]))
+ stacks[key] = stack_metadata.get("title", "")
return stacks
\ No newline at end of file
diff --git a/vot/stack/otb100.yaml b/vot/stack/otb100.yaml
new file mode 100644
index 0000000..ff604f5
--- /dev/null
+++ b/vot/stack/otb100.yaml
@@ -0,0 +1,10 @@
+title: OTB100 dataset experiment stack
+url: http://cvlab.hanyang.ac.kr/tracker_benchmark/index.html
+dataset: otb100
+experiments:
+ baseline:
+ type: unsupervised
+ analyses:
+ - type: average_accuracy
+ name: accuracy
+ burnin: 1
\ No newline at end of file
diff --git a/vot/stack/otb50.yaml b/vot/stack/otb50.yaml
new file mode 100644
index 0000000..1794bdd
--- /dev/null
+++ b/vot/stack/otb50.yaml
@@ -0,0 +1,10 @@
+title: OTB50 dataset experiment stack
+url: http://cvlab.hanyang.ac.kr/tracker_benchmark/index.html
+dataset: otb50
+experiments:
+ baseline:
+ type: unsupervised
+ analyses:
+ - type: average_accuracy
+ name: accuracy
+ burnin: 1
\ No newline at end of file
diff --git a/vot/stack/tests.py b/vot/stack/tests.py
index 083f5c5..0b5c68f 100644
--- a/vot/stack/tests.py
+++ b/vot/stack/tests.py
@@ -1,19 +1,27 @@
-import os
+"""Tests for the experiment stack module."""
+
import unittest
import yaml
-from vot.workspace import Workspace, NullStorage
+from vot.workspace import NullStorage
from vot.stack import Stack, list_integrated_stacks, resolve_stack
class NoWorkspace:
+ """Empty workspace, does not save anything
+ """
@property
def storage(self):
+ """Returns the storage object for the workspace. """
return NullStorage()
class TestStacks(unittest.TestCase):
+ """Tests for the experiment stack utilities
+ """
def test_stacks(self):
+ """Test loading integrated stacks
+ """
stacks = list_integrated_stacks()
for stack_name in stacks:
diff --git a/vot/stack/tests/basic.yaml b/vot/stack/tests/basic.yaml
new file mode 100644
index 0000000..368bce4
--- /dev/null
+++ b/vot/stack/tests/basic.yaml
@@ -0,0 +1,7 @@
+title: VOT Basic Test Stack
+url: http://www.votchallenge.net/
+dataset: https://data.votchallenge.net/toolkit/test.zip
+experiments:
+ baseline:
+ type: unsupervised
+ repetitions: 1
\ No newline at end of file
diff --git a/vot/stack/tests/multiobject.yaml b/vot/stack/tests/multiobject.yaml
new file mode 100644
index 0000000..6b6d6ba
--- /dev/null
+++ b/vot/stack/tests/multiobject.yaml
@@ -0,0 +1,21 @@
+title: VOTS2023 Test Stack
+dataset: https://data.votchallenge.net/vots2023/test/description.json
+experiments:
+ baseline:
+ type: unsupervised
+ repetitions: 1
+ multiobject: True
+ analyses:
+ - type: average_accuracy
+ name: Quality
+ burnin: 0
+ ignore_unknown: False
+ weighted: False
+ - type: average_success_plot
+ name: Quality plot
+ burnin: 0
+ ignore_unknown: False
+ - type: longterm_ar
+ name: AR
+ - type: average_quality_auxiliary
+ name: Auxiliary
\ No newline at end of file
diff --git a/vot/stack/testing.yaml b/vot/stack/tests/segmentation.yaml
similarity index 85%
rename from vot/stack/testing.yaml
rename to vot/stack/tests/segmentation.yaml
index 04c1459..43e5288 100644
--- a/vot/stack/testing.yaml
+++ b/vot/stack/tests/segmentation.yaml
@@ -1,6 +1,6 @@
-title: VOT testing
+title: VOT Segmentation testing
url: http://www.votchallenge.net/
-dataset: vot:segmentation
+dataset: http://box.vicos.si/tracking/vot20_test_dataset.zip
experiments:
baseline:
type: multistart
diff --git a/vot/stack/vot2013.yaml b/vot/stack/vot2013.yaml
index 85ad419..3006908 100644
--- a/vot/stack/vot2013.yaml
+++ b/vot/stack/vot2013.yaml
@@ -1,6 +1,6 @@
title: VOT2013 challenge
url: http://www.votchallenge.net/vot2013/
-dataset: vot:vot2013
+dataset: http://data.votchallenge.net/vot2013/dataset/description.json
deprecated: True
experiments:
baseline:
diff --git a/vot/stack/vot2014.yaml b/vot/stack/vot2014.yaml
index 171df7e..3713448 100644
--- a/vot/stack/vot2014.yaml
+++ b/vot/stack/vot2014.yaml
@@ -1,5 +1,5 @@
title: VOT2014 challenge
-dataset: vot:vot2014
+dataset: http://data.votchallenge.net/vot2014/dataset/description.json
url: http://www.votchallenge.net/vot2014/
deprecated: True
experiments:
diff --git a/vot/stack/vot2015.yaml b/vot/stack/vot2015/rgb.yaml
similarity index 82%
rename from vot/stack/vot2015.yaml
rename to vot/stack/vot2015/rgb.yaml
index eb30eec..8e3c2ab 100644
--- a/vot/stack/vot2015.yaml
+++ b/vot/stack/vot2015/rgb.yaml
@@ -1,5 +1,5 @@
title: VOT2015 challenge
-dataset: vot:vot2015
+dataset: http://data.votchallenge.net/vot2015/dataset/description.json
url: http://www.votchallenge.net/vot2015/
experiments:
baseline:
diff --git a/vot/stack/vottir2015.yaml b/vot/stack/vot2015/tir.yaml
similarity index 72%
rename from vot/stack/vottir2015.yaml
rename to vot/stack/vot2015/tir.yaml
index 5ab84de..e2d09cf 100644
--- a/vot/stack/vottir2015.yaml
+++ b/vot/stack/vot2015/tir.yaml
@@ -1,5 +1,5 @@
title: VOT-TIR2015 challenge
-dataset: vot:vot-tir2015
+dataset: http://www.cvl.isy.liu.se/research/datasets/ltir/version1.0/ltir_v1_0_8bit.zip
url: http://www.votchallenge.net/vot2015/
experiments:
baseline:
diff --git a/vot/stack/vot2016.yaml b/vot/stack/vot2016/rgb.yaml
similarity index 87%
rename from vot/stack/vot2016.yaml
rename to vot/stack/vot2016/rgb.yaml
index 7dd3de6..8a616ea 100644
--- a/vot/stack/vot2016.yaml
+++ b/vot/stack/vot2016/rgb.yaml
@@ -1,5 +1,5 @@
title: VOT2016 challenge
-dataset: vot:vot2016
+dataset: http://data.votchallenge.net/vot2016/main/description.json
url: http://www.votchallenge.net/vot2016/
experiments:
baseline:
diff --git a/vot/stack/vottir2016.yaml b/vot/stack/vot2016/tir.yaml
similarity index 79%
rename from vot/stack/vottir2016.yaml
rename to vot/stack/vot2016/tir.yaml
index a5b09a3..6d4366b 100644
--- a/vot/stack/vottir2016.yaml
+++ b/vot/stack/vot2016/tir.yaml
@@ -1,5 +1,5 @@
title: VOT-TIR2016 challenge
-dataset: vot:vot-tir2016
+dataset: http://data.votchallenge.net/vot2016/vot-tir2016.zip
url: http://www.votchallenge.net/vot2016/
experiments:
baseline:
diff --git a/vot/stack/vot2017.yaml b/vot/stack/vot2017.yaml
index 3a362ff..bc7a24b 100644
--- a/vot/stack/vot2017.yaml
+++ b/vot/stack/vot2017.yaml
@@ -1,5 +1,5 @@
title: VOT2017 challenge
-dataset: vot:vot2017
+dataset: http://data.votchallenge.net/vot2017/main/description.json
url: http://www.votchallenge.net/vot2017/
experiments:
baseline:
diff --git a/vot/stack/vot2018/longterm.yaml b/vot/stack/vot2018/longterm.yaml
new file mode 100644
index 0000000..b31e677
--- /dev/null
+++ b/vot/stack/vot2018/longterm.yaml
@@ -0,0 +1,20 @@
+title: VOT-LT2018 challenge
+dataset: http://data.votchallenge.net/vot2018/longterm/description.json
+url: http://www.votchallenge.net/vot2018/
+experiments:
+ longterm:
+ type: unsupervised
+ repetitions: 1
+ analyses:
+ - type: average_tpr
+ name: average_tpr
+ - type: pr_curve
+ - type: f_curve
+ redetection:
+ type: unsupervised
+ transformers:
+ - type: redetection
+ length: 200
+ initialization: 5
+ padding: 2
+ scaling: 3
diff --git a/vot/stack/vot2018.yaml b/vot/stack/vot2018/shortterm.yaml
similarity index 91%
rename from vot/stack/vot2018.yaml
rename to vot/stack/vot2018/shortterm.yaml
index fb7942a..60a97af 100644
--- a/vot/stack/vot2018.yaml
+++ b/vot/stack/vot2018/shortterm.yaml
@@ -1,5 +1,5 @@
title: VOT-ST2018 challenge
-dataset: vot:vot-st2018
+dataset: http://data.votchallenge.net/vot2018/main/description.json
url: http://www.votchallenge.net/vot2018/
experiments:
baseline:
diff --git a/vot/stack/votlt2019.yaml b/vot/stack/vot2019/longterm.yaml
similarity index 85%
rename from vot/stack/votlt2019.yaml
rename to vot/stack/vot2019/longterm.yaml
index 65dba8a..0bc2805 100644
--- a/vot/stack/votlt2019.yaml
+++ b/vot/stack/vot2019/longterm.yaml
@@ -1,5 +1,5 @@
title: VOT-LT2019 challenge
-dataset: vot:vot-lt2019
+dataset: http://data.votchallenge.net/vot2019/longterm/description.json
url: http://www.votchallenge.net/vot2019/
experiments:
longterm:
diff --git a/vot/stack/votrgbd2019.yaml b/vot/stack/vot2019/rgbd.yaml
similarity index 77%
rename from vot/stack/votrgbd2019.yaml
rename to vot/stack/vot2019/rgbd.yaml
index f54726d..481fcf7 100644
--- a/vot/stack/votrgbd2019.yaml
+++ b/vot/stack/vot2019/rgbd.yaml
@@ -1,5 +1,5 @@
title: VOT-RGBD2019 challenge
-dataset: vot:vot-rgbd2019
+dataset: http://data.votchallenge.net/vot2019/rgbd/description.json
url: http://www.votchallenge.net/vot2019/
experiments:
rgbd-unsupervised:
diff --git a/vot/stack/vot2019/rgbtir.yaml b/vot/stack/vot2019/rgbtir.yaml
new file mode 100644
index 0000000..15c60cf
--- /dev/null
+++ b/vot/stack/vot2019/rgbtir.yaml
@@ -0,0 +1,15 @@
+title: VOT-RGBTIR2019 challenge
+dataset: http://data.votchallenge.net/vot2019/rgbtir/meta/description.json
+url: http://www.votchallenge.net/vot2019/
+experiments:
+ baseline:
+ type: multistart
+ realtime:
+ grace: 3
+ analyses:
+ - type: multistart_average_ar
+ - type: multistart_eao_score
+ low: 115
+ high: 755
+ - type: multistart_eao_curve
+ high: 755
\ No newline at end of file
diff --git a/vot/stack/vot2019.yaml b/vot/stack/vot2019/shortterm.yaml
similarity index 91%
rename from vot/stack/vot2019.yaml
rename to vot/stack/vot2019/shortterm.yaml
index 387605b..1ee18e7 100644
--- a/vot/stack/vot2019.yaml
+++ b/vot/stack/vot2019/shortterm.yaml
@@ -1,5 +1,5 @@
title: VOT-ST2019 challenge
-dataset: vot:vot-st2019
+dataset: http://data.votchallenge.net/vot2019/main/description.json
url: http://www.votchallenge.net/vot2019/
experiments:
baseline:
diff --git a/vot/stack/votlt2020.yaml b/vot/stack/vot2020/longterm.yaml
similarity index 85%
rename from vot/stack/votlt2020.yaml
rename to vot/stack/vot2020/longterm.yaml
index d8418fd..e698f87 100644
--- a/vot/stack/votlt2020.yaml
+++ b/vot/stack/vot2020/longterm.yaml
@@ -1,5 +1,5 @@
title: VOT-LT2020 challenge
-dataset: vot:vot-lt2019
+dataset: http://data.votchallenge.net/vot2019/longterm/description.json
url: http://www.votchallenge.net/vot2020/
experiments:
longterm:
diff --git a/vot/stack/votrgbd2020.yaml b/vot/stack/vot2020/rgbd.yaml
similarity index 77%
rename from vot/stack/votrgbd2020.yaml
rename to vot/stack/vot2020/rgbd.yaml
index f3265e8..75a8061 100644
--- a/vot/stack/votrgbd2020.yaml
+++ b/vot/stack/vot2020/rgbd.yaml
@@ -1,5 +1,5 @@
title: VOT-RGBD2020 challenge
-dataset: vot:vot-rgbd2019
+dataset: http://data.votchallenge.net/vot2019/rgbd/description.json
url: http://www.votchallenge.net/vot2020/
experiments:
rgbd-unsupervised:
diff --git a/vot/stack/votrgbtir2020.yaml b/vot/stack/vot2020/rgbtir.yaml
similarity index 81%
rename from vot/stack/votrgbtir2020.yaml
rename to vot/stack/vot2020/rgbtir.yaml
index c656ec3..98945c3 100644
--- a/vot/stack/votrgbtir2020.yaml
+++ b/vot/stack/vot2020/rgbtir.yaml
@@ -1,5 +1,5 @@
title: VOT-RGBTIR2020 challenge
-dataset: vot:vot-rgbt2020
+dataset: http://data.votchallenge.net/vot2020/rgbtir/meta/description.json
url: http://www.votchallenge.net/vot2020/
experiments:
baseline:
diff --git a/vot/stack/vot2020.yaml b/vot/stack/vot2020/shortterm.yaml
similarity index 91%
rename from vot/stack/vot2020.yaml
rename to vot/stack/vot2020/shortterm.yaml
index b736d6b..0452531 100644
--- a/vot/stack/vot2020.yaml
+++ b/vot/stack/vot2020/shortterm.yaml
@@ -1,5 +1,5 @@
title: VOT-ST2020 challenge
-dataset: vot:vot-st2020
+dataset: https://data.votchallenge.net/vot2020/shortterm/description.json
url: http://www.votchallenge.net/vot2020/
experiments:
baseline:
diff --git a/vot/stack/votlt2021.yaml b/vot/stack/vot2021/lt.yaml
similarity index 85%
rename from vot/stack/votlt2021.yaml
rename to vot/stack/vot2021/lt.yaml
index 842820b..d31c092 100644
--- a/vot/stack/votlt2021.yaml
+++ b/vot/stack/vot2021/lt.yaml
@@ -1,5 +1,5 @@
title: VOT-LT2021 challenge
-dataset: vot:vot-lt2019
+dataset: http://data.votchallenge.net/vot2019/longterm/description.json
url: http://www.votchallenge.net/vot2021/
experiments:
longterm:
diff --git a/vot/stack/votrgbd2021.yaml b/vot/stack/vot2021/rgbd.yaml
similarity index 77%
rename from vot/stack/votrgbd2021.yaml
rename to vot/stack/vot2021/rgbd.yaml
index 44eaa3c..cfffa5e 100644
--- a/vot/stack/votrgbd2021.yaml
+++ b/vot/stack/vot2021/rgbd.yaml
@@ -1,5 +1,5 @@
title: VOT-RGBD2021 challenge
-dataset: vot:vot-rgbd2019
+dataset: http://data.votchallenge.net/vot2019/rgbd/description.json
url: http://www.votchallenge.net/vot2021/
experiments:
rgbd-unsupervised:
diff --git a/vot/stack/vot2021.yaml b/vot/stack/vot2021/st.yaml
similarity index 91%
rename from vot/stack/vot2021.yaml
rename to vot/stack/vot2021/st.yaml
index 566d68b..b1f8a7b 100644
--- a/vot/stack/vot2021.yaml
+++ b/vot/stack/vot2021/st.yaml
@@ -1,5 +1,5 @@
title: VOT-ST2021 challenge
-dataset: vot:vot-st2021
+dataset: https://data.votchallenge.net/vot2021/shortterm/description.json
url: http://www.votchallenge.net/vot2021/
experiments:
baseline:
diff --git a/vot/stack/vot2022/depth.yaml b/vot/stack/vot2022/depth.yaml
new file mode 100644
index 0000000..0b97f8e
--- /dev/null
+++ b/vot/stack/vot2022/depth.yaml
@@ -0,0 +1,16 @@
+title: VOT-D2022 challenge
+dataset: https://data.votchallenge.net/vot2022/depth/description.json
+url: https://www.votchallenge.net/vot2022/
+experiments:
+ baseline:
+ type: multistart
+ analyses:
+ - type: multistart_eao_score
+ name: eaoscore
+ low: 115
+ high: 755
+ - type: multistart_eao_curve
+ name: eaocurve
+ high: 755
+ - type: multistart_average_ar
+ name: ar
diff --git a/vot/stack/vot2022/lt.yaml b/vot/stack/vot2022/lt.yaml
new file mode 100644
index 0000000..fb80d22
--- /dev/null
+++ b/vot/stack/vot2022/lt.yaml
@@ -0,0 +1,20 @@
+title: VOT-LT2022 challenge
+dataset: https://data.votchallenge.net/vot2022/lt/description.json
+url: https://www.votchallenge.net/vot2022/
+experiments:
+ longterm:
+ type: unsupervised
+ repetitions: 1
+ analyses:
+ - type: average_tpr
+ name: average_tpr
+ - type: pr_curve
+ - type: f_curve
+ redetection:
+ type: unsupervised
+ transformers:
+ - type: redetection
+ length: 200
+ initialization: 5
+ padding: 2
+ scaling: 3
diff --git a/vot/stack/vot2022/rgbd.yaml b/vot/stack/vot2022/rgbd.yaml
new file mode 100644
index 0000000..c5a8f19
--- /dev/null
+++ b/vot/stack/vot2022/rgbd.yaml
@@ -0,0 +1,16 @@
+title: VOT-RGBD2022 challenge
+dataset: https://data.votchallenge.net/vot2022/rgbd/description.json
+url: https://www.votchallenge.net/vot2022/
+experiments:
+ baseline:
+ type: multistart
+ analyses:
+ - type: multistart_eao_score
+ name: eaoscore
+ low: 115
+ high: 755
+ - type: multistart_eao_curve
+ name: eaocurve
+ high: 755
+ - type: multistart_average_ar
+ name: ar
diff --git a/vot/stack/vot2022/stb.yaml b/vot/stack/vot2022/stb.yaml
new file mode 100644
index 0000000..5378a63
--- /dev/null
+++ b/vot/stack/vot2022/stb.yaml
@@ -0,0 +1,37 @@
+title: VOT-ST2022 bounding-box challenge
+dataset: https://data.votchallenge.net/vot2022/stb/description.json
+url: https://www.votchallenge.net/vot2022/
+experiments:
+ baseline:
+ type: multistart
+ analyses:
+ - type: multistart_eao_score
+ name: eaoscore
+ low: 115
+ high: 755
+ - type: multistart_eao_curve
+ name: eaocurve
+ high: 755
+ - type: multistart_average_ar
+ name: ar
+ realtime:
+ type: multistart
+ realtime:
+ grace: 3
+ analyses:
+ - type: multistart_eao_score
+ name: eaoscore
+ low: 115
+ high: 755
+ - type: multistart_eao_curve
+ name: eaocurve
+ high: 755
+ - type: multistart_average_ar
+ name: ar
+ unsupervised:
+ type: unsupervised
+ repetitions: 1
+ analyses:
+ - type: average_accuracy
+ name: accuracy
+ burnin: 1
\ No newline at end of file
diff --git a/vot/stack/vot2022/sts.yaml b/vot/stack/vot2022/sts.yaml
new file mode 100644
index 0000000..99f8de9
--- /dev/null
+++ b/vot/stack/vot2022/sts.yaml
@@ -0,0 +1,37 @@
+title: VOT-ST2021 segmentation challenge
+dataset: https://data.votchallenge.net/vot2022/sts/description.json
+url: https://www.votchallenge.net/vot2022/
+experiments:
+ baseline:
+ type: multistart
+ analyses:
+ - type: multistart_eao_score
+ name: eaoscore
+ low: 115
+ high: 755
+ - type: multistart_eao_curve
+ name: eaocurve
+ high: 755
+ - type: multistart_average_ar
+ name: ar
+ realtime:
+ type: multistart
+ realtime:
+ grace: 3
+ analyses:
+ - type: multistart_eao_score
+ name: eaoscore
+ low: 115
+ high: 755
+ - type: multistart_eao_curve
+ name: eaocurve
+ high: 755
+ - type: multistart_average_ar
+ name: ar
+ unsupervised:
+ type: unsupervised
+ repetitions: 1
+ analyses:
+ - type: average_accuracy
+ name: accuracy
+ burnin: 1
\ No newline at end of file
diff --git a/vot/stack/vots2023.yaml b/vot/stack/vots2023.yaml
new file mode 100644
index 0000000..6de44b1
--- /dev/null
+++ b/vot/stack/vots2023.yaml
@@ -0,0 +1,7 @@
+title: VOTS2023 Challenge Stack
+dataset: https://data.votchallenge.net/vots2023/dataset/description.json
+experiments:
+ baseline:
+ type: unsupervised
+ repetitions: 1
+ multiobject: True
\ No newline at end of file
diff --git a/vot/tracker/__init__.py b/vot/tracker/__init__.py
index b7119cf..eb6a468 100644
--- a/vot/tracker/__init__.py
+++ b/vot/tracker/__init__.py
@@ -1,11 +1,12 @@
+""" This module contains the base classes for trackers and the registry of known trackers. """
import os
import re
import configparser
import logging
import copy
-from typing import Tuple
-from collections import OrderedDict
+from typing import Tuple, List, Union
+from collections import OrderedDict, namedtuple
from abc import abstractmethod, ABC
import yaml
@@ -18,20 +19,35 @@
logger = logging.getLogger("vot")
class TrackerException(ToolkitException):
+ """ Base class for all tracker related exceptions."""
+
def __init__(self, *args, tracker, tracker_log=None):
+ """ Initialize the exception.
+
+ Args:
+ tracker (Tracker): Tracker that caused the exception.
+ tracker_log (str, optional): Optional log message. Defaults to None.
+ """
super().__init__(*args)
self._tracker_log = tracker_log
self._tracker = tracker
@property
- def log(self):
+ def log(self) -> str:
+ """ Returns the log message of the tracker.
+
+ Returns:
+ sts: Log message of the tracker.
+ """
return self._tracker_log
@property
def tracker(self):
+ """ Returns the tracker that caused the exception."""
return self._tracker
class TrackerTimeoutException(TrackerException):
+ """ Exception raised when the tracker communication times out."""
pass
VALID_IDENTIFIER = re.compile("^[a-zA-Z0-9-_]+$")
@@ -39,12 +55,39 @@ class TrackerTimeoutException(TrackerException):
VALID_REFERENCE = re.compile("^([a-zA-Z0-9-_]+)(@[a-zA-Z0-9-_]*)?$")
def is_valid_identifier(identifier):
+ """Checks if the identifier is valid.
+
+ Args:
+ identifier (str): The identifier to check.
+
+ Returns:
+ bool: True if the identifier is valid, False otherwise.
+ """
return not VALID_IDENTIFIER.match(identifier) is None
def is_valid_reference(reference):
+ """Checks if the reference is valid.
+
+ Args:
+ reference (str): The reference to check.
+
+ Returns:
+ bool: True if the reference is valid, False otherwise.
+ """
return not VALID_REFERENCE.match(reference) is None
def parse_reference(reference):
+ """Parses the reference into identifier and version.
+
+ Args:
+ reference (str): The reference to parse.
+
+ Returns:
+ tuple: A tuple containing the identifier and the version.
+
+ Raises:
+ ValueError: If the reference is not valid.
+ """
matches = VALID_REFERENCE.match(reference)
if not matches:
return None, None
@@ -53,8 +96,15 @@ def parse_reference(reference):
_runtime_protocols = {}
class Registry(object):
+ """ Repository of known trackers. Trackers are loaded from a manifest files in one or more directories. """
def __init__(self, directories, root=os.getcwd()):
+ """ Initialize the registry.
+
+ Args:
+ directories (list): List of directories to scan for trackers.
+ root (str, optional): The root directory of the workspace. Defaults to os.getcwd().
+ """
trackers = dict()
registries = []
@@ -107,19 +157,37 @@ def __init__(self, directories, root=os.getcwd()):
logger.debug("Found %d trackers", len(self._trackers))
def __getitem__(self, reference) -> "Tracker":
+ """ Returns the tracker for the given reference. """
+
return self.resolve(reference, skip_unknown=False, resolve_plural=False)[0]
def __contains__(self, reference) -> bool:
+ """ Checks if the tracker is registered. """
identifier, _ = parse_reference(reference)
return identifier in self._trackers
def __iter__(self):
+ """ Returns an iterator over the trackers."""
return iter(self._trackers.values())
def __len__(self):
+ """ Returns the number of trackers."""
return len(self._trackers)
def resolve(self, *references, storage=None, skip_unknown=True, resolve_plural=True):
+ """ Resolves the references to trackers.
+
+ Args:
+ storage (_type_, optional): Storage to use for resolving references. Defaults to None.
+ skip_unknown (bool, optional): Skip unknown trackers. Defaults to True.
+ resolve_plural (bool, optional): Resolve plural references. Defaults to True.
+
+ Raises:
+ ToolkitException: When a reference cannot be resolved.
+
+ Returns:
+ list: Resolved trackers.
+ """
trackers = []
@@ -151,7 +219,16 @@ def resolve(self, *references, storage=None, skip_unknown=True, resolve_plural=T
return trackers
- def _find_versions(self, identifier, storage):
+ def _find_versions(self, identifier: str, storage: "Storage"):
+ """ Finds all versions of the tracker in the storage.
+
+ Args:
+ identifier (str): The identifier of the tracker.
+ storage (Storage): The storage to use for finding the versions.
+
+ Returns:
+ list: List of trackers.
+ """
trackers = []
@@ -167,15 +244,33 @@ def _find_versions(self, identifier, storage):
return trackers
def references(self):
+ """ Returns a list of all tracker references.
+
+ Returns:
+ list: List of tracker references.
+ """
return [t.reference for t in self._trackers.values()]
def identifiers(self):
+ """ Returns a list of all tracker identifiers.
+
+ Returns:
+ list: List of tracker identifiers.
+ """
return [t.identifier for t in self._trackers.values()]
class Tracker(object):
+ """ Tracker definition class. """
@staticmethod
def _collect_envvars(**kwargs):
+ """ Collects environment variables from the keyword arguments.
+
+ Args:
+ **kwargs: Keyword arguments.
+
+ Returns:
+ tuple: Tuple of environment variables and other keyword arguments. """
envvars = dict()
other = dict()
@@ -194,6 +289,14 @@ def _collect_envvars(**kwargs):
@staticmethod
def _collect_arguments(**kwargs):
+ """ Collects arguments from the keyword arguments.
+
+ Args:
+ **kwargs: Keyword arguments.
+
+ Returns:
+ tuple: Tuple of arguments and other keyword arguments.
+ """
arguments = dict()
other = dict()
@@ -212,6 +315,18 @@ def _collect_arguments(**kwargs):
@staticmethod
def _collect_metadata(**kwargs):
+ """ Collects metadata from the keyword arguments.
+
+ Args:
+ **kwargs: Keyword arguments.
+
+ Returns:
+ tuple: Tuple of metadata and other keyword arguments.
+
+ Examples:
+ >>> Tracker._collect_metadata(meta_author="John Doe", meta_year=2018)
+ ({'author': 'John Doe', 'year': 2018}, {})
+ """
metadata = dict()
other = dict()
@@ -229,6 +344,23 @@ def _collect_metadata(**kwargs):
return metadata, other
def __init__(self, _identifier, _source, command, protocol=None, label=None, version=None, tags=None, storage=None, **kwargs):
+ """ Initializes the tracker definition.
+
+ Args:
+ _identifier (str): The identifier of the tracker.
+ _source (str): The source of the tracker.
+ command (str): The command to execute.
+ protocol (str, optional): The protocol of the tracker. Defaults to None.
+ label (str, optional): The label of the tracker. Defaults to None.
+ version (str, optional): The version of the tracker. Defaults to None.
+ tags (str, optional): The tags of the tracker. Defaults to None.
+ storage (str, optional): The storage of the tracker. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ ValueError: When the identifier is not valid.
+
+ """
from vot.workspace import LocalStorage
self._identifier = _identifier
self._source = _source
@@ -267,6 +399,7 @@ def reversion(self, version=None) -> "Tracker":
return tracker
def runtime(self, log=False) -> "TrackerRuntime":
+ """Creates a new runtime instance for this tracker instance."""
if not self._command:
raise TrackerException("Tracker does not have an attached executable", tracker=self)
@@ -276,31 +409,49 @@ def runtime(self, log=False) -> "TrackerRuntime":
return _runtime_protocols[self._protocol](self, self._command, log=log, envvars=self._envvars, arguments=self._arguments, **self._args)
def __eq__(self, other):
+ """ Checks if two trackers are equal.
+
+ Args:
+ other (Tracker): The other tracker.
+
+ Returns:
+ bool: True if the trackers are equal, False otherwise.
+ """
if other is None or not isinstance(other, Tracker):
return False
return self.reference == other.identifier
def __hash__(self):
+ """ Returns the hash of the tracker. """
return hash(self.reference)
def __repr__(self):
+ """ Returns the string representation of the tracker. """
return self.reference
@property
def source(self):
+ """Returns the source of the tracker."""
return self._source
@property
def storage(self) -> "Storage":
+ """Returns the storage of the tracker results."""
return self._storage
@property
def identifier(self) -> str:
+ """Returns the identifier of the tracker."""
return self._identifier
@property
def label(self):
+ """Returns the label of the tracker. If the version is specified, the label will contain the version as well.
+
+ Returns:
+ str: Label of the tracker.
+ """
if self._version is None:
return self._label
else:
@@ -308,10 +459,20 @@ def label(self):
@property
def version(self) -> str:
+ """Returns the version of the tracker. If the version is not specified, None is returned.
+
+ Returns:
+ str: Version of the tracker.
+ """
return self._version
@property
def reference(self) -> str:
+ """Returns the reference of the tracker. If the version is specified, the reference will contain the version as well.
+
+ Returns:
+ str: Reference of the tracker.
+ """
if self._version is None:
return self._identifier
else:
@@ -319,55 +480,124 @@ def reference(self) -> str:
@property
def protocol(self) -> str:
+ """Returns the communication protocol used by this tracker.
+
+ Returns:
+ str: Communication protocol
+ """
return self._protocol
def describe(self):
+ """Returns a dictionary containing the tracker description.
+
+ Returns:
+ dict: Dictionary containing the tracker description.
+ """
data = dict(command=self._command, label=self.label, protocol=self.protocol, arguments=self._arguments, env=self._envvars)
data.update(self._args)
return data
def metadata(self, key):
+ """Returns the metadata value for specified key."""
if not key in self._metadata:
return None
return self._metadata[key]
def tagged(self, tag):
+ """Returns true if the tracker is tagged with specified tag.
+
+ Args:
+ tag (str): The tag to check.
+
+ Returns:
+ bool: True if the tracker is tagged with specified tag, False otherwise.
+ """
+
return tag in self._tags
+ObjectStatus = namedtuple("ObjectStatus", ["region", "properties"])
+
+Objects = Union[List[ObjectStatus], ObjectStatus]
class TrackerRuntime(ABC):
+ """Base class for tracker runtime implementations. Tracker runtime is responsible for running the tracker executable and communicating with it."""
def __init__(self, tracker: Tracker):
+ """Creates a new tracker runtime instance.
+
+ Args:
+ tracker (Tracker): The tracker instance.
+ """
self._tracker = tracker
@property
def tracker(self) -> Tracker:
+ """Returns the tracker instance associated with this runtime."""
return self._tracker
def __enter__(self):
+ """Starts the tracker runtime."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
+ """Stops the tracker runtime."""
self.stop()
+ @property
+ def multiobject(self):
+ """Returns True if the tracker supports multiple objects, False otherwise."""
+ return False
+
@abstractmethod
def stop(self):
+ """Stops the tracker runtime."""
pass
@abstractmethod
def restart(self):
+ """Restarts the tracker runtime, usually stars a new process."""
pass
@abstractmethod
- def initialize(self, frame: Frame, region: Region, properties: dict = None) -> Tuple[Region, dict, float]:
+ def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """Initializes the tracker runtime with specified frame and objects. Returns the initial objects and the time it took to initialize the tracker.
+
+ Arguments:
+ frame {Frame} -- The frame to initialize the tracker with.
+ new {Objects} -- The objects to initialize the tracker with.
+ properties {dict} -- The properties to initialize the tracker with.
+
+ Returns:
+ Tuple[Objects, float] -- The initial objects and the time it took to initialize the tracker.
+ """
pass
@abstractmethod
- def update(self, frame: Frame, properties: dict = None) -> Tuple[Region, dict, float]:
+ def update(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """Updates the tracker runtime with specified frame and objects. Returns the updated objects and the time it took to update the tracker.
+
+ Arguments:
+ frame {Frame} -- The frame to update the tracker with.
+ new {Objects} -- The objects to update the tracker with.
+ properties {dict} -- The properties to update the tracker with.
+
+ Returns:
+ Tuple[Objects, float] -- The updated objects and the time it took to update the tracker.
+ """
pass
class RealtimeTrackerRuntime(TrackerRuntime):
+ """Base class for realtime tracker runtime implementations.
+ Realtime tracker runtime is responsible for running the tracker executable and communicating with it while simulating given real-time constraints."""
def __init__(self, runtime: TrackerRuntime, grace: int = 1, interval: float = 0.1):
+ """Initializes the realtime tracker runtime with specified tracker runtime, grace period and update interval.
+
+ Arguments:
+ runtime {TrackerRuntime} -- The tracker runtime to wrap.
+ grace {int} -- The grace period in seconds. The tracker will be updated at least once during the grace period. (default: {1})
+ interval {float} -- The update interval in seconds. (default: {0.1})
+
+ """
super().__init__(runtime.tracker)
self._runtime = runtime
self._grace = grace
@@ -376,27 +606,38 @@ def __init__(self, runtime: TrackerRuntime, grace: int = 1, interval: float = 0.
self._time = 0
self._out = None
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.stop()
+ @property
+ def multiobject(self):
+ """Returns True if the tracker supports multiple objects, False otherwise."""
+ return self._runtime.multiobject
def stop(self):
+ """Stops the tracker runtime."""
self._runtime.stop()
self._time = 0
self._out = None
def restart(self):
+ """Restarts the tracker runtime, usually stars a new process."""
self._runtime.restart()
self._time = 0
self._out = None
- def initialize(self, frame: Frame, region: Region, properties: dict = None) -> Tuple[Region, dict, float]:
+ def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """Initializes the tracker runtime with specified frame and objects. Returns the initial objects and the time it took to initialize the tracker.
+
+ Arguments:
+ frame {Frame} -- The frame to initialize the tracker with.
+ new {Objects} -- The objects to initialize the tracker with.
+ properties {dict} -- The properties to initialize the tracker with.
+
+ Returns:
+ Tuple[Objects, float] -- The initial objects and the time it took to initialize the tracker.
+ """
self._countdown = self._grace
self._out = None
- out, prop, time = self._runtime.initialize(frame, region, properties)
+ out, prop, time = self._runtime.initialize(frame, new, properties)
if time > self._interval:
if self._countdown > 0:
@@ -411,7 +652,17 @@ def initialize(self, frame: Frame, region: Region, properties: dict = None) -> T
return out, prop, time
- def update(self, frame: Frame, properties: dict = None) -> Tuple[Region, dict, float]:
+ def update(self, frame: Frame, _: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """Updates the tracker runtime with specified frame and objects. Returns the updated objects and the time it took to update the tracker.
+
+ Arguments:
+ frame {Frame} -- The frame to update the tracker with.
+ new {Objects} -- The objects to update the tracker with.
+ properties {dict} -- The properties to update the tracker with.
+
+ Returns:
+ Tuple[Objects, float] -- The updated objects and the time it took to update the tracker.
+ """
if self._time > self._interval:
self._time = self._time - self._interval
@@ -434,25 +685,43 @@ def update(self, frame: Frame, properties: dict = None) -> Tuple[Region, dict, f
class PropertyInjectorTrackerRuntime(TrackerRuntime):
+ """Base class for tracker runtime implementations that inject properties into the tracker runtime."""
def __init__(self, runtime: TrackerRuntime, **kwargs):
+ """Initializes the property injector tracker runtime with specified tracker runtime and properties.
+
+ Arguments:
+ runtime {TrackerRuntime} -- The tracker runtime to wrap.
+ **kwargs -- The properties to inject into the tracker runtime.
+ """
super().__init__(runtime.tracker)
self._runtime = runtime
self._properties = {k : str(v) for k, v in kwargs.items()}
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.stop()
+ @property
+ def multiobject(self):
+ """Returns True if the tracker supports multiple objects, False otherwise."""
+ return self._runtime.multiobject
def stop(self):
+ """Stops the tracker runtime."""
self._runtime.stop()
def restart(self):
+ """Restarts the tracker runtime, usually stars a new process."""
self._runtime.restart()
- def initialize(self, frame: Frame, region: Region, properties: dict = None) -> Tuple[Region, dict, float]:
+ def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """Initializes the tracker runtime with specified frame and objects. Returns the initial objects and the time it took to initialize the tracker.
+ This method injects the properties into the tracker runtime.
+
+ Arguments:
+ frame {Frame} -- The frame to initialize the tracker with.
+ new {Objects} -- The objects to initialize the tracker with.
+ properties {dict} -- The properties to initialize the tracker with.
+
+ Returns:
+ Tuple[Objects, float] -- The initial objects and the time it took to initialize the tracker."""
if not properties is None:
tproperties = dict(properties)
@@ -461,12 +730,174 @@ def initialize(self, frame: Frame, region: Region, properties: dict = None) -> T
tproperties.update(self._properties)
- return self._runtime.initialize(frame, region, tproperties)
+ return self._runtime.initialize(frame, new, tproperties)
+
+
+ def update(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """Updates the tracker runtime with specified frame and objects. Returns the updated objects and the time it took to update the tracker.
+
+ Arguments:
+ frame {Frame} -- The frame to update the tracker with.
+ new {Objects} -- The objects to update the tracker with.
+ properties {dict} -- The properties to update the tracker with.
+
+ Returns:
+ Tuple[Objects, float] -- The updated objects and the time it took to update the tracker.
+
+ """
+ return self._runtime.update(frame, new, properties)
+
+
+class SingleObjectTrackerRuntime(TrackerRuntime):
+ """Wrapper for tracker runtime that only support single object tracking. Used to enforce single object tracking even for multi object trackers."""
+
+ def __init__(self, runtime: TrackerRuntime):
+ """Initializes the single object tracker runtime with specified tracker runtime.
+
+ Arguments:
+ runtime {TrackerRuntime} -- The tracker runtime to wrap.
+ """
+ super().__init__(runtime.tracker)
+ self._runtime = runtime
+
+ @property
+ def multiobject(self):
+ """Returns False, since the tracker runtime only supports single object tracking."""
+ return False
+
+ def stop(self):
+ """Stops the tracker runtime.
+
+ Raises:
+ TrackerException -- If the tracker runtime does not support stopping.
+ """
+ self._runtime.stop()
+
+ def restart(self):
+ """Restarts the tracker runtime, usually stars a new process.
+
+ Raises:
+ TrackerException -- If the tracker runtime does not support restarting.
+ """
+ self._runtime.restart()
+
+ def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """Initializes the tracker runtime with specified frame and objects. Returns the initial objects and the time it took to initialize the tracker.
+
+ Arguments:
+ frame {Frame} -- The frame to initialize the tracker with.
+ new {Objects} -- The objects to initialize the tracker with.
+ properties {dict} -- The properties to initialize the tracker with.
+
+ Returns:
+ Tuple[Objects, float] -- The initial objects and the time it took to initialize the tracker.
+ """
+
+ if isinstance(new, list) and len(new) != 1: raise TrackerException("Only supports single object tracking", tracker=self.tracker)
+ status, time = self._runtime.initialize(frame, new, properties)
+ if isinstance(status, list): status = status[0]
+ return status, time
+
+ def update(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """Updates the tracker runtime with specified frame and objects. Returns the updated objects and the time it took to update the tracker.
+
+ Arguments:
+ frame {Frame} -- The frame to update the tracker with.
+ new {Objects} -- The objects to update the tracker with.
+ properties {dict} -- The properties to update the tracker with.
+
+ Returns:
+ Tuple[Objects, float] -- The updated objects and the time it took to update the tracker.
+ """
+
+ if not new is None: raise TrackerException("Only supports single object tracking", tracker=self.tracker)
+ status, time = self._runtime.update(frame, new, properties)
+ if isinstance(status, list): status = status[0]
+ return status, time
+
+class MultiObjectTrackerRuntime(TrackerRuntime):
+ """ This is a wrapper for tracker runtimes that do not support multi object tracking. STILL IN DEVELOPMENT!"""
+
+ def __init__(self, runtime: TrackerRuntime):
+ """Initializes the multi object tracker runtime with specified tracker runtime.
+
+ Arguments:
+ runtime {TrackerRuntime} -- The tracker runtime to wrap.
+ """
+
+ super().__init__(runtime.tracker)
+ if runtime.multiobject:
+ self._runtime = runtime
+ else:
+ self._runtime = [runtime]
+ self._used = 0
+
+ @property
+ def multiobject(self):
+ """Always returns True, since the tracker runtime supports multi object tracking."""
+ return True
+
+ def stop(self):
+ """Stops the tracker runtime."""
+ if isinstance(self._runtime, TrackerRuntime):
+ self._runtime.stop()
+ else:
+ for r in self._runtime:
+ r.stop()
+
+ def restart(self):
+ """Restarts the tracker runtime, usually stars a new process."""
+ if isinstance(self._runtime, TrackerRuntime):
+ self._runtime.restart()
+ else:
+ for r in self._runtime:
+ r.restart()
+
+ def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """Initializes the tracker runtime with specified frame and objects. Returns the initial objects and the time it took to initialize the tracker.
+ Internally this method initializes the tracker runtime for each object in the objects list.
+
+ Arguments:
+ frame {Frame} -- The frame to initialize the tracker with.
+ new {Objects} -- The objects to initialize the tracker with.
+ properties {dict} -- The properties to initialize the tracker with.
+
+ Returns:
+ Tuple[Objects, float] -- The initial objects and the time it took to initialize the tracker.
+ """
+
+ if isinstance(self._runtime, TrackerRuntime):
+ return self._runtime.initialize(frame, new, properties)
+ if isinstance(new, ObjectStatus):
+ new = [new]
+
+ self._used = 0
+ status = []
+ for i, o in enumerate(new):
+ if i >= len(self._runtime):
+ self._runtime.append(self._tracker.runtime())
+ self._runtime.initialize(frame, new, properties)
+ if isinstance(status, list): status = status[0]
+ return status
- def update(self, frame: Frame, properties: dict = None) -> Tuple[Region, dict, float]:
- return self._runtime.update(frame, properties)
+ def update(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """Updates the tracker runtime with specified frame and objects. Returns the updated objects and the time it took to update the tracker.
+ Internally this method updates the tracker runtime for each object in the new objects list.
+
+ Arguments:
+ frame {Frame} -- The frame to update the tracker with.
+ new {Objects} -- The objects to update the tracker with.
+ properties {dict} -- The properties to update the tracker with.
+
+ Returns:
+ Tuple[Objects, float] -- The updated objects and the time it took to update the tracker.
+ """
+ if not new is None: raise TrackerException("Only supports single object tracking")
+ status = self._runtime.update(frame, new, properties)
+ if isinstance(status, list): status = status[0]
+ return status
try:
diff --git a/vot/tracker/dummy.py b/vot/tracker/dummy.py
index 2e06666..b80f31d 100644
--- a/vot/tracker/dummy.py
+++ b/vot/tracker/dummy.py
@@ -1,20 +1,23 @@
+""" Dummy tracker for testing purposes. """
+
from __future__ import absolute_import
import os
from sys import path
import time
-from trax import Image, Region, Server, TraxStatus
-
def _main():
- region = None
+ """ Dummy tracker main function for testing purposes."""
+ from trax import Image, Region, Server, TraxStatus
+
+ objects = None
with Server([Region.RECTANGLE], [Image.PATH]) as server:
while True:
request = server.wait()
if request.type in [TraxStatus.QUIT, TraxStatus.ERROR]:
break
if request.type == TraxStatus.INITIALIZE:
- region = request.region
- server.status(region)
+ objects = request.objects
+ server.status(objects)
time.sleep(0.1)
if __name__ == '__main__':
diff --git a/vot/tracker/results.py b/vot/tracker/results.py
index 90a59bf..564d180 100644
--- a/vot/tracker/results.py
+++ b/vot/tracker/results.py
@@ -1,42 +1,113 @@
+"""Results module for storing and retrieving tracker results."""
+
import os
import fnmatch
from typing import List
from copy import copy
-from vot.region import Region, RegionType, Special, write_file, read_file, calculate_overlap
+from vot.region import Region, RegionType, Special, calculate_overlap
+from vot.region.io import write_trajectory, read_trajectory
from vot.utilities import to_string
class Results(object):
+ """Generic results interface for storing and retrieving results."""
def __init__(self, storage: "Storage"):
+ """Creates a new results interface.
+
+ Args:
+ storage (Storage): Storage interface
+ """
self._storage = storage
def exists(self, name):
+ """Returns true if the given file exists in the results storage.
+
+ Args:
+ name (str): File name
+
+ Returns:
+ bool: True if the file exists
+ """
return self._storage.isdocument(name)
def read(self, name):
+ """Returns a file handle for reading the given file from the results storage.
+
+ Args:
+ name (str): File name
+
+ Returns:
+ file: File handle
+ """
+ if name.endswith(".bin"):
+ return self._storage.read(name, binary=True)
return self._storage.read(name)
- def write(self, name):
+ def write(self, name: str):
+ """Returns a file handle for writing the given file to the results storage.
+
+ Args:
+ name (str): File name
+
+ Returns:
+ file: File handle
+ """
+ if name.endswith(".bin"):
+ return self._storage.write(name, binary=True)
return self._storage.write(name)
def find(self, pattern):
- return fnmatch.filter(self._storage.documents(), pattern)
+ """Returns a list of files matching the given pattern in the results storage.
+
+ Args:
+ pattern (str): Pattern
+ Returns:
+ list: List of files
+ """
+
+ return fnmatch.filter(self._storage.documents(), pattern)
+
class Trajectory(object):
+ """Trajectory class for storing and retrieving tracker trajectories."""
+
+ UNKNOWN = 0
+ INITIALIZATION = 1
+ FAILURE = 2
@classmethod
def exists(cls, results: Results, name: str) -> bool:
- return results.exists(name + ".txt")
+ """Returns true if the trajectory exists in the results storage.
+
+ Args:
+ results (Results): Results storage
+ name (str): Trajectory name (without extension)
+
+ Returns:
+ bool: True if the trajectory exists
+ """
+ return results.exists(name + ".bin") or results.exists(name + ".txt")
@classmethod
def gather(cls, results: Results, name: str) -> list:
-
- if not Trajectory.exists(results, name):
+ """Returns a list of files that are part of the trajectory.
+
+ Args:
+ results (Results): Results storage
+ name (str): Trajectory name (without extension)
+
+ Returns:
+ list: List of files
+ """
+
+ if results.exists(name + ".bin"):
+ files = [name + ".bin"]
+ elif results.exists(name + ".txt"):
+ files = [name + ".txt"]
+ else:
return []
- files = [name + ".txt"]
-
for propertyfile in results.find(name + "_*.value"):
files.append(propertyfile)
@@ -44,18 +115,38 @@ def gather(cls, results: Results, name: str) -> list:
@classmethod
def read(cls, results: Results, name: str) -> 'Trajectory':
+ """Reads a trajectory from the results storage.
+
+ Args:
+ results (Results): Results storage
+ name (str): Trajectory name (without extension)
+
+ Returns:
+ Trajectory: Trajectory
+ """
def parse_float(line):
+ """Parses a float from a line.
+
+ Args:
+ line (str): Line
+
+ Returns:
+ float: Float value
+ """
if not line.strip():
return None
return float(line.strip())
- if not results.exists(name + ".txt"):
+ if results.exists(name + ".txt"):
+ with results.read(name + ".txt") as fp:
+ regions = read_trajectory(fp)
+ elif results.exists(name + ".bin"):
+ with results.read(name + ".bin") as fp:
+ regions = read_trajectory(fp)
+ else:
raise FileNotFoundError("Trajectory data not found: {}".format(name))
- with results.read(name + ".txt") as fp:
- regions = read_file(fp)
-
trajectory = Trajectory(len(regions))
trajectory._regions = regions
@@ -67,14 +158,29 @@ def parse_float(line):
trajectory._properties[propertyname] = [parse_float(line) for line in lines]
except ValueError:
trajectory._properties[propertyname] = [line.strip() for line in lines]
-
+
return trajectory
- def __init__(self, length:int):
- self._regions = [Special(Special.UNKNOWN)] * length
+ def __init__(self, length: int):
+ """Creates a new trajectory of the given length.
+
+ Args:
+ length (int): Trajectory length
+ """
+ self._regions = [Special(Trajectory.UNKNOWN)] * length
self._properties = dict()
def set(self, frame: int, region: Region, properties: dict = None):
+ """Sets the region for the given frame.
+
+ Args:
+ frame (int): Frame index
+ region (Region): Region
+ properties (dict, optional): Frame properties. Defaults to None.
+
+ Raises:
+ IndexError: Frame index out of bounds
+ """
if frame < 0 or frame >= len(self._regions):
raise IndexError("Frame index out of bounds")
@@ -89,14 +195,41 @@ def set(self, frame: int, region: Region, properties: dict = None):
self._properties[k][frame] = v
def region(self, frame: int) -> Region:
+ """Returns the region for the given frame.
+
+ Args:
+ frame (int): Frame index
+
+ Raises:
+ IndexError: Frame index out of bounds
+
+ Returns:
+ Region: Region
+ """
if frame < 0 or frame >= len(self._regions):
raise IndexError("Frame index out of bounds")
return self._regions[frame]
def regions(self) -> List[Region]:
+ """ Returns the list of regions.
+
+ Returns:
+ List[Region]: List of regions
+ """
return copy(self._regions)
def properties(self, frame: int = None) -> dict:
+ """Returns the properties for the given frame or all properties if frame is None.
+
+ Args:
+ frame (int, optional): Frame index. Defaults to None.
+
+ Raises:
+ IndexError: Frame index out of bounds
+
+ Returns:
+ dict: Properties
+ """
if frame is None:
return tuple(self._properties.keys())
@@ -107,12 +240,29 @@ def properties(self, frame: int = None) -> dict:
return {k : v[frame] for k, v in self._properties.items() if not v[frame] is None}
def __len__(self):
+ """Returns the length of the trajectory.
+
+ Returns:
+ int: Length
+ """
return len(self._regions)
def write(self, results: Results, name: str):
-
- with results.write(name + ".txt") as fp:
- write_file(fp, self._regions)
+ """Writes the trajectory to the results storage.
+
+ Args:
+ results (Results): Results storage
+ name (str): Trajectory name (without extension)
+ """
+ from vot import config
+
+ if config.results_binary:
+ with results.write(name + ".bin") as fp:
+ write_trajectory(fp, self._regions)
+ else:
+ with results.write(name + ".txt") as fp:
+ # write_trajectory_file(fp, self._regions)
+ write_trajectory(fp, self._regions)
for k, v in self._properties.items():
with results.write(name + "_" + k + ".value") as fp:
@@ -120,6 +270,16 @@ def write(self, results: Results, name: str):
def equals(self, trajectory: 'Trajectory', check_properties: bool = False, overlap_threshold: float = 0.99999):
+ """Returns true if the trajectories are equal.
+
+ Args:
+ trajectory (Trajectory): _description_
+ check_properties (bool, optional): _description_. Defaults to False.
+ overlap_threshold (float, optional): _description_. Defaults to 0.99999.
+
+ Returns:
+ _type_: _description_
+ """
if not len(self) == len(trajectory):
return False
diff --git a/vot/tracker/tests.py b/vot/tracker/tests.py
index 9126bd9..1334c44 100644
--- a/vot/tracker/tests.py
+++ b/vot/tracker/tests.py
@@ -1,18 +1,20 @@
-
+""" Unit tests for the tracker module. """
import unittest
-from ..dataset.dummy import DummySequence
+from ..dataset.dummy import generate_dummy
from ..tracker.dummy import DummyTracker
class TestStacks(unittest.TestCase):
+ """Tests for the stacks module."""
def test_tracker_test(self):
+ """Test tracker runtime with dummy sequence and dummy tracker."""
tracker = DummyTracker
- sequence = DummySequence(10)
+ sequence = generate_dummy(10)
with tracker.runtime(log=False) as runtime:
runtime.initialize(sequence.frame(0), sequence.groundtruth(0))
- for i in range(1, sequence.length):
+ for i in range(1, len(sequence)):
runtime.update(sequence.frame(i))
diff --git a/vot/tracker/trax.py b/vot/tracker/trax.py
index 055e828..8fb0618 100644
--- a/vot/tracker/trax.py
+++ b/vot/tracker/trax.py
@@ -1,4 +1,8 @@
+""" TraX protocol implementation for the toolkit. TraX is a communication protocol for visual object tracking.
+ It enables communication between a tracker and a client. The protocol was originally developed for the VOT challenge to address
+ the need for a unified communication interface between trackers and benchmarking tools.
+"""
import sys
import os
import time
@@ -27,7 +31,7 @@
from vot.dataset import Frame, DatasetException
from vot.region import Region, Polygon, Rectangle, Mask
-from vot.tracker import Tracker, TrackerRuntime, TrackerException
+from vot.tracker import Tracker, TrackerRuntime, TrackerException, Objects, ObjectStatus
from vot.utilities import to_logical, to_number, normalize_path
PORT_POOL_MIN = 9090
@@ -36,27 +40,40 @@
logger = logging.getLogger("vot")
class LogAggregator(object):
+ """ Aggregates log messages from the tracker. """
def __init__(self):
+ """ Initializes the aggregator."""
self._fragments = []
def __call__(self, fragment):
+ """ Appends a new fragment to the log."""
self._fragments.append(fragment)
def __str__(self):
+ """ Returns the aggregated log."""
return "".join(self._fragments)
class ColorizedOutput(object):
+ """ Colorized output for the tracker."""
def __init__(self):
+ """ Initializes the colorized output."""
colorama.init()
def __call__(self, fragment):
+ """ Prints a new fragment to the output.
+
+ Args:
+ fragment: The fragment to be printed.
+ """
print(colorama.Fore.CYAN + fragment + colorama.Fore.RESET, end="")
class PythonCrashHelper(object):
+ """ Helper class for detecting Python crashes in the tracker."""
def __init__(self):
+ """ Initializes the crash helper."""
self._matcher = re.compile(r'''
^Traceback
[\s\S]+?
@@ -64,12 +81,27 @@ def __init__(self):
''', re.M | re.X)
def __call__(self, log, directory):
+ """ Detects Python crashes in the log.
+
+ Args:
+ log: The log to be checked.
+ directory: The directory where the log is stored.
+ """
matches = self._matcher.findall(log)
if len(matches) > 0:
return matches[-1].group(0)
return None
def convert_frame(frame: Frame, channels: list) -> dict:
+ """ Converts a frame to a dictionary of Trax images.
+
+ Args:
+ frame: The frame to be converted.
+ channels: The list of channels to be converted.
+
+ Returns:
+ A dictionary of Trax images.
+ """
tlist = dict()
for channel in channels:
@@ -82,16 +114,31 @@ def convert_frame(frame: Frame, channels: list) -> dict:
return tlist
def convert_region(region: Region) -> TraxRegion:
+ """ Converts a region to a Trax region.
+
+ Args:
+ region: The region to be converted.
+
+ Returns:
+ A Trax region.
+ """
if isinstance(region, Rectangle):
return TraxRectangle.create(region.x, region.y, region.width, region.height)
elif isinstance(region, Polygon):
return TraxPolygon.create([region[i] for i in range(region.size)])
elif isinstance(region, Mask):
return TraxMask.create(region.mask, x=region.offset[0], y=region.offset[1])
-
return None
def convert_traxregion(region: TraxRegion) -> Region:
+ """ Converts a Trax region to a region.
+
+ Args:
+ region: The Trax region to be converted.
+
+ Returns:
+ A region.
+ """
if region.type == TraxRegion.RECTANGLE:
x, y, width, height = region.bounds()
return Rectangle(x, y, width, height)
@@ -99,22 +146,61 @@ def convert_traxregion(region: TraxRegion) -> Region:
return Polygon(list(region))
elif region.type == TraxRegion.MASK:
return Mask(region.array(), region.offset(), optimize=True)
+ return None
+
+def convert_objects(objects: Objects) -> TraxRegion:
+ """ Converts a list of objects to a Trax region.
+
+ Args:
+ objects: The list of objects to be converted.
+
+ Returns:
+ A Trax region.
+ """
+ if objects is None: return []
+ if isinstance(objects, (list, )):
+ return [(convert_region(o.region), dict(o.properties)) for o in objects]
+ if isinstance(objects, (ObjectStatus, )):
+ return [(convert_region(objects.region), dict(objects.properties))]
+ else:
+ return [(convert_region(objects), dict())]
+
+def convert_traxobjects(region: TraxRegion) -> Region:
+ """ Converts a Trax region to a region.
+ Args:
+ region: The Trax region to be converted.
+
+ Returns:
+ A region.
+
+ """
+ if region.type == TraxRegion.RECTANGLE:
+ x, y, width, height = region.bounds()
+ return Rectangle(x, y, width, height)
+ elif region.type == TraxRegion.POLYGON:
+ return Polygon(list(region))
+ elif region.type == TraxRegion.MASK:
+ return Mask(region.array(), region.offset(), optimize=True)
return None
class TestRasterMethods(unittest.TestCase):
+ """ Tests for the raster methods. """
def test_convert_traxregion(self):
+ """ Tests the conversion of Trax regions."""
convert_traxregion(TraxRectangle.create(0, 0, 10, 10))
convert_traxregion(TraxPolygon.create([(0, 0), (10, 0), (10, 10), (0, 10)]))
convert_traxregion(TraxMask.create(np.ones((100, 100), dtype=np.uint8)))
def test_convert_region(self):
+ """ Tests the conversion of regions."""
convert_region(Rectangle(0, 0, 10, 10))
convert_region(Polygon([(0, 0), (10, 0), (10, 10), (0, 10)]))
convert_region(Mask(np.ones((100, 100), dtype=np.uint8)))
def open_local_port(port: int):
+ """ Opens a local port for listening."""
socket = socketio.socket(socketio.AF_INET, socketio.SOCK_STREAM)
try:
socket.setsockopt(socketio.SOL_SOCKET, socketio.SO_REUSEADDR, 1)
@@ -129,12 +215,25 @@ def open_local_port(port: int):
return None
def normalize_paths(paths, tracker):
+ """ Normalizes a list of paths relative to the tracker source."""
root = os.path.dirname(tracker.source)
return [normalize_path(path, root) for path in paths]
class TrackerProcess(object):
-
+ """ A tracker process. This class is used to run trackers in a separate process and handles
+ starting, stopping and communication with the process. """
+
def __init__(self, command: str, envvars=dict(), timeout=30, log=False, socket=False):
+ """ Initializes a new tracker process.
+
+ Args:
+ command: The command to run the tracker.
+ envvars: A dictionary of environment variables to be set for the tracker process.
+ timeout: The timeout for the tracker process.
+ log: Whether to log the tracker output.
+ socket: Whether to use a socket for communication.
+
+ """
environment = dict(os.environ)
environment.update(envvars)
@@ -188,6 +287,7 @@ def __init__(self, command: str, envvars=dict(), timeout=30, log=False, socket=F
self._client = Client(
stream=(self._process.stdin.fileno(), self._process.stdout.fileno()), log=log
)
+
except TraxException as e:
self.terminate()
self._watchdog_reset(False)
@@ -195,8 +295,15 @@ def __init__(self, command: str, envvars=dict(), timeout=30, log=False, socket=F
self._watchdog_reset(False)
self._has_vot_wrapper = not self._client.get("vot") is None
+ self._multiobject = self._client.get("multiobject")
def _watchdog_reset(self, enable=True):
+ """ Resets the watchdog.
+
+ Args:
+ enable: Whether to enable the watchdog.
+
+ """
if self._watchdog_counter == 0:
return
@@ -206,6 +313,7 @@ def _watchdog_reset(self, enable=True):
self._watchdog_counter = -1
def _watchdog_loop(self):
+ """ The watchdog loop. This loop is used to monitor the tracker process and terminate it if it does not respond anymore."""
while self.alive:
time.sleep(0.1)
@@ -219,28 +327,46 @@ def _watchdog_loop(self):
@property
def has_vot_wrapper(self):
+ """ Whether the tracker has a VOT wrapper. VOT wrapper limits TraX functionality and injects a property at handshake to let the client know this."""
return self._has_vot_wrapper
@property
def returncode(self):
+ """ The return code of the tracker process."""
return self._returncode
@property
def workdir(self):
+ """ The working directory of the tracker process."""
return self._workdir
@property
def interrupted(self):
+ """ Whether the tracker process was interrupted."""
return self._watchdog_counter == 0
@property
def alive(self):
+ """ Whether the tracker process is alive."""
if self._process is None:
return False
self._returncode = self._process.returncode
return self._returncode is None
- def initialize(self, frame: Frame, region: Region, properties: dict = None) -> Tuple[Region, dict, float]:
+ def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """ Initializes the tracker. This method is used to initialize the tracker with the first frame. It returns the initial state of the tracker.
+
+ Args:
+ frame: The first frame.
+ new: The initial state of the tracker.
+ properties: The properties to be set for the tracker.
+
+ Returns:
+ The initial state of the tracker.
+
+ Raises:
+ TraxException: If the tracker is not alive.
+ """
if not self.alive:
raise TraxException("Tracker not alive")
@@ -249,33 +375,55 @@ def initialize(self, frame: Frame, region: Region, properties: dict = None) -> T
properties = dict()
tlist = convert_frame(frame, self._client.channels)
- tregion = convert_region(region)
+ tobjects = convert_objects(new)
self._watchdog_reset(True)
- region, properties, elapsed = self._client.initialize(tlist, tregion, properties)
+ status, elapsed = self._client.initialize(tlist, tobjects, properties)
self._watchdog_reset(False)
- return convert_traxregion(region), properties.dict(), elapsed
+ status = [ObjectStatus(convert_traxregion(region), properties) for region, properties in status]
+
+ return status, elapsed
+
+
+ def update(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """ Updates the tracker with a new frame. This method is used to update the tracker with a new frame. It returns the new state of the tracker.
+ Args:
+ frame: The new frame.
+ new: The new state of the tracker.
+ properties: The properties to be set for the tracker.
- def frame(self, frame: Frame, properties: dict = dict()) -> Tuple[Region, dict, float]:
+ Returns:
+ The new state of the tracker.
+
+ Raises:
+ TraxException: If the tracker is not alive.
+
+ """
if not self.alive:
raise TraxException("Tracker not alive")
tlist = convert_frame(frame, self._client.channels)
+ tobjects = convert_objects(new)
+
self._watchdog_reset(True)
- region, properties, elapsed = self._client.frame(tlist, properties)
+ status, elapsed = self._client.frame(tlist, properties, tobjects)
self._watchdog_reset(False)
- return convert_traxregion(region), properties.dict(), elapsed
+ status = [ObjectStatus(convert_traxregion(region), properties) for region, properties in status]
+
+ return status, elapsed
def terminate(self):
+ """ Terminates the tracker. This method is used to terminate the tracker. It closes the connection to the tracker and terminates the tracker process.
+ """
with self._watchdog_lock:
if not self.alive:
@@ -314,10 +462,12 @@ def terminate(self):
self._process = None
def __del__(self):
+ """ Destructor. This method is used to terminate the tracker process if it is still alive."""
if hasattr(self, "_workdir"):
shutil.rmtree(self._workdir, ignore_errors=True)
def wait(self):
+ """ Waits for the tracker to terminate. This method is used to wait for the tracker to terminate. It waits until the tracker process terminates."""
self._watchdog_reset(True)
@@ -331,8 +481,23 @@ def wait(self):
class TraxTrackerRuntime(TrackerRuntime):
+ """ The TraX tracker runtime. This class is used to run a tracker using the TraX protocol."""
def __init__(self, tracker: Tracker, command: str, log: bool = False, timeout: int = 30, linkpaths=None, envvars=None, arguments=None, socket=False, restart=False, onerror=None):
+ """ Initializes the TraX tracker runtime.
+
+ Args:
+ tracker: The tracker to be run.
+ command: The command to run the tracker.
+ log: Whether to log the output of the tracker.
+ timeout: The timeout in seconds for the tracker to respond.
+ linkpaths: The paths to be added to the PATH environment variable.
+ envvars: The environment variables to be set for the tracker.
+ arguments: The arguments to be passed to the tracker.
+ socket: Whether to use a socket to communicate with the tracker.
+ restart: Whether to restart the tracker if it crashes.
+ onerror: The error handler to be called if the tracker crashes.
+ """
super().__init__(tracker)
self._command = command
self._process = None
@@ -365,9 +530,17 @@ def __init__(self, tracker: Tracker, command: str, log: bool = False, timeout: i
@property
def tracker(self) -> Tracker:
+ """ The associated tracker object. """
return self._tracker
+ @property
+ def multiobject(self):
+ """ Whether the tracker supports multiple objects."""
+ self._connect()
+ return self._process._multiobject
+
def _connect(self):
+ """ Connects to the tracker. This method is used to connect to the tracker. It starts the tracker process if it is not running yet."""
if not self._process:
if not self._output is None:
log = self._output
@@ -378,6 +551,7 @@ def _connect(self):
self._restart = True
def _error(self, exception):
+ """ Handles an error. This method is used to handle an error. It calls the error handler if it is set."""
workdir = None
timeout = False
if not self._output is None:
@@ -409,13 +583,24 @@ def _error(self, exception):
tracker_log=log if not self._output is None else None)
def restart(self):
+ """ Restarts the tracker. This method is used to restart the tracker. It stops the tracker process and starts it again."""
try:
self.stop()
self._connect()
except TraxException as e:
self._error(e)
- def initialize(self, frame: Frame, region: Region, properties: dict = None) -> Tuple[Region, dict, float]:
+ def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """ Initializes the tracker. This method is used to initialize the tracker. It starts the tracker process if it is not running yet.
+
+ Args:
+ frame: The initial frame.
+ new: The initial objects.
+ properties: The initial properties.
+
+ Returns:
+ A tuple containing the initial objects and the initial score.
+ """
try:
if self._restart:
self.stop()
@@ -426,35 +611,72 @@ def initialize(self, frame: Frame, region: Region, properties: dict = None) -> T
if not properties is None:
tproperties.update(properties)
- return self._process.initialize(frame, region, tproperties)
+ return self._process.initialize(frame, new, tproperties)
except TraxException as e:
self._error(e)
- def update(self, frame: Frame, properties: dict = None) -> Tuple[Region, dict, float]:
+ def update(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]:
+ """ Updates the tracker. This method is used to update the tracker state with a new frame.
+
+ Args:
+ frame: The current frame.
+ new: The current objects.
+ properties: The current properties.
+
+ Returns:
+ A tuple containing the updated objects and the updated score.
+ """
try:
if properties is None:
properties = dict()
- return self._process.frame(frame, properties)
+ return self._process.update(frame, new, properties)
except TraxException as e:
self._error(e)
def stop(self):
+ """ Stops the tracker. This method is used to stop the tracker. It stops the tracker process."""
if not self._process is None:
self._process.terminate()
self._process = None
def __del__(self):
- if not self._process is None:
- self._process.terminate()
- self._process = None
+ """ Destructor. This method is used to stop the tracker process when the object is deleted."""
+ self.stop()
def escape_path(path):
+ """ Escapes a path. This method is used to escape a path.
+
+ Args:
+ path: The path to escape.
+
+ Returns:
+ The escaped path.
+ """
if sys.platform.startswith("win"):
return path.replace("\\\\", "\\").replace("\\", "\\\\")
else:
return path
def trax_python_adapter(tracker, command, envvars, paths="", log: bool = False, timeout: int = 30, linkpaths=None, arguments=None, python=None, socket=False, restart=False, **kwargs):
+ """ Creates a Python adapter for a tracker. This method is used to create a Python adapter for a tracker.
+
+ Args:
+ tracker: The tracker to create the adapter for.
+ command: The command to run the tracker.
+ envvars: The environment variables to set.
+ paths: The paths to add to the Python path.
+ log: Whether to log the tracker output.
+ timeout: The timeout in seconds.
+ linkpaths: The paths to link.
+ arguments: The arguments to pass to the tracker.
+ python: The Python interpreter to use.
+ socket: Whether to use a socket to communicate with the tracker.
+ restart: Whether to restart the tracker after each frame.
+ kwargs: Additional keyword arguments.
+
+ Returns:
+ The Python TraX runtime object.
+ """
if not isinstance(paths, list):
paths = paths.split(os.pathsep)
@@ -475,6 +697,25 @@ def trax_python_adapter(tracker, command, envvars, paths="", log: bool = False,
return TraxTrackerRuntime(tracker, command, log=log, timeout=timeout, linkpaths=linkpaths, envvars=envvars, arguments=arguments, socket=socket, restart=restart)
def trax_matlab_adapter(tracker, command, envvars, paths="", log: bool = False, timeout: int = 30, linkpaths=None, arguments=None, matlab=None, socket=False, restart=False, **kwargs):
+ """ Creates a Matlab adapter for a tracker. This method is used to create a Matlab adapter for a tracker.
+
+ Args:
+ tracker: The tracker to create the adapter for.
+ command: The command to run the tracker.
+ envvars: The environment variables to set.
+ paths: The paths to add to the Matlab path.
+ log: Whether to log the tracker output.
+ timeout: The timeout in seconds.
+ linkpaths: The paths to link.
+ arguments: The arguments to pass to the tracker.
+ matlab: The Matlab executable to use.
+ socket: Whether to use a socket to communicate with the tracker.
+ restart: Whether to restart the tracker after each frame.
+ kwargs: Additional keyword arguments.
+
+ Returns:
+ The Matlab TraX runtime object.
+ """
if not isinstance(paths, list):
paths = paths.split(os.pathsep)
@@ -514,6 +755,25 @@ def trax_matlab_adapter(tracker, command, envvars, paths="", log: bool = False,
return TraxTrackerRuntime(tracker, command, log=log, timeout=timeout, linkpaths=linkpaths, envvars=envvars, arguments=arguments, socket=socket, restart=restart)
def trax_octave_adapter(tracker, command, envvars, paths="", log: bool = False, timeout: int = 30, linkpaths=None, arguments=None, socket=False, restart=False, **kwargs):
+ """ Creates an Octave adapter for a tracker. This method is used to create an Octave adapter for a tracker.
+
+ Args:
+ tracker: The tracker to create the adapter for.
+ command: The command to run the tracker.
+ envvars: The environment variables to set.
+ paths: The paths to add to the Octave path.
+ log: Whether to log the tracker output.
+ timeout: The timeout in seconds.
+ linkpaths: The paths to link.
+ arguments: The arguments to pass to the tracker.
+ socket: Whether to use a socket to communicate with the tracker.
+ restart: Whether to restart the tracker after each frame.
+ kwargs: Additional keyword arguments.
+
+ Returns:
+ The Octave TraX runtime object.
+ """
+
if not isinstance(paths, list):
paths = paths.split(os.pathsep)
diff --git a/vot/utilities/__init__.py b/vot/utilities/__init__.py
index aebccaa..db90717 100644
--- a/vot/utilities/__init__.py
+++ b/vot/utilities/__init__.py
@@ -1,16 +1,18 @@
+""" This module contains various utility functions and classes used throughout the toolkit. """
+
import os
import sys
import csv
import re
import hashlib
-import errno
import logging
import inspect
+import time
import concurrent.futures as futures
from logging import Formatter, LogRecord
from numbers import Number
-from typing import Tuple
+from typing import Any, Mapping, Tuple
import typing
from vot import get_logger
@@ -19,8 +21,18 @@
__ALIASES = dict()
+def import_class(classpath: str) -> typing.Type:
+ """Import a class from a string by importing parent packages.
+
+ Args:
+ classpath (str): String representing a canonical class name with all parent packages.
+
+ Raises:
+ ImportError: Raised when
-def import_class(classpath):
+ Returns:
+ [type]: [description]
+ """
delimiter = classpath.rfind(".")
if delimiter == -1:
if classpath in __ALIASES:
@@ -33,7 +45,14 @@ def import_class(classpath):
return getattr(module, classname)
def alias(*args):
- def register(cls):
+ """ Decorator for registering class aliases. Aliases are used to refer to classes by a short name.
+
+ Args:
+ *args: A list of strings representing aliases for the class.
+ """
+
+ def register(cls: typing.Type):
+ """ Register the class with the given aliases. """
assert cls is not None
for name in args:
if name in __ALIASES:
@@ -45,9 +64,25 @@ def register(cls):
return register
def class_fullname(o):
+ """Returns the full name of the class of the given object.
+
+ Args:
+ o: The object to get the class name from.
+
+ Returns:
+ The full name of the class of the given object.
+ """
return class_string(o.__class__)
def class_string(kls):
+ """Returns the full name of the given class.
+
+ Args:
+ kls: The class to get the name from.
+
+ Returns:
+ The full name of the given class.
+ """
assert inspect.isclass(kls)
module = kls.__module__
if module is None or module == str.__class__.__module__:
@@ -56,8 +91,27 @@ def class_string(kls):
return module + '.' + kls.__name__
def flip(size: Tuple[Number, Number]) -> Tuple[Number, Number]:
+ """Flips the given size tuple.
+
+ Args:
+ size: The size tuple to flip.
+
+ Returns:
+ The flipped size tuple.
+ """
return (size[1], size[0])
+def flatten(nested_list):
+ """Flattens a nested list.
+
+ Args:
+ nested_list: The nested list to flatten.
+
+ Returns:
+ The flattened list.
+ """
+ return [item for sublist in nested_list for item in sublist]
+
from vot.utilities.notebook import is_notebook
if is_notebook():
@@ -70,22 +124,36 @@ def flip(size: Tuple[Number, Number]) -> Tuple[Number, Number]:
from tqdm import tqdm
class Progress(object):
+ """Wrapper around tqdm progress bar, enables silecing the progress output and some more
+ costumizations.
+ """
class StreamProxy(object):
+ """Proxy class for tqdm to enable silent mode."""
def write(self, x):
+ """Write function used by tqdm."""
# Avoid print() second call (useless \n)
if len(x.rstrip()) > 0:
tqdm.write(x)
+
def flush(self):
+ """Flush function used by tqdm."""
#return getattr(self.file, "flush", lambda: None)()
pass
@staticmethod
def logstream():
+ """Returns a stream proxy that can be used to redirect output to the progress bar."""
return Progress.StreamProxy()
def __init__(self, description="Processing", total=100):
+ """Creates a new progress bar.
+
+ Args:
+ description: The description of the progress bar.
+ total: The total number of steps.
+ """
silent = get_logger().level > logging.INFO
if not silent:
@@ -99,9 +167,22 @@ def __init__(self, description="Processing", total=100):
self._total = total if not silent else 0
def _percent(self, n):
+ """Returns the percentage of the given value.
+
+ Args:
+ n: The value to compute the percentage of.
+
+ Returns:
+ The percentage of the given value.
+ """
return int((n * 100) / self._total)
def absolute(self, value):
+ """Sets the progress to the given value.
+
+ Args:
+ value: The value to set the progress to.
+ """
if self._tqdm is None:
if self._total == 0:
return
@@ -113,6 +194,11 @@ def absolute(self, value):
self._tqdm.update(value - self._tqdm.n) # will also set self.n = b * bsize
def relative(self, n):
+ """Increments the progress by the given value.
+
+ Args:
+ n: The value to increment the progress by.
+ """
if self._tqdm is None:
if self._total == 0:
return
@@ -124,6 +210,11 @@ def relative(self, n):
self._tqdm.update(n) # will also set self.n = b * bsize
def total(self, t):
+ """Sets the total number of steps.
+
+ Args:
+ t: The total number of steps.
+ """
if self._tqdm is None:
if self._total == 0:
return
@@ -135,16 +226,26 @@ def total(self, t):
self._tqdm.refresh()
def __enter__(self):
+ """Enters the context manager."""
return self
def __exit__(self, exc_type, exc_value, traceback):
+ """Exits the context manager."""
self.close()
def close(self):
+ """Closes the progress bar."""
if self._tqdm:
self._tqdm.close()
def extract_files(archive, destination, callback = None):
+ """Extracts all files from the given archive to the given destination.
+
+ Args:
+ archive: The archive to extract the files from.
+ destination: The destination to extract the files to.
+ callback: An optional callback function that is called after each file is extracted.
+ """
from zipfile import ZipFile
with ZipFile(file=archive) as zip_file:
@@ -181,21 +282,32 @@ def read_properties(filename: str, delimiter: str = '=') -> typing.Dict[str, str
properties[groups.group(1)] = groups.group(2)
return properties
-def write_properties(filename, dictionary, delimiter='='):
- ''' Writes the provided dictionary in key sorted order to a properties
+def write_properties(filename: str, dictionary: Mapping[str, Any], delimiter: str = '='):
+ """Writes the provided dictionary in key sorted order to a properties
file with each line in the format: keyvalue
- filename -- the name of the file to be written
- dictionary -- a dictionary containing the key/value pairs.
- '''
+
+ Args:
+ filename (str): the name of the file to be written
+ dictionary (Mapping[str, str]): a dictionary containing the key/value pairs.
+ delimiter (str, optional): _description_. Defaults to '='.
+ """
+
open_kwargs = {'mode': 'w', 'newline': ''} if six.PY3 else {'mode': 'wb'}
with open(filename, **open_kwargs) as csvfile:
writer = csv.writer(csvfile, delimiter=delimiter, escapechar='\\',
quoting=csv.QUOTE_NONE)
writer.writerows(sorted(dictionary.items()))
-def file_hash(filename):
+def file_hash(filename: str) -> Tuple[str, str]:
+ """Calculates MD5 and SHA1 hashes based on file content
- # BUF_SIZE is totally arbitrary, change for your app!
+ Args:
+ filename (str): Filename of the file to open and analyze
+
+ Returns:
+ Tuple[str, str]: MD5 and SHA1 hashes as hexadecimal strings.
+ """
+
bufsize = 65536 # lets read stuff in 64kb chunks!
md5 = hashlib.md5()
@@ -211,7 +323,16 @@ def file_hash(filename):
return md5.hexdigest(), sha1.hexdigest()
-def arg_hash(*args, **kwargs):
+def arg_hash(*args, **kwargs) -> str:
+ """Computes hash based on input positional and keyword arguments.
+
+ The algorithm tries to convert all arguments to string, then enclose them with delimiters. The
+ positonal arguments are listed as is, keyword arguments are sorted and encoded with their keys as
+ well as values.
+
+ Returns:
+ str: SHA1 hash as hexadecimal string
+ """
sha1 = hashlib.sha1()
for arg in args:
@@ -222,9 +343,25 @@ def arg_hash(*args, **kwargs):
return sha1.hexdigest()
-def which(program):
+def which(program: str) -> str:
+ """Locates an executable in system PATH list by its name.
+
+ Args:
+ program (str): Name of the executable
+
+ Returns:
+ str: Full path or None if not found
+ """
def is_exe(fpath):
+ """Checks if the given path is an executable file.
+
+ Args:
+ fpath (str): Path to check
+
+ Returns:
+ bool: True if the path is an executable file
+ """
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
fpath, _ = os.path.split(program)
@@ -240,6 +377,15 @@ def is_exe(fpath):
return None
def normalize_path(path, root=None):
+ """Normalizes the given path by making it absolute and removing redundant parts.
+
+ Args:
+ path (str): Path to normalize
+ root (str, optional): Root path to use if the given path is relative. Defaults to None.
+
+ Returns:
+ str: Normalized path
+ """
if os.path.isabs(path):
return path
if not root:
@@ -247,18 +393,47 @@ def normalize_path(path, root=None):
return os.path.normpath(os.path.join(root, path))
def localize_path(path):
+ """Converts path to local format (backslashes on Windows, slashes on Linux)
+
+ Args:
+ path (str): Path to convert
+
+ Returns:
+ str: Converted path
+ """
if sys.platform.startswith("win"):
return path.replace("/", "\\")
else:
return path.replace("\\", "/")
-def to_string(n):
+def to_string(n: Any) -> str:
+ """Converts object to string, returs empty string if object is None (so a bit different behaviour than
+ the original string conversion).
+
+ Args:
+ n (Any): Object of any kind
+
+ Returns:
+ str: String representation (using built-in conversion)
+ """
if n is None:
return ""
else:
return str(n)
def to_number(val, max_n = None, min_n = None, conversion=int):
+ """Converts the given value to a number and checks if it is within the given range. If the value is not a number,
+ a RuntimeError is raised.
+
+ Args:
+ val (Any): Value to convert
+ max_n (int, optional): Maximum allowed value. Defaults to None.
+ min_n (int, optional): Minimum allowed value. Defaults to None.
+ conversion (function, optional): Conversion function. Defaults to int.
+
+ Returns:
+ int: Converted value
+ """
try:
n = conversion(val)
@@ -274,6 +449,15 @@ def to_number(val, max_n = None, min_n = None, conversion=int):
raise RuntimeError("Number conversion error")
def to_logical(val):
+ """Converts the given value to a logical value (True/False). If the value is not a logical value,
+ a RuntimeError is raised.
+
+ Args:
+ val (Any): Value to convert
+
+ Returns:
+ bool: Converted value
+ """
try:
if isinstance(val, str):
return val.lower() in ['true', '1', 't', 'y', 'yes']
@@ -283,24 +467,63 @@ def to_logical(val):
except ValueError:
raise RuntimeError("Logical value conversion error")
+def format_size(num, suffix="B"):
+ """Formats the given number as a human-readable size string.
+
+ Args:
+ num (int): Number to format
+ suffix (str, optional): Suffix to use. Defaults to "B".
+
+ Returns:
+ str: Formatted string
+ """
+ for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
+ if abs(num) < 1024.0:
+ return f"{num:3.1f}{unit}{suffix}"
+ num /= 1024.0
+ return f"{num:.1f}Yi{suffix}"
+
def singleton(class_):
+ """Singleton decorator for classes.
+
+ Args:
+ class_ (class): Class to decorate
+
+ Returns:
+ class: Decorated class
+
+ Example:
+ @singleton
+ class MyClass:
+ pass
+
+ a = MyClass()
+ """
instances = {}
def getinstance(*args, **kwargs):
+ """Returns the singleton instance of the class. If the instance does not exist, it is created."""
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
class ColoredFormatter(Formatter):
+ """Colored log formatter using colorama package.
+ """
class Empty(object):
"""An empty class used to copy :class:`~logging.LogRecord` objects without reinitializing them."""
def __init__(self, **kwargs):
+ """Initializes the formatter.
+
+ Args:
+ **kwargs: Keyword arguments passed to the base class
+
+ """
super().__init__(**kwargs)
colorama.init()
-
self._styles = dict(
debug=colorama.Fore.GREEN,
verbose=colorama.Fore.BLACK,
@@ -311,8 +534,15 @@ def __init__(self, **kwargs):
critical=colorama.Fore.RED + colorama.Style.BRIGHT,
)
+ def format(self, record: LogRecord) -> str:
+ """Formats message by injecting colorama terminal codes for text coloring.
+
+ Args:
+ record (LogRecord): Input log record
- def format(self, record: LogRecord):
+ Returns:
+ str: Formatted string
+ """
style = self._styles[record.levelname.lower()]
copy = ColoredFormatter.Empty()
@@ -326,11 +556,20 @@ def format(self, record: LogRecord):
class ThreadPoolExecutor(futures.ThreadPoolExecutor):
+ """Thread pool executor with a shutdown method that waits for all threads to finish.
+ """
+
def __init__(self, *args, **kwargs):
+ """Initializes the thread pool executor."""
super().__init__(*args, **kwargs)
#self._work_queue = Queue.Queue(maxsize=maxsize)
def shutdown(self, wait=True):
+ """Shuts down the thread pool executor. If wait is True, waits for all threads to finish.
+
+ Args:
+ wait (bool, optional): Wait for all threads to finish. Defaults to True.
+ """
import queue
with self._shutdown_lock:
self._shutdown = True
@@ -344,3 +583,26 @@ def shutdown(self, wait=True):
if wait:
for t in self._threads:
t.join()
+
+class Timer(object):
+ """Simple timer class for measuring elapsed time."""
+
+ def __init__(self, name=None):
+ """Initializes the timer.
+
+ Args:
+ name (str, optional): Name of the timer. Defaults to None.
+ """
+ self.name = name
+
+ def __enter__(self):
+ """Starts the timer."""
+ self._tstart = time.time()
+
+ def __exit__(self, type, value, traceback):
+ """Stops the timer and prints the elapsed time."""
+ elapsed = time.time() - self._tstart
+ if self.name:
+ print('[%s]: %.4fs' % (self.name, elapsed))
+ else:
+ print('Elapsed: %.4fs' % elapsed)
\ No newline at end of file
diff --git a/vot/utilities/cli.py b/vot/utilities/cli.py
index a229868..972122b 100644
--- a/vot/utilities/cli.py
+++ b/vot/utilities/cli.py
@@ -1,3 +1,5 @@
+"""Command line interface for the toolkit. This module provides a command line interface for the toolkit. It is used to run experiments, manage trackers and datasets, and to perform other tasks."""
+
import os
import sys
import argparse
@@ -18,6 +20,7 @@ class EnvDefault(argparse.Action):
"""Argparse action that resorts to a value in a specified envvar if no value is provided via program arguments.
"""
def __init__(self, envvar, required=True, default=None, separator=None, **kwargs):
+ """Initialize the action"""
if not default and envvar:
if envvar in os.environ:
default = os.environ[envvar]
@@ -30,6 +33,7 @@ def __init__(self, envvar, required=True, default=None, separator=None, **kwargs
**kwargs)
def __call__(self, parser, namespace, values, option_string=None):
+ """Call the action"""
if self.separator:
values = values.split(self.separator)
setattr(namespace, self.dest, values)
@@ -40,8 +44,11 @@ def do_test(config: argparse.Namespace):
Args:
config (argparse.Namespace): Configuration
"""
- from vot.dataset.dummy import DummySequence
- from vot.dataset import load_sequence
+ from vot.dataset.dummy import generate_dummy
+ from vot.dataset import load_sequence, Frame
+ from vot.tracker import ObjectStatus
+ from vot.experiment.helpers import MultiObjectHelper
+
trackers = Registry(config.registry)
if not config.tracker:
@@ -57,57 +64,86 @@ def do_test(config: argparse.Namespace):
tracker = trackers[config.tracker]
- logger.info("Generating dummy sequence")
-
- if config.sequence is None:
- sequence = DummySequence(50)
- else:
- sequence = load_sequence(normalize_path(config.sequence))
-
- logger.info("Obtaining runtime for tracker %s", tracker.identifier)
-
- if config.visualize:
- import matplotlib.pylab as plt
- from vot.utilities.draw import MatplotlibDrawHandle
- figure = plt.figure()
- figure.canvas.set_window_title('VOT Test')
- axes = figure.add_subplot(1, 1, 1)
- axes.set_aspect("equal")
- handle = MatplotlibDrawHandle(axes, size=sequence.size)
- handle.style(fill=False)
- figure.show()
-
- runtime = None
-
+ def visualize(axes, frame: Frame, reference, state):
+ """Visualize the frame and the state of the tracker.
+
+ Args:
+ axes (matplotlib.axes.Axes): The axes to draw on.
+ frame (Frame): The frame to draw.
+ reference (list): List of references.
+ state (ObjectStatus): The state of the tracker.
+
+ """
+ axes.clear()
+ handle.image(frame.channel())
+ if not isinstance(state, list):
+ state = [state]
+ for gt, st in zip(reference, state):
+ handle.style(color="green").region(gt)
+ handle.style(color="red").region(st.region)
+
try:
runtime = tracker.runtime(log=True)
- for repeat in range(1, 4):
-
- logger.info("Initializing tracker ({}/{})".format(repeat, 3))
+ logger.info("Generating dummy sequence")
- region, _, _ = runtime.initialize(sequence.frame(0), sequence.groundtruth(0))
+ if config.sequence is None:
+ sequence = generate_dummy(50, objects=3 if runtime.multiobject else 1)
+ else:
+ sequence = load_sequence(normalize_path(config.sequence))
+
+ logger.info("Obtaining runtime for tracker %s", tracker.identifier)
+
+ context = {"continue" : True}
+
+ def on_press(event):
+ """Callback for key press event.
+
+ Args:
+ event (matplotlib.backend_bases.Event): The event.
+ """
+ if event.key == 'q':
+ context["continue"] = False
+
+ if config.visualize:
+ import matplotlib.pylab as plt
+ from vot.utilities.draw import MatplotlibDrawHandle
+ figure = plt.figure()
+ figure.canvas.set_window_title('VOT Test')
+ axes = figure.add_subplot(1, 1, 1)
+ axes.set_aspect("equal")
+ handle = MatplotlibDrawHandle(axes, size=sequence.size)
+ context["click"] = figure.canvas.mpl_connect('key_press_event', on_press)
+ handle.style(fill=False)
+ figure.show()
+
+ helper = MultiObjectHelper(sequence)
+
+ logger.info("Initializing tracker")
+
+ frame = sequence.frame(0)
+ state, _ = runtime.initialize(frame, [ObjectStatus(frame.object(x), {}) for x in helper.new(0)])
+
+ if config.visualize:
+ visualize(axes, frame, [frame.object(x) for x in helper.objects(0)], state)
+ figure.canvas.draw()
+
+ for i in range(1, len(sequence)):
+
+ logger.info("Processing frame %d/%d", i, len(sequence)-1)
+ frame = sequence.frame(i)
+ state, _ = runtime.update(frame, [ObjectStatus(frame.object(x), {}) for x in helper.new(i)])
if config.visualize:
- axes.clear()
- handle.image(sequence.frame(0).channel())
- handle.style(color="green").region(sequence.frame(0).groundtruth())
- handle.style(color="red").region(region)
+ visualize(axes, frame, [frame.object(x) for x in helper.objects(i)], state)
figure.canvas.draw()
+ figure.canvas.flush_events()
- for i in range(1, sequence.length):
- logger.info("Updating on frame %d/%d", i, sequence.length-1)
- region, _, _ = runtime.update(sequence.frame(i))
-
- if config.visualize:
- axes.clear()
- handle.image(sequence.frame(i).channel())
- handle.style(color="green").region(sequence.frame(i).groundtruth())
- handle.style(color="red").region(region)
- figure.canvas.draw()
+ if not context["continue"]:
+ break
- logger.info("Stopping tracker")
+ logger.info("Stopping tracker")
runtime.stop()
@@ -122,6 +158,11 @@ def do_test(config: argparse.Namespace):
runtime.stop()
def do_workspace(config: argparse.Namespace):
+ """Initialize / manage a workspace.
+
+ Args:
+ config (argparse.Namespace): Configuration
+ """
from vot.workspace import WorkspaceException
@@ -153,18 +194,23 @@ def do_workspace(config: argparse.Namespace):
logger.error("Error during workspace initialization: %s", we)
def do_evaluate(config: argparse.Namespace):
+ """Run an evaluation for a tracker on an experiment stack and a set of sequences.
+
+ Args:
+ config (argparse.Namespace): Configuration
+ """
from vot.experiment import run_experiment
workspace = Workspace.load(config.workspace)
- logger.info("Loaded workspace in '%s'", config.workspace)
+ logger.debug("Loaded workspace in '%s'", config.workspace)
global_registry = [os.path.abspath(x) for x in config.registry]
- registry = Registry(workspace.registry + global_registry, root=config.workspace)
+ registry = Registry(list(workspace.registry) + global_registry, root=config.workspace)
- logger.info("Found data for %d trackers", len(registry))
+ logger.debug("Found data for %d trackers", len(registry))
trackers = registry.resolve(*config.trackers, storage=workspace.storage.substorage("results"), skip_unknown=False)
@@ -177,7 +223,7 @@ def do_evaluate(config: argparse.Namespace):
try:
for tracker in trackers:
- logger.info("Evaluating tracker %s", tracker.identifier)
+ logger.debug("Evaluating tracker %s", tracker.identifier)
for experiment in workspace.stack:
run_experiment(experiment, tracker, workspace.dataset, config.force, config.persist)
@@ -189,19 +235,24 @@ def do_evaluate(config: argparse.Namespace):
logger.error("Evaluation interrupted by tracker error: {}".format(te))
def do_analysis(config: argparse.Namespace):
+ """Run an analysis for a tracker on an experiment stack and a set of sequences.
+
+ Args:
+ config (argparse.Namespace): Configuration
+ """
from vot.analysis import AnalysisProcessor, process_stack_analyses
from vot.document import generate_document
workspace = Workspace.load(config.workspace)
- logger.info("Loaded workspace in '%s'", config.workspace)
+ logger.debug("Loaded workspace in '%s'", config.workspace)
global_registry = [os.path.abspath(x) for x in config.registry]
- registry = Registry(workspace.registry + global_registry, root=config.workspace)
+ registry = Registry(list(workspace.registry) + global_registry, root=config.workspace)
- logger.info("Found data for %d trackers", len(registry))
+ logger.debug("Found data for %d trackers", len(registry))
if not config.trackers:
trackers = workspace.list_results(registry)
@@ -217,7 +268,7 @@ def do_analysis(config: argparse.Namespace):
if config.workers == 1:
if config.debug:
- from vot.analysis._processor import DebugExecutor
+ from vot.analysis.processor import DebugExecutor
logging.getLogger("concurrent.futures").setLevel(logging.DEBUG)
executor = DebugExecutor()
else:
@@ -263,7 +314,7 @@ def do_pack(config: argparse.Namespace):
"""Package results to a ZIP file so that they can be submitted to a challenge.
Args:
- config ([type]): [description]
+ config (argparse.Namespace): Configuration
"""
import zipfile, io
@@ -271,11 +322,9 @@ def do_pack(config: argparse.Namespace):
workspace = Workspace.load(config.workspace)
- logger.info("Loaded workspace in '%s'", config.workspace)
-
- registry = Registry(workspace.registry + config.registry, root=config.workspace)
+ logger.debug("Loaded workspace in '%s'", config.workspace)
- logger.info("Found data for %d trackers", len(registry))
+ registry = Registry(list(workspace.registry) + config.registry, root=config.workspace)
tracker = registry[config.tracker]
@@ -287,8 +336,8 @@ def do_pack(config: argparse.Namespace):
with Progress("Scanning", len(workspace.dataset) * len(workspace.stack)) as progress:
for experiment in workspace.stack:
- for sequence in workspace.dataset:
- sequence = experiment.transform(sequence)
+ sequences = experiment.transform(workspace.dataset)
+ for sequence in sequences:
complete, files, results = experiment.scan(tracker, sequence)
all_files.extend([(f, experiment.identifier, sequence.name, results) for f in files])
if not complete:
@@ -300,7 +349,7 @@ def do_pack(config: argparse.Namespace):
logger.error("Unable to continue, experiments not complete")
return
- logger.info("Collected %d files, compressing to archive ...", len(all_files))
+ logger.debug("Collected %d files, compressing to archive ...", len(all_files))
timestamp = datetime.now()
@@ -315,8 +364,11 @@ def do_pack(config: argparse.Namespace):
with zipfile.ZipFile(workspace.storage.write(archive_name, binary=True), mode="w") as archive:
for f in all_files:
info = zipfile.ZipInfo(filename=os.path.join(f[1], f[2], f[0]), date_time=timestamp.timetuple())
- with io.TextIOWrapper(archive.open(info, mode="w")) as fout, f[3].read(f[0]) as fin:
- copyfileobj(fin, fout)
+ with archive.open(info, mode="w") as fout, f[3].read(f[0]) as fin:
+ if isinstance(fin, io.TextIOBase):
+ copyfileobj(fin, io.TextIOWrapper(fout))
+ else:
+ copyfileobj(fin, fout)
progress.relative(1)
info = zipfile.ZipInfo(filename="manifest.yml", date_time=timestamp.timetuple())
@@ -326,13 +378,13 @@ def do_pack(config: argparse.Namespace):
logger.info("Result packaging successful, archive available in %s", archive_name)
def main():
- """Entrypoint to the VOT Command Line Interface utility, should be executed as a program and provided with arguments.
+ """Entrypoint to the toolkit Command Line Interface utility, should be executed as a program and provided with arguments.
"""
stream = logging.StreamHandler()
stream.setFormatter(ColoredFormatter())
logger.addHandler(stream)
- parser = argparse.ArgumentParser(description='VOT Toolkit Command Line Utility', prog="vot")
+ parser = argparse.ArgumentParser(description='VOT Toolkit Command Line Interface', prog="vot")
parser.add_argument("--debug", "-d", default=False, help="Backup backend", required=False, action='store_true')
parser.add_argument("--registry", default=".", help='Tracker registry paths', required=False, action=EnvDefault, \
separator=os.path.pathsep, envvar='VOT_REGISTRY')
@@ -373,22 +425,29 @@ def main():
logger.setLevel(logging.INFO)
- if args.debug or check_debug:
+ if args.debug or check_debug():
logger.setLevel(logging.DEBUG)
- update, version = check_updates()
- if update:
- logger.warning("A newer version of the VOT toolkit is available (%s), please update.", version)
+ def check_version():
+ """Check if a newer version of the toolkit is available."""
+ update, version = check_updates()
+ if update:
+ logger.warning("A newer version of the VOT toolkit is available (%s), please update.", version)
if args.action == "test":
+ check_version()
do_test(args)
elif args.action == "initialize":
+ check_version()
do_workspace(args)
elif args.action == "evaluate":
+ check_version()
do_evaluate(args)
elif args.action == "analysis":
+ check_version()
do_analysis(args)
elif args.action == "pack":
+ check_version()
do_pack(args)
else:
parser.print_help()
diff --git a/vot/utilities/data.py b/vot/utilities/data.py
index 2c3c207..1b1aa02 100644
--- a/vot/utilities/data.py
+++ b/vot/utilities/data.py
@@ -1,23 +1,41 @@
+""" Data structures for storing data in a grid."""
+
import functools
-from typing import Tuple, Union, Iterable
import unittest
-import numpy as np
-
class Grid(object):
+ """ A grid is a multidimensional array with named dimensions. """
@staticmethod
def scalar(obj):
+ """ Creates a grid with a single cell containing the given object.
+
+ Args:
+ obj (object): The object to store in the grid.
+ """
grid = Grid(1,1)
grid[0, 0] = obj
return grid
def __init__(self, *size):
+ """ Creates a grid with the given dimensions.
+
+ Args:
+ size (int): The size of each dimension.
+ """
assert len(size) > 0
self._size = size
self._data = [None] * functools.reduce(lambda x, y: x * y, size)
def _ravel(self, pos):
+ """ Converts a multidimensional index to a single index.
+
+ Args:
+ pos (tuple): The multidimensional index.
+
+ Returns:
+ int: The single index.
+ """
if not isinstance(pos, tuple):
pos = (pos, )
assert len(pos) == len(self._size)
@@ -32,6 +50,14 @@ def _ravel(self, pos):
return raveled
def _unravel(self, index):
+ """ Converts a single index to a multidimensional index.
+
+ Args:
+ index (int): The single index.
+
+ Returns:
+ tuple: The multidimensional index.
+ """
unraveled = []
for n in reversed(self._size):
unraveled.append(index % n)
@@ -39,34 +65,76 @@ def _unravel(self, index):
return tuple(reversed(unraveled))
def __str__(self):
+ """ Returns a string representation of the grid."""
return str(self._data)
@property
def dimensions(self):
+ """ Returns the number of dimensions of the grid. """
return len(self._size)
def size(self, i: int = None):
+ """ Returns the size of the grid or the size of a specific dimension.
+
+ Args:
+ i (int): The dimension to query. If None, the size of the grid is returned.
+
+ Returns:
+ int: The size of the grid or the size of the given dimension.
+ """
if i is None:
return tuple(self._size)
assert i >= 0 and i < len(self._size)
return self._size[i]
def __len__(self):
+ """ Returns the number of elements in the grid. """
return len(self._data)
def __getitem__(self, i):
+ """ Returns the element at the given index.
+
+ Args:
+ i (tuple): The index of the element. If the grid is one-dimensional, the index can be an integer.
+
+ Returns:
+ object: The element at the given index.
+ """
return self._data[self._ravel(i)]
def __setitem__(self, i, data):
+ """ Sets the element at the given index.
+
+ Args:
+ i (tuple): The index of the element. If the grid is one-dimensional, the index can be an integer.
+ data (object): The data to store at the given index.
+ """
self._data[self._ravel(i)] = data
def __iter__(self):
+ """ Returns an iterator over the elements of the grid. """
return iter(self._data)
def cell(self, *i):
+ """ Returns the element at the given index packed in a scalar grid.
+
+ Args:
+ i (int): The index of the element. If the grid is one-dimensional, the index can be an integer.
+
+ Returns:
+ object: The element at the given index packed in a scalar grid.
+ """
return Grid.scalar(self[i])
def column(self, i):
+ """ Returns the column at the given index.
+
+ Args:
+ i (int): The index of the column.
+
+ Returns:
+ Grid: The column at the given index.
+ """
assert self.dimensions == 2
column = Grid(1, self.size()[0])
for j in range(self.size()[0]):
@@ -74,6 +142,14 @@ def column(self, i):
return column
def row(self, i):
+ """ Returns the row at the given index.
+
+ Args:
+ i (int): The index of the row.
+
+ Returns:
+ Grid: The row at the given index.
+ """
assert self.dimensions == 2
row = Grid(self.size()[1], 1)
for j in range(self.size()[1]):
@@ -81,6 +157,15 @@ def row(self, i):
return row
def foreach(self, cb) -> "Grid":
+ """ Applies a function to each element of the grid.
+
+ Args:
+ cb (function): The function to apply to each element. The first argument is the element, the following
+ arguments are the indices of the element.
+
+ Returns:
+ Grid: A grid containing the results of the function.
+ """
result = Grid(*self._size)
for i, x in enumerate(self._data):
@@ -90,8 +175,10 @@ def foreach(self, cb) -> "Grid":
return result
class TestGrid(unittest.TestCase):
+ """ Unit tests for the Grid class. """
def test_foreach1(self):
+ """ Tests the foreach method. """
a = Grid(5, 3)
@@ -100,6 +187,7 @@ def test_foreach1(self):
self.assertTrue(all([x == 5 for x in b]), "Output incorrect")
def test_foreach2(self):
+ """ Tests the foreach method. """
a = Grid(5, 6, 3)
diff --git a/vot/utilities/draw.py b/vot/utilities/draw.py
index da9215a..cb5fcc6 100644
--- a/vot/utilities/draw.py
+++ b/vot/utilities/draw.py
@@ -1,4 +1,4 @@
-
+""" Drawing utilities for visualizing results."""
from typing import Tuple, List, Union
@@ -13,6 +13,7 @@
from io import BytesIO
def show_image(a):
+ """Shows an image in the IPython notebook."""
try:
import IPython.display
except ImportError:
@@ -36,19 +37,37 @@ def show_image(a):
}
def resolve_color(color: Union[Tuple[float, float, float], str]):
+ """Resolves a color to a tuple of floats. If the color is a string, it is resolved from an internal palette."""
if isinstance(color, str):
return _PALETTE.get(color, (0, 0, 0, 1))
return (np.clip(color[0], 0, 1), np.clip(color[1], 0, 1), np.clip(color[2], 0, 1))
class DrawHandle(object):
+ """Base class for drawing handles."""
def __init__(self, color: Union[Tuple[float, float, float], str] = (1, 0, 0), width: int = 1, fill: bool = False):
+ """Initializes the drawing handle.
+
+ Args:
+ color (tuple or str): Color of the drawing handle.
+ width (int): Width of the drawing handle.
+ fill (bool): Whether to fill the drawing handle.
+ """
self._color = resolve_color(color)
self._width = width
self._fill = fill
- def style(self, color: Union[Tuple[float, float, float], str] = (1, 0, 0), width: int = 1, fill: bool = False):
+ def style(self, color: Union[Tuple[float, float, float], str] = (1, 0, 0), width: int = 1, fill: bool = False) -> 'DrawHandle':
+ """Sets the style of the drawing handle. Returns self for chaining.
+
+ Args:
+ color (tuple or str): Color of the drawing handle.
+ width (int): Width of the drawing handle.
+ fill (bool): Whether to fill the drawing handle.
+
+ Returns:
+ self"""
color = resolve_color(color)
self._color = (color[0], color[1], color[2], 1)
self._width = width
@@ -59,34 +78,47 @@ def style(self, color: Union[Tuple[float, float, float], str] = (1, 0, 0), width
return self
def region(self, region):
+ """Draws a region."""
region.draw(self)
return self
def image(self, image: Union[np.ndarray, Image.Image], offset: Tuple[int, int] = None):
+ """Draws an image at the given offset."""
return self
def line(self, p1: Tuple[float, float], p2: Tuple[float, float]):
+ """Draws a line between two points."""
return self
def lines(self, points: List[Tuple[float, float]]):
+ """Draws a line between multiple points."""
return self
def polygon(self, points: List[Tuple[float, float]]):
+ """Draws a polygon."""
return self
def points(self, points: List[Tuple[float, float]]):
+ """Draws points."""
return self
def rectangle(self, left: float, top: float, right: float, bottom: float):
+ """Draws a rectangle.
+
+ The rectangle is defined by the top-left and bottom-right corners.
+ """
self.polygon([(left, top), (right, top), (right, bottom), (left, bottom)])
return self
def mask(self, mask: np.array, offset: Tuple[int, int] = (0, 0)):
+ """Draws a mask."""
return self
class MatplotlibDrawHandle(DrawHandle):
+ """Draw handle for Matplotlib. This handle is used for drawing to a Matplotlib axis."""
def __init__(self, axis, color: Tuple[float, float, float] = (1, 0, 0), width: int = 1, fill: bool = False, size: Tuple[int, int] = None):
+ """Initializes a new instance of the MatplotlibDrawHandle class."""
super().__init__(color, width, fill)
self._axis = axis
self._size = size
@@ -96,6 +128,7 @@ def __init__(self, axis, color: Tuple[float, float, float] = (1, 0, 0), width: i
def image(self, image: Union[np.ndarray, Image.Image], offset: Tuple[int, int] = None):
+ """Draws an image at the given offset."""
if offset is None:
offset = (0, 0)
@@ -113,16 +146,19 @@ def image(self, image: Union[np.ndarray, Image.Image], offset: Tuple[int, int] =
return self
def line(self, p1: Tuple[float, float], p2: Tuple[float, float]):
+ """Draws a line between two points."""
self._axis.plot((p1[0], p2[0]), (p1[1], p2[1]), linewidth=self._width, color=self._color)
return self
def lines(self, points: List[Tuple[float, float]]):
+ """Draws a line between multiple points."""
x = [x for x, _ in points]
y = [y for _, y in points]
self._axis.plot(x, y, linewidth=self._width, color=self._color)
return self
def polygon(self, points: List[Tuple[float, float]]):
+ """Draws a polygon."""
if self._fill:
poly = Polygon(points, edgecolor=self._color, linewidth=self._width, fill=True, color=self._fill)
else:
@@ -131,11 +167,13 @@ def polygon(self, points: List[Tuple[float, float]]):
return self
def points(self, points: List[Tuple[float, float]]):
+ """Draws points."""
x, y = zip(*points)
self._axis.plot(x, y, markeredgecolor=self._color, markeredgewidth=self._width, linewidth=0)
return self
def mask(self, mask: np.array, offset: Tuple[int, int] = (0, 0)):
+ """Draws a mask."""
# TODO: segmentation should also have option of non-filled
mask[mask != 0] = 1
if self._fill:
@@ -156,12 +194,15 @@ def mask(self, mask: np.array, offset: Tuple[int, int] = (0, 0)):
class ImageDrawHandle(DrawHandle):
+ """Draw handle for Pillow. This handle is used for drawing to a Pillow image."""
@staticmethod
def _convert_color(c, alpha=255):
+ """Converts a color from a float tuple to an integer tuple."""
return (int(c[0] * 255), int(c[1] * 255), int(c[2] * 255), alpha)
def __init__(self, image: Union[np.ndarray, Image.Image], color: Tuple[float, float, float] = (1, 0, 0), width: int = 1, fill: bool = False):
+ """Initializes a new instance of the ImageDrawHandle class."""
super().__init__(color, width, fill)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
@@ -171,13 +212,16 @@ def __init__(self, image: Union[np.ndarray, Image.Image], color: Tuple[float, fl
@property
def array(self) -> np.ndarray:
+ """Returns the image as a numpy array."""
return np.asarray(self._image)
@property
def snapshot(self) -> Image.Image:
+ """Returns a snapshot of the current image."""
return self._image.copy()
def image(self, image: Union[np.ndarray, Image.Image], offset: Tuple[int, int] = None):
+ """Draws an image at the given offset."""
if isinstance(image, np.ndarray):
if image.dtype == np.float32 or image.dtype == np.float64:
image = (image * 255).astype(np.uint8)
@@ -189,11 +233,13 @@ def image(self, image: Union[np.ndarray, Image.Image], offset: Tuple[int, int] =
return self
def line(self, p1, p2):
+ """Draws a line between two points."""
color = ImageDrawHandle._convert_color(self._color)
self._handle.line([p1, p2], fill=color, width=self._width)
return self
def lines(self, points: List[Tuple[float, float]]):
+ """Draws a line between multiple points."""
if len(points) == 0:
return
color = ImageDrawHandle._convert_color(self._color)
@@ -201,6 +247,7 @@ def lines(self, points: List[Tuple[float, float]]):
return self
def polygon(self, points: List[Tuple[float, float]]):
+ """Draws a polygon."""
if len(points) == 0:
return self
@@ -213,12 +260,14 @@ def polygon(self, points: List[Tuple[float, float]]):
return self
def points(self, points: List[Tuple[float, float]]):
+ """Draws points."""
color = ImageDrawHandle._convert_color(self._color)
for (x, y) in points:
self._handle.ellipse((x - 2, y - 2, x + 2, y + 2), outline=color, width=self._width)
return self
def mask(self, mask: np.array, offset: Tuple[int, int] = (0, 0)):
+ """Draws a mask."""
if mask.size == 0:
return self
diff --git a/vot/utilities/migration.py b/vot/utilities/migration.py
index 9e975cd..8dc1ff8 100644
--- a/vot/utilities/migration.py
+++ b/vot/utilities/migration.py
@@ -1,4 +1,4 @@
-
+""" Migration utilities for old workspaces (legacy Matlab toolkit)"""
import os
import re
@@ -11,13 +11,32 @@
from vot.stack import resolve_stack
from vot.workspace import WorkspaceException
-def migrate_matlab_workspace(directory):
+def migrate_matlab_workspace(directory: str):
+ """ Migrates a legacy matlab workspace to the new format.
+
+ Args:
+ directory (str): The directory of the workspace.
+
+ Raises:
+ WorkspaceException: If the workspace is already initialized.
+ WorkspaceException: If the workspace is not a legacy workspace.
+ """
logger = logging.getLogger("vot")
logger.info("Attempting to migrate workspace in %s", directory)
def scan_text(pattern, content, default=None):
+ """ Scans the text for a pattern and returns the first match.
+
+ Args:
+ pattern (str): The pattern to search for.
+ content (str): The content to search in.
+ default (str): The default value if no match is found.
+
+ Returns:
+ str: The first match or the default value.
+ """
matches = re.findall(pattern, content)
if not len(matches) == 1:
return default
diff --git a/vot/utilities/net.py b/vot/utilities/net.py
index 488485b..0da9394 100644
--- a/vot/utilities/net.py
+++ b/vot/utilities/net.py
@@ -1,3 +1,4 @@
+""" Network utilities for the toolkit. """
import os
import re
@@ -7,23 +8,57 @@
import requests
-from vot import ToolkitException
+from vot import ToolkitException, get_logger
class NetworkException(ToolkitException):
+ """ Exception raised when a network error occurs. """
pass
def get_base_url(url):
+ """ Returns the base url of a given url.
+
+ Args:
+ url (str): The url to parse.
+
+ Returns:
+ str: The base url."""
return url.rsplit('/', 1)[0]
def is_absolute_url(url):
+ """ Returns True if the given url is absolute.
+
+ Args:
+ url (str): The url to parse.
+
+ Returns:
+ bool: True if the url is absolute, False otherwise.
+ """
+
return bool(urlparse(url).netloc)
def join_url(url_base, url_path):
+ """ Joins a base url with a path.
+
+ Args:
+ url_base (str): The base url.
+ url_path (str): The path to join.
+
+ Returns:
+ str: The joined url.
+ """
if is_absolute_url(url_path):
return url_path
return urljoin(url_base, url_path)
def get_url_from_gdrive_confirmation(contents):
+ """ Returns the url of a google drive file from the confirmation page.
+
+ Args:
+ contents (str): The contents of the confirmation page.
+
+ Returns:
+ str: The url of the file.
+ """
url = ''
for line in contents.splitlines():
m = re.search(r'href="(\/uc\?export=download[^"]+)', line)
@@ -45,17 +80,49 @@ def get_url_from_gdrive_confirmation(contents):
def is_google_drive_url(url):
+ """ Returns True if the given url is a google drive url.
+
+ Args:
+ url (str): The url to parse.
+
+ Returns:
+ bool: True if the url is a google drive url, False otherwise.
+ """
m = re.match(r'^https?://drive.google.com/uc\?id=.*$', url)
return m is not None
def download_json(url):
+ """ Downloads a JSON file from the given url.
+
+ Args:
+ url (str): The url to parse.
+
+ Returns:
+ dict: The JSON content.
+ """
try:
return requests.get(url).json()
except requests.exceptions.RequestException as e:
raise NetworkException("Unable to read JSON file {}".format(e))
-def download(url, output, callback=None, chunk_size=1024*32):
+def download(url, output, callback=None, chunk_size=1024*32, retry=10):
+ """ Downloads a file from the given url. Supports google drive urls.
+ callback for progress report, automatically resumes download if connection is closed.
+
+ Args:
+ url (str): The url to parse.
+ output (str): The output file path or file handle.
+ callback (function): The callback function for progress report.
+ chunk_size (int): The chunk size for download.
+ retry (int): The number of retries.
+
+ Raises:
+ NetworkException: If the file is not available.
+ """
+
+ logger = get_logger()
+
with requests.session() as sess:
is_gdrive = is_google_drive_url(url)
@@ -79,6 +146,7 @@ def download(url, output, callback=None, chunk_size=1024*32):
raise NetworkException("Permission denied for {}".format(gurl))
url = gurl
+
if output is None:
if is_gdrive:
m = re.search('filename="(.*)"',
@@ -96,21 +164,49 @@ def download(url, output, callback=None, chunk_size=1024*32):
tmp_file = None
filehandle = output
+ position = 0
+ progress = False
+
try:
total = res.headers.get('Content-Length')
-
if total is not None:
total = int(total)
-
- for chunk in res.iter_content(chunk_size=chunk_size):
- filehandle.write(chunk)
- if callback:
- callback(len(chunk), total)
- if tmp_file:
- filehandle.close()
- shutil.copy(tmp_file, output)
- except IOError:
- raise NetworkException("Error when downloading file")
+ while True:
+ try:
+ for chunk in res.iter_content(chunk_size=chunk_size):
+ filehandle.write(chunk)
+ position += len(chunk)
+ progress = True
+ if callback:
+ callback(position, total)
+
+ if position < total:
+ raise requests.exceptions.RequestException("Connection closed")
+
+ if tmp_file:
+ filehandle.close()
+ shutil.copy(tmp_file, output)
+ break
+
+ except requests.exceptions.RequestException as e:
+ if not progress:
+ logger.warning("Error when downloading file, retrying")
+ retry-=1
+ if retry < 1:
+ raise NetworkException("Unable to download file {}".format(e))
+ res = sess.get(url, stream=True)
+ filehandle.seek(0)
+ position = 0
+ else:
+ logger.warning("Error when downloading file, trying to resume download")
+ res = sess.get(url, stream=True, headers=({'Range': f'bytes={position}-'} if position > 0 else None))
+ progress = False
+
+ if position < total:
+ raise NetworkException("Unable to download file")
+
+ except IOError as e:
+ raise NetworkException("Local I/O Error when downloading file: %s" % e)
finally:
try:
if tmp_file:
@@ -120,7 +216,17 @@ def download(url, output, callback=None, chunk_size=1024*32):
return output
+
def download_uncompress(url, path):
+ """ Downloads a file from the given url and uncompress it to the given path.
+
+ Args:
+ url (str): The url to parse.
+ path (str): The path to uncompress the file.
+
+ Raises:
+ NetworkException: If the file is not available.
+ """
from vot.utilities import extract_files
_, ext = os.path.splitext(urlparse(url).path)
tmp_file = tempfile.mktemp(suffix=ext)
diff --git a/vot/utilities/notebook.py b/vot/utilities/notebook.py
index aedb7b3..f292fc0 100644
--- a/vot/utilities/notebook.py
+++ b/vot/utilities/notebook.py
@@ -1,9 +1,15 @@
+""" This module contains functions for visualization in Jupyter notebooks. """
import os
import io
from threading import Thread, Condition
def is_notebook():
+ """ Returns True if the current environment is a Jupyter notebook.
+
+ Returns:
+ bool: True if the current environment is a Jupyter notebook.
+ """
try:
from IPython import get_ipython
if get_ipython() is None:
@@ -17,93 +23,259 @@ def is_notebook():
else:
return True
-
-def run_tracker(tracker: "Tracker", sequence: "Sequence"):
+if is_notebook():
+
from IPython.display import display
from ipywidgets import widgets
from vot.utilities.draw import ImageDrawHandle
- def encode_image(handle):
- with io.BytesIO() as output:
- handle.snapshot.save(output, format="PNG")
- return output.getvalue()
+ class SequenceView(object):
+ """ A widget for visualizing a sequence. """
- handle = ImageDrawHandle(sequence.frame(0).image())
+ def __init__(self):
+ """ Initializes a new instance of the SequenceView class.
+
+ Args:
+ sequence (Sequence): The sequence to visualize.
+ """
- button_restart = widgets.Button(description='Restart')
- button_next = widgets.Button(description='Next')
- button_play = widgets.Button(description='Run')
- frame = widgets.Label(value="")
- frame.layout.display = "none"
- frame2 = widgets.Label(value="")
- image = widgets.Image(value=encode_image(handle), format="png", width=sequence.size[0] * 2, height=sequence.size[1] * 2)
+ self._handle = ImageDrawHandle(sequence.frame(0).image())
- state = dict(frame=0, auto=False, alive=True, region=None)
- condition = Condition()
+ self._button_restart = widgets.Button(description='Restart')
+ self._button_next = widgets.Button(description='Next')
+ self._button_play = widgets.Button(description='Run')
+ self._frame = widgets.Label(value="")
+ self._frame.layout.display = "none"
+ self._frame_feedback = widgets.Label(value="")
+ self._image = widgets.Image(value="", format="png", width=sequence.size[0] * 2, height=sequence.size[1] * 2)
- buttons = widgets.HBox(children=(frame, button_restart, button_next, button_play, frame2))
+ state = dict(frame=0, auto=False, alive=True, region=None)
+ condition = Condition()
- image.value = encode_image(handle)
+ self._buttons = widgets.HBox(children=(frame, self._button_restart, self._button_next, button_play, frame2))
- def run():
+ def _push_image(handle):
+ """ Pushes an image to the widget.
- runtime = tracker.runtime()
+ Args:
+ handle (ImageDrawHandle): The image handle.
+ """
+ with io.BytesIO() as output:
+ handle.snapshot.save(output, format="PNG")
+ return output.getvalue()
- while state["alive"]:
+ def visualize_tracker(tracker: "Tracker", sequence: "Sequence"):
+ """ Visualizes a tracker in a Jupyter notebook.
- if state["frame"] == 0:
- state["region"], _, _ = runtime.initialize(sequence.frame(0), sequence.groundtruth(0))
- else:
- state["region"], _, _ = runtime.update(sequence.frame(state["frame"]))
+ Args:
+ tracker (Tracker): The tracker to visualize.
+ sequence (Sequence): The sequence to visualize.
+ """
+ from IPython.display import display
+ from ipywidgets import widgets
+ from vot.utilities.draw import ImageDrawHandle
- update_image()
+ def encode_image(handle):
+ """ Encodes an image so that it can be displayed in a Jupyter notebook.
+
+ Args:
+ handle (ImageDrawHandle): The image handle.
+
+ Returns:
+ bytes: The encoded image."""
+ with io.BytesIO() as output:
+ handle.snapshot.save(output, format="PNG")
+ return output.getvalue()
- with condition:
- condition.wait()
+ handle = ImageDrawHandle(sequence.frame(0).image())
- if state["frame"] == sequence.length:
- state["alive"] = False
- continue
+ button_restart = widgets.Button(description='Restart')
+ button_next = widgets.Button(description='Next')
+ button_play = widgets.Button(description='Run')
+ frame = widgets.Label(value="")
+ frame.layout.display = "none"
+ frame2 = widgets.Label(value="")
+ image = widgets.Image(value=encode_image(handle), format="png", width=sequence.size[0] * 2, height=sequence.size[1] * 2)
- state["frame"] = state["frame"] + 1
+ state = dict(frame=0, auto=False, alive=True, region=None)
+ condition = Condition()
+ buttons = widgets.HBox(children=(frame, button_restart, button_next, button_play, frame2))
- def update_image():
- handle.image(sequence.frame(state["frame"]).image())
- handle.style(color="green").region(sequence.frame(state["frame"]).groundtruth())
- if state["region"]:
- handle.style(color="red").region(state["region"])
image.value = encode_image(handle)
- frame.value = "Frame: " + str(state["frame"] - 1)
- def on_click(button):
- if button == button_next:
- with condition:
- state["auto"] = False
- condition.notify()
- if button == button_restart:
- with condition:
- state["frame"] = 0
- condition.notify()
- if button == button_play:
+ def run():
+ """ Runs the tracker. """
+
+ runtime = tracker.runtime()
+
+ while state["alive"]:
+
+ if state["frame"] == 0:
+ state["region"], _, _ = runtime.initialize(sequence.frame(0), sequence.groundtruth(0))
+ else:
+ state["region"], _, _ = runtime.update(sequence.frame(state["frame"]))
+
+ update_image()
+
+ with condition:
+ condition.wait()
+
+ if state["frame"] == len(sequence):
+ state["alive"] = False
+ continue
+
+ state["frame"] = state["frame"] + 1
+
+
+ def update_image():
+ """ Updates the image. """
+ handle.image(sequence.frame(state["frame"]).image())
+ handle.style(color="green").region(sequence.frame(state["frame"]).groundtruth())
+ if state["region"]:
+ handle.style(color="red").region(state["region"])
+ image.value = encode_image(handle)
+ frame.value = "Frame: " + str(state["frame"] - 1)
+
+ def on_click(button):
+ """ Handles a button click. """
+ if button == button_next:
+ with condition:
+ state["auto"] = False
+ condition.notify()
+ if button == button_restart:
+ with condition:
+ state["frame"] = 0
+ condition.notify()
+ if button == button_play:
+ with condition:
+ state["auto"] = not state["auto"]
+ button.description = "Stop" if state["auto"] else "Run"
+ condition.notify()
+
+ button_next.on_click(on_click)
+ button_restart.on_click(on_click)
+ button_play.on_click(on_click)
+ widgets.jslink((frame, "value"), (frame2, "value"))
+
+ def on_update(_):
+ """ Handles a widget update."""
with condition:
- state["auto"] = not state["auto"]
- button.description = "Stop" if state["auto"] else "Run"
- condition.notify()
+ if state["auto"]:
+ condition.notify()
+
+ frame2.observe(on_update, names=("value", ))
+
+ thread = Thread(target=run)
+ display(widgets.Box([widgets.VBox(children=(image, buttons))]))
+ thread.start()
+
+ def visualize_results(experiment: "Experiment", sequence: "Sequence"):
+ """ Visualizes the results of an experiment in a Jupyter notebook.
+
+ Args:
+ experiment (Experiment): The experiment to visualize.
+ sequence (Sequence): The sequence to visualize.
+
+ """
+
+ from IPython.display import display
+ from ipywidgets import widgets
+ from vot.utilities.draw import ImageDrawHandle
- button_next.on_click(on_click)
- button_restart.on_click(on_click)
- button_play.on_click(on_click)
- widgets.jslink((frame, "value"), (frame2, "value"))
+ def encode_image(handle):
+ """ Encodes an image so that it can be displayed in a Jupyter notebook.
+
+ Args:
+ handle (ImageDrawHandle): The image handle.
+
+ Returns:
+ bytes: The encoded image.
+ """
- def on_update(_):
- with condition:
- if state["auto"]:
- condition.notify()
+ with io.BytesIO() as output:
+ handle.snapshot.save(output, format="PNG")
+ return output.getvalue()
- frame2.observe(on_update, names=("value", ))
+ handle = ImageDrawHandle(sequence.frame(0).image())
+
+ button_restart = widgets.Button(description='Restart')
+ button_next = widgets.Button(description='Next')
+ button_play = widgets.Button(description='Run')
+ frame = widgets.Label(value="")
+ frame.layout.display = "none"
+ frame2 = widgets.Label(value="")
+ image = widgets.Image(value=encode_image(handle), format="png", width=sequence.size[0] * 2, height=sequence.size[1] * 2)
+
+ state = dict(frame=0, auto=False, alive=True, region=None)
+ condition = Condition()
+
+ buttons = widgets.HBox(children=(frame, button_restart, button_next, button_play, frame2))
+
+ image.value = encode_image(handle)
+
+ def run():
+ """ Runs the tracker. """
+
+ runtime = tracker.runtime()
+
+ while state["alive"]:
+
+ if state["frame"] == 0:
+ state["region"], _, _ = runtime.initialize(sequence.frame(0), sequence.groundtruth(0))
+ else:
+ state["region"], _, _ = runtime.update(sequence.frame(state["frame"]))
+
+ update_image()
+
+ with condition:
+ condition.wait()
+
+ if state["frame"] == len(sequence):
+ state["alive"] = False
+ continue
+
+ state["frame"] = state["frame"] + 1
+
+
+ def update_image():
+ """ Updates the image. """
+ handle.image(sequence.frame(state["frame"]).image())
+ handle.style(color="green").region(sequence.frame(state["frame"]).groundtruth())
+ if state["region"]:
+ handle.style(color="red").region(state["region"])
+ image.value = encode_image(handle)
+ frame.value = "Frame: " + str(state["frame"] - 1)
+
+ def on_click(button):
+ """ Handles a button click. """
+ if button == button_next:
+ with condition:
+ state["auto"] = False
+ condition.notify()
+ if button == button_restart:
+ with condition:
+ state["frame"] = 0
+ condition.notify()
+ if button == button_play:
+ with condition:
+ state["auto"] = not state["auto"]
+ button.description = "Stop" if state["auto"] else "Run"
+ condition.notify()
+
+ button_next.on_click(on_click)
+ button_restart.on_click(on_click)
+ button_play.on_click(on_click)
+ widgets.jslink((frame, "value"), (frame2, "value"))
+
+ def on_update(_):
+ """ Handles a widget update."""
+ with condition:
+ if state["auto"]:
+ condition.notify()
- thread = Thread(target=run)
- display(widgets.Box([widgets.VBox(children=(image, buttons))]))
- thread.start()
+ frame2.observe(on_update, names=("value", ))
+ thread = Thread(target=run)
+ display(widgets.Box([widgets.VBox(children=(image, buttons))]))
+ thread.start()
\ No newline at end of file
diff --git a/vot/version.py b/vot/version.py
index 08d79c0..6dbe830 100644
--- a/vot/version.py
+++ b/vot/version.py
@@ -1 +1,4 @@
-__version__ = '0.5.1'
\ No newline at end of file
+"""
+Toolkit version
+"""
+__version__ = '0.6.4'
\ No newline at end of file
diff --git a/vot/workspace/__init__.py b/vot/workspace/__init__.py
index 7fc0626..08600d7 100644
--- a/vot/workspace/__init__.py
+++ b/vot/workspace/__init__.py
@@ -1,46 +1,68 @@
+"""This module contains the Workspace class that represents the main junction of trackers, datasets and experiments."""
import os
import typing
import importlib
import yaml
+from lazy_object_proxy import Proxy
-from attributee import Attribute, Attributee, Nested, List, String
+from attributee import Attribute, Attributee, Nested, List, String, CoerceContext
from .. import ToolkitException, get_logger
from ..dataset import Dataset, load_dataset
from ..tracker import Registry, Tracker
from ..stack import Stack, resolve_stack
-from ..utilities import normalize_path, class_fullname
+from ..utilities import normalize_path
from ..document import ReportConfiguration
from .storage import LocalStorage, Storage, NullStorage
+
_logger = get_logger()
class WorkspaceException(ToolkitException):
+ """Errors related to workspace raise this exception
+ """
pass
class StackLoader(Attribute):
"""Special attribute that converts a string or a dictionary input to a Stack object.
"""
- def coerce(self, value, ctx):
+ def coerce(self, value, context: typing.Optional[CoerceContext]):
+ """Coerce a value to a Stack object
+
+ Args:
+ value (typing.Any): Value to coerce
+ context (typing.Optional[CoerceContext]): Coercion context
+
+ Returns:
+ Stack: Coerced value
+ """
importlib.import_module("vot.analysis")
importlib.import_module("vot.experiment")
if isinstance(value, str):
- stack_file = resolve_stack(value, ctx["parent"].directory)
+ stack_file = resolve_stack(value, context.parent.directory)
if stack_file is None:
raise WorkspaceException("Experiment stack does not exist")
with open(stack_file, 'r') as fp:
stack_metadata = yaml.load(fp, Loader=yaml.BaseLoader)
- return Stack(value, ctx["parent"], **stack_metadata)
+ return Stack(value, context.parent, **stack_metadata)
else:
- return Stack(None, ctx["parent"], **value)
+ return Stack(None, context.parent, **value)
- def dump(self, value):
+ def dump(self, value: "Stack") -> str:
+ """Dump a Stack object to a string or a dictionary
+
+ Args:
+ value (Stack): Value to dump
+
+ Returns:
+ str: Dumped value
+ """
if value.name is None:
return value.dump()
else:
@@ -51,14 +73,14 @@ class Workspace(Attributee):
given experiments on a provided dataset.
"""
- registry = List(String(transformer=lambda x, ctx: normalize_path(x, ctx["parent"].directory)))
+ registry = List(String(transformer=lambda x, ctx: normalize_path(x, ctx.parent.directory)))
stack = StackLoader()
sequences = String(default="sequences")
report = Nested(ReportConfiguration)
@staticmethod
def initialize(directory: str, config: typing.Optional[typing.Dict] = None, download: bool = True) -> None:
- """[summary]
+ """Initialize a new workspace in a given directory with the given config
Args:
directory (str): Root for workspace storage
@@ -145,9 +167,12 @@ def __init__(self, directory: str, **kwargs):
directory ([type]): [description]
"""
self._directory = directory
- self._storage = LocalStorage(directory) if directory is not None else NullStorage()
+ self._storage = Proxy(lambda: LocalStorage(directory) if directory is not None else NullStorage())
+
super().__init__(**kwargs)
+
+
dataset_directory = normalize_path(self.sequences, directory)
if not self.stack.dataset is None:
diff --git a/vot/workspace/storage.py b/vot/workspace/storage.py
index c1c5e41..df8e7f9 100644
--- a/vot/workspace/storage.py
+++ b/vot/workspace/storage.py
@@ -1,3 +1,4 @@
+"""Storage abstraction for the workspace."""
import os
import pickle
@@ -13,6 +14,8 @@
from ..dataset import Sequence
from ..tracker import Tracker, Results
+from attributee import Attributee, Boolean
+
class Storage(ABC):
"""Abstract superclass for workspace storage abstraction
"""
@@ -120,71 +123,133 @@ def copy(self, localfile: str, destination: str):
pass
class NullStorage(Storage):
- """An implementation of dummy storage that does not save anything
- """
+ """An implementation of dummy storage that does not save anything."""
def results(self, tracker: Tracker, experiment: Experiment, sequence: Sequence):
+ """Returns results object for the given tracker, experiment, sequence combination."""
return Results(self)
def __repr__(self) -> str:
+ """Returns a string representation of the storage object."""
return "".format(self._root)
def write(self, name, binary=False):
+ """Opens the given file entry for writing, returns opened handle."""
if binary:
return open(os.devnull, "wb")
else:
return open(os.devnull, "w")
def documents(self):
+ """Lists documents in the storage."""
return []
def folders(self):
+ """Lists folders in the storage. Reuturns an empty list.
+
+ Returns:
+ list: Empty list"""
return []
def read(self, name, binary=False):
+ """Opens the given file entry for reading, returns opened handle.
+
+ Returns:
+ None: Returns None.
+ """
return None
def isdocument(self, name):
+ """Checks if given name is a document/file in this storage.
+
+ Returns:
+ bool: Returns False."""
return False
def isfolder(self, name):
+ """Checks if given name is a folder in this storage.
+
+ Returns:
+ bool: Returns False.
+ """
return False
def delete(self, name) -> bool:
+ """Deletes a given document.
+
+ Returns:
+ bool: Returns False since nothing is deleted."""
return False
def substorage(self, name):
+ """Returns a substorage, storage object with root in a subfolder."""
return NullStorage()
def copy(self, localfile, destination):
+ """Copy a document to another location. Does nothing."""
return
class LocalStorage(Storage):
- """Storage backed by the local filesystem.
- """
+ """Storage backed by the local filesystem. This is the default real storage implementation."""
def __init__(self, root: str):
+ """Creates a new local storage object.
+
+ Args:
+ root (str): Root path of the storage.
+ """
self._root = root
self._results = os.path.join(root, "results")
def __repr__(self) -> str:
+ """Returns a string representation of the storage object."""
return "".format(self._root)
@property
def base(self) -> str:
+ """Returns the base path of the storage."""
return self._root
def results(self, tracker: Tracker, experiment: Experiment, sequence: Sequence):
+ """Returns results object for the given tracker, experiment, sequence combination.
+
+ Args:
+ tracker (Tracker): Selected tracker
+ experiment (Experiment): Selected experiment
+ sequence (Sequence): Selected sequence
+
+ Returns:
+ Results: Results object
+ """
storage = LocalStorage(os.path.join(self._results, tracker.reference, experiment.identifier, sequence.name))
return Results(storage)
def documents(self):
+ """Lists documents in the storage.
+
+ Returns:
+ list: List of document names.
+ """
return [name for name in os.listdir(self._root) if os.path.isfile(os.path.join(self._root, name))]
def folders(self):
+ """Lists folders in the storage.
+
+ Returns:
+ list: List of folder names.
+ """
return [name for name in os.listdir(self._root) if os.path.isdir(os.path.join(self._root, name))]
def write(self, name: str, binary: bool = False):
+ """Opens the given file entry for writing, returns opened handle.
+
+ Args:
+ name (str): File name.
+ binary (bool, optional): Open file in binary mode. Defaults to False.
+
+ Returns:
+ file: Opened file handle.
+ """
full = os.path.join(self.base, name)
os.makedirs(os.path.dirname(full), exist_ok=True)
@@ -194,6 +259,15 @@ def write(self, name: str, binary: bool = False):
return open(full, mode="w", newline="")
def read(self, name, binary=False):
+ """Opens the given file entry for reading, returns opened handle.
+
+ Args:
+ name (str): File name.
+ binary (bool, optional): Open file in binary mode. Defaults to False.
+
+ Returns:
+ file: Opened file handle.
+ """
full = os.path.join(self.base, name)
if binary:
@@ -202,6 +276,14 @@ def read(self, name, binary=False):
return open(full, mode="r", newline="")
def delete(self, name) -> bool:
+ """Deletes a given document. Returns True if successful, False otherwise.
+
+ Args:
+ name (str): File name.
+
+ Returns:
+ bool: Returns True if successful, False otherwise.
+ """
full = os.path.join(self.base, name)
if os.path.isfile(full):
os.unlink(full)
@@ -209,15 +291,49 @@ def delete(self, name) -> bool:
return False
def isdocument(self, name):
+ """Checks if given name is a document/file in this storage.
+
+ Args:
+ name (str): Name of the entry to check
+
+ Returns:
+ bool: Returns True if entry is a document, False otherwise.
+ """
return os.path.isfile(os.path.join(self._root, name))
def isfolder(self, name):
+ """Checks if given name is a folder in this storage.
+
+ Args:
+ name (str): Name of the entry to check
+
+ Returns:
+ bool: Returns True if entry is a folder, False otherwise.
+ """
return os.path.isdir(os.path.join(self._root, name))
def substorage(self, name):
+ """Returns a substorage, storage object with root in a subfolder.
+
+ Args:
+ name (str): Name of the entry, must be a folder
+
+ Returns:
+ Storage: Storage object
+ """
return LocalStorage(os.path.join(self.base, name))
def copy(self, localfile, destination):
+ """Copy a document to another location in the storage.
+
+ Args:
+ localfile (str): Original location
+ destination (str): New location
+
+ Raises:
+ IOError: If the destination is an absolute path.
+
+ """
import shutil
if os.path.isabs(destination):
raise IOError("Only relative paths allowed")
@@ -228,6 +344,17 @@ def copy(self, localfile, destination):
shutil.move(localfile, os.path.join(self.base, full))
def directory(self, *args):
+ """Returns a path to a directory in the storage.
+
+ Args:
+ *args: Path segments.
+
+ Returns:
+ str: Path to the directory.
+
+ Raises:
+ ValueError: If the path is not a directory.
+ """
segments = []
for arg in args:
if arg is None:
@@ -278,7 +405,19 @@ def _filename(self, key: typing.Union[typing.Tuple, str]) -> str:
directory = ""
return os.path.join(directory, filename)
- def __getitem__(self, key: str):
+ def __getitem__(self, key: str) -> typing.Any:
+ """Retrieves an image from cache. If it does not exist, a KeyError is raised
+
+ Args:
+ key (str): Key of the item
+
+ Raises:
+ KeyError: Entry does not exist or cannot be retrieved
+ PickleError: Unable to
+
+ Returns:
+ typing.Any: item value
+ """
try:
return super().__getitem__(key)
except KeyError as e:
@@ -290,20 +429,32 @@ def __getitem__(self, key: str):
data = pickle.load(filehandle)
super().__setitem__(key, data)
return data
- except pickle.PickleError:
- raise e
+ except pickle.PickleError as e:
+ raise KeyError(e)
- def __setitem__(self, key: str, value: typing.Any):
+ def __setitem__(self, key: str, value: typing.Any) -> None:
+ """Sets an item for given key
+
+ Args:
+ key (str): Item key
+ value (typing.Any): Item value
+
+ """
super().__setitem__(key, value)
filename = self._filename(key)
try:
with self._storage.write(filename, binary=True) as filehandle:
- return pickle.dump(value, filehandle)
+ pickle.dump(value, filehandle)
except pickle.PickleError:
pass
- def __delitem__(self, key: str):
+ def __delitem__(self, key: str) -> None:
+ """Operator for item deletion.
+
+ Args:
+ key (str): Key of object to remove
+ """
try:
super().__delitem__(key)
filename = self._filename(key)
@@ -314,6 +465,14 @@ def __delitem__(self, key: str):
except KeyError:
pass
- def __contains__(self, key: str):
+ def __contains__(self, key: str) -> bool:
+ """Magic method, does the cache include an item for a given key.
+
+ Args:
+ key (str): Item key
+
+ Returns:
+ bool: True if object exists for a given key
+ """
filename = self._filename(key)
return self._storage.isdocument(filename)
diff --git a/vot/workspace/tests.py b/vot/workspace/tests.py
index bfd1b64..4ead470 100644
--- a/vot/workspace/tests.py
+++ b/vot/workspace/tests.py
@@ -1,4 +1,4 @@
-
+"""Tests for workspace related methods and classes."""
import logging
import tempfile
@@ -9,8 +9,12 @@
from vot.workspace import Workspace, NullStorage
class TestStacks(unittest.TestCase):
+ """Tests for workspace related methods
+ """
def test_void_storage(self):
+ """Test if void storage works
+ """
storage = NullStorage()
@@ -20,6 +24,8 @@ def test_void_storage(self):
self.assertIsNone(storage.read("test.data"))
def test_local_storage(self):
+ """Test if local storage works
+ """
with tempfile.TemporaryDirectory() as testdir:
storage = LocalStorage(testdir)
@@ -32,16 +38,21 @@ def test_local_storage(self):
# TODO: more tests
def test_workspace_create(self):
+ """Test if workspace creation works
+ """
get_logger().setLevel(logging.WARN) # Disable progress bar
- default_config = dict(stack="testing", registry=["./trackers.ini"])
+ default_config = dict(stack="tests/basic", registry=["./trackers.ini"])
with tempfile.TemporaryDirectory() as testdir:
Workspace.initialize(testdir, default_config, download=True)
Workspace.load(testdir)
def test_cache(self):
+ """Test if local storage cache works
+ """
+
with tempfile.TemporaryDirectory() as testdir:
cache = Cache(LocalStorage(testdir))