diff --git a/.gitignore b/.gitignore index 486a089..51dbe93 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ venv/ MANIFEST *.pyc dist/ -*.egg-info \ No newline at end of file +*.egg-info +_build \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f288702 --- /dev/null +++ b/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/MANIFEST.in b/MANIFEST.in index f057b7b..d8678ee 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,8 +1,8 @@ include requirements.txt include README.md -include vot/stack/*.yaml +recursive-include vot/stack/ *.yaml include vot/dataset/*.png include vot/dataset/*.jpg include vot/document/*.css include vot/document/*.js -include vot/document/*.tex \ No newline at end of file +include vot/document/*.tex diff --git a/README.md b/README.md index c3e2a63..b732bed 100644 --- a/README.md +++ b/README.md @@ -2,21 +2,32 @@ The VOT evaluation toolkit ========================== -This repository contains the official evaluation toolkit for the [Visual Object Tracking (VOT) challenge](http://votchallenge.net/). This is the official version of the toolkit, implemented in Python 3 language. If you are looking for the old Matlab version, you can find an archived repository [here](https://github.com/vicoslab/toolkit-legacy). +[![PyPI package version](https://badge.fury.io/py/vot-toolkit.svg)](https://badge.fury.io/py/vot-toolkit) -For more detailed informations consult the documentation available in the source or a compiled version of the documentation [here](http://www.votchallenge.net/howto/). You can also subscribe to the VOT [mailing list](https://service.ait.ac.at/mailman/listinfo/votchallenge) to receive news about challenges and important software updates or join our [support form](https://groups.google.com/forum/?hl=en#!forum/votchallenge-help) to ask questions. +This repository contains the official evaluation toolkit for the [Visual Object Tracking (VOT) challenge](http://votchallenge.net/). This is the official version of the toolkit, implemented in Python 3 language. If you are looking for the old Matlab version, you can find an archived repository [here](https://github.com/votchallenge/toolkit-legacy). + +For more detailed informations consult the documentation available in the source or a compiled version of the documentation [here](http://www.votchallenge.net/howto/). You can also subscribe to the VOT [mailing list](https://liste.arnes.si/mailman3/lists/votchallenge.lists.arnes.si/) to receive news about challenges and important software updates or join our [support form](https://groups.google.com/forum/?hl=en#!forum/votchallenge-help) to ask questions. Developers ---------- -* Luka Čehovin Zajc, University of Ljubljana (lead developer) -* Alan Lukežič, University of Ljubljana +The VOT toolkit is developed and maintained by [Luka Čehovin Zajc](https://vicos.si/lukacu) with the help of the VOT innitiative members and the VOT community. + +Contributors: + +* [Luka Čehovin Zajc](https://vicos.si/lukacu), University of Ljubljana +* [Alan Lukežič](https://vicos.si/people/alan_lukezic/), University of Ljubljana * Yan Song, Tampere University +Acknowledgements +---------------- + +The development of this package was supported by Slovenian research agency (ARRS) projects Z2-1866 and J2-316. + License ------- -Copyright (C) 2021 Luka Čehovin Zajc and the [VOT Challenge innitiative](http://votchallenge.net/). +Copyright (C) 2023 Luka Čehovin Zajc and the [VOT Challenge innitiative](http://votchallenge.net/). This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. @@ -27,4 +38,4 @@ You should have received a copy of the GNU General Public License along with thi Enquiries, Question and Comments -------------------------------- -If you have any further enquiries, question, or comments, please refer to the contact information link on the [VOT homepage](http://votchallenge.net/). If you would like to file a bug report or a feature request, use the [Github issue tracker](https://github.com/vicoslab/toolkit/issues). **The issue tracker is for toolkit issues only**, if you have a problem with tracker integration or any other questions, please use our [support forum](https://groups.google.com/forum/?hl=en#!forum/votchallenge-help). +If you have any further enquiries, question, or comments, please refer to the contact information link on the [VOT homepage](http://votchallenge.net/). If you would like to file a bug report or a feature request, use the [Github issue tracker](https://github.com/votchallenge/toolkit/issues). **The issue tracker is for toolkit issues only**, if you have a problem with tracker integration or any other questions, please use our [support forum](https://groups.google.com/forum/?hl=en#!forum/votchallenge-help). diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..676b72b --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,196 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +BUILDDIR = _build + +# Internal variables. +PAPEROPT_a4 = -D latex_paper_size=a4 +PAPEROPT_letter = -D latex_paper_size=letter +ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . +# the i18n builder cannot share the environment and doctrees with the others +I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . + +.PHONY: help +help: + @echo "Please use \`make ' where is one of" + @echo " html to make standalone HTML files" + @echo " dirhtml to make HTML files named index.html in directories" + @echo " singlehtml to make a single large HTML file" + @echo " pickle to make pickle files" + @echo " json to make JSON files" + @echo " htmlhelp to make HTML files and a HTML help project" + @echo " qthelp to make HTML files and a qthelp project" + @echo " applehelp to make an Apple Help Book" + @echo " devhelp to make HTML files and a Devhelp project" + @echo " epub to make an epub" + @echo " epub3 to make an epub3" + @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" + @echo " latexpdf to make LaTeX files and run them through pdflatex" + @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" + @echo " text to make text files" + @echo " man to make manual pages" + @echo " texinfo to make Texinfo files" + @echo " info to make Texinfo files and run them through makeinfo" + @echo " gettext to make PO message catalogs" + @echo " changes to make an overview of all changed/added/deprecated items" + @echo " xml to make Docutils-native XML files" + @echo " pseudoxml to make pseudoxml-XML files for display purposes" + @echo " linkcheck to check all external links for integrity" + @echo " doctest to run all doctests embedded in the documentation (if enabled)" + @echo " coverage to run coverage check of the documentation (if enabled)" + @echo " dummy to check syntax errors of document sources" + +.PHONY: clean +clean: + rm -rf $(BUILDDIR)/* + +.PHONY: html +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +.PHONY: dirhtml +dirhtml: + $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." + +.PHONY: singlehtml +singlehtml: + $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml + @echo + @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." + +.PHONY: pickle +pickle: + $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle + @echo + @echo "Build finished; now you can process the pickle files." + +.PHONY: json +json: + $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json + @echo + @echo "Build finished; now you can process the JSON files." + +.PHONY: htmlhelp +htmlhelp: + $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp + @echo + @echo "Build finished; now you can run HTML Help Workshop with the" \ + ".hhp project file in $(BUILDDIR)/htmlhelp." + +.PHONY: epub +epub: + $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub + @echo + @echo "Build finished. The epub file is in $(BUILDDIR)/epub." + +.PHONY: epub3 +epub3: + $(SPHINXBUILD) -b epub3 $(ALLSPHINXOPTS) $(BUILDDIR)/epub3 + @echo + @echo "Build finished. The epub3 file is in $(BUILDDIR)/epub3." + +.PHONY: latex +latex: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo + @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." + @echo "Run \`make' in that directory to run these through (pdf)latex" \ + "(use \`make latexpdf' here to do that automatically)." + +.PHONY: latexpdf +latexpdf: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through pdflatex..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +.PHONY: latexpdfja +latexpdfja: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through platex and dvipdfmx..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +.PHONY: text +text: + $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text + @echo + @echo "Build finished. The text files are in $(BUILDDIR)/text." + +.PHONY: man +man: + $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man + @echo + @echo "Build finished. The manual pages are in $(BUILDDIR)/man." + +.PHONY: texinfo +texinfo: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo + @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." + @echo "Run \`make' in that directory to run these through makeinfo" \ + "(use \`make info' here to do that automatically)." + +.PHONY: info +info: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo "Running Texinfo files through makeinfo..." + make -C $(BUILDDIR)/texinfo info + @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." + +.PHONY: gettext +gettext: + $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale + @echo + @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." + +.PHONY: changes +changes: + $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes + @echo + @echo "The overview file is in $(BUILDDIR)/changes." + +.PHONY: linkcheck +linkcheck: + $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck + @echo + @echo "Link check complete; look for any errors in the above output " \ + "or in $(BUILDDIR)/linkcheck/output.txt." + +.PHONY: doctest +doctest: + $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." + +.PHONY: coverage +coverage: + $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage + @echo "Testing of coverage in the sources finished, look at the " \ + "results in $(BUILDDIR)/coverage/python.txt." + +.PHONY: xml +xml: + $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml + @echo + @echo "Build finished. The XML files are in $(BUILDDIR)/xml." + +.PHONY: pseudoxml +pseudoxml: + $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml + @echo + @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." + +.PHONY: dummy +dummy: + $(SPHINXBUILD) -b dummy $(ALLSPHINXOPTS) $(BUILDDIR)/dummy + @echo + @echo "Build finished. Dummy builder generates no files." diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..61f3434 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,22 @@ +Documentation +============= + +The API section contains the generated documentation of individual structures and functions from source code docstrings. + +.. toctree:: + :maxdepth: 2 + + api/analysis + api/dataset + api/document + api/experiment + api/region + api/stack + api/tracker + api/utilities + +Core utilities +-------------- + +.. automodule:: vot + :members: \ No newline at end of file diff --git a/docs/api/analysis.rst b/docs/api/analysis.rst new file mode 100644 index 0000000..c29896a --- /dev/null +++ b/docs/api/analysis.rst @@ -0,0 +1,41 @@ +Analysis module +=============== + +The analysis module contains classes that implement various performance analysis methodologies. It also contains a parallel runtime with caching capabilities that +enables efficient execution of large-scale evaluations. + +.. automodule:: vot.analysis + :members: + +.. automodule:: vot.analysis.parallel + :members: + +Accuracy analysis +----------------- + +.. automodule:: vot.analysis.accuracy + :members: + +Failure analysis +---------------- + +.. automodule:: vot.analysis.failure + :members: + +Long-term measures +------------------ + +.. automodule:: vot.analysis.longterm + :members: + +Multi-start measures +-------------------- + +.. automodule:: vot.analysis.multistart + :members: + +Supervision analysis +-------------------- + +.. automodule:: vot.analysis.supervision + :members: \ No newline at end of file diff --git a/docs/api/dataset.rst b/docs/api/dataset.rst new file mode 100644 index 0000000..bd34f02 --- /dev/null +++ b/docs/api/dataset.rst @@ -0,0 +1,28 @@ +Dataset module +============== + +.. automodule:: vot.dataset + :members: + +.. automodule:: vot.datase.common + :members: + +Extended dataset support +------------------------ + +Many datasets are supported by the toolkit using special adapters. + +### OTB + +.. automodule:: vot.dataset.otb + :members: + +### GOT10k + +.. automodule:: vot.dataset.got10k + :members: + +### TrackingNet + +.. automodule:: vot.dataset.trackingnet + :members: diff --git a/docs/api/document.rst b/docs/api/document.rst new file mode 100644 index 0000000..09b1b2d --- /dev/null +++ b/docs/api/document.rst @@ -0,0 +1,21 @@ +Document module +============ + +.. automodule:: vot.document + :members: + +.. automodule:: vot.document.common + :members: + +HTML document generation +------------------------ + +.. automodule:: vot.document + :members: + +LaTeX document generation +------------------------- + +.. automodule:: vot.document.latex + :members: + diff --git a/docs/api/experiment.rst b/docs/api/experiment.rst new file mode 100644 index 0000000..25b3353 --- /dev/null +++ b/docs/api/experiment.rst @@ -0,0 +1,5 @@ +Experiment module +================ + +.. automodule:: vot.experiment + :members: \ No newline at end of file diff --git a/docs/api/region.rst b/docs/api/region.rst new file mode 100644 index 0000000..cf8e48b --- /dev/null +++ b/docs/api/region.rst @@ -0,0 +1,23 @@ +Region module +============ + +.. automodule:: vot.region + :members: + +Shapes +------ + +.. automodule:: vot.region.shapes + :members: + +Raster utilities +---------------- + +.. automodule:: vot.region.raster + :members: + +IO functions +------------ + +.. automodule:: vot.region.io + :members: \ No newline at end of file diff --git a/docs/api/stack.rst b/docs/api/stack.rst new file mode 100644 index 0000000..6d360d7 --- /dev/null +++ b/docs/api/stack.rst @@ -0,0 +1,5 @@ +Stack module +============ + +.. automodule:: vot.stack + :members: \ No newline at end of file diff --git a/docs/api/tracker.rst b/docs/api/tracker.rst new file mode 100644 index 0000000..c84ac22 --- /dev/null +++ b/docs/api/tracker.rst @@ -0,0 +1,18 @@ +Tracker module +============== + +.. automodule:: vot.tracker + :members: + +TraX protocol module +-------------------- + +.. automodule:: vot.tracker.trax + :members: + +Results module +-------------- + +.. automodule:: vot.tracker.results + :members: + diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst new file mode 100644 index 0000000..c1a2157 --- /dev/null +++ b/docs/api/utilities.rst @@ -0,0 +1,5 @@ +Utilities module +=============== + +.. automodule:: vot.utilities + :members: \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..dbc9f00 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +# +import os +import sys +sys.path.insert(0, os.path.abspath('../python')) + +# -- General configuration ------------------------------------------------ + + +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +from recommonmark.parser import CommonMarkParser + +source_parsers = { + '.md': CommonMarkParser, +} + +source_suffix = ['.rst', '.md'] + +master_doc = 'index' + +# General information about the project. +project = u'VOT Toolkit' +copyright = u'2022, Luka Cehovin Zajc' +author = u'Luka Cehovin Zajc' + +try: + import sys + import os + + __version__ = "0.0.0" + + exec(open(os.path.join(os.path.dirname(__file__), '..', 'vot', 'version.py')).read()) + + version = __version__ +except: + version = 'unknown' + +# The full version, including alpha/beta/rc tags. +release = version + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +# +# today = '' +# +# Else, today_fmt is used as the format for a strftime call. +# +# today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# The reST default role (used for this markup: `text`) to use for all +# documents. +# +# default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +# +# add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +# +# add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +# +# show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# A list of ignored prefixes for module index sorting. +# modindex_common_prefix = [] + +# If true, keep warnings as "system message" paragraphs in the built documents. +# keep_warnings = False + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +html_static_path = ['_static'] + +htmlhelp_basename = 'vottoolkitdoc' + +# -- Options for LaTeX output --------------------------------------------- + +latex_documents = [ + (master_doc, 'vot-toolkit.tex', u'VOT Toolkit Documentation', + u'Luka Cehovin Zajc', 'manual'), +] + +man_pages = [ + (master_doc, 'vot-toolkit', u'VOT Toolkit Documentation', + [author], 1) +] + +# If true, show URL addresses after external links. +# +# man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'VOT Toolkit', u'VOT Toolkit Documentation', + author, 'VOT Toolkit', 'The official VOT Challenge evaluation toolkit', + 'Miscellaneous'), +] + diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..58b2683 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,24 @@ +Welcome to the VOT Toolkit documentation +======================================== + +The VOT toolkit is the official evaluation tool for the [Visual Object Tracking (VOT) challenge](http://votchallenge.net/). +It is written in Python 3 language. The toolkit is designed to be easy to use and to have broad support for various trackers, +datasets and evaluation measures. + +Contributions and development +----------------------------- + +The VOT toolkit is developed by the VOT Committee, primarily by Luka Čehovin Zajc and the tracking community as an open-source project (GPLv3 license). + +Contributions to the VOT toolkit are welcome, the preferred way to do it is by submitting an issue or a pull requests on `GitHub `_. + +Index +----- + +.. toctree:: + :maxdepth: 1 + + overview + tutorials + api + diff --git a/docs/overview.rst b/docs/overview.rst new file mode 100644 index 0000000..ba24c58 --- /dev/null +++ b/docs/overview.rst @@ -0,0 +1,41 @@ +Overview +======== + +The toolkit is designed as a modular framework with several modules that address different aspects of the performance evaluation problem. + +Key concepts +------------ + +Key concepts that are used throughout the toolkit are: + +* **Dataset** - a collection of sequences that is used for performance evaluation. A dataset is a collection of **sequences**. +* **Sequence** - a sequence of frames with correspoding ground truth annotations for one or more objects. A sequence is a collection of **frames**. +* **Tracker** - a tracker is an algorithm that takes frames from a sequence as input (one by one) and produces a set of **trajectories** as output. +* **Experiment** - an experiment is a method that applies a tracker to a given sequence in a specific way. +* **Analysis** - an analysis is a set of **measures** that are used to evaluate the performance of a tracker (compare predicted trajectories to groundtruth). +* **Stack** - a stack is a collection of **experiments** and **analyses** that are performed on a given dataset. +* **Workspace** - a workspace is a collection of experiments and analyses that are performed on a given dataset. + +Tracker support +--------------- + +The toolkit supports various ways of interacting with a tracking methods. Primary manner (at the only supported at the moment) is using the TraX protocol. +The toolkit provides a wrapper for the TraX protocol that allows to use any tracker that supports the protocol. + +Dataset support +--------------- + +The toolkit is capable of using any dataset that is provided in the toolkit format. +The toolkit format is a simple directory structure that contains a set of sequences. Each sequence is a directory that contains a set of frames and a groundtruth file. +The groundtruth file is a text file that contains one line per frame. Each line contains the bounding box of the object in the frame in the format `x,y,w,h`. The toolkit format is used by the toolkit itself and by the VOT challenges. + + +Performance methodology support +------------------------------- + +Various performance measures and visualzatons are implemented, most of them were used in VOT challenges. + + * **Accuracy** - the accuracy measure is the overlap between the predicted and groundtruth bounding boxes. The overlap is measured using the intersection over union (IoU) measure. + * **Robustness** - the robustness measure is the number of failures of the tracker. A failure is defined as the overlap between the predicted and groundtruth bounding boxes being less than 0.5. + * **Expected Average Overlap** - the expected average overlap (EAO) is a measure that combines accuracy and robustness into a single measure. The EAO is computed as the area under the accuracy-robustness curve. + * **Expected Overlap** - the expected overlap (EO) is a measure that combines accuracy and robustness into a single measure. The EO is computed as the area under the accuracy-robustness curve. diff --git a/docs/tutorials.rst b/docs/tutorials.rst new file mode 100644 index 0000000..2e38175 --- /dev/null +++ b/docs/tutorials.rst @@ -0,0 +1,15 @@ +Tutorials +========= + +The main purpose of the toolkit is to facilitate tracker evaluation for VOT challenges and benchmarks. But there are many other +ways that the toolkit can be used and extended. The following tutorials are provided to help you get started with the toolkit. + +.. toctree:: + :maxdepth: 1 + + tutorial_introduction + tutorial_integration + tutorial_evaluation + tutorial_dataset + tutorial_stack + tutorial_jupyter diff --git a/requirements.txt b/requirements.txt index a509fff..6e58b59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -vot-trax>=3.0.2 +vot-trax>=4.0.1 tqdm>=4.37 numpy>=1.16 opencv-python>=4.0 @@ -16,4 +16,5 @@ dominate>=2.5 cachetools>=4.1 bidict>=0.19 phx-class-registry>=3.0 -attributee>=0.1.3 \ No newline at end of file +attributee>=0.1.8 +lazy-object-proxy>=1.9 \ No newline at end of file diff --git a/setup.py b/setup.py index 5163392..e2fdda1 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "Development Status :: 4 - Beta", "Intended Audience :: Science/Research", ], - python_requires='>=3.6', + python_requires='>=3.7', entry_points={ 'console_scripts': ['vot=vot.utilities.cli:main'], }, diff --git a/vot/__init__.py b/vot/__init__.py index ea9adbc..2db5dfc 100644 --- a/vot/__init__.py +++ b/vot/__init__.py @@ -1,3 +1,4 @@ +""" Some basic functions and classes used by the toolkit. """ import os import logging @@ -44,7 +45,7 @@ def check_updates() -> bool: try: get_logger().debug("Checking for new version") - response = requests.get(version_url, timeout=2) + response = requests.get(version_url, timeout=5, allow_redirects=True) except Exception as e: get_logger().debug("Unable to retrieve version information %s", e) return False, None @@ -62,12 +63,44 @@ def check_updates() -> bool: else: return False, None +from attributee import Attributee, Integer, Boolean + +class GlobalConfiguration(Attributee): + """Global configuration object for the toolkit. It is used to store global configuration options. + """ + + debug_mode = Boolean(default=False, description="Enables debug mode for the toolkit.") + sequence_cache_size = Integer(default=1000, description="Maximum number of sequences to keep in cache.") + results_binary = Boolean(default=True, description="Enables binary results format.") + mask_optimize_read = Boolean(default=True, description="Enables mask optimization when reading masks.") + + def __init__(self): + """Initializes the global configuration object. It reads the configuration from environment variables. + + Raises: + ValueError: When an invalid value is provided for an attribute. + """ + kwargs = {} + for k in self.attributes(): + envname = "VOT_{}".format(k.upper()) + if envname in os.environ: + kwargs[k] = os.environ[envname] + super().__init__(**kwargs) + _logger.debug("Global configuration: %s", self) + + def __repr__(self): + """Returns a string representation of the global configuration object.""" + return "debug_mode={} sequence_cache_size={} results_binary={} mask_optimize_read={}".format( + self.debug_mode, self.sequence_cache_size, self.results_binary, self.mask_optimize_read + ) + +config = GlobalConfiguration() + def check_debug() -> bool: """Checks if debug is enabled for the toolkit via an environment variable. Returns: bool: True if debug is enabled, False otherwise """ - var = os.environ.get("VOT_TOOLKIT_DEBUG", "false").lower() - return var in ["true", "1"] + return config.debug_mode diff --git a/vot/__main__.py b/vot/__main__.py index ab9bbe2..9db15b3 100644 --- a/vot/__main__.py +++ b/vot/__main__.py @@ -1,4 +1,4 @@ - +""" This module is a shortcut for the CLI interface so that it can be run as a "vot" module. """ # Just a shortcut for the CLI interface so that it can be run as a "vot" module. diff --git a/vot/analysis/__init__.py b/vot/analysis/__init__.py index 83baddc..ffec339 100644 --- a/vot/analysis/__init__.py +++ b/vot/analysis/__init__.py @@ -1,14 +1,11 @@ -import logging -import functools -import threading +"""This module contains classes and functions for analysis of tracker performance. The analysis is performed on the results of an experiment.""" + from collections import namedtuple -from enum import Enum, Flag, auto -from typing import List, Optional, Tuple, Dict, Any, Set, Union, NamedTuple +from enum import Enum, auto +from typing import List, Optional, Tuple, Any from abc import ABC, abstractmethod -from concurrent.futures import Executor import importlib -from cachetools import Cache from class_registry import ClassRegistry from attributee import Attributee, String @@ -26,8 +23,12 @@ class MissingResultsException(ToolkitException): """Exception class that denotes missing results during analysis """ - pass - + def __init__(self, *args: object) -> None: + """Constructor""" + if not args: + args = ["Missing results"] + super().__init__(*args) + class Sorting(Enum): """Sorting direction enumeration class """ @@ -67,19 +68,24 @@ def __init__(self, name: str, abbreviation: Optional[str] = None, description: O @property def name(self) -> str: + """Name of the result, used in reports""" return self._name @property def abbreviation(self) -> str: + """Abbreviation, if empty, then name is used. Can be used to define a shorter text representation.""" return self._abbreviation @property def description(self) -> str: + """Description of the result, used in reports""" return self._description class Label(Result): + """Label describes a single categorical output of an analysis. Can have a set of possible values.""" def __init__(self, *args, **kwargs): + """Constructor.""" super().__init__(*args, **kwargs) class Measure(Result): @@ -89,6 +95,19 @@ class Measure(Result): def __init__(self, name: str, abbreviation: Optional[str] = None, minimal: Optional[float] = None, \ maximal: Optional[float] = None, direction: Optional[Sorting] = Sorting.UNSORTABLE): + """Constructor for Measure class. + + Arguments: + name {str} -- Name of the measure, used in reports + + Keyword Arguments: + abbreviation {Optional[str]} -- Abbreviation, if empty, then name is used. + Can be used to define a shorter text representation. (default: {None}) + minimal {Optional[float]} -- Minimal value of the measure. If None, then the measure is not bounded from below. (default: {None}) + maximal {Optional[float]} -- Maximal value of the measure. If None, then the measure is not bounded from above. (default: {None}) + direction {Optional[Sorting]} -- Direction of sorting. If Sorting.UNSORTABLE, then the measure is not sortable. (default: {Sorting.UNSORTABLE}) + + """ super().__init__(name, abbreviation) self._minimal = minimal @@ -97,14 +116,17 @@ def __init__(self, name: str, abbreviation: Optional[str] = None, minimal: Optio @property def minimal(self) -> float: + """Minimal value of the measure. If None, then the measure is not bounded from below.""" return self._minimal @property def maximal(self) -> float: + """Maximal value of the measure. If None, then the measure is not bounded from above.""" return self._maximal @property def direction(self) -> Sorting: + """Direction of sorting. If Sorting.UNSORTABLE, then the measure is not sortable.""" return self._direction class Drawable(Result): @@ -124,11 +146,30 @@ def __init__(self, name: str, abbreviation: Optional[str] = None, trait: Optiona @property def trait(self): + """Trait of the data, used for specification""" return self._trait class Multidimensional(Drawable): + """Base class for multidimensional results. This class is used to describe results that can be visualized in a scatter plot.""" + def __init__(self, name: str, dimensions: int, abbreviation: Optional[str] = None, minimal: Optional[Tuple[float]] = None, \ maximal: Optional[Tuple[float]] = None, labels: Optional[Tuple[str]] = None, trait: Optional[str] = None): + """Constructor for Multidimensional class. + + Arguments: + name {str} -- Name of the measure, used in reports + dimensions {int} -- Number of dimensions of the result + + Keyword Arguments: + abbreviation {Optional[str]} -- Abbreviation, if empty, then name is used. + Can be used to define a shorter text representation. (default: {None}) + minimal {Optional[Tuple[float]]} -- Minimal value of the measure. If None, then the measure is not bounded from below. (default: {None}) + maximal {Optional[Tuple[float]]} -- Maximal value of the measure. If None, then the measure is not bounded from above. (default: {None}) + labels {Optional[Tuple[str]]} -- Labels for each dimension. (default: {None}) + trait {Optional[str]} -- Trait of the data, used for specification . Defaults to None. + + """ + assert(dimensions > 1) super().__init__(name, abbreviation, trait) self._dimensions = dimensions @@ -138,15 +179,19 @@ def __init__(self, name: str, dimensions: int, abbreviation: Optional[str] = Non @property def dimensions(self): + """Number of dimensions of the result""" return self._dimensions def minimal(self, i): + """Minimal value of the i-th dimension. If None, then the measure is not bounded from below.""" return self._minimal[i] def maximal(self, i): + """Maximal value of the i-th dimension. If None, then the measure is not bounded from above.""" return self._maximal[i] def label(self, i): + """Label for the i-th dimension.""" return self._labels[i] class Point(Multidimensional): @@ -160,6 +205,20 @@ class Plot(Drawable): def __init__(self, name: str, abbreviation: Optional[str] = None, wrt: str = "frames", minimal: Optional[float] = None, \ maximal: Optional[float] = None, trait: Optional[str] = None): + """Constructor for Plot class. + + Arguments: + name {str} -- Name of the measure, used in reports + + Keyword Arguments: + abbreviation {Optional[str]} -- Abbreviation, if empty, then name is used. + Can be used to define a shorter text representation. (default: {None}) + wrt {str} -- Unit of the independent variable. (default: {"frames"}) + minimal {Optional[float]} -- Minimal value of the measure. If None, then the measure is not bounded from below. (default: {None}) + maximal {Optional[float]} -- Maximal value of the measure. If None, then the measure is not bounded from above. (default: {None}) + trait {Optional[str]} -- Trait of the data, used for specification . Defaults to None. + + """ super().__init__(name, abbreviation, trait) self._wrt = wrt self._minimal = minimal @@ -167,15 +226,17 @@ def __init__(self, name: str, abbreviation: Optional[str] = None, wrt: str = "fr @property def minimal(self): + """Minimal value of the measure. If None, then the measure is not bounded from below.""" return self._minimal @property def maximal(self): + """Maximal value of the measure. If None, then the measure is not bounded from above.""" return self._maximal - @property def wrt(self): + """Unit of the independent variable.""" return self._wrt class Curve(Multidimensional): @@ -183,42 +244,73 @@ class Curve(Multidimensional): """ class Analysis(Attributee): + """Base class for all analysis classes. Analysis is a class that descibes computation of one or more performance metrics for a given experiment.""" - name = String(default=None) + name = String(default=None, description="Name of the analysis") def __init__(self, **kwargs): + """Constructor for Analysis class. + + Keyword Arguments: + name {str} -- Name of the analysis (default: {None}) + """ super().__init__(**kwargs) self._identifier_cache = None def compatible(self, experiment: Experiment): + """Checks if the analysis is compatible with the experiment type.""" raise NotImplementedError() @property def title(self) -> str: + """Returns the title of the analysis. If name is not set, then the default title is returned.""" + + if self.name is None: + return self._title_default + else: + return self.name + + @property + def _title_default(self) -> str: + """Returns the default title of the analysis. This is used when name is not set.""" raise NotImplementedError() def dependencies(self) -> List["Analysis"]: + """Returns a list of dependencies of the analysis. This is used to determine the order of execution of the analysis.""" return [] @property def identifier(self) -> str: + """Returns a unique identifier of the analysis. This is used to determine if the analysis has been already computed.""" + if not self._identifier_cache is None: return self._identifier_cache params = self.dump() del params["name"] + confighash = arg_hash(**params) self._identifier_cache = class_fullname(self) + "@" + confighash - + return self._identifier_cache def describe(self) -> Tuple["Result"]: - """Returns a tuple of descriptions of results - """ + """Returns a tuple of descriptions of results of the analysis.""" raise NotImplementedError() def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid: + """Computes the analysis for the given experiment, trackers and sequences. The dependencies are the results of the dependnent analyses. + The result is a grid with the results of the analysis. The grid is indexed by trackers and sequences. The axes are described by the axes() method. + + Args: + experiment (Experiment): Experiment to compute the analysis for. + trackers (List[Tracker]): List of trackers to compute the analysis for. + sequences (List[Sequence]): List of sequences to compute the analysis for. + dependencies (List[Grid]): List of dependencies of the analysis. + + Returns: Grid with the results of the analysis. + """ raise NotImplementedError() @property @@ -227,15 +319,18 @@ def axes(self) -> Axes: raise NotImplementedError() def commit(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence]): + """Commits the analysis for execution on default processor.""" return AnalysisProcessor.commit_default(self, experiment, trackers, sequences) def run(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence]): + """Runs the analysis on default processor.""" return AnalysisProcessor.run_default(self, experiment, trackers, sequences) class SeparableAnalysis(Analysis): """Analysis that is separable with respect to trackers and/or sequences, each part can be processed in parallel as a separate job. The separation is determined by the result of the axes() method: Axes.BOTH means separation - in tracker-sequence pairs, Axes.TRACKER means separation according to + in tracker-sequence pairs, Axes.TRACKER means separation according to trackers and Axes.SEQUENCE means separation + according to sequences. """ SeparablePart = namedtuple("SeparablePart", ["trackers", "sequences", "tid", "sid"]) @@ -252,15 +347,13 @@ def subcompute(self, experiment: Experiment, tracker, sequence, dependencies: Li note that each dependency is processed using select function to only contain information relevant for the current part of the analysis - Raises: - NotImplementedError: [description] - Returns: - Tuple[Any]: [description] + Tuple[Any]: Tuple of results of the analysis """ raise NotImplementedError() def __init__(self, **kwargs): + """Initializes the analysis. The axes semantic description is checked to be compatible with the dependencies.""" super().__init__(**kwargs) # All dependencies should be mappable to individual parts. If parts contain @@ -270,6 +363,15 @@ def __init__(self, **kwargs): assert all([dependency.axes != Axes.BOTH for dependency in self.dependencies()]) def separate(self, trackers: List[Tracker], sequences: List[Sequence]) -> List["SeparablePart"]: + """Separates the analysis into parts that can be processed separately. + + Args: + trackers (List[Tracker]): List of trackers to compute the analysis for. + sequences (List[Sequence]): List of sequences to compute the analysis for. + + Returns: List of parts of the analysis. + + """ if self.axes == Axes.BOTH: parts = [] for i, tracker in enumerate(trackers): @@ -288,6 +390,17 @@ def separate(self, trackers: List[Tracker], sequences: List[Sequence]) -> List[" return parts def join(self, trackers: List[Tracker], sequences: List[Sequence], results: List[Tuple[Any]]): + """Joins the results of the analysis into a single grid. The results are indexed by trackers and sequences. + + Args: + trackers (List[Tracker]): List of trackers to compute the analysis for. + sequences (List[Sequence]): List of sequences to compute the analysis for. + results (List[Tuple[Any]]): List of results of the analysis. + + Returns: + Grid: Grid with the results of the analysis. + """ + if self.axes == Axes.BOTH: transformed_results = Grid(len(trackers), len(sequences)) k = 0 @@ -356,7 +469,7 @@ def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: Li return Grid.scalar(self.subcompute(experiment, trackers[0], sequences[0], dependencies)) elif self.axes == Axes.TRACKERS and len(trackers) == 1: return Grid.scalar(self.subcompute(experiment, trackers[0], sequences, dependencies)) - elif self.axes == Axes.BOTH and len(sequences) == 1: + elif self.axes == Axes.SEQUENCES and len(sequences) == 1: return Grid.scalar(self.subcompute(experiment, trackers, sequences[0], dependencies)) else: parts = self.separate(trackers, sequences) @@ -370,11 +483,14 @@ def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: Li @property def axes(self) -> Axes: + """Returns the axes of the analysis. This is used to determine how the analysis is split into parts.""" return Axes.BOTH class SequenceAggregator(Analysis): # pylint: disable=W0223 + """Base class for sequence aggregators. Sequence aggregators take the results of a tracker and aggregate them over sequences.""" def __init__(self, **kwargs): + """Base constructor.""" super().__init__(**kwargs) # We only support one dependency in aggregator ... assert len(self.dependencies()) == 1 @@ -383,9 +499,27 @@ def __init__(self, **kwargs): @abstractmethod def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]: + """Aggregate the results of the analysis over sequences for a single tracker. + + Args: + tracker (Tracker): Tracker to aggregate the results for. + sequences (List[Sequence]): List of sequences to aggregate the results for. + results (Grid): Results of the analysis for the tracker and sequences. + + """ raise NotImplementedError() def compute(self, _: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid: + """Compute the analysis for a list of trackers and sequences. + + Args: + trackers (List[Tracker]): List of trackers to compute the analysis for. + sequences (List[Sequence]): List of sequences to compute the analysis for. + dependencies (List[Grid]): List of dependencies, should be one grid with results of the dependency analysis. + + Returns: + Grid: Grid with the results of the analysis. + """ results = dependencies[0] transformed_results = Grid(len(trackers), 1) @@ -396,6 +530,7 @@ def compute(self, _: Experiment, trackers: List[Tracker], sequences: List[Sequen @property def axes(self) -> Axes: + """The analysis is separable in trackers.""" return Axes.TRACKERS class TrackerSeparableAnalysis(SeparableAnalysis): @@ -404,10 +539,12 @@ class TrackerSeparableAnalysis(SeparableAnalysis): @abstractmethod def subcompute(self, experiment: Experiment, tracker: Tracker, sequences: List[Sequence], dependencies: List[Grid]) -> Tuple[Any]: + """Compute the analysis for a single tracker.""" raise NotImplementedError() @property def axes(self) -> Axes: + """The analysis is separable in trackers.""" return Axes.TRACKERS class SequenceSeparableAnalysis(SeparableAnalysis): @@ -416,18 +553,20 @@ class SequenceSeparableAnalysis(SeparableAnalysis): @abstractmethod def subcompute(self, experiment: Experiment, trackers: List[Tracker], sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: + """Compute the analysis for a single sequence.""" raise NotImplementedError @property def axes(self) -> Axes: + """The analysis is separable in sequences.""" return Axes.SEQUENCES def is_special(region: Region, code=None) -> bool: + """Check if the region is special (not a shape) and optionally if it has a specific code.""" if code is None: return region.type == RegionType.SPECIAL return region.type == RegionType.SPECIAL and region.code == code -from ._processor import process_stack_analyses, AnalysisProcessor, AnalysisError - -for module in ["vot.analysis.multistart", "vot.analysis.supervised", "vot.analysis.basic", "vot.analysis.tpr"]: - importlib.import_module(module) +from .processor import process_stack_analyses, AnalysisProcessor, AnalysisError +for module in [".multistart", ".supervised", ".accuracy", ".failures", ".longterm"]: + importlib.import_module(module, package="vot.analysis") diff --git a/vot/analysis/accuracy.py b/vot/analysis/accuracy.py new file mode 100644 index 0000000..27b9e45 --- /dev/null +++ b/vot/analysis/accuracy.py @@ -0,0 +1,291 @@ +"""Accuracy analysis. Computes average overlap between predicted and groundtruth regions.""" + +from typing import List, Tuple, Any + +import numpy as np + +from attributee import Boolean, Integer, Include, Float + +from vot.analysis import (Measure, + MissingResultsException, + SequenceAggregator, Sorting, + is_special, SeparableAnalysis, + analysis_registry, Curve) +from vot.dataset import Sequence +from vot.experiment import Experiment +from vot.experiment.multirun import (MultiRunExperiment) +from vot.region import Region, calculate_overlaps +from vot.tracker import Tracker, Trajectory +from vot.utilities.data import Grid + +def gather_overlaps(trajectory: List[Region], groundtruth: List[Region], burnin: int = 10, + ignore_unknown: bool = True, ignore_invisible: bool = False, bounds = None, threshold: float = None) -> np.ndarray: + """Gather overlaps between trajectory and groundtruth regions. + + Args: + trajectory (List[Region]): List of regions predicted by the tracker. + groundtruth (List[Region]): List of groundtruth regions. + burnin (int, optional): Number of frames to skip at the beginning of the sequence. Defaults to 10. + ignore_unknown (bool, optional): Ignore unknown regions in the groundtruth. Defaults to True. + ignore_invisible (bool, optional): Ignore invisible regions in the groundtruth. Defaults to False. + bounds ([type], optional): Bounds of the sequence. Defaults to None. + threshold (float, optional): Minimum overlap to consider. Defaults to None. + + Returns: + np.ndarray: List of overlaps.""" + + overlaps = np.array(calculate_overlaps(trajectory, groundtruth, bounds)) + mask = np.ones(len(overlaps), dtype=bool) + + if threshold is None: threshold = -1 + + for i, (region_tr, region_gt) in enumerate(zip(trajectory, groundtruth)): + # Skip if groundtruth is unknown + if is_special(region_gt, Sequence.UNKNOWN): + mask[i] = False + elif ignore_invisible and region_gt.is_empty(): + mask[i] = False + # Skip if predicted is unknown + elif is_special(region_tr, Trajectory.UNKNOWN) and ignore_unknown: + mask[i] = False + # Skip if predicted is initialization frame + elif is_special(region_tr, Trajectory.INITIALIZATION): + for j in range(i, min(len(trajectory), i + burnin)): + mask[j] = False + elif is_special(region_tr, Trajectory.FAILURE): + mask[i] = False + elif overlaps[i] <= threshold: + mask[i] = False + + return overlaps[mask] + +@analysis_registry.register("accuracy") +class SequenceAccuracy(SeparableAnalysis): + """Sequence accuracy analysis. Computes average overlap between predicted and groundtruth regions.""" + + burnin = Integer(default=10, val_min=0, description="Number of frames to skip after the initialization.") + ignore_unknown = Boolean(default=True, description="Ignore unknown regions in the groundtruth.") + ignore_invisible = Boolean(default=False, description="Ignore invisible regions in the groundtruth.") + bounded = Boolean(default=True, description="Consider only the bounded region of the sequence.") + threshold = Float(default=None, val_min=0, val_max=1, description="Minimum overlap to consider.") + + def compatible(self, experiment: Experiment): + """Check if the experiment is compatible with the analysis.""" + return isinstance(experiment, MultiRunExperiment) + + @property + def _title_default(self): + """Default title of the analysis.""" + return "Sequence accurarcy" + + def describe(self): + """Describe the analysis.""" + return Measure(self.title, "", 0, 1, Sorting.DESCENDING), + + def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: + """Compute the analysis for a single sequence. + + Args: + experiment (Experiment): Experiment. + tracker (Tracker): Tracker. + sequence (Sequence): Sequence. + dependencies (List[Grid]): List of dependencies. + + Returns: + Tuple[Any]: Tuple of results. + """ + assert isinstance(experiment, MultiRunExperiment) + + objects = sequence.objects() + objects_accuracy = 0 + bounds = (sequence.size) if self.bounded else None + + for object in objects: + trajectories = experiment.gather(tracker, sequence, objects=[object]) + if len(trajectories) == 0: + raise MissingResultsException() + + cummulative = 0 + + for trajectory in trajectories: + overlaps = gather_overlaps(trajectory.regions(), sequence.object(object), self.burnin, + ignore_unknown=self.ignore_unknown, ignore_invisible=self.ignore_invisible, bounds=bounds, threshold=self.threshold) + if overlaps.size > 0: + cummulative += np.mean(overlaps) + + objects_accuracy += cummulative / len(trajectories) + + return objects_accuracy / len(objects), + +@analysis_registry.register("average_accuracy") +class AverageAccuracy(SequenceAggregator): + """Average accuracy analysis. Computes average overlap between predicted and groundtruth regions.""" + + analysis = Include(SequenceAccuracy, description="Sequence accuracy analysis.") + weighted = Boolean(default=True, description="Weight accuracy by the number of frames.") + + def compatible(self, experiment: Experiment): + """Check if the experiment is compatible with the analysis. This analysis requires a multirun experiment.""" + return isinstance(experiment, MultiRunExperiment) + + @property + def _title_default(self): + """Default title of the analysis.""" + return "Accurarcy" + + def dependencies(self): + """List of dependencies.""" + return self.analysis, + + def describe(self): + """Describe the analysis.""" + return Measure(self.title, "", 0, 1, Sorting.DESCENDING), + + def aggregate(self, _: Tracker, sequences: List[Sequence], results: Grid): + """Aggregate the results of the analysis. + + Args: + tracker (Tracker): Tracker. + sequences (List[Sequence]): List of sequences. + results (Grid): Grid of results. + + Returns: + Tuple[Any]: Tuple of results. + """ + + accuracy = 0 + frames = 0 + + for i, sequence in enumerate(sequences): + if results[i, 0] is None: + continue + + if self.weighted: + accuracy += results[i, 0][0] * len(sequence) + frames += len(sequence) + else: + accuracy += results[i, 0][0] + frames += 1 + + return accuracy / frames, + +@analysis_registry.register("success_plot") +class SuccessPlot(SeparableAnalysis): + """Success plot analysis. Computes the success plot of the tracker.""" + + ignore_unknown = Boolean(default=True, description="Ignore unknown regions in the groundtruth.") + ignore_invisible = Boolean(default=False, description="Ignore invisible regions in the groundtruth.") + burnin = Integer(default=0, val_min=0, description="Number of frames to skip after the initialization.") + bounded = Boolean(default=True, description="Consider only the bounded region of the sequence.") + threshold = Float(default=None, val_min=0, val_max=1, description="Minimum overlap to consider.") + resolution = Integer(default=100, val_min=2, description="Number of points in the plot.") + + def compatible(self, experiment: Experiment): + """Check if the experiment is compatible with the analysis. This analysis is only compatible with multi-run experiments.""" + return isinstance(experiment, MultiRunExperiment) + + @property + def _title_default(self): + """Default title of the analysis.""" + return "Sequence success plot" + + def describe(self): + """Describe the analysis.""" + return Curve("Plot", 2, "S", minimal=(0, 0), maximal=(1, 1), labels=("Threshold", "Success"), trait="success"), + + def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: + """Compute the analysis for a single sequence. + + Args: + experiment (Experiment): Experiment. + tracker (Tracker): Tracker. + sequence (Sequence): Sequence. + dependencies (List[Grid]): List of dependencies. + + Returns: + Tuple[Any]: Tuple of results. + """ + + assert isinstance(experiment, MultiRunExperiment) + + objects = sequence.objects() + bounds = (sequence.size) if self.bounded else None + + axis_x = np.linspace(0, 1, self.resolution) + axis_y = np.zeros_like(axis_x) + + for object in objects: + trajectories = experiment.gather(tracker, sequence, objects=[object]) + if len(trajectories) == 0: + raise MissingResultsException() + + object_y = np.zeros_like(axis_x) + + for trajectory in trajectories: + overlaps = gather_overlaps(trajectory.regions(), sequence.object(object), burnin=self.burnin, ignore_unknown=self.ignore_unknown, ignore_invisible=self.ignore_invisible, bounds=bounds, threshold=self.threshold) + + for i, threshold in enumerate(axis_x): + if threshold == 1: + # Nicer handling of the edge case + object_y[i] += np.sum(overlaps >= threshold) / len(overlaps) + else: + object_y[i] += np.sum(overlaps > threshold) / len(overlaps) + + axis_y += object_y / len(trajectories) + + axis_y /= len(objects) + + return [(x, y) for x, y in zip(axis_x, axis_y)], + + +@analysis_registry.register("average_success_plot") +class AverageSuccessPlot(SequenceAggregator): + """Average success plot analysis. Computes the average success plot of the tracker.""" + + resolution = Integer(default=100, val_min=2) + analysis = Include(SuccessPlot) + + def dependencies(self): + """List of dependencies.""" + return self.analysis, + + def compatible(self, experiment: Experiment): + """Check if the experiment is compatible with the analysis. This analysis is only compatible with multi-run experiments.""" + return isinstance(experiment, MultiRunExperiment) + + @property + def _title_default(self): + """Default title of the analysis.""" + return "Success plot" + + def describe(self): + """Describe the analysis.""" + return Curve("Plot", 2, "S", minimal=(0, 0), maximal=(1, 1), labels=("Threshold", "Success"), trait="success"), + + def aggregate(self, _: Tracker, sequences: List[Sequence], results: Grid): + """Aggregate the results of the analysis. + + Args: + tracker (Tracker): Tracker. + sequences (List[Sequence]): List of sequences. + results (Grid): Grid of results. + + Returns: + Tuple[Any]: Tuple of results. + """ + + axis_x = np.linspace(0, 1, self.resolution) + axis_y = np.zeros_like(axis_x) + + for i, _ in enumerate(sequences): + if results[i, 0] is None: + continue + + curve = results[i, 0][0] + + for j, (_, y) in enumerate(curve): + axis_y[j] += y + + axis_y /= len(sequences) + + return [(x, y) for x, y in zip(axis_x, axis_y)], diff --git a/vot/analysis/basic.py b/vot/analysis/basic.py deleted file mode 100644 index de38c4b..0000000 --- a/vot/analysis/basic.py +++ /dev/null @@ -1,184 +0,0 @@ -from typing import List, Tuple, Any - -import numpy as np - -from attributee import Boolean, Integer, Include - -from vot.analysis import (Measure, - MissingResultsException, - SequenceAggregator, Sorting, - is_special, SeparableAnalysis, - analysis_registry) -from vot.dataset import Sequence -from vot.experiment import Experiment -from vot.experiment.multirun import (MultiRunExperiment, SupervisedExperiment) -from vot.region import Region, Special, calculate_overlaps -from vot.tracker import Tracker -from vot.utilities.data import Grid - -def compute_accuracy(trajectory: List[Region], sequence: Sequence, burnin: int = 10, - ignore_unknown: bool = True, bounded: bool = True) -> float: - - overlaps = np.array(calculate_overlaps(trajectory, sequence.groundtruth(), (sequence.size) if bounded else None)) - mask = np.ones(len(overlaps), dtype=bool) - - for i, region in enumerate(trajectory): - if is_special(region, Special.UNKNOWN) and ignore_unknown: - mask[i] = False - elif is_special(region, Special.INITIALIZATION): - for j in range(i, min(len(trajectory), i + burnin)): - mask[j] = False - elif is_special(region, Special.FAILURE): - mask[i] = False - - if any(mask): - return np.mean(overlaps[mask]), np.sum(mask) - else: - return 0, 0 - -def compute_eao_partial(overlaps: List, success: List[bool], curve_length: int): - phi = curve_length * [float(0)] - active = curve_length * [float(0)] - - for o, success in zip(overlaps, success): - - o_array = np.array(o) - - for j in range(1, curve_length): - - if j < len(o): - phi[j] += np.mean(o_array[1:j+1]) - active[j] += 1 - elif not success: - phi[j] += np.sum(o_array[1:len(o)]) / (j - 1) - active[j] += 1 - - phi = [p / a if a > 0 else 0 for p, a in zip(phi, active)] - return phi, active - -def count_failures(trajectory: List[Region]) -> Tuple[int, int]: - return len([region for region in trajectory if is_special(region, Special.FAILURE)]), len(trajectory) - -@analysis_registry.register("accuracy") -class SequenceAccuracy(SeparableAnalysis): - - burnin = Integer(default=10, val_min=0) - ignore_unknown = Boolean(default=True) - bounded = Boolean(default=True) - - def compatible(self, experiment: Experiment): - return isinstance(experiment, MultiRunExperiment) - - @property - def title(self): - return "Sequence accurarcy" - - def describe(self): - return Measure("Accuracy", "AUC", 0, 1, Sorting.DESCENDING), - - def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: - - assert isinstance(experiment, MultiRunExperiment) - - trajectories = experiment.gather(tracker, sequence) - - if len(trajectories) == 0: - raise MissingResultsException() - - cummulative = 0 - for trajectory in trajectories: - accuracy, _ = compute_accuracy(trajectory.regions(), sequence, self.burnin, self.ignore_unknown, self.bounded) - cummulative = cummulative + accuracy - - return cummulative / len(trajectories), - -@analysis_registry.register("average_accuracy") -class AverageAccuracy(SequenceAggregator): - - analysis = Include(SequenceAccuracy) - weighted = Boolean(default=True) - - def compatible(self, experiment: Experiment): - return isinstance(experiment, MultiRunExperiment) - - @property - def title(self): - return "Average accurarcy" - - def dependencies(self): - return self.analysis, - - def describe(self): - return Measure("Accuracy", "AUC", 0, 1, Sorting.DESCENDING), - - def aggregate(self, _: Tracker, sequences: List[Sequence], results: Grid): - accuracy = 0 - frames = 0 - - for i, sequence in enumerate(sequences): - if results[i, 0] is None: - continue - - if self.weighted: - accuracy += results[i, 0][0] * len(sequence) - frames += len(sequence) - else: - accuracy += results[i, 0][0] - frames += 1 - - return accuracy / frames, - -@analysis_registry.register("failures") -class FailureCount(SeparableAnalysis): - - def compatible(self, experiment: Experiment): - return isinstance(experiment, SupervisedExperiment) - - @property - def title(self): - return "Number of failures" - - def describe(self): - return Measure("Failures", "F", 0, None, Sorting.ASCENDING), - - def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: - - assert isinstance(experiment, SupervisedExperiment) - - trajectories = experiment.gather(tracker, sequence) - - if len(trajectories) == 0: - raise MissingResultsException() - - failures = 0 - for trajectory in trajectories: - failures = failures + count_failures(trajectory.regions())[0] - - return failures / len(trajectories), len(trajectories[0]) - - -@analysis_registry.register("cumulative_failures") -class CumulativeFailureCount(SequenceAggregator): - - analysis = Include(FailureCount) - - def compatible(self, experiment: Experiment): - return isinstance(experiment, SupervisedExperiment) - - def dependencies(self): - return self.analysis, - - @property - def title(self): - return "Number of failures" - - def describe(self): - return Measure("Failures", "F", 0, None, Sorting.ASCENDING), - - def aggregate(self, _: Tracker, sequences: List[Sequence], results: Grid): - failures = 0 - - for a in results: - failures = failures + a[0] - - return failures, diff --git a/vot/analysis/failures.py b/vot/analysis/failures.py new file mode 100644 index 0000000..615bdb7 --- /dev/null +++ b/vot/analysis/failures.py @@ -0,0 +1,101 @@ +"""This module contains the implementation of the FailureCount analysis. The analysis counts the number of failures in one or more sequences.""" + +from typing import List, Tuple, Any + +from attributee import Include + +from vot.analysis import (Measure, + MissingResultsException, + SequenceAggregator, Sorting, + is_special, SeparableAnalysis, + analysis_registry) +from vot.dataset import Sequence +from vot.experiment import Experiment +from vot.experiment.multirun import (SupervisedExperiment) +from vot.region import Region, Special, calculate_overlaps +from vot.tracker import Tracker +from vot.utilities.data import Grid + + +def count_failures(trajectory: List[Region]) -> Tuple[int, int]: + """Count the number of failures in a trajectory. A failure is defined as a region that overlaps with a Special.FAILURE region.""" + return len([region for region in trajectory if is_special(region, SupervisedExperiment.FAILURE)]), len(trajectory) + + +@analysis_registry.register("failures") +class FailureCount(SeparableAnalysis): + """Count the number of failures in a sequence. A failure is defined as a region that overlaps with a Special.FAILURE region.""" + + def compatible(self, experiment: Experiment): + """Check if the experiment is compatible with the analysis.""" + return isinstance(experiment, SupervisedExperiment) + + @property + def _title_default(self): + """Default title for the analysis.""" + return "Number of failures" + + def describe(self): + """Describe the analysis.""" + return Measure("Failures", "F", 0, None, Sorting.ASCENDING), + + def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: + """Compute the analysis for a single sequence.""" + + assert isinstance(experiment, SupervisedExperiment) + + objects = sequence.objects() + objects_failures = 0 + + for object in objects: + trajectories = experiment.gather(tracker, sequence, objects=[object]) + if len(trajectories) == 0: + raise MissingResultsException() + + failures = 0 + for trajectory in trajectories: + failures = failures + count_failures(trajectory.regions())[0] + objects_failures += failures / len(trajectories) + + return objects_failures / len(objects), len(sequence) + +@analysis_registry.register("cumulative_failures") +class CumulativeFailureCount(SequenceAggregator): + """Count the number of failures in a sequence. A failure is defined as a region that overlaps with a Special.FAILURE region.""" + + analysis = Include(FailureCount) + + def compatible(self, experiment: Experiment): + """Check if the experiment is compatible with the analysis.""" + return isinstance(experiment, SupervisedExperiment) + + def dependencies(self): + """Return the dependencies of the analysis.""" + return self.analysis, + + @property + def _title_default(self): + """Default title for the analysis.""" + return "Number of failures" + + def describe(self): + """Describe the analysis.""" + return Measure("Failures", "F", 0, None, Sorting.ASCENDING), + + def aggregate(self, _: Tracker, sequences: List[Sequence], results: Grid): + """Aggregate the analysis for a list of sequences. The aggregation is done by summing the number of failures for each sequence. + + Args: + sequences (List[Sequence]): The list of sequences to aggregate. + results (Grid): The results of the analysis for each sequence. + + Returns: + Tuple[Any]: The aggregated analysis. + """ + + failures = 0 + + for a in results: + failures = failures + a[0] + + return failures, diff --git a/vot/analysis/longterm.py b/vot/analysis/longterm.py new file mode 100644 index 0000000..a64bc40 --- /dev/null +++ b/vot/analysis/longterm.py @@ -0,0 +1,677 @@ +"""This module contains the implementation of the long term tracking performance measures.""" +import math +import numpy as np +from typing import List, Iterable, Tuple, Any +import itertools + +from attributee import Float, Integer, Boolean, Include + +from vot.tracker import Tracker +from vot.dataset import Sequence +from vot.region import Region, RegionType, calculate_overlaps +from vot.experiment import Experiment +from vot.experiment.multirun import UnsupervisedExperiment, MultiRunExperiment +from vot.analysis import SequenceAggregator, Analysis, SeparableAnalysis, \ + MissingResultsException, Measure, Sorting, Curve, Plot, SequenceAggregator, \ + Axes, analysis_registry, Point, is_special, Analysis +from vot.utilities.data import Grid + +def determine_thresholds(scores: Iterable[float], resolution: int) -> List[float]: + """Determine thresholds for a given set of scores and a resolution. + The thresholds are determined by sorting the scores and selecting the thresholds that divide the sorted scores into equal sized bins. + + Args: + scores (Iterable[float]): Scores to determine thresholds for. + resolution (int): Number of thresholds to determine. + + Returns: + List[float]: List of thresholds. + """ + scores = [score for score in scores if not math.isnan(score)] #and not score is None] + scores = sorted(scores, reverse=True) + + if len(scores) > resolution - 2: + delta = math.floor(len(scores) / (resolution - 2)) + idxs = np.round(np.linspace(delta, len(scores) - delta, num=resolution - 2)).astype(np.int) + thresholds = [scores[idx] for idx in idxs] + else: + thresholds = scores + + thresholds.insert(0, math.inf) + thresholds.insert(len(thresholds), -math.inf) + + return thresholds + +def compute_tpr_curves(trajectory: List[Region], confidence: List[float], sequence: Sequence, thresholds: List[float], + ignore_unknown: bool = True, bounded: bool = True): + """Compute the TPR curves for a given trajectory and confidence scores. + + Args: + trajectory (List[Region]): Trajectory to compute the TPR curves for. + confidence (List[float]): Confidence scores for the trajectory. + sequence (Sequence): Sequence to compute the TPR curves for. + thresholds (List[float]): Thresholds to compute the TPR curves for. + ignore_unknown (bool, optional): Ignore unknown regions. Defaults to True. + bounded (bool, optional): Bounded evaluation. Defaults to True. + + Returns: + List[float], List[float]: TPR curves for the given thresholds. + """ + + overlaps = np.array(calculate_overlaps(trajectory, sequence.groundtruth(), (sequence.size) if bounded else None)) + confidence = np.array(confidence) + + n_visible = len([region for region in sequence.groundtruth() if region.type is not RegionType.SPECIAL]) + + precision = len(thresholds) * [float(0)] + recall = len(thresholds) * [float(0)] + + for i, threshold in enumerate(thresholds): + + subset = confidence >= threshold + + if np.sum(subset) == 0: + precision[i] = 1 + recall[i] = 0 + else: + precision[i] = np.mean(overlaps[subset]) + recall[i] = np.sum(overlaps[subset]) / n_visible + + return precision, recall + +class _ConfidenceScores(SeparableAnalysis): + """Computes the confidence scores for a tracker for given sequences. This is internal analysis and should not be used directly.""" + + @property + def _title_default(self): + """Title of the analysis.""" + return "Aggregate confidence scores" + + def describe(self): + """Describes the analysis.""" + return None, + + def compatible(self, experiment: Experiment): + """Checks if the experiment is compatible with the analysis. """ + return isinstance(experiment, UnsupervisedExperiment) + + def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: + """Computes the confidence scores for a tracker for given sequences. + + Args: + experiment (Experiment): Experiment to compute the confidence scores for. + tracker (Tracker): Tracker to compute the confidence scores for. + sequence (Sequence): Sequence to compute the confidence scores for. + dependencies (List[Grid]): Dependencies of the analysis. + + Returns: + Tuple[Any]: Confidence scores for the given sequence. + """ + + scores_all = [] + trajectories = experiment.gather(tracker, sequence) + + if len(trajectories) == 0: + raise MissingResultsException("Missing results for sequence {}".format(sequence.name)) + + for trajectory in trajectories: + confidence = [trajectory.properties(i).get('confidence', 0) for i in range(len(trajectory))] + scores_all.extend(confidence) + + return scores_all, + +class _Thresholds(SequenceAggregator): + """Computes the thresholds for a tracker for given sequences. This is internal analysis and should not be used directly.""" + + resolution = Integer(default=100) + + @property + def _title_default(self): + """Title of the analysis.""" + return "Thresholds for tracking precision/recall" + + def describe(self): + """Describes the analysis.""" + return None, + + def compatible(self, experiment: Experiment): + """Checks if the experiment is compatible with the analysis. """ + return isinstance(experiment, UnsupervisedExperiment) + + def dependencies(self): + """Dependencies of the analysis.""" + return _ConfidenceScores(), + + def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]: + """Computes the thresholds for a tracker for given sequences. + + Args: + tracker (Tracker): Tracker to compute the thresholds for. + sequences (List[Sequence]): Sequences to compute the thresholds for. + results (Grid): Results of the dependencies. + + Returns: + Tuple[Any]: Thresholds for the given sequences.""" + + thresholds = determine_thresholds(itertools.chain(*[result[0] for result in results]), self.resolution), + + return thresholds, + +@analysis_registry.register("pr_curves") +class PrecisionRecallCurves(SeparableAnalysis): + """ Computes the precision/recall curves for a tracker for given sequences. """ + + thresholds = Include(_Thresholds) + ignore_unknown = Boolean(default=True, description="Ignore unknown regions") + bounded = Boolean(default=True, description="Bounded evaluation") + + @property + def _title_default(self): + """Title of the analysis.""" + return "Tracking precision/recall" + + def describe(self): + """Describes the analysis.""" + return Curve("Precision Recall curve", dimensions=2, abbreviation="PR", minimal=(0, 0), maximal=(1, 1), labels=("Recall", "Precision")), None + + def compatible(self, experiment: Experiment): + """Checks if the experiment is compatible with the analysis.""" + return isinstance(experiment, UnsupervisedExperiment) + + def dependencies(self): + """Dependencies of the analysis.""" + return self.thresholds, + + def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: + """Computes the precision/recall curves for a tracker for given sequences. + + Args: + experiment (Experiment): Experiment to compute the precision/recall curves for. + tracker (Tracker): Tracker to compute the precision/recall curves for. + sequence (Sequence): Sequence to compute the precision/recall curves for. + dependencies (List[Grid]): Dependencies of the analysis. + + Returns: + Tuple[Any]: Precision/recall curves for the given sequence. + """ + + thresholds = dependencies[0, 0][0][0] # dependencies[0][0, 0] + + trajectories = experiment.gather(tracker, sequence) + + if len(trajectories) == 0: + raise MissingResultsException() + + precision = len(thresholds) * [float(0)] + recall = len(thresholds) * [float(0)] + for trajectory in trajectories: + confidence = [trajectory.properties(i).get('confidence', 0) for i in range(len(trajectory))] + pr, re = compute_tpr_curves(trajectory.regions(), confidence, sequence, thresholds, self.ignore_unknown, self.bounded) + for i in range(len(thresholds)): + precision[i] += pr[i] + recall[i] += re[i] + +# return [(re / len(trajectories), pr / len(trajectories)) for pr, re in zip(precision, recall)], thresholds + return [(pr / len(trajectories), re / len(trajectories)) for pr, re in zip(precision, recall)], thresholds + +@analysis_registry.register("pr_curve") +class PrecisionRecallCurve(SequenceAggregator): + """ Computes the average precision/recall curve for a tracker. """ + + curves = Include(PrecisionRecallCurves) + + @property + def _title_default(self): + """Title of the analysis.""" + return "Tracking precision/recall average curve" + + def describe(self): + """Describes the analysis.""" + return self.curves.describe() + + def compatible(self, experiment: Experiment): + """Checks if the experiment is compatible with the analysis. This analysis is compatible with unsupervised experiments.""" + return isinstance(experiment, UnsupervisedExperiment) + + def dependencies(self): + """Dependencies of the analysis.""" + return self.curves, + + def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]: + """Computes the average precision/recall curve for a tracker. + + Args: + tracker (Tracker): Tracker to compute the average precision/recall curve for. + sequences (List[Sequence]): Sequences to compute the average precision/recall curve for. + results (Grid): Results of the dependencies. + + Returns: + Tuple[Any]: Average precision/recall curve for the given sequences. + """ + + curve = None + thresholds = None + + for partial, thresholds in results: + if curve is None: + curve = partial + continue + + curve = [(pr1 + pr2, re1 + re2) for (pr1, re1), (pr2, re2) in zip(curve, partial)] + + curve = [(re / len(results), pr / len(results)) for pr, re in curve] + + return curve, thresholds + + +@analysis_registry.register("f_curve") +class FScoreCurve(Analysis): + """ Computes the F-score curve for a tracker. """ + + beta = Float(default=1, description="Beta value for the F-score") + prcurve = Include(PrecisionRecallCurve) + + @property + def _title_default(self): + """Title of the analysis.""" + return "Tracking precision/recall" + + def describe(self): + """Describes the analysis.""" + return Plot("Tracking F-score curve", "F", wrt="normalized threshold", minimal=0, maximal=1), None + + def compatible(self, experiment: Experiment): + """Checks if the experiment is compatible with the analysis. This analysis is compatible with unsupervised experiments.""" + return isinstance(experiment, UnsupervisedExperiment) + + def dependencies(self): + """Dependencies of the analysis.""" + return self.prcurve, + + def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid: + """Computes the F-score curve for a tracker. + + Args: + experiment (Experiment): Experiment to compute the F-score curve for. + trackers (List[Tracker]): Trackers to compute the F-score curve for. + sequences (List[Sequence]): Sequences to compute the F-score curve for. + dependencies (List[Grid]): Dependencies of the analysis. + + Returns: + Grid: F-score curve for the given sequences. + """ + + processed_results = Grid(len(trackers), 1) + + for i, result in enumerate(dependencies[0]): + beta2 = (self.beta * self.beta) + f_curve = [((1 + beta2) * pr_ * re_) / (beta2 * pr_ + re_) for pr_, re_ in result[0]] + + processed_results[i, 0] = (f_curve, result[0][1]) + + return processed_results + + @property + def axes(self): + """Axes of the analysis.""" + return Axes.TRACKERS + +@analysis_registry.register("average_tpr") +class PrecisionRecall(Analysis): + """ Computes the average precision/recall for a tracker. """ + + prcurve = Include(PrecisionRecallCurve) + fcurve = Include(FScoreCurve) + + @property + def _title_default(self): + """Title of the analysis.""" + return "Tracking precision/recall" + + def describe(self): + """Describes the analysis.""" + return Measure("Precision", "Pr", minimal=0, maximal=1, direction=Sorting.DESCENDING), \ + Measure("Recall", "Re", minimal=0, maximal=1, direction=Sorting.DESCENDING), \ + Measure("F Score", "F", minimal=0, maximal=1, direction=Sorting.DESCENDING) + + def compatible(self, experiment: Experiment): + """Checks if the experiment is compatible with the analysis. This analysis is compatible with unsupervised experiments.""" + return isinstance(experiment, UnsupervisedExperiment) + + def dependencies(self): + """Dependencies of the analysis.""" + return self.prcurve, self.fcurve + + def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid: + """Computes the average precision/recall for a tracker. + + Args: + experiment (Experiment): Experiment to compute the average precision/recall for. + trackers (List[Tracker]): Trackers to compute the average precision/recall for. + sequences (List[Sequence]): Sequences to compute the average precision/recall for. + dependencies (List[Grid]): Dependencies of the analysis. + + Returns: + Grid: Average precision/recall for the given sequences. + """ + + f_curves = dependencies[1] + pr_curves = dependencies[0] + + joined = Grid(len(trackers), 1) + + for i, (f_curve, pr_curve) in enumerate(zip(f_curves, pr_curves)): + # get optimal F-score and Pr and Re at this threshold + f_score = max(f_curve[0]) + best_i = f_curve[0].index(f_score) + re_score = pr_curve[0][best_i][0] + pr_score = pr_curve[0][best_i][1] + joined[i, 0] = (pr_score, re_score, f_score) + + return joined + + @property + def axes(self): + """Axes of the analysis.""" + return Axes.TRACKERS + + +def count_frames(trajectory: List[Region], groundtruth: List[Region], bounds = None, threshold: float = 0) -> float: + """Counts the number of frames where the tracker is correct, fails, misses, hallucinates or notices an object. + + Args: + trajectory (List[Region]): Trajectory of the tracker. + groundtruth (List[Region]): Groundtruth trajectory. + bounds (Optional[Region]): Bounds of the sequence. + threshold (float): Threshold for the overlap. + + Returns: + float: Number of frames where the tracker is correct, fails, misses, hallucinates or notices an object. + """ + + overlaps = np.array(calculate_overlaps(trajectory, groundtruth, bounds)) + if threshold is None: threshold = -1 + + # Tracking, Failure, Miss, Halucination, Notice + T, F, M, H, N = 0, 0, 0, 0, 0 + + for i, (region_tr, region_gt) in enumerate(zip(trajectory, groundtruth)): + if (is_special(region_gt, Sequence.UNKNOWN)): + continue + if region_gt.is_empty(): + if region_tr.is_empty(): + N += 1 + else: + H += 1 + else: + if overlaps[i] > threshold: + T += 1 + else: + if region_tr.is_empty(): + M += 1 + else: + F += 1 + + return T, F, M, H, N + +class CountFrames(SeparableAnalysis): + """Counts the number of frames where the tracker is correct, fails, misses, hallucinates or notices an object.""" + + threshold = Float(default=0.0, val_min=0, val_max=1) + bounded = Boolean(default=True) + + def compatible(self, experiment: Experiment): + """Checks if the experiment is compatible with the analysis. This analysis is compatible with multi-run experiments.""" + return isinstance(experiment, MultiRunExperiment) + + def describe(self): + """Describes the analysis.""" + return None, + + def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: + """Computes the number of frames where the tracker is correct, fails, misses, hallucinates or notices an object.""" + + assert isinstance(experiment, MultiRunExperiment) + + objects = sequence.objects() + distribution = [] + bounds = (sequence.size) if self.bounded else None + + for object in objects: + trajectories = experiment.gather(tracker, sequence, objects=[object]) + if len(trajectories) == 0: + raise MissingResultsException() + + CN, CF, CM, CH, CT = 0, 0, 0, 0, 0 + + for trajectory in trajectories: + T, F, M, H, N = count_frames(trajectory.regions(), sequence.object(object), bounds=bounds) + CN += N + CF += F + CM += M + CH += H + CT += T + CN /= len(trajectories) + CF /= len(trajectories) + CM /= len(trajectories) + CH /= len(trajectories) + CT /= len(trajectories) + + distribution.append((CT, CF, CM, CH, CN)) + + return distribution, + + +@analysis_registry.register("quality_auxiliary") +class QualityAuxiliary(SeparableAnalysis): + """Computes the non-reported error, drift-rate error and absence-detection quality.""" + + threshold = Float(default=0.0, val_min=0, val_max=1) + bounded = Boolean(default=True) + absence_threshold = Integer(default=10, val_min=0) + + def compatible(self, experiment: Experiment): + """Checks if the experiment is compatible with the analysis. This analysis is compatible with multi-run experiments.""" + return isinstance(experiment, MultiRunExperiment) + + @property + def _title_default(self): + """Default title of the analysis.""" + return "Quality Auxiliary" + + def describe(self): + """Describes the analysis.""" + return Measure("Non-reported Error", "NRE", 0, 1, Sorting.DESCENDING), \ + Measure("Drift-rate Error", "DRE", 0, 1, Sorting.DESCENDING), \ + Measure("Absence-detection Quality", "ADQ", 0, 1, Sorting.DESCENDING), + + def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: + """Computes the non-reported error, drift-rate error and absence-detection quality. + + Args: + experiment (Experiment): Experiment. + tracker (Tracker): Tracker. + sequence (Sequence): Sequence. + dependencies (List[Grid]): Dependencies. + + Returns: + Tuple[Any]: Non-reported error, drift-rate error and absence-detection quality. + + """ + + assert isinstance(experiment, MultiRunExperiment) + + not_reported_error = 0 + drift_rate_error = 0 + absence_detection = 0 + + objects = sequence.objects() + bounds = (sequence.size) if self.bounded else None + + absence_valid = 0 + + for object in objects: + trajectories = experiment.gather(tracker, sequence, objects=[object]) + if len(trajectories) == 0: + raise MissingResultsException() + + CN, CF, CM, CH, CT = 0, 0, 0, 0, 0 + + for trajectory in trajectories: + T, F, M, H, N = count_frames(trajectory.regions(), sequence.object(object), bounds=bounds) + CN += N + CF += F + CM += M + CH += H + CT += T + CN /= len(trajectories) + CF /= len(trajectories) + CM /= len(trajectories) + CH /= len(trajectories) + CT /= len(trajectories) + + not_reported_error += CM / (CT + CF + CM) + drift_rate_error += CF / (CT + CF + CM) + + if CN + CH > self.absence_threshold: + absence_detection += CN / (CN + CH) + absence_valid += 1 + + if absence_valid > 0: + absence_detection /= absence_valid + else: + absence_detection = None + + return not_reported_error / len(objects), drift_rate_error / len(objects), absence_detection, + + +@analysis_registry.register("average_quality_auxiliary") +class AverageQualityAuxiliary(SequenceAggregator): + """Computes the average non-reported error, drift-rate error and absence-detection quality.""" + + analysis = Include(QualityAuxiliary) + + @property + def _title_default(self): + """Default title of the analysis.""" + return "Quality Auxiliary" + + def dependencies(self): + """Returns the dependencies of the analysis.""" + return self.analysis, + + def describe(self): + """Describes the analysis.""" + return Measure("Non-reported Error", "NRE", 0, 1, Sorting.DESCENDING), \ + Measure("Drift-rate Error", "DRE", 0, 1, Sorting.DESCENDING), \ + Measure("Absence-detection Quality", "ADQ", 0, 1, Sorting.DESCENDING), + + def compatible(self, experiment: Experiment): + """Checks if the experiment is compatible with the analysis. This analysis is compatible with multi-run experiments.""" + return isinstance(experiment, MultiRunExperiment) + + def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid): + """Aggregates the non-reported error, drift-rate error and absence-detection quality. + + Args: + tracker (Tracker): Tracker. + sequences (List[Sequence]): Sequences. + results (Grid): Results. + + Returns: + Tuple[Any]: Non-reported error, drift-rate error and absence-detection quality. + """ + + not_reported_error = 0 + drift_rate_error = 0 + absence_detection = 0 + absence_count = 0 + + for nre, dre, ad in results: + not_reported_error += nre + drift_rate_error += dre + if ad is not None: + absence_count += 1 + absence_detection += ad + + if absence_count > 0: + absence_detection /= absence_count + + return not_reported_error / len(sequences), drift_rate_error / len(sequences), absence_detection + +from vot.analysis import SequenceAggregator +from vot.analysis.accuracy import SequenceAccuracy + +@analysis_registry.register("longterm_ar") +class AccuracyRobustness(Analysis): + """Longterm multi-object accuracy-robustness measure. """ + + threshold = Float(default=0.0, val_min=0, val_max=1) + bounded = Boolean(default=True) + counts = Include(CountFrames) + + def dependencies(self) -> List[Analysis]: + """Returns the dependencies of the analysis.""" + return self.counts, SequenceAccuracy(burnin=0, threshold=self.threshold, bounded=self.bounded, ignore_invisible=True, ignore_unknown=False) + + def compatible(self, experiment: Experiment): + """Checks if the experiment is compatible with the analysis. This analysis is compatible with multi-run experiments.""" + return isinstance(experiment, MultiRunExperiment) + + @property + def _title_default(self): + """Default title of the analysis.""" + return "Accuracy-robustness" + + def describe(self): + """Describes the analysis.""" + return Measure("Accuracy", "A", minimal=0, maximal=1, direction=Sorting.DESCENDING), \ + Measure("Robustness", "R", minimal=0, direction=Sorting.DESCENDING), \ + Point("AR plot", dimensions=2, abbreviation="AR", minimal=(0, 0), \ + maximal=(1, 1), labels=("Robustness", "Accuracy"), trait="ar") + + def compute(self, _: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid: + """Aggregate results from multiple sequences into a single value. + + Args: + experiment (Experiment): Experiment. + trackers (List[Tracker]): Trackers. + sequences (List[Sequence]): Sequences. + dependencies (List[Grid]): Dependencies. + + Returns: + Grid: Aggregated results. + """ + + frame_counts = dependencies[0] + accuracy_analysis = dependencies[1] + + results = Grid(len(trackers), 1) + + for j, _ in enumerate(trackers): + accuracy = 0 + robustness = 0 + count = 0 + + for i, _ in enumerate(sequences): + if accuracy_analysis[j, i] is None: + continue + + accuracy += accuracy_analysis[j, i][0] + + frame_counts_sequence = frame_counts[j, i][0] + + objects = len(frame_counts_sequence) + for o in range(objects): + robustness += (1/objects) * frame_counts_sequence[o][0] / (frame_counts_sequence[o][0] + frame_counts_sequence[o][1] + frame_counts_sequence[o][2]) + + count += 1 + + results[j, 0] = (accuracy / count, robustness / count, (robustness / count, accuracy / count)) + + return results + + @property + def axes(self) -> Axes: + """Returns the axes of the analysis.""" + return Axes.TRACKERS \ No newline at end of file diff --git a/vot/analysis/multistart.py b/vot/analysis/multistart.py index e03e5ff..21a1aeb 100644 --- a/vot/analysis/multistart.py +++ b/vot/analysis/multistart.py @@ -1,4 +1,5 @@ -import math +"""This module contains the implementation of the accuracy-robustness analysis and EAO analysis for the multistart experiment.""" + from typing import List, Tuple, Any import numpy as np @@ -16,6 +17,16 @@ from vot.utilities.data import Grid def compute_eao_partial(overlaps: List, success: List[bool], curve_length: int): + """Compute the EAO curve for a single sequence. The curve is computed as the average overlap at each frame. + + Args: + overlaps (List): List of overlaps for each frame. + success (List[bool]): List of success flags for each frame. + curve_length (int): Length of the curve. + + Returns: + List[float]: EAO curve. + """ phi = curve_length * [float(0)] active = curve_length * [float(0)] @@ -38,6 +49,7 @@ def compute_eao_partial(overlaps: List, success: List[bool], curve_length: int): @analysis_registry.register("multistart_ar") class AccuracyRobustness(SeparableAnalysis): + """This analysis computes the accuracy-robustness curve for the multistart experiment.""" burnin = Integer(default=10, val_min=0) grace = Integer(default=10, val_min=0) @@ -45,10 +57,12 @@ class AccuracyRobustness(SeparableAnalysis): threshold = Float(default=0.1, val_min=0, val_max=1) @property - def title(self): + def _title_default(self): + """Title of the analysis.""" return "AR Analysis" def describe(self): + """Return the description of the analysis.""" return Measure("Accuracy", "A", minimal=0, maximal=1, direction=Sorting.DESCENDING), \ Measure("Robustness", "R", minimal=0, direction=Sorting.DESCENDING), \ Point("AR plot", dimensions=2, abbreviation="AR", @@ -56,9 +70,21 @@ def describe(self): None, None def compatible(self, experiment: Experiment): + """Check if the experiment is compatible with the analysis. The experiment must be a multistart experiment.""" return isinstance(experiment, MultiStartExperiment) def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: + """Compute the accuracy-robustness for each sequence. + + Args: + experiment (Experiment): Experiment. + tracker (Tracker): Tracker. + sequence (Sequence): Sequence. + dependencies (List[Grid]): List of dependencies. + + Returns: + Tuple[Any]: Accuracy, robustness, AR curve, robustness, length of the sequence. + """ results = experiment.results(tracker, sequence) @@ -79,7 +105,7 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc if reverse: proxy = FrameMapSequence(sequence, list(reversed(range(0, i + 1)))) else: - proxy = FrameMapSequence(sequence, list(range(i, sequence.length))) + proxy = FrameMapSequence(sequence, list(range(i, len(sequence)))) trajectory = Trajectory.read(results, name) @@ -107,17 +133,21 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc @analysis_registry.register("multistart_average_ar") class AverageAccuracyRobustness(SequenceAggregator): + """This analysis computes the average accuracy-robustness curve for the multistart experiment.""" analysis = Include(AccuracyRobustness) @property - def title(self): + def _title_default(self): + """Title of the analysis.""" return "AR Analysis" def dependencies(self): + """Return the dependencies of the analysis.""" return self.analysis, def describe(self): + """Return the description of the analysis.""" return Measure("Accuracy", "A", minimal=0, maximal=1, direction=Sorting.DESCENDING), \ Measure("Robustness", "R", minimal=0, direction=Sorting.DESCENDING), \ Point("AR plot", dimensions=2, abbreviation="AR", @@ -125,9 +155,20 @@ def describe(self): None, None def compatible(self, experiment: Experiment): + """Check if the experiment is compatible with the analysis. The experiment must be a multistart experiment.""" return isinstance(experiment, MultiStartExperiment) def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid): + """Aggregate the results of the analysis. + + Args: + tracker (Tracker): Tracker. + sequences (List[Sequence]): List of sequences. + results (Grid): Grid of results. + + Returns: + Tuple[Any]: Aggregated results. + """ total_accuracy = 0 total_robustness = 0 weight_accuracy = 0 @@ -145,6 +186,7 @@ def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid): @analysis_registry.register("multistart_fragments") class MultiStartFragments(SeparableAnalysis): + """This analysis computes the accuracy-robustness curve for the multistart experiment.""" burnin = Integer(default=10, val_min=0) grace = Integer(default=10, val_min=0) @@ -152,16 +194,29 @@ class MultiStartFragments(SeparableAnalysis): threshold = Float(default=0.1, val_min=0, val_max=1) @property - def title(self): + def _title_default(self): + """Title of the analysis.""" return "Fragment Analysis" def describe(self): + """Return the description of the analysis.""" return Curve("Success", 2, "Sc", minimal=(0, 0), maximal=(1,1), trait="points"), Curve("Accuracy", 2, "Ac", minimal=(0, 0), maximal=(1,1), trait="points") def compatible(self, experiment: Experiment): + """Check if the experiment is compatible with the analysis. The experiment must be a multistart experiment.""" return isinstance(experiment, MultiStartExperiment) def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: + """Compute the analysis for a single sequence. The sequence must contain at least one anchor. + + Args: + experiment (Experiment): Experiment. + tracker (Tracker): Tracker. + sequence (Sequence): Sequence. + dependencies (List[Grid]): List of dependencies. + + Returns: + Tuple[Any]: Results of the analysis.""" results = experiment.results(tracker, sequence) @@ -182,7 +237,7 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc if reverse: proxy = FrameMapSequence(sequence, list(reversed(range(0, i + 1)))) else: - proxy = FrameMapSequence(sequence, list(range(i, sequence.length))) + proxy = FrameMapSequence(sequence, list(range(i, len(sequence)))) trajectory = Trajectory.read(results, name) @@ -208,6 +263,7 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc # TODO: remove high @analysis_registry.register("multistart_eao_curves") class EAOCurves(SeparableAnalysis): + """This analysis computes the expected average overlap curve for the multistart experiment.""" burnin = Integer(default=10, val_min=0) grace = Integer(default=10, val_min=0) @@ -217,17 +273,31 @@ class EAOCurves(SeparableAnalysis): high = Integer() @property - def title(self): + def _title_default(self): + """Title of the analysis.""" return "EAO Curve" def describe(self): + """Return the description of the analysis.""" return Plot("Expected average overlap", "EAO", minimal=0, maximal=1, wrt="frames", trait="eao"), def compatible(self, experiment: Experiment): + """Check if the experiment is compatible with the analysis. The experiment must be a multistart experiment.""" return isinstance(experiment, MultiStartExperiment) def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: - + """Compute the analysis for a single sequence. The sequence must contain at least one anchor. + + Args: + experiment (Experiment): Experiment. + tracker (Tracker): Tracker. + sequence (Sequence): Sequence. + dependencies (List[Grid]): List of dependencies. + + Returns: + Tuple[Any]: Results of the analysis. + """ + results = experiment.results(tracker, sequence) forward, backward = find_anchors(sequence, experiment.anchor) @@ -247,7 +317,7 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc if reverse: proxy = FrameMapSequence(sequence, list(reversed(range(0, i + 1)))) else: - proxy = FrameMapSequence(sequence, list(range(i, sequence.length))) + proxy = FrameMapSequence(sequence, list(range(i, len(sequence)))) trajectory = Trajectory.read(results, name) @@ -279,23 +349,38 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc #TODO: remove high @analysis_registry.register("multistart_eao_curve") class EAOCurve(SequenceAggregator): + """This analysis computes the expected average overlap curve for the multistart experiment. It is an aggregator of the curves for individual sequences.""" curves = Include(EAOCurves) @property - def title(self): + def _title_default(self): + """Title of the analysis.""" return "EAO Curve" def describe(self): + """Return the description of the analysis.""" return Plot("Expected average overlap", "EAO", minimal=0, maximal=1, wrt="frames", trait="eao"), def compatible(self, experiment: Experiment): + """Check if the experiment is compatible with the analysis. The experiment must be a multistart experiment.""" return isinstance(experiment, MultiStartExperiment) def dependencies(self): + """Return the dependencies of the analysis.""" return self.curves, def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]: + """Aggregate the results of the analysis for multiple sequences. The sequences must contain at least one anchor. + + Args: + tracker (Tracker): Tracker. + sequences (List[Sequence]): List of sequences. + results (Grid): Grid of results. + + Returns: + Tuple[Any]: Results of the analysis. + """ eao_curve = self.curves.high * [float(0)] eao_weights = self.curves.high * [float(0)] @@ -309,29 +394,46 @@ def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) @analysis_registry.register("multistart_eao_score") class EAOScore(Analysis): + """This analysis computes the expected average overlap score for the multistart experiment. It does this by computing the EAO curve and then integrating it.""" low = Integer() high = Integer() eaocurve = Include(EAOCurve) @property - def title(self): + def _title_default(self): + """Title of the analysis.""" return "EAO analysis" def describe(self): + """Return the description of the analysis.""" return Measure("Expected average overlap", "EAO", minimal=0, maximal=1, direction=Sorting.DESCENDING), def compatible(self, experiment: Experiment): + """Check if the experiment is compatible with the analysis. The experiment must be a multistart experiment.""" return isinstance(experiment, MultiStartExperiment) def dependencies(self): + """Return the dependencies of the analysis.""" return self.eaocurve, def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid: + """Compute the analysis for multiple sequences. The sequences must contain at least one anchor. + + Args: + experiment (Experiment): Experiment. + trackers (List[Tracker]): List of trackers. + sequences (List[Sequence]): List of sequences. + dependencies (List[Grid]): List of dependencies. + + Returns: + Grid: Grid of results. + """ return dependencies[0].foreach(lambda x, i, j: (float(np.mean(x[0][self.low:self.high + 1])), ) ) @property def axes(self): + """Return the axes of the analysis.""" return Axes.TRACKERS diff --git a/vot/analysis/_processor.py b/vot/analysis/processor.py similarity index 66% rename from vot/analysis/_processor.py rename to vot/analysis/processor.py index e7794e0..1015a45 100644 --- a/vot/analysis/_processor.py +++ b/vot/analysis/processor.py @@ -1,7 +1,13 @@ +"""This module contains the implementation of the analysis processor. The processor is responsible for executing the analysis tasks in parallel and caching the results.""" import logging +import sys import threading -from collections import Iterable, OrderedDict, namedtuple +from collections import OrderedDict, namedtuple +if sys.version_info >= (3, 3): + from collections.abc import Iterable +else: + from collections import Iterable from functools import partial from typing import List, Union, Mapping, Tuple, Any from concurrent.futures import Executor, Future, ThreadPoolExecutor @@ -22,7 +28,9 @@ logger = logging.getLogger("vot") def hashkey(analysis: Analysis, *args): + """Compute a hash key for the analysis and its arguments. The key is used for caching the results.""" def transform(arg): + """Transform an argument into a hashable object.""" if isinstance(arg, Sequence): return arg.name if isinstance(arg, Tracker): @@ -37,13 +45,24 @@ def transform(arg): return (analysis.identifier, *[transform(arg) for arg in args]) def unwrap(arg): + """Unwrap a single element list.""" + if isinstance(arg, list) and len(arg) == 1: return arg[0] else: return arg class AnalysisError(ToolkitException): + """An exception that is raised when an analysis fails.""" + def __init__(self, cause, task=None): + """Creates an analysis error. + + Args: + cause (Exception): The cause of the error. + task (AnalysisTask, optional): The task that caused the error. Defaults to None. + + """ self._tasks = [] self._cause = cause super().__init__(cause, task) @@ -51,12 +70,15 @@ def __init__(self, cause, task=None): @property def task(self): + """The task that caused the error.""" return self._tasks[-1] def __str__(self): + """String representation of the error.""" return "Error during analysis {}".format(self.task) def print(self, logoutput): + """Print the error to the log output.""" logoutput.error(str(self)) if len(self._tasks) > 1: for task in reversed(self._tasks[:-1]): @@ -65,6 +87,7 @@ def print(self, logoutput): @property def root_cause(self): + """The root cause of the error.""" cause = self._cause if cause is None: return None @@ -94,6 +117,8 @@ def __init__(self, strict=True): self._thread.start() def submit(self, fn, *args, **kwargs): + """Submits a task to the executor.""" + promise = Future() with self._lock: self._queue.put(DebugExecutor.Task(fn, args, kwargs, promise)) @@ -102,6 +127,7 @@ def submit(self, fn, *args, **kwargs): return promise def _run(self): + """The main loop of the executor.""" while True: @@ -130,7 +156,7 @@ def _run(self): except TypeError as e: - logger.debug("Task %s call resulted in error: %s", task.fn, e) + logger.info("Task %s call resulted in error: %s", task.fn, e) error = e @@ -138,7 +164,11 @@ def _run(self): error = e - logger.debug("Task %s resulted in exception: %s", task.fn, e) + logger.info("Task %s resulted in exception: %s", task.fn, e) + if logger.isEnabledFor(logging.DEBUG): + logger.exception(e) + + logger.exception(e) if error is not None: task.promise.set_exception(error) @@ -149,6 +179,7 @@ def _run(self): break def _clear(self): + """Clears the queue.""" with self._lock: while True: @@ -160,6 +191,12 @@ def _clear(self): break def shutdown(self, wait=True): + """Shuts down the executor. If wait is True, the method blocks until all tasks are completed. + + Args: + wait (bool, optional): Wait for all tasks to complete. Defaults to True. + """ + self._alive = False self._clear() if wait: @@ -169,8 +206,14 @@ def shutdown(self, wait=True): self._thread.join() class ExecutorWrapper(object): + """A wrapper for an executor that allows to submit tasks with dependencies.""" def __init__(self, executor: Executor): + """Creates an executor wrapper. + + Args: + executor (Executor): The executor to wrap. + """ self._lock = RLock() self._executor = executor self._pending = OrderedDict() @@ -178,9 +221,20 @@ def __init__(self, executor: Executor): @property def total(self): + """The total number of tasks submitted to the executor.""" return self._total def submit(self, fn, *futures: Tuple[Future], mapping=None) -> Future: + """Submits a task to the executor. The task will be executed when all futures are completed. + + Args: + fn (Callable): The task to execute. + futures (Tuple[Future]): The futures that must be completed before the task is executed. + mapping (Dict[Future, Any], optional): A mapping of futures to values. Defaults to None. + + Returns: + Future: A future that will be completed when the task is completed. + """ with self._lock: @@ -200,7 +254,7 @@ def submit(self, fn, *futures: Tuple[Future], mapping=None) -> Future: return proxy def _ready_callback(self, fn, mapping, proxy: Future, future: Future): - """ Internally handles completion of dependencies + """ Internally handles completion of dependencies. Submits the task to the executor. """ with self._lock: @@ -229,9 +283,10 @@ def _ready_callback(self, fn, mapping, proxy: Future, future: Future): def _done_callback(self, proxy: Future, future: Future): """ Internally handles completion of executor future, copies result to proxy + Args: - fn (function): [description] - future (Future): [description] + proxy (Future): Proxy future + future (Future): Executor future """ if future.cancelled(): @@ -246,6 +301,10 @@ def _done_callback(self, proxy: Future, future: Future): def _proxy_done(self, future: Future): """ Internally handles events for proxy futures, this means handling cancellation. + + Args: + future (Future): Proxy future + """ with self._lock: @@ -261,8 +320,15 @@ def _proxy_done(self, future: Future): dependency.cancel() class FuturesAggregator(Future): + """A future that aggregates results from other futures.""" def __init__(self, *futures: Tuple[Future]): + """Initializes the aggregator. + + Args: + *futures (Tuple[Future]): The futures to aggregate. + """ + super().__init__() self._lock = RLock() self._results = [None] * len(futures) @@ -275,6 +341,8 @@ def __init__(self, *futures: Tuple[Future]): self.set_result([]) def _on_result(self, i, future): + """Handles completion of a dependency future.""" + with self._lock: if self.done(): return @@ -288,6 +356,8 @@ def _on_result(self, i, future): self.set_result(self._results) def _on_done(self, future): + """Handles completion of the future.""" + with self._lock: try: self.set_result(future.result()) @@ -295,6 +365,8 @@ def _on_done(self, future): self.set_exception(e) def cancel(self): + """Cancels the future and all dependencies.""" + with self._lock: for promise in self._tasks: promise.cancel() @@ -302,9 +374,19 @@ def cancel(self): class AnalysisTask(object): + """A task that computes an analysis.""" def __init__(self, analysis: Analysis, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence]): + """Initializes a new instance of the AnalysisTask class. + + Args: + analysis (Analysis): The analysis to compute. + experiment (Experiment): The experiment to compute the analysis for. + trackers (List[Tracker]): The trackers to compute the analysis for. + sequences (List[Sequence]): The sequences to compute the analysis for. + """ + self._analysis = analysis self._trackers = trackers self._experiment = experiment @@ -312,6 +394,15 @@ def __init__(self, analysis: Analysis, experiment: Experiment, self._key = hashkey(analysis, experiment, trackers, sequences) def __call__(self, dependencies: List[Grid] = None): + """Computes the analysis. + + Args: + dependencies (List[Grid], optional): The dependencies to use. Defaults to None. + + Returns: + Grid: The computed analysis. + """ + try: if dependencies is None: dependencies = [] @@ -320,9 +411,18 @@ def __call__(self, dependencies: List[Grid] = None): raise AnalysisError(cause=e, task=self._key) class AnalysisPartTask(object): + """A task that computes a part of a separable analysis.""" def __init__(self, analysis: SeparableAnalysis, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence]): + """Initializes a new instance of the AnalysisPartTask class. + + Args: + analysis (SeparableAnalysis): The analysis to compute. + experiment (Experiment): The experiment to compute the analysis for. + trackers (List[Tracker]): The trackers to compute the analysis for. + sequences (List[Sequence]): The sequences to compute the analysis for. + """ self._analysis = analysis self._trackers = trackers self._experiment = experiment @@ -330,6 +430,14 @@ def __init__(self, analysis: SeparableAnalysis, experiment: Experiment, self._key = hashkey(analysis, experiment, unwrap(trackers), unwrap(sequences)) def __call__(self, dependencies: List[Grid] = None): + """Computes the analysis. + + Args: + dependencies (List[Grid], optional): The dependencies to use. Defaults to None. + + Returns: + Grid: The computed analysis. + """ try: if dependencies is None: dependencies = [] @@ -338,9 +446,19 @@ def __call__(self, dependencies: List[Grid] = None): raise AnalysisError(cause=e, task=self._key) class AnalysisJoinTask(object): + """A task that joins the results of a separable analysis.""" def __init__(self, analysis: SeparableAnalysis, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence]): + + """Initializes a new instance of the AnalysisJoinTask class. + + Args: + analysis (Analysis): The analysis to join. + experiment (Experiment): The experiment to join the analysis for. + trackers (List[Tracker]): The trackers to join the analysis for. + sequences (List[Sequence]): The sequences to join the analysis for. + """ self._analysis = analysis self._trackers = trackers self._experiment = experiment @@ -348,29 +466,55 @@ def __init__(self, analysis: SeparableAnalysis, experiment: Experiment, self._key = hashkey(analysis, experiment, trackers, sequences) def __call__(self, results: List[Grid]): + """Joins the results of the analysis. + + Args: + results (List[Grid]): The results to join. + + Returns: + Grid: The joined analysis. + """ + try: return self._analysis.join(self._trackers, self._sequences, results) except BaseException as e: raise AnalysisError(cause=e, task=self._key) class AnalysisFuture(Future): + """A future that represents the result of an analysis.""" def __init__(self, key): + """Initializes a new instance of the AnalysisFuture class. + + Args: + key (str): The key of the analysis. + """ + super().__init__() self._key = key @property def key(self): + """Gets the key of the analysis.""" return self._key def __repr__(self) -> str: + """Gets a string representation of the future.""" return "".format(self._key) class AnalysisProcessor(object): + """A processor that computes analyses.""" _context = threading.local() def __init__(self, executor: Executor = None, cache: Cache = None): + """Initializes a new instance of the AnalysisProcessor class. + + Args: + executor (Executor, optional): The executor to use for computations. Defaults to None. + cache (Cache, optional): The cache to use for computations. Defaults to None. + + """ if executor is None: executor = ThreadPoolExecutor(1) @@ -383,6 +527,17 @@ def __init__(self, executor: Executor = None, cache: Cache = None): def commit(self, analysis: Analysis, experiment: Experiment, trackers: Union[Tracker, List[Tracker]], sequences: Union[Sequence, List[Sequence]]) -> Future: + """Commits an analysis for computation. If the analysis is already being computed, the existing future is returned. + + Args: + analysis (Analysis): The analysis to commit. + experiment (Experiment): The experiment to commit the analysis for. + trackers (Union[Tracker, List[Tracker]]): The trackers to commit the analysis for. + sequences (Union[Sequence, List[Sequence]]): The sequences to commit the analysis for. + + Returns: + Future: A future that represents the result of the analysis. + """ key = hashkey(analysis, experiment, trackers, sequences) @@ -401,6 +556,7 @@ def commit(self, analysis: Analysis, experiment: Experiment, if isinstance(analysis, SeparableAnalysis): def select_dependencies(analysis: SeparableAnalysis, tracker: int, sequence: int, *dependencies): + """Selects the dependencies for a part of a separable analysis.""" return [analysis.select(meta, data, tracker, sequence) for meta, data in zip(analysis.dependencies(), dependencies)] promise = AnalysisFuture(key) @@ -440,7 +596,16 @@ def select_dependencies(analysis: SeparableAnalysis, tracker: int, sequence: int return promise - def _exists(self, key): + def _exists(self, key: str): + """Checks if an analysis is already being computed. + + Args: + key (str): The key of the analysis to check. + + Returns: + AnalysisFuture: The future that represents the analysis if it is already being computed, None otherwise. + """ + if self._cache is not None and key in self._cache: promise = AnalysisFuture(key) promise.set_result(self._cache[key]) @@ -454,6 +619,12 @@ def _exists(self, key): return None def _future_done(self, future: Future): + """Handles the completion of a future. + + Args: + future (Future): The future that completed. + + """ with self._lock: @@ -489,6 +660,16 @@ def _future_done(self, future: Future): def _promise_cancelled(self, future: Future): + """Handles the cancellation of a promise. If the promise is the last promise for a computation, the computation is cancelled. + + Args: + future (Future): The promise that was cancelled. + + Returns: + bool: True if the promise was the last promise for a computation, False otherwise. + + """ + if not future.cancelled(): return @@ -508,20 +689,27 @@ def _promise_cancelled(self, future: Future): @property def pending(self): + """The number of pending analyses.""" + with self._lock: return len(self._pending) @property def total(self): + """The total number of analyses.""" + with self._lock: return self._executor.total def cancel(self): + """Cancels all pending analyses.""" + with self._lock: for _, future in list(self._pending.items()): future.cancel() def wait(self): + """Waits for all pending analyses to complete. If no analyses are pending, this method returns immediately.""" if self.total == 0: return @@ -542,6 +730,11 @@ def wait(self): progress.close() def __enter__(self): + """Sets this analysis processor as the default for the current thread. + + Returns: + AnalysisProcessor: This analysis processor. + """ processor = getattr(AnalysisProcessor._context, 'analysis_processor', None) @@ -557,6 +750,8 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): + """Clears the default analysis processor for the current thread.""" + processor = getattr(AnalysisProcessor._context, 'analysis_processor', None) if processor == self: @@ -565,6 +760,11 @@ def __exit__(self, exc_type, exc_value, traceback): @staticmethod def default(): + """Returns the default analysis processor for the current thread. + + Returns: + AnalysisProcessor: The default analysis processor for the current thread. + """ processor = getattr(AnalysisProcessor._context, 'analysis_processor', None) @@ -581,11 +781,23 @@ def default(): @staticmethod def commit_default(analysis: Analysis, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence]): + """Commits an analysis to the default analysis processor. This method is thread-safe. If the analysis is already being computed, this method returns immediately.""" processor = AnalysisProcessor.default() return processor.commit(analysis, experiment, trackers, sequences) def run(self, analysis: Analysis, experiment: Experiment, trackers: Union[Tracker, List[Tracker]], sequences: Union[Sequence, List[Sequence]]) -> Grid: + """Runs an analysis on a set of trackers and sequences. This method is thread-safe. If the analysis is already being computed, this method returns immediately. + + Args: + analysis (Analysis): The analysis to run. + experiment (Experiment): The experiment to run the analysis on. + trackers (Union[Tracker, List[Tracker]]): The trackers to run the analysis on. + sequences (Union[Sequence, List[Sequence]]): The sequences to run the analysis on. + + Returns: + Grid: The results of the analysis. + """ assert self.pending == 0 @@ -597,23 +809,49 @@ def run(self, analysis: Analysis, experiment: Experiment, @staticmethod def run_default(analysis: Analysis, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence]): + """Runs an analysis on a set of trackers and sequences. This method is thread-safe. If the analysis is already being computed, this method returns immediately. + + Args: + analysis (Analysis): The analysis to run. + experiment (Experiment): The experiment to run the analysis on. + trackers (List[Tracker]): The trackers to run the analysis on. + sequences (List[Sequence]): The sequences to run the analysis on. + + Returns: + Grid: The results of the analysis.""" + processor = AnalysisProcessor.default() return processor.run(analysis, experiment, trackers, sequences) def process_stack_analyses(workspace: "Workspace", trackers: List[Tracker]): + """Process all analyses in the workspace stack. This function is used by the command line interface to run all the analyses provided in a stack. + + Args: + workspace (Workspace): The workspace to process. + trackers (List[Tracker]): The trackers to run the analyses on. + + """ processor = AnalysisProcessor.default() results = dict() condition = Condition() + errors = [] def insert_result(container: dict, key): + """Creates a callback that inserts the result of a computation into a container. The container is a dictionary that maps analyses to their results. + + Args: + container (dict): The container to insert the result into. + key (Analysis): The analysis to insert the result for. + """ def insert(future: Future): + """Inserts the result of a computation into a container.""" try: container[key] = future.result() except AnalysisError as e: - e.print(logger) + errors.append(e) except Exception as e: logger.exception(e) with condition: @@ -628,7 +866,7 @@ def insert(future: Future): results[experiment] = experiment_results - sequences = [experiment.transform(sequence) for sequence in workspace.dataset] + sequences = experiment.transform(workspace.dataset) for analysis in experiment.analyses: @@ -665,4 +903,12 @@ def insert(future: Future): logger.info("Analysis interrupted by user, aborting.") return None + if len(errors) > 0: + logger.info("Errors occured during analysis, incomplete.") + for e in errors: + logger.info("Failed task {}: {}".format(e.task, e.root_cause)) + if logger.isEnabledFor(logging.DEBUG): + e.print(logger) + return None + return results \ No newline at end of file diff --git a/vot/analysis/supervised.py b/vot/analysis/supervised.py index 0077d45..93f89b6 100644 --- a/vot/analysis/supervised.py +++ b/vot/analysis/supervised.py @@ -11,28 +11,40 @@ from vot.tracker import Tracker, Trajectory from vot.dataset import Sequence -from vot.dataset.proxy import FrameMapSequence from vot.experiment import Experiment from vot.experiment.multirun import SupervisedExperiment -from vot.experiment.multistart import MultiStartExperiment, find_anchors -from vot.region import Region, Special, calculate_overlaps +from vot.region import Region, calculate_overlaps from vot.analysis import MissingResultsException, Measure, Point, is_special, Plot, Analysis, \ Sorting, SeparableAnalysis, SequenceAggregator, analysis_registry, TrackerSeparableAnalysis, Axes from vot.utilities.data import Grid def compute_accuracy(trajectory: List[Region], sequence: Sequence, burnin: int = 10, ignore_unknown: bool = True, bounded: bool = True) -> float: + """ Computes accuracy of a tracker on a given sequence. Accuracy is defined as mean overlap of the tracker + region with the groundtruth region. The overlap is computed only for frames where the tracker is not in + initialization or failure state. The overlap is computed only for frames after the burnin period. + + Args: + trajectory (List[Region]): Tracker trajectory. + sequence (Sequence): Sequence to compute accuracy on. + burnin (int, optional): Burnin period. Defaults to 10. + ignore_unknown (bool, optional): Ignore unknown regions. Defaults to True. + bounded (bool, optional): Consider only first N frames. Defaults to True. + + Returns: + float: Accuracy. + """ overlaps = np.array(calculate_overlaps(trajectory, sequence.groundtruth(), (sequence.size) if bounded else None)) mask = np.ones(len(overlaps), dtype=bool) for i, region in enumerate(trajectory): - if is_special(region, Special.UNKNOWN) and ignore_unknown: + if is_special(region, Trajectory.UNKNOWN) and ignore_unknown: mask[i] = False - elif is_special(region, Special.INITIALIZATION): + elif is_special(region, Trajectory.INITIALIZATION): for j in range(i, min(len(trajectory), i + burnin)): mask[j] = False - elif is_special(region, Special.FAILURE): + elif is_special(region, Trajectory.FAILURE): mask[i] = False if any(mask): @@ -41,14 +53,17 @@ def compute_accuracy(trajectory: List[Region], sequence: Sequence, burnin: int = return 0, 0 def count_failures(trajectory: List[Region]) -> Tuple[int, int]: - return len([region for region in trajectory if is_special(region, Special.FAILURE)]), len(trajectory) + """Counts number of failures in a trajectory. Failure is defined as a frame where the tracker is in failure state.""" + return len([region for region in trajectory if is_special(region, Trajectory.FAILURE)]), len(trajectory) def locate_failures_inits(trajectory: List[Region]) -> Tuple[int, int]: - return [i for i, region in enumerate(trajectory) if is_special(region, Special.FAILURE)], \ - [i for i, region in enumerate(trajectory) if is_special(region, Special.INITIALIZATION)] + """Locates failures and initializations in a trajectory. Failure is defined as a frame where the tracker is in failure state.""" + return [i for i, region in enumerate(trajectory) if is_special(region, Trajectory.FAILURE)], \ + [i for i, region in enumerate(trajectory) if is_special(region, Trajectory.INITIALIZATION)] def compute_eao_curve(overlaps: List, weights: List[float], success: List[bool]): + """Computes EAO curve from a list of overlaps, weights and success flags.""" max_length = max([len(el) for el in overlaps]) total_runs = len(overlaps) @@ -74,6 +89,11 @@ def compute_eao_curve(overlaps: List, weights: List[float], success: List[bool]) @analysis_registry.register("supervised_ar") class AccuracyRobustness(SeparableAnalysis): + """Accuracy-Robustness analysis. Computes accuracy and robustness of a tracker on a given sequence. + Accuracy is defined as mean overlap of the tracker region with the groundtruth region. The overlap is computed only for frames where the tracker is not in + initialization or failure state. The overlap is computed only for frames after the burnin period. + Robustness is defined as a number of failures divided by the total number of frames. + """ sensitivity = Float(default=30, val_min=1) burnin = Integer(default=10, val_min=0) @@ -81,10 +101,12 @@ class AccuracyRobustness(SeparableAnalysis): bounded = Boolean(default=True) @property - def title(self): + def _title_default(self): + """Returns title of the analysis.""" return "AR analysis" def describe(self): + """Returns description of the analysis.""" return Measure("Accuracy", "A", minimal=0, maximal=1, direction=Sorting.DESCENDING), \ Measure("Robustness", "R", minimal=0, direction=Sorting.ASCENDING), \ Point("AR plot", dimensions=2, abbreviation="AR", minimal=(0, 0), \ @@ -92,9 +114,21 @@ def describe(self): None def compatible(self, experiment: Experiment): + """Returns True if the analysis is compatible with the experiment. Only SupervisedExperiment is compatible.""" return isinstance(experiment, SupervisedExperiment) def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: + """Computes accuracy and robustness of a tracker on a given sequence. + + Args: + experiment (Experiment): Experiment. + tracker (Tracker): Tracker. + sequence (Sequence): Sequence. + dependencies (List[Grid]): Dependencies. + + Returns: + Tuple[Any]: Accuracy, robustness, AR, number of frames. + """ trajectories = experiment.gather(tracker, sequence) if len(trajectories) == 0: @@ -112,17 +146,27 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequenc @analysis_registry.register("supervised_average_ar") class AverageAccuracyRobustness(SequenceAggregator): + """Average accuracy-robustness analysis. Computes average accuracy and robustness of a tracker on a given sequence. + + Accuracy is defined as mean overlap of the tracker region with the groundtruth region. The overlap is computed only for frames where the tracker is not in + initialization or failure state. The overlap is computed only for frames after the burnin period. + Robustness is defined as a number of failures divided by the total number of frames. + The analysis is computed as an average of accuracy and robustness over all sequences. + """ analysis = Include(AccuracyRobustness) @property - def title(self): + def _title_default(self): + """Returns title of the analysis.""" return "AR Analysis" def dependencies(self): + """Returns dependencies of the analysis.""" return self.analysis, def describe(self): + """Returns description of the analysis.""" return Measure("Accuracy", "A", minimal=0, maximal=1, direction=Sorting.DESCENDING), \ Measure("Robustness", "R", minimal=0, direction=Sorting.ASCENDING), \ Point("AR plot", dimensions=2, abbreviation="AR", minimal=(0, 0), \ @@ -130,9 +174,21 @@ def describe(self): None def compatible(self, experiment: Experiment): + """Returns True if the analysis is compatible with the experiment. Only SupervisedExperiment is compatible.""" return isinstance(experiment, SupervisedExperiment) def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid): + """Aggregates results of the analysis. + + Args: + tracker (Tracker): Tracker. + sequences (List[Sequence]): List of sequences. + results (Grid): Results of the analysis. + + Returns: + Tuple[Any]: Accuracy, robustness, AR, number of frames. + """ + failures = 0 accuracy = 0 weight_total = 0 @@ -148,21 +204,40 @@ def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid): @analysis_registry.register("supervised_eao_curve") class EAOCurve(TrackerSeparableAnalysis): + """Expected Average Overlap curve analysis. Computes expected average overlap of a tracker on a given sequence. + The overlap is computed only for frames where the tracker is not in initialization or failure state. + The overlap is computed only for frames after the burnin period. + The analysis is computed as an average of accuracy and robustness over all sequences. + """ burnin = Integer(default=10, val_min=0) bounded = Boolean(default=True) @property - def title(self): + def _title_default(self): + """Returns title of the analysis.""" return "EAO Curve" def describe(self): + """Returns description of the analysis.""" return Plot("Expected Average Overlap", "EAO", minimal=0, maximal=1, trait="eao"), def compatible(self, experiment: Experiment): + """Returns True if the analysis is compatible with the experiment. Only SupervisedExperiment is compatible.""" return isinstance(experiment, SupervisedExperiment) def subcompute(self, experiment: Experiment, tracker: Tracker, sequences: List[Sequence], dependencies: List[Grid]) -> Tuple[Any]: + """Computes expected average overlap of a tracker on a given sequence. + + Args: + experiment (Experiment): Experiment. + tracker (Tracker): Tracker. + sequences (List[Sequence]): List of sequences. + dependencies (List[Grid]): Dependencies. + + Returns: + Tuple[Any]: Expected average overlap. + """ overlaps_all = [] weights_all = [] @@ -203,27 +278,45 @@ def subcompute(self, experiment: Experiment, tracker: Tracker, sequences: List[S @analysis_registry.register("supervised_eao_score") class EAOScore(Analysis): + """Expected Average Overlap score analysis. The analysis is computed as an average of EAO scores over multiple sequences. + """ eaocurve = Include(EAOCurve) low = Integer() high = Integer() @property - def title(self): + def _title_default(self): + """Returns title of the analysis.""" return "EAO analysis" def describe(self): + """Returns description of the analysis.""" return Measure("Expected average overlap", "EAO", 0, 1, Sorting.DESCENDING), def compatible(self, experiment: Experiment): + """Returns True if the analysis is compatible with the experiment. Only SupervisedExperiment is compatible.""" return isinstance(experiment, SupervisedExperiment) def dependencies(self): + """Returns dependencies of the analysis.""" return self.eaocurve, def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid: + """Computes expected average overlap of a tracker on a given sequence. + + Args: + experiment (Experiment): Experiment. + trackers (List[Tracker]): List of trackers. + sequences (List[Sequence]): List of sequences. + dependencies (List[Grid]): Dependencies. + + Returns: + Grid: Expected average overlap. + """ return dependencies[0].foreach(lambda x, i, j: (float(np.mean(x[0][self.low:self.high + 1])), ) ) @property def axes(self): + """Returns axes of the analysis.""" return Axes.TRACKERS \ No newline at end of file diff --git a/vot/analysis/tpr.py b/vot/analysis/tpr.py deleted file mode 100644 index d53d0d6..0000000 --- a/vot/analysis/tpr.py +++ /dev/null @@ -1,261 +0,0 @@ -import math -import numpy as np -from typing import List, Iterable, Tuple, Any -import itertools - -from attributee import Float, Integer, Boolean, Include - -from vot.tracker import Tracker -from vot.dataset import Sequence -from vot.region import Region, RegionType, calculate_overlaps -from vot.experiment import Experiment -from vot.experiment.multirun import UnsupervisedExperiment -from vot.analysis import SequenceAggregator, Analysis, SeparableAnalysis, \ - MissingResultsException, Measure, Sorting, Curve, Plot, SequenceAggregator, \ - Axes, analysis_registry -from vot.utilities.data import Grid - -def determine_thresholds(scores: Iterable[float], resolution: int) -> List[float]: - scores = [score for score in scores if not math.isnan(score)] #and not score is None] - scores = sorted(scores, reverse=True) - - if len(scores) > resolution - 2: - delta = math.floor(len(scores) / (resolution - 2)) - idxs = np.round(np.linspace(delta, len(scores) - delta, num=resolution - 2)).astype(np.int) - thresholds = [scores[idx] for idx in idxs] - else: - thresholds = scores - - thresholds.insert(0, math.inf) - thresholds.insert(len(thresholds), -math.inf) - - return thresholds - -def compute_tpr_curves(trajectory: List[Region], confidence: List[float], sequence: Sequence, thresholds: List[float], - ignore_unknown: bool = True, bounded: bool = True): - - overlaps = np.array(calculate_overlaps(trajectory, sequence.groundtruth(), (sequence.size) if bounded else None)) - confidence = np.array(confidence) - - n_visible = len([region for region in sequence.groundtruth() if region.type is not RegionType.SPECIAL]) - - - - precision = len(thresholds) * [float(0)] - recall = len(thresholds) * [float(0)] - - for i, threshold in enumerate(thresholds): - - subset = confidence >= threshold - - if np.sum(subset) == 0: - precision[i] = 1 - recall[i] = 0 - else: - precision[i] = np.mean(overlaps[subset]) - recall[i] = np.sum(overlaps[subset]) / n_visible - - return precision, recall - -class _ConfidenceScores(SeparableAnalysis): - - @property - def title(self): - return "Aggregate confidence scores" - - def describe(self): - return None, - - def compatible(self, experiment: Experiment): - return isinstance(experiment, UnsupervisedExperiment) - - def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: - - scores_all = [] - trajectories = experiment.gather(tracker, sequence) - - if len(trajectories) == 0: - raise MissingResultsException("Missing results for sequence {}".format(sequence.name)) - - for trajectory in trajectories: - confidence = [trajectory.properties(i).get('confidence', 0) for i in range(len(trajectory))] - scores_all.extend(confidence) - - return scores_all, - - -class _Thresholds(SequenceAggregator): - - resolution = Integer(default=100) - - @property - def title(self): - return "Thresholds for tracking precision/recall" - - def describe(self): - return None, - - def compatible(self, experiment: Experiment): - return isinstance(experiment, UnsupervisedExperiment) - - def dependencies(self): - return _ConfidenceScores(), - - def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]: - - thresholds = determine_thresholds(itertools.chain(*[result[0] for result in results]), self.resolution), - - return thresholds, - -@analysis_registry.register("pr_curves") -class PrecisionRecallCurves(SeparableAnalysis): - - thresholds = Include(_Thresholds) - ignore_unknown = Boolean(default=True) - bounded = Boolean(default=True) - - @property - def title(self): - return "Tracking precision/recall" - - def describe(self): - return Curve("Precision Recall curve", dimensions=2, abbreviation="PR", minimal=(0, 0), maximal=(1, 1), labels=("Recall", "Precision")), None - - def compatible(self, experiment: Experiment): - return isinstance(experiment, UnsupervisedExperiment) - - def dependencies(self): - return self.thresholds, - - def subcompute(self, experiment: Experiment, tracker: Tracker, sequence: Sequence, dependencies: List[Grid]) -> Tuple[Any]: - - thresholds = dependencies[0, 0][0][0] # dependencies[0][0, 0] - - trajectories = experiment.gather(tracker, sequence) - - if len(trajectories) == 0: - raise MissingResultsException() - - precision = len(thresholds) * [float(0)] - recall = len(thresholds) * [float(0)] - for trajectory in trajectories: - confidence = [trajectory.properties(i).get('confidence', 0) for i in range(len(trajectory))] - pr, re = compute_tpr_curves(trajectory.regions(), confidence, sequence, thresholds, self.ignore_unknown, self.bounded) - for i in range(len(thresholds)): - precision[i] += pr[i] - recall[i] += re[i] - -# return [(re / len(trajectories), pr / len(trajectories)) for pr, re in zip(precision, recall)], thresholds - return [(pr / len(trajectories), re / len(trajectories)) for pr, re in zip(precision, recall)], thresholds - -@analysis_registry.register("pr_curve") -class PrecisionRecallCurve(SequenceAggregator): - - curves = Include(PrecisionRecallCurves) - - @property - def title(self): - return "Tracking precision/recall average curve" - - def describe(self): - return self.curves.describe() - - def compatible(self, experiment: Experiment): - return isinstance(experiment, UnsupervisedExperiment) - - def dependencies(self): - return self.curves, - - # def collapse(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]: - def aggregate(self, tracker: Tracker, sequences: List[Sequence], results: Grid) -> Tuple[Any]: - - curve = None - thresholds = None - - for partial, thresholds in results: - if curve is None: - curve = partial - continue - - curve = [(pr1 + pr2, re1 + re2) for (pr1, re1), (pr2, re2) in zip(curve, partial)] - - curve = [(re / len(results), pr / len(results)) for pr, re in curve] - - return curve, thresholds - - -@analysis_registry.register("f_curve") -class FScoreCurve(Analysis): - - beta = Float(default=1) - prcurve = Include(PrecisionRecallCurve) - - @property - def title(self): - return "Tracking precision/recall" - - def describe(self): - return Plot("Tracking F-score curve", "F", wrt="normalized threshold", minimal=0, maximal=1), None - - def compatible(self, experiment: Experiment): - return isinstance(experiment, UnsupervisedExperiment) - - def dependencies(self): - return self.prcurve, - - def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid: - processed_results = Grid(len(trackers), 1) - - for i, result in enumerate(dependencies[0]): - beta2 = (self.beta * self.beta) - f_curve = [((1 + beta2) * pr_ * re_) / (beta2 * pr_ + re_) for pr_, re_ in result[0]] - - processed_results[i, 0] = (f_curve, result[0][1]) - - return processed_results - - @property - def axes(self): - return Axes.TRACKERS - -@analysis_registry.register("average_tpr") -class PrecisionRecall(Analysis): - - prcurve = Include(PrecisionRecallCurve) - fcurve = Include(FScoreCurve) - - @property - def title(self): - return "Tracking precision/recall" - - def describe(self): - return Measure("Precision", "Pr", minimal=0, maximal=1, direction=Sorting.DESCENDING), \ - Measure("Recall", "Re", minimal=0, maximal=1, direction=Sorting.DESCENDING), \ - Measure("F Score", "F", minimal=0, maximal=1, direction=Sorting.DESCENDING) - - def compatible(self, experiment: Experiment): - return isinstance(experiment, UnsupervisedExperiment) - - def dependencies(self): - return self.prcurve, self.fcurve - - def compute(self, experiment: Experiment, trackers: List[Tracker], sequences: List[Sequence], dependencies: List[Grid]) -> Grid: - - f_curves = dependencies[1] - pr_curves = dependencies[0] - - joined = Grid(len(trackers), 1) - - for i, (f_curve, pr_curve) in enumerate(zip(f_curves, pr_curves)): - # get optimal F-score and Pr and Re at this threshold - f_score = max(f_curve[0]) - best_i = f_curve[0].index(f_score) - re_score = pr_curve[0][best_i][0] - pr_score = pr_curve[0][best_i][1] - joined[i, 0] = (pr_score, re_score, f_score) - - return joined - - @property - def axes(self): - return Axes.TRACKERS diff --git a/vot/dataset/__init__.py b/vot/dataset/__init__.py index 78b9f2b..960dfc5 100644 --- a/vot/dataset/__init__.py +++ b/vot/dataset/__init__.py @@ -1,98 +1,242 @@ +"""Dataset module provides an interface for accessing the datasets and sequences. It also provides a set of utility functions for downloading and extracting datasets.""" + import os -import json -import glob +import logging +from numbers import Number +from collections import namedtuple from abc import abstractmethod, ABC +from typing import List, Mapping, Optional, Set, Tuple, Iterator + +from class_registry import ClassRegistry from PIL.Image import Image import numpy as np +from cachetools import cached, LRUCache + +from vot.region import Region from vot import ToolkitException -from vot.utilities import read_properties -from vot.region import parse import cv2 +logger = logging.getLogger("vot") + +dataset_downloader = ClassRegistry("vot_downloader") +sequence_indexer = ClassRegistry("vot_indexer") +sequence_reader = ClassRegistry("vot_sequence") + class DatasetException(ToolkitException): + """Dataset and sequence related exceptions + """ pass class Channel(ABC): + """Abstract representation of individual image channel, a sequence of images with + uniform dimensions. + """ def __init__(self): + """ Base constructor for channel""" pass - @property - @abstractmethod - def length(self): - pass + def __len__(self) -> int: + """Returns the length of channel + + Returns: + int: Length of channel + """ + raise NotImplementedError() @abstractmethod - def frame(self, index): + def frame(self, index: int) -> "Frame": + """Returns frame object for the given index + + Args: + index (int): Index of the frame + + Returns: + Frame: Frame object + """ pass @abstractmethod - def filename(self, index): + def filename(self, index: int) -> str: + """Returns filename for the given index of the channel sequence + + Args: + index (int): Index of the frame + + Returns: + str: Filename of the frame + """ pass @property @abstractmethod - def size(self): + def size(self) -> int: + """Returns the size of the channel in bytes""" pass class Frame(object): + """Frame object represents a single frame in the sequence. It provides access to the + image data, groundtruth, tags and values as a wrapper around the sequence object.""" def __init__(self, sequence, index): + """Base constructor for frame object + + Args: + sequence (Sequence): Sequence object + index (int): Index of the frame + + Returns: + Frame: Frame object + """ self._sequence = sequence self._index = index @property def index(self) -> int: + """Returns the index of the frame + + Returns: + int: Index of the frame + """ return self._index @property def sequence(self) -> 'Sequence': + """Returns the sequence object of the frame object + + Returns: + Sequence: Sequence object + """ return self._sequence def channels(self): + """Returns the list of channels in the sequence + + Returns: + List[str]: List of channels + """ return self._sequence.channels() - def channel(self, channel=None): + def channel(self, channel: Optional[str] = None): + """Returns the channel object for the given channel name + + Args: + channel (Optional[str], optional): Name of the channel. Defaults to None. + """ channelobj = self._sequence.channel(channel) if channelobj is None: return None return channelobj.frame(self._index) - def filename(self, channel=None): + def filename(self, channel: Optional[str] = None): + """Returns the filename for the given channel name and frame index + + Args: + channel (Optional[str], optional): Name of the channel. Defaults to None. + + Returns: + str: Filename of the frame + """ channelobj = self._sequence.channel(channel) if channelobj is None: return None return channelobj.filename(self._index) - def image(self, channel=None): + def image(self, channel: Optional[str] = None) -> np.ndarray: + """Returns the image for the given channel name and frame index + + Args: + channel (Optional[str], optional): Name of the channel. Defaults to None. + + Returns: + np.ndarray: Image object + """ channelobj = self._sequence.channel(channel) if channelobj is None: return None return channelobj.frame(self._index) - def groundtruth(self): + def objects(self) -> List[str]: + """Returns the list of objects in the frame + + Returns: + List[str]: List of object ids + """ + objects = {} + for o in self._sequence.objects(): + region = self._sequence.object(o, self._index) + if region is not None: + objects[o] = region + return objects + + def object(self, id: str) -> Region: + """Returns the object region for the given object id and frame index + + Args: + id (str): Id of the object + + Returns: + Region: Object region + """ + return self._sequence.object(id, self._index) + + def groundtruth(self) -> Region: + """Returns the groundtruth region for the frame + + Returns: + Region: Groundtruth region + + Raises: + DatasetException: If groundtruth is not available + """ return self._sequence.groundtruth(self._index) - def tags(self, index = None): + def tags(self) -> List[str]: + """Returns the tags for the frame + + Returns: + List[str]: List of tags + """ return self._sequence.tags(self._index) - def values(self, index=None): + def values(self) -> Mapping[str, float]: + """Returns the values for the frame + + Returns: + Mapping[str, float]: Mapping of values + """ return self._sequence.values(self._index) class SequenceIterator(object): - - def __init__(self, sequence): + """Sequence iterator provides an iterator interface for the sequence object""" + + def __init__(self, sequence: "Sequence"): + """Base constructor for sequence iterator + + Args: + sequence (Sequence): Sequence object + """ self._position = 0 self._sequence = sequence def __iter__(self): + """Returns the iterator object + + Returns: + SequenceIterator: Sequence iterator object + """ return self - def __next__(self): + def __next__(self) -> Frame: + """Returns the next frame object in the sequence iterator + + Returns: + Frame: Frame object + """ if self._position >= len(self._sequence): raise StopIteration() index = self._position @@ -100,8 +244,11 @@ def __next__(self): return Frame(self._sequence, index) class InMemoryChannel(Channel): + """In-memory channel represents a sequence of images with uniform dimensions. It is + used to represent a sequence of images in memory.""" def __init__(self): + """Base constructor for in-memory channel""" super().__init__() self._images = [] self._width = 0 @@ -109,6 +256,11 @@ def __init__(self): self._depth = 0 def append(self, image): + """Appends an image to the channel + + Args: + image (np.ndarray): Image object + """ if isinstance(image, Image): image = np.asarray(image) @@ -133,40 +285,104 @@ def append(self, image): self._images.append(image) @property - def length(self): + def length(self) -> int: + """Returns the length of the sequence channel in number of frames + + Returns: + int: Length of the sequence channel + """ return len(self._images) def frame(self, index): - if index < 0 or index >= self.length: + """Returns the frame object for the given index in the sequence channel + + Args: + index (int): Index of the frame + + Returns: + Frame: Frame object + """ + if index < 0 or index >= len(self): return None return self._images[index] @property def size(self): + """Returns the size of the channel in the format (width, height) + + Returns: + Tuple[int, int]: Size of the channel + """ return self._width, self._height def filename(self, index): + """ Thwows an exception as the sequence is available in memory and not in files. """ raise DatasetException("Sequence is available in memory, image files not available") class PatternFileListChannel(Channel): - - def __init__(self, path, start=1, step=1, end=None): + """Sequence channel implementation where each frame is stored in a file and all file names + follow a specific pattern. + """ + + def __init__(self, path, start=1, step=1, end=None, check_files=True): + """ Creates a new channel object + + Args: + path (str): Path to the sequence + start (int, optional): First frame index + step (int, optional): Step between frames + end (int, optional): Last frame index + check_files (bool, optional): Check that files exist + + Raises: + DatasetException: If the pattern is invalid + + Returns: + PatternFileListChannel: Channel object + """ super().__init__() base, pattern = os.path.split(path) self._base = base self._pattern = pattern - self.__scan(pattern, start, step, end) + self.__scan(pattern, start, step, end, check_files=check_files) @property - def base(self): + def base(self) -> str: + """Returns the base path of the sequence + + Returns: + str: Base path + """ return self._base @property def pattern(self): + """Returns the pattern of the sequence + + Returns: + str: Pattern + """ return self._pattern - def __scan(self, pattern, start, step, end): + def __scan(self, pattern, start, step, end, check_files=True): + """Scans the sequence directory for files matching the pattern and stores the file names in + the internal list. The pattern must contain a single %d placeholder for the frame index. The + placeholder must be at the end of the pattern. The pattern may contain a file extension. If + the pattern contains no file extension, .jpg is assumed. If end frame is specified, the + scanning stops when the end frame is reached. If check_files is True, and end frame is set + then files are checked to exist. + + Args: + pattern (str): Pattern + start (int): First frame index + step (int): Step between frames + end (int): Last frame index + check_files (bool, optional): Check that files exist + + Raises: + DatasetException: If the pattern is invalid + """ extension = os.path.splitext(pattern)[1] if not extension in {'.jpg', '.png'}: @@ -177,10 +393,12 @@ def __scan(self, pattern, start, step, end): fullpattern = os.path.join(self.base, pattern) + assert end is not None or check_files + while True: image_file = os.path.join(fullpattern % i) - if not os.path.isfile(image_file): + if check_files and not os.path.isfile(image_file): break self._files.append(os.path.basename(image_file)) i = i + step @@ -191,234 +409,564 @@ def __scan(self, pattern, start, step, end): if i <= start: raise DatasetException("Empty sequence, no frames found.") - im = cv2.imread(self.filename(0)) - self._width = im.shape[1] - self._height = im.shape[0] - self._depth = im.shape[2] + if os.path.isfile(self.filename(0)): + im = cv2.imread(self.filename(0)) + self._width = im.shape[1] + self._height = im.shape[0] + self._depth = im.shape[2] + else: + self._depth = None + self._width = None + self._height = None - @property - def length(self): + def __len__(self) -> int: + """Returns the number of frames in the sequence + + Returns: + int: Number of frames + """ return len(self._files) - def frame(self, index): - if index < 0 or index >= self.length: + def frame(self, index: int) -> np.ndarray: + """Returns the frame at the specified index as a numpy array. The image is loaded using + OpenCV and converted to RGB color space if necessary. + + Args: + index (int): Frame index + + Returns: + np.ndarray: Frame + + Raises: + DatasetException: If the index is out of bounds + + """ + if index < 0 or index >= len(self): return None bgr = cv2.imread(self.filename(index)) + + # Check if the image is grayscale + if len(bgr.shape) == 2: + return bgr + return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) @property - def size(self): + def size(self) -> tuple: + """Returns the size of the frames in the sequence as a tuple (width, height) + + Returns: + tuple: Size of the frames + """ return self._width, self._height @property - def width(self): + def width(self) -> int: + """Returns the width of the frames in the sequence + + Returns: + int: Width of the frames + """ return self._width @property - def height(self): + def height(self) -> int: + """Returns the height of the frames in the sequence + + Returns: + int: Height of the frames + """ return self._height - def filename(self, index): - if index < 0 or index >= self.length: + def filename(self, index) -> str: + """Returns the filename of the frame at the specified index + + Args: + index (int): Frame index + + Returns: + str: Filename + """ + if index < 0 or index >= len(self): return None return os.path.join(self.base, self._files[index]) -class FrameList(ABC): +class FrameList(object): + """Abstract base for all sequences, just a list of frame objects + """ def __iter__(self): + """Returns an iterator over the frames in the sequence + + Returns: + SequenceIterator: Iterator + """ return SequenceIterator(self) - @abstractmethod def __len__(self) -> int: - pass - - @abstractmethod - def frame(self, index) -> Frame: - pass + """Returns the number of frames in the sequence. + + Returns: + int: Number of frames + """ + return NotImplementedError() + + def frame(self, index: int) -> Frame: + """Returns the frame at the specified index + + Args: + index (int): Frame index + """ + return NotImplementedError() class Sequence(FrameList): + """A sequence is a list of frames (multiple channels) and a list of one or more annotated objects. It also contains + additional metadata and per-frame information, such as tags and values. + """ - def __init__(self, name: str, dataset: "Dataset" = None): - self._name = name - self._dataset = dataset + UNKNOWN = 0 # object state is unknown in this frame + INVISIBLE = 1 # object is not visible in this frame - def __len__(self) -> int: - return self.length + def __init__(self, name: str): + """Creates a new sequence with the specified name""" + self._name = name @property def name(self) -> str: + """Returns the name of the sequence + + Returns: + str: Name + """ return self._name @property def identifier(self) -> str: + """Returns the identifier of the sequence. The identifier is a string that uniquely identifies the sequence in + the dataset. The identifier is usually the same as the name, but may be different if the name is not unique. + + Returns: + str: Identifier + """ return self._name - @property - def dataset(self): - return self._dataset - @abstractmethod def metadata(self, name, default=None): + """Returns the value of the specified metadata field. If the field does not exist, the default value is returned + + Args: + name (str): Name of the metadata field + default (object, optional): Default value + + Returns: + object: Value of the metadata field + """ pass @abstractmethod - def channel(self, channel=None): + def channel(self, channel=None) -> Channel: + """Returns the channel with the specified name or the default channel if no name is specified + + Args: + channel (str, optional): Name of the channel + + Returns: + Channel: Channel + """ pass @abstractmethod - def channels(self): + def channels(self) -> Set[str]: + """Returns the names of all channels in the sequence + + Returns: + set: Names of all channels + """ pass @abstractmethod - def groundtruth(self, index: int): + def objects(self) -> Set[str]: + """Returns the names of all objects in the sequence + + Returns: + set: Names of all objects + """ pass @abstractmethod - def tags(self, index=None): + def object(self, id, index=None): + """Returns the object with the specified name or identifier. If the index is specified, the object is returned + only if it is visible in the frame at the specified index. + + Args: + id (str): Name or identifier of the object + index (int, optional): Frame index + + Returns: + Region: Object + """ pass @abstractmethod - def values(self, index=None): - pass + def groundtruth(self, index: int) -> Region: + """Returns the ground truth region for the specified frame index or None if no ground truth is available for the + frame or the frame index is out of bounds. This is a legacy method for compatibility with single-object datasets + and should not be used in new code. - @property - @abstractmethod - def size(self): - pass + Args: + index (int): Frame index - @property - @abstractmethod - def length(self): + Returns: + Region: Ground truth region + """ pass - def describe(self): - data = dict(length=self.length, size=self.size) - return data - -class Dataset(ABC): - - def __init__(self, path): - self._path = path - - def __len__(self): - return self.length - - @property - def path(self): - return self._path - - @property @abstractmethod - def length(self): + def tags(self, index=None) -> List[str]: + """Returns the tags for the specified frame index or None if no tags are available for the frame or the frame + index is out of bounds. + + Args: + index (int, optional): Frame index + + Returns: + list: List of tags + """ pass @abstractmethod - def __getitem__(self, key): + def values(self, index=None) -> Mapping[str, Number]: + """Returns the values for the specified frame index or None if no values are available for the frame or the frame + index is out of bounds. + + Args: + index (int, optional): Frame index + + Returns: + dict: Dictionary of values""" pass + @property @abstractmethod - def __contains__(self, key): - return False + def width(self) -> int: + """Returns the width of the frames in the sequence in pixels - @abstractmethod - def __iter__(self): + Returns: + int: Width of the frames + """ pass + @property @abstractmethod - def list(self): - return [] + def height(self) -> int: + """Returns the height of the frames in the sequence in pixels - def keys(self): - return self.list() - -class BaseSequence(Sequence): + Returns: + int: Height of the frames + """ + pass - def __init__(self, name, dataset=None): - super().__init__(name, dataset) - self._metadata = self._read_metadata() - self._data = None + @property + def size(self) -> Tuple[int, int]: + """Returns the size of the frames in the sequence in pixels as a tuple (width, height) + + Returns: + tuple: Size of the frames + """ + return self.width, self.height - @abstractmethod - def _read_metadata(self): - raise NotImplementedError + def describe(self): + """Returns a dictionary with information about the sequence + + Returns: + dict: Dictionary with information + """ + return dict(length=len(self), width=self.width, height=self.height) + +class Dataset(object): + """Base abstract class for a tracking dataset, a list of image sequences addressable by their names and interatable. + """ + + def __init__(self, sequences: Mapping[str, Sequence]) -> None: + """Creates a new dataset with the specified sequences + + Args: + sequences (dict): Dictionary of sequences + """ + self._sequences = sequences - @abstractmethod - def _read(self): - raise NotImplementedError + def __len__(self) -> int: + """Returns the number of sequences in the dataset + + Returns: + int: Number of sequences + """ + return len(self._sequences) + + def __getitem__(self, key: str) -> Sequence: + """Returns the sequence with the specified name + + Args: + key (str): Sequence name + + Returns: + Sequence: Sequence + """ + return self._sequences[key] + + def __contains__(self, key: str) -> bool: + """Returns true if the dataset contains a sequence with the specified name + + Args: + key (str): Sequence name + + Returns: + bool: True if the dataset contains the sequence + """ + return key in self._sequences + + def __iter__(self) -> Iterator[Sequence]: + """Returns an iterator over the sequences in the dataset + + Returns: + DatasetIterator: Iterator + """ + return iter(self._sequences.values()) + + def list(self) -> List[str]: + """Returns a list of unique sequence names + + Returns: + List[str]: List of sequence names + """ + return list(self._sequences.keys()) + + def keys(self) -> List[str]: + """Returns a list of unique sequence names + + Returns: + List[str]: List of sequence names + """ + return list(self._sequences.keys()) + +SequenceData = namedtuple("SequenceData", ["channels", "objects", "tags", "values", "length"]) + +from vot import config + +@cached(LRUCache(maxsize=config.sequence_cache_size)) +def _cached_loader(sequence): + """Loads the sequence data from the sequence object. This function serves as a cache for the sequence data and is only + called if the sequence data is not already loaded. The cache is implemented as a LRU cache with a maximum size + specified in the configuration.""" + return sequence._loader(sequence._metadata) + +class BasedSequence(Sequence): + """This class implements the caching of the sequence data. The sequence data is loaded only when it is needed. + """ + + def __init__(self, name: str, loader: callable, metadata: dict = None): + """Initializes the sequence + + Args: + name (str): Sequence name + loader (callable): Loader function that takes the metadata as an argument and returns a SequenceData object + metadata (dict, optional): Sequence metadata. Defaults to None. + + Raises: + ValueError: If the loader is not callable + """ + super().__init__(name) + self._loader = loader + self._metadata = metadata if metadata is not None else {} def __preload(self): - if self._data is None: - self._data = self._read() - - def metadata(self, name, default=None): + """Loads the sequence data if needed. This is an internal function that should not be called directly. + It calles a cached loader function that is implemented as a LRU cache with a configurable maximum size.""" + return _cached_loader(self) + + def metadata(self, name: str, default: object=None) -> object: + """Returns the metadata value with the specified name. + + Args: + name (str): Metadata name + default (object, optional): Default value. Defaults to None. + + Returns: + object: Metadata value""" return self._metadata.get(name, default) - def channels(self): - self.__preload() - return self._data[0] - - def channel(self, channel=None): - self.__preload() + def channels(self) -> List[str]: + """Returns a list of channel names in the sequence + + Returns: + List[str]: List of channel names + """ + data = self.__preload() + return data.channels.keys() + + def channel(self, channel: str=None) -> Channel: + """Returns the channel with the specified name. If the channel name is not specified, the default channel is returned. + + Args: + channel (str, optional): Channel name. Defaults to None. + + Returns: + Channel: Channel + """ + data = self.__preload() if channel is None: channel = self.metadata("channel.default") - return self._data[0].get(channel, None) + return data.channels.get(channel, None) def frame(self, index): + """Returns the frame with the specified index in the sequence as a Frame object + + Args: + index (int): Frame index + + Returns: + Frame: Frame + """ return Frame(self, index) - def groundtruth(self, index=None): - self.__preload() + def objects(self) -> List[str]: + """Returns a list of object ids in the sequence + + Returns: + List[str]: List of object ids + """ + data = self.__preload() + return data.objects.keys() + + def object(self, id, index=None) -> Region: + """Returns the object with the specified id. If the index is specified, the object is returned as a Region object. + + Args: + id (str): Object id + index (int, optional): Frame index. Defaults to None. + + Returns: + Region: Object region + """ + data = self.__preload() + obj = data.objects.get(id, None) if index is None: - return self._data[1] - return self._data[1][index] + return obj + if obj is None: + return None + return obj[index] - def tags(self, index=None): - self.__preload() + def groundtruth(self, index=None): + """Returns the groundtruth object. If the index is specified, the object is returned as a Region object. If the + sequence contains more than one object, an exception is raised. + + Args: + index (int, optional): Frame index. Defaults to None. + + Returns: + Region: Groundtruth region + """ + data = self.__preload() + if len(self.objects()) != 1: + raise DatasetException("More than one object in sequence") + + id = next(iter(data.objects)) + return self.object(id, index) + + def tags(self, index: int = None) -> List[str]: + """Returns a list of tags in the sequence. If the index is specified, only the tags that are present in the frame + with the specified index are returned. + + Args: + index (int, optional): Frame index. Defaults to None. + + Returns: + List[str]: List of tags + """ + data = self.__preload() if index is None: - return self._data[2].keys() - return [t for t, sq in self._data[2].items() if sq[index]] - - def values(self, index=None): - self.__preload() + return data.tags.keys() + return [t for t, sq in data.tags.items() if sq[index]] + + def values(self, index: int = None) -> List[float]: + """Returns a list of values in the sequence. If the index is specified, only the values that are present in the frame + with the specified index are returned. + + Args: + index (int, optional): Frame index. Defaults to None. + + Returns: + List[float]: List of values + """ + data = self.__preload() if index is None: - return self._data[3].keys() - return {v: sq[index] for v, sq in self._data[3].items()} + return data.values.keys() + return {v: sq[index] for v, sq in data.values.items()} @property def size(self): - return self.channel().size + """Returns the sequence size as a tuple (width, height) + + Returns: + tuple: Sequence size + + """ + return self.width, self.height @property def width(self): - return self.channel().width + """Returns the sequence width""" + return self._metadata["width"] @property def height(self): - return self.channel().height + """Returns the sequence height""" + return self._metadata["height"] - @property - def length(self): - self.__preload() - return len(self._data[1]) - -class InMemorySequence(BaseSequence): + def __len__(self): + """Returns the sequence length in frames + + Returns: + int: Sequence length + """ + data = self.__preload() + return data.length + +class InMemorySequence(Sequence): + """ An in-memory sequence that can be used to construct a sequence programmatically and store it do disk. + Used mainly for testing and debugging. + + Only single object sequences are supported at the moment. + """ def __init__(self, name, channels): + """Creates a new in-memory sequence. + + Args: + name (str): Sequence name + channels (list): List of channel names + + Raises: + DatasetException: If images are not provided for all channels + """ super().__init__(name, None) self._channels = {c: InMemoryChannel() for c in channels} self._tags = {} self._values = {} self._groundtruth = [] - def _read_metadata(self): - return dict() - - def _read(self): - return self._channels, self._groundtruth, self._tags, self._values - def append(self, images: dict, region: "Region", tags: list = None, values: dict = None): + """Appends a new frame to the sequence. The frame is specified by a dictionary of images, a region and optional + tags and values. + + Args: + images (dict): Dictionary of images + region (Region): Region + tags (list, optional): List of tags + values (dict, optional): Dictionary of values + """ if not set(images.keys()).issuperset(self._channels.keys()): raise DatasetException("Images not provided for all channels") @@ -432,7 +980,7 @@ def append(self, images: dict, region: "Region", tags: list = None, values: dict tags = set(tags) for tag in tags: if not tag in self._tags: - self._tags[tag] = [False] * self.length + self._tags[tag] = [False] * len(self) self._tags[tag].append(True) for tag in set(self._tags.keys()).difference(tags): self._tags[tag].append(False) @@ -441,51 +989,290 @@ def append(self, images: dict, region: "Region", tags: list = None, values: dict values = dict() for name, value in values.items(): if not name in self._values: - self._values[name] = [0] * self.length + self._values[name] = [0] * len(self) self._values[tag].append(value) for name in set(self._values.keys()).difference(values.keys()): self._values[name].append(0) self._groundtruth.append(region) + def channel(self, channel : str) -> "Channel": + """Returns the specified channel object. + + Args: + channel (str): Channel name + + Returns: + Channel: Channel object + """ + return self._channels.get(channel, None) + + def channels(self) -> List[str]: + """Returns a list of channel names. + + Returns: + List[str]: List of channel names + + """ + return self._channels.keys() + + def frame(self, index : int) -> "Frame": + """Returns the specified frame. The frame is returned as a Frame object. + + Args: + index (int): Frame index + + Returns: + Frame: Frame object + """ + return Frame(self, index) + + def groundtruth(self, index: int = None) -> "Region": + """Returns the groundtruth object. If the index is specified, the object is returned as a Region object. If the + sequence contains more than one object, an exception is raised. If the index is not specified, the groundtruth + object is returned as a Region object. If the sequence contains more than one object, an exception is raised. + + Args: + index (int, optional): Frame index. Defaults to None. + + Returns: + Region: Groundtruth object + """ + if index is None: + return self._groundtruth + return self._groundtruth[index] + + def object(self, id: str, index: int = None) -> "Region": + """Returns the specified object. If the index is specified, the object is returned as a Region object. If the + sequence contains more than one object, an exception is raised. If the index is not specified, the groundtruth + object is returned as a Region object. If the sequence contains more than one object, an exception is raised. + + Args: + id (str): Object id + index (int, optional): Frame index. Defaults to None. + + Returns: + Region: Object + """ + if id != "object": + return None + + if index is None: + return self._groundtruth + return self._groundtruth[index] + + def objects(self, index: str = None) -> List[str]: + """Returns a list of object ids. If the index is specified, only the objects that are present in the frame with + the specified index are returned. + + Since only single object sequences are supported, the only object id that is returned is "object". + + Args: + index (int, optional): Frame index. Defaults to None.""" + return ["object"] + + def tags(self, index=None): + """Returns a list of tags in the sequence. If the index is specified, only the tags that are present in the frame + with the specified index are returned. + + Args: + index (int, optional): Frame index. Defaults to None. + + Returns: + List[str]: List of tags + """ + if index is None: + return self._tags.keys() + return [t for t, sq in self._tags.items() if sq[index]] + + def values(self, index=None): + """Returns a list of values in the sequence. If the index is specified, only the values that are present in the + frame with the specified index are returned. + + Args: + index (int, optional): Frame index. Defaults to None. + + Returns: + List[str]: List of values + """ + if index is None: + return self._values.keys() + return {v: sq[index] for v, sq in self._values.items()} + + def __len__(self): + """Returns the sequence length in frames + + Returns: + int: Sequence length + """ + return len(self._groundtruth) + + @property + def width(self) -> int: + """Returns the sequence width + + Returns: + int: Sequence width + """ + return self.channel().width + + @property + def height(self) -> int: + """Returns the sequence height + + Returns: + int: Sequence height + """ + return self.channel().height + + @property + def size(self) -> tuple: + """Returns the sequence size as a tuple (width, height) + + Returns: + tuple: Sequence size + """ + return self.channel().size + + @property + def channels(self) -> list: + """Returns a list of channel names + + Returns: + list: List of channel names + """ + return self._channels.keys() + +def download_bundle(url: str, path: str = "."): + """Downloads a dataset bundle as a ZIP file and decompresses it. + + Args: + url (str): Source bundle URL + path (str, optional): Destination directory. Defaults to ".". + + Raises: + DatasetException: If the bundle cannot be downloaded or is not supported. + """ + + from vot.utilities.net import download_uncompress, NetworkException + + if not url.endswith(".zip"): + raise DatasetException("Unknown bundle format") + + logger.info('Downloading sequence bundle from "%s". This may take a while ...', url) + + try: + download_uncompress(url, path) + except NetworkException as e: + raise DatasetException("Unable do download dataset bundle, Please try to download the bundle manually from {} and uncompress it to {}'".format(url, path)) + except IOError as e: + raise DatasetException("Unable to extract dataset bundle, is the target directory writable and do you have enough space?") + +def download_dataset(url: str, path: str): + """Downloads a dataset from a given url or an alias. + + Args: + url (str): URL to the data bundle or metadata description file + path (str): Destination directory + + Raises: + DatasetException: If the dataset is not found or a network error occured + """ + from urllib.parse import urlsplit + + try: + res = urlsplit(url) + + if res.scheme in ["http", "https"]: + if res.path.endswith(".json"): + from .common import download_dataset_meta + download_dataset_meta(url, path) + return + else: + download_bundle(url, path) + return + + raise DatasetException("Unknown dataset domain: {}".format(res.scheme)) + + except ValueError: + + if url in dataset_downloader: + dataset_downloader[url](path) + return + + raise DatasetException("Illegal dataset identifier: {}".format(url)) + + +def load_dataset(path: str) -> Dataset: + """Loads a dataset from a local directory + + Args: + path (str): The path to the local dataset data + + Raises: + DatasetException: When a folder does not exist or the format is not recognized. + + Returns: + Dataset: Dataset object + """ + + from collections import OrderedDict + + names = [] + + if os.path.isfile(path): + with open(os.path.join(path), 'r') as fd: + names = fd.readlines() + + if os.path.isdir(path): + if os.path.isfile(os.path.join(path, "list.txt")): + with open(os.path.join(path, "list.txt"), 'r') as fd: + names = fd.readlines() + + if len(names) == 0: + raise DatasetException("Dataset directory does not contain a list.txt file") + + sequences = OrderedDict() + + logger.debug("Loading sequences...") + + for name in names: + root = os.path.join(path, name.strip()) + sequences[name.strip()] = load_sequence(root) + + logger.debug("Found %d sequences in dataset" % len(names)) -from .vot import VOTDataset, VOTSequence -from .got10k import GOT10kSequence, GOT10kDataset + return Dataset(sequences) -def download_dataset(identifier: str, path: str): +def load_sequence(path: str) -> Sequence: + """Loads a sequence from a given path (directory), tries to guess the format of the sequence. - split = identifier.find(":") - domain = "vot" + Args: + path (str): The path to the local sequence data - if split > 0: - domain = identifier[0:split].lower() - identifier = identifier[split+1:] + Raises: + DatasetException: If an loading error occures, unsupported format or other issues. - if domain == "vot": - from .vot import download_dataset - download_dataset(identifier, path) - elif domain == "otb": - from .otb import download_dataset - download_dataset(path, identifier == "otb50") - else: - raise DatasetException("Unknown dataset domain: {}".format(domain)) + Returns: + Sequence: Sequence object + """ -def load_dataset(path: str): + for _, loader in sequence_reader.items(): + logger.debug("Trying to load sequence with {}.{}".format(loader.__module__, loader.__name__)) + sequence = loader(path) + if sequence is not None: + return sequence - if not os.path.isdir(path): - raise DatasetException("Dataset directory does not exist") + raise DatasetException("Unable to load sequence, unknown format or unsupported sequence: {}".format(path)) - if VOTDataset.check(path): - return VOTDataset(path) - elif GOT10kDataset.check(path): - return GOT10kDataset(path) - else: - raise DatasetException("Unsupported dataset type") +import importlib +for module in [".common", ".otb", ".got10k", ".trackingnet"]: + importlib.import_module(module, package="vot.dataset") -def load_sequence(path: str): - if VOTSequence.check(path): - return VOTSequence(path) - elif GOT10kSequence.check(path): - return GOT10kSequence(path) - else: - raise DatasetException("Unsupported sequence type") \ No newline at end of file +# Legacy reader is registered last, otherwise it will cause problems +# TODO: implement explicit ordering of readers +@sequence_reader.register("legacy") +def read_legacy_sequence(path: str) -> Sequence: + """Wrapper around the legacy sequence reader.""" + from vot.dataset.common import read_sequence_legacy + return read_sequence_legacy(path) \ No newline at end of file diff --git a/vot/dataset/common.py b/vot/dataset/common.py new file mode 100644 index 0000000..439f647 --- /dev/null +++ b/vot/dataset/common.py @@ -0,0 +1,323 @@ +"""This module contains functionality for reading sequences from the storage using VOT compatible format.""" + +import os +import glob +import logging + +import six + +import cv2 + +from vot.dataset import Dataset, DatasetException, Sequence, BasedSequence, PatternFileListChannel, SequenceData +from vot.region.io import write_trajectory, read_trajectory +from vot.region import Special +from vot.utilities import Progress, localize_path, read_properties, write_properties + +logger = logging.getLogger("vot") + +def convert_int(value: str) -> int: + """Converts the given value to an integer. If the value is not a valid integer, None is returned. + + Args: + value (str): The value to convert. + + Returns: + int: The converted value or None if the value is not a valid integer. + """ + try: + if value is None: + return None + return int(value) + except ValueError: + return None + +def _load_channel(source, length=None): + """Loads a channel from the given source. + + Args: + source (str): The source to load the channel from. + length (int): The length of the channel. If not specified, the channel is loaded from a pattern file list. + + Returns: + Channel: The loaded channel. + """ + + extension = os.path.splitext(source)[1] + + if extension == '': + source = os.path.join(source, '%08d.jpg') + return PatternFileListChannel(source, end=length, check_files=length is None) + +def _read_data(metadata): + """Reads data from the given metadata. + + Args: + metadata (dict): The metadata to read data from. + + Returns: + dict: The data read from the metadata. + """ + + channels = {} + tags = {} + values = {} + length = metadata["length"] + + root = metadata["root"] + + for c in ["color", "depth", "ir"]: + channel_path = metadata.get("channels.%s" % c, None) + if not channel_path is None: + channels[c] = _load_channel(os.path.join(root, localize_path(channel_path)), length) + + # Load default channel if no explicit channel data available + if len(channels) == 0: + channels["color"] = _load_channel(os.path.join(root, "color", "%08d.jpg"), length=length) + else: + metadata["channel.default"] = next(iter(channels.keys())) + + if metadata.get("width", None) is None or metadata.get("height", None) is None: + metadata["width"], metadata["height"] = six.next(six.itervalues(channels)).size + + lengths = [len(t) for t in channels.values()] + assert all([x == lengths[0] for x in lengths]), "Sequence channels have different lengths" + length = lengths[0] + + objectsfiles = glob.glob(os.path.join(root, 'groundtruth_*.txt')) + objects = {} + if len(objectsfiles) > 0: + for objectfile in objectsfiles: + groundtruth = read_trajectory(os.path.join(objectfile)) + if len(groundtruth) < length: groundtruth += [Special(Sequence.UNKNOWN)] * (length - len(groundtruth)) + objectid = os.path.basename(objectfile)[12:-4] + objects[objectid] = groundtruth + else: + groundtruth_file = os.path.join(root, metadata.get("groundtruth", "groundtruth.txt")) + groundtruth = read_trajectory(groundtruth_file) + if len(groundtruth) < length: groundtruth += [Special(Sequence.UNKNOWN)] * (length - len(groundtruth)) + objects["object"] = groundtruth + + metadata["length"] = length + + tagfiles = glob.glob(os.path.join(root, '*.tag')) + glob.glob(os.path.join(root, '*.label')) + + for tagfile in tagfiles: + with open(tagfile, 'r') as filehandle: + tagname = os.path.splitext(os.path.basename(tagfile))[0] + tag = [line.strip() == "1" for line in filehandle.readlines()] + while not len(tag) >= length: + tag.append(False) + tags[tagname] = tag + + valuefiles = glob.glob(os.path.join(root, '*.value')) + + for valuefile in valuefiles: + with open(valuefile, 'r') as filehandle: + valuename = os.path.splitext(os.path.basename(valuefile))[0] + value = [float(line.strip()) for line in filehandle.readlines()] + while not len(value) >= length: + value.append(0.0) + values[valuename] = value + + for name, tag in tags.items(): + if not len(tag) == length: + tag_tmp = length * [False] + tag_tmp[:len(tag)] = tag + tag = tag_tmp + + for name, value in values.items(): + if not len(value) == length: + raise DatasetException("Length mismatch for value %s" % name) + + return SequenceData(channels, objects, tags, values, length) + +def _read_metadata(path): + """Reads metadata from the given path. The metadata is read from the sequence file in the given path. + + Args: + path (str): The path to read metadata from. + + Returns: + dict: The metadata read from the given path. + """ + metadata = dict(fps=30, format="default") + metadata["channel.default"] = "color" + + metadata_file = os.path.join(path, 'sequence') + metadata.update(read_properties(metadata_file)) + + metadata["height"] = convert_int(metadata.get("height", None)) + metadata["width"] = convert_int(metadata.get("width", None)) + metadata["length"] = convert_int(metadata.get("length", None)) + metadata["fps"] = convert_int(metadata.get("fps", None)) + + metadata["root"] = path + + return metadata + +from vot.dataset import sequence_reader + +@sequence_reader.register("default") +def read_sequence(path): + """Reads a sequence from the given path. + + Args: + path (str): The path to read the sequence from. + + Returns: + Sequence: The sequence read from the given path. + """ + if not os.path.isfile(os.path.join(path, "sequence")): + return None + + return BasedSequence(os.path.basename(path), _read_data, _read_metadata(path)) + +def read_sequence_legacy(path): + """Reads a sequence from the given path. + + Args: + path (str): The path to read the sequence from. + + Returns: + Sequence: The sequence read from the given path. + """ + if not os.path.isfile(os.path.join(path, "groundtruth.txt")): + return None + + metadata = dict(fps=30, format="default") + metadata["channel.default"] = "color" + metadata["channel.color"] = "%08d.jpg" + + return BasedSequence(os.path.basename(path), _read_data, metadata=metadata) + +def download_dataset_meta(url: str, path: str) -> None: + """Downloads the metadata of a dataset from a given URL and stores it in the given path. + + Args: + url (str): The URL to download the metadata from. + path (str): The path to store the metadata in. + + """ + from vot.utilities.net import download_uncompress, download_json, get_base_url, join_url, NetworkException + from vot.utilities import format_size + + meta = download_json(url) + + total_size = 0 + for sequence in meta["sequences"]: + total_size += sequence["annotations"]["uncompressed"] + for channel in sequence["channels"].values(): + total_size += channel["uncompressed"] + + logger.info('Downloading sequence dataset "%s" with %s sequences (total %s).', meta["name"], len(meta["sequences"]), format_size(total_size)) + + base_url = get_base_url(url) + "/" + + failed = [] + + with Progress("Downloading", len(meta["sequences"])) as progress: + for sequence in meta["sequences"]: + sequence_directory = os.path.join(path, sequence["name"]) + os.makedirs(sequence_directory, exist_ok=True) + + if os.path.isfile(os.path.join(sequence_directory, "sequence")): + refdata = read_properties(os.path.join(sequence_directory, "sequence")) + if "uid" in refdata and refdata["uid"] == sequence["annotations"]["uid"]: + logger.info('Sequence "%s" already downloaded.', sequence["name"]) + progress.relative(1) + continue + + data = {'name': sequence["name"], 'fps': sequence["fps"], 'format': 'default'} + + annotations_url = join_url(base_url, sequence["annotations"]["url"]) + + data["uid"] = sequence["annotations"]["uid"] + + try: + download_uncompress(annotations_url, sequence_directory) + except NetworkException as e: + logger.exception(e) + failed.append(sequence["name"]) + continue + except IOError as e: + logger.exception(e) + failed.append(sequence["name"]) + continue + + failure = False + + for cname, channel in sequence["channels"].items(): + channel_directory = os.path.join(sequence_directory, cname) + os.makedirs(channel_directory, exist_ok=True) + + channel_url = join_url(base_url, channel["url"]) + + try: + download_uncompress(channel_url, channel_directory) + except NetworkException as e: + logger.exception(e) + failed.append(sequence["name"]) + failure = False + except IOError as e: + logger.exception(e) + failed.append(sequence["name"]) + failure = False + + if "pattern" in channel: + data["channels." + cname] = cname + os.path.sep + channel["pattern"] + else: + data["channels." + cname] = cname + os.path.sep + + if failure: + continue + + write_properties(os.path.join(sequence_directory, 'sequence'), data) + progress.relative(1) + + if len(failed) > 0: + logger.error('Failed to download %d sequences.', len(failed)) + logger.error('Failed sequences: %s', ', '.join(failed)) + else: + logger.info('Successfully downloaded all sequences.') + with open(os.path.join(path, "list.txt"), "w") as fp: + for sequence in meta["sequences"]: + fp.write('{}\n'.format(sequence["name"])) + +def write_sequence(directory: str, sequence: Sequence): + """Writes a sequence to a directory. The sequence is written as a set of images in a directory structure + corresponding to the channel names. The sequence metadata is written to a file called sequence in the root + directory. + + Args: + directory (str): The directory to write the sequence to. + sequence (Sequence): The sequence to write. + """ + + channels = sequence.channels() + + metadata = dict() + metadata["channel.default"] = sequence.metadata("channel.default", "color") + metadata["fps"] = sequence.metadata("fps", "30") + + for channel in channels: + cdir = os.path.join(directory, channel) + os.makedirs(cdir, exist_ok=True) + + metadata["channels.%s" % channel] = os.path.join(channel, "%08d.jpg") + + for i in range(len(sequence)): + frame = sequence.frame(i).channel(channel) + cv2.imwrite(os.path.join(cdir, "%08d.jpg" % (i + 1)), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + + for tag in sequence.tags(): + data = "\n".join(["1" if tag in sequence.tags(i) else "0" for i in range(len(sequence))]) + with open(os.path.join(directory, "%s.tag" % tag), "w") as fp: + fp.write(data) + + for value in sequence.values(): + data = "\n".join([ str(sequence.values(i).get(value, "")) for i in range(len(sequence))]) + with open(os.path.join(directory, "%s.value" % value), "w") as fp: + fp.write(data) + + write_trajectory(os.path.join(directory, "groundtruth.txt"), [f.groundtruth() for f in sequence]) + write_properties(os.path.join(directory, "sequence"), metadata) diff --git a/vot/dataset/dummy.py b/vot/dataset/dummy.py index 28fa57d..bb77261 100644 --- a/vot/dataset/dummy.py +++ b/vot/dataset/dummy.py @@ -1,73 +1,97 @@ +""" Dummy sequences for testing purposes.""" import os import math import tempfile -from vot.dataset import VOTSequence -from vot.region import Rectangle, write_file +from vot.dataset import BasedSequence +from vot.region import Rectangle +from vot.region.io import write_trajectory from vot.utilities import write_properties from PIL import Image import numpy as np -class DummySequence(VOTSequence): +def _generate(base, length, size, objects): + """Generate a new dummy sequence. + + Args: + base (str): The base directory for the sequence. + length (int): The length of the sequence. + size (tuple): The size of the sequence. + objects (int): The number of objects in the sequence. + """ - def __init__(self, length=100, size=(640, 480)): - base = os.path.join(tempfile.gettempdir(), "vot_dummy_%d_%d_%d" % (length, size[0], size[1])) - if not os.path.isdir(base) or not os.path.isfile(os.path.join(base, "groundtruth.txt")): - DummySequence._generate(base, length, size) - super().__init__(base, None) + background_color = Image.fromarray(np.random.normal(15, 5, (size[1], size[0], 3)).astype(np.uint8)) + background_depth = Image.fromarray(np.ones((size[1], size[0]), dtype=np.uint8) * 200) + background_ir = Image.fromarray(np.zeros((size[1], size[0]), dtype=np.uint8)) - @staticmethod - def _generate(base, length, size): + template = Image.open(os.path.join(os.path.dirname(__file__), "cow.png")) - background_color = Image.fromarray(np.random.normal(15, 5, (size[1], size[0], 3)).astype(np.uint8)) - background_depth = Image.fromarray(np.ones((size[1], size[0]), dtype=np.uint8) * 200) - background_ir = Image.fromarray(np.zeros((size[1], size[0]), dtype=np.uint8)) + dir_color = os.path.join(base, "color") + dir_depth = os.path.join(base, "depth") + dir_ir = os.path.join(base, "ir") - template = Image.open(os.path.join(os.path.dirname(__file__), "cow.png")) + os.makedirs(dir_color, exist_ok=True) + os.makedirs(dir_depth, exist_ok=True) + os.makedirs(dir_ir, exist_ok=True) - dir_color = os.path.join(base, "color") - dir_depth = os.path.join(base, "depth") - dir_ir = os.path.join(base, "ir") + path_color = os.path.join(dir_color, "%08d.jpg") + path_depth = os.path.join(dir_depth, "%08d.png") + path_ir = os.path.join(dir_ir, "%08d.png") - os.makedirs(dir_color, exist_ok=True) - os.makedirs(dir_depth, exist_ok=True) - os.makedirs(dir_ir, exist_ok=True) + groundtruth = {i : [] for i in range(objects)} - path_color = os.path.join(dir_color, "%08d.jpg") - path_depth = os.path.join(dir_depth, "%08d.png") - path_ir = os.path.join(dir_ir, "%08d.png") + center_x = size[0] / 2 + center_y = size[1] / 2 - groundtruth = [] + radius = min(center_x - template.size[0], center_y - template.size[1]) - center_x = size[0] / 2 - center_y = size[1] / 2 + speed = (math.pi * 2) / length + offset = (math.pi * 2) / objects - radius = min(center_x - template.size[0], center_y - template.size[1]) + for i in range(length): + frame_color = background_color.copy() + frame_depth = background_depth.copy() + frame_ir = background_ir.copy() - speed = (math.pi * 2) / length + for o in range(objects): - for i in range(length): - frame_color = background_color.copy() - frame_depth = background_depth.copy() - frame_ir = background_ir.copy() - - x = int(center_x + math.cos(i * speed) * radius - template.size[0] / 2) - y = int(center_y + math.sin(i * speed) * radius - template.size[1] / 2) + x = int(center_x + math.cos(i * speed + offset * o) * radius - template.size[0] / 2) + y = int(center_y + math.sin(i * speed + offset * o) * radius - template.size[1] / 2) frame_color.paste(template, (x, y), template) frame_depth.paste(10, (x, y), template) frame_ir.paste(240, (x, y), template) - frame_color.save(path_color % (i + 1)) - frame_depth.save(path_depth % (i + 1)) - frame_ir.save(path_ir % (i + 1)) - - groundtruth.append(Rectangle(x, y, template.size[0], template.size[1])) - - write_file(os.path.join(base, "groundtruth.txt"), groundtruth) - metadata = {"name": "dummy", "fps" : 30, "format" : "dummy", - "channel.default": "color"} - write_properties(os.path.join(base, "sequence"), metadata) - + groundtruth[o].append(Rectangle(x, y, template.size[0], template.size[1])) + + frame_color.save(path_color % (i + 1)) + frame_depth.save(path_depth % (i + 1)) + frame_ir.save(path_ir % (i + 1)) + + if objects == 1: + write_trajectory(os.path.join(base, "groundtruth.txt"), groundtruth[0]) + else: + for i, g in groundtruth.items(): + write_trajectory(os.path.join(base, "groundtruth_%03d.txt" % i), g) + + metadata = {"name": "dummy", "fps" : 30, "format" : "dummy", + "channel.default": "color"} + write_properties(os.path.join(base, "sequence"), metadata) + +def generate_dummy(length=100, size=(640, 480), objects=1): + """Create a new dummy sequence. + + Args: + length (int, optional): The length of the sequence. Defaults to 100. + size (tuple, optional): The size of the sequence. Defaults to (640, 480). + objects (int, optional): The number of objects in the sequence. Defaults to 1. + """ + from vot.dataset import load_sequence + + base = os.path.join(tempfile.gettempdir(), "vot_dummy_%d_%d_%d_%d" % (length, size[0], size[1], objects)) + if not os.path.isdir(base) or not os.path.isfile(os.path.join(base, "sequence")): + _generate(base, length, size, objects) + + return load_sequence(base) diff --git a/vot/dataset/got10k.py b/vot/dataset/got10k.py index 48f6a7b..e354bd8 100644 --- a/vot/dataset/got10k.py +++ b/vot/dataset/got10k.py @@ -1,153 +1,128 @@ +""" GOT-10k dataset adapter module. The format of GOT-10k dataset is very similar to a subset of VOT, so there +is a lot of code duplication.""" import os import glob -import logging -from collections import OrderedDict import configparser import six -from vot.dataset import Dataset, DatasetException, BaseSequence, PatternFileListChannel -from vot.region import parse -from vot.utilities import Progress +from vot import get_logger +from vot.dataset import DatasetException, BasedSequence, \ + PatternFileListChannel, SequenceData, Sequence +from vot.region import Special +from vot.region.io import read_trajectory -logger = logging.getLogger("vot") +logger = get_logger() def load_channel(source): - + """ Load channel from the given source. + + Args: + source (str): Path to the source. If the source is a directory, it is + assumed to be a pattern file list. If the source is a file, it is + assumed to be a video file. + + Returns: + Channel: Channel object. + """ extension = os.path.splitext(source)[1] if extension == '': source = os.path.join(source, '%08d.jpg') return PatternFileListChannel(source) -class GOT10kSequence(BaseSequence): - - def __init__(self, base, name=None, dataset=None): - self._base = base - if name is None: - name = os.path.basename(base) - super().__init__(name, dataset) - - @staticmethod - def check(path: str): - return os.path.isfile(os.path.join(path, 'groundtruth.txt')) and not os.path.isfile(os.path.join(path, 'sequence')) - - def _read_metadata(self): - metadata = dict(fps=30, format="default") - - if os.path.isfile(os.path.join(self._base, 'meta_info.ini')): - config = configparser.ConfigParser() - config.read(os.path.join(self._base, 'meta_info.ini')) - metadata.update(config["METAINFO"]) - metadata["fps"] = int(metadata["anno_fps"][:-3]) - - metadata["channel.default"] = "color" - - return metadata - - def _read(self): - - channels = {} - tags = {} - values = {} - groundtruth = [] - - channels["color"] = load_channel(os.path.join(self._base, "%08d.jpg")) - self._metadata["channel.default"] = "color" - self._metadata["width"], self._metadata["height"] = six.next(six.itervalues(channels)).size - - groundtruth_file = os.path.join(self._base, self.metadata("groundtruth", "groundtruth.txt")) - - with open(groundtruth_file, 'r') as filehandle: - for region in filehandle.readlines(): - groundtruth.append(parse(region)) - - self._metadata["length"] = len(groundtruth) - tagfiles = glob.glob(os.path.join(self._base, '*.label')) +def _read_data(metadata): + """ Read data from the given metadata. + + Args: + metadata (dict): Metadata dictionary. + """ + channels = {} + tags = {} + values = {} + groundtruth = [] - for tagfile in tagfiles: - with open(tagfile, 'r') as filehandle: - tagname = os.path.splitext(os.path.basename(tagfile))[0] - tag = [line.strip() == "1" for line in filehandle.readlines()] - while not len(tag) >= len(groundtruth): - tag.append(False) - tags[tagname] = tag + base = metadata["root"] - valuefiles = glob.glob(os.path.join(self._base, '*.value')) + channels["color"] = load_channel(os.path.join(base, "%08d.jpg")) + metadata["channel.default"] = "color" + metadata["width"], metadata["height"] = six.next(six.itervalues(channels)).size - for valuefile in valuefiles: - with open(valuefile, 'r') as filehandle: - valuename = os.path.splitext(os.path.basename(valuefile))[0] - value = [float(line.strip()) for line in filehandle.readlines()] - while not len(value) >= len(groundtruth): - value.append(0.0) - values[valuename] = value + groundtruth_file = os.path.join(base, metadata.get("groundtruth", "groundtruth.txt")) + groundtruth = read_trajectory(groundtruth_file) - for name, channel in channels.items(): - if not channel.length == len(groundtruth): - raise DatasetException("Length mismatch for channel %s" % name) + if len(groundtruth) == 1 and channels["color"].length > 1: + # We are dealing with testing dataset, only first frame is available, so we pad the + # groundtruth with unknowns. Only unsupervised experiment will work, but it is ok + groundtruth.extend([Special(Sequence.UNKNOWN)] * (channels["color"].length - 1)) - for name, tag in tags.items(): - if not len(tag) == len(groundtruth): - tag_tmp = len(groundtruth) * [False] - tag_tmp[:len(tag)] = tag - tag = tag_tmp + metadata["length"] = len(groundtruth) - for name, value in values.items(): - if not len(value) == len(groundtruth): - raise DatasetException("Length mismatch for value %s" % name) + tagfiles = glob.glob(os.path.join(base, '*.label')) - return channels, groundtruth, tags, values + for tagfile in tagfiles: + with open(tagfile, 'r') as filehandle: + tagname = os.path.splitext(os.path.basename(tagfile))[0] + tag = [line.strip() == "1" for line in filehandle.readlines()] + while not len(tag) >= len(groundtruth): + tag.append(False) + tags[tagname] = tag -class GOT10kDataset(Dataset): + valuefiles = glob.glob(os.path.join(base, '*.value')) - def __init__(self, path, sequence_list="list.txt"): - super().__init__(path) + for valuefile in valuefiles: + with open(valuefile, 'r') as filehandle: + valuename = os.path.splitext(os.path.basename(valuefile))[0] + value = [float(line.strip()) for line in filehandle.readlines()] + while not len(value) >= len(groundtruth): + value.append(0.0) + values[valuename] = value - if not os.path.isabs(sequence_list): - sequence_list = os.path.join(path, sequence_list) + for name, channel in channels.items(): + if not channel.length == len(groundtruth): + raise DatasetException("Length mismatch for channel %s" % name) - if not os.path.isfile(sequence_list): - raise DatasetException("Sequence list does not exist") + for name, tag in tags.items(): + if not len(tag) == len(groundtruth): + tag_tmp = len(groundtruth) * [False] + tag_tmp[:len(tag)] = tag + tag = tag_tmp - with open(sequence_list, 'r') as handle: - names = handle.readlines() + for name, value in values.items(): + if not len(value) == len(groundtruth): + raise DatasetException("Length mismatch for value %s" % name) - self._sequences = OrderedDict() + objects = {"object" : groundtruth} - with Progress("Loading dataset", len(names)) as progress: + return SequenceData(channels, objects, tags, values, len(groundtruth)) - for name in names: - self._sequences[name.strip()] = GOT10kSequence(os.path.join(path, name.strip()), dataset=self) - progress.relative(1) +from vot.dataset import sequence_reader - @staticmethod - def check(path: str): - if not os.path.isfile(os.path.join(path, 'list.txt')): - return False +@sequence_reader.register("GOT-10k") +def read_sequence(path): + """ Read GOT-10k sequence from the given path. + + Args: + path (str): Path to the sequence. + """ - with open(os.path.join(path, 'list.txt'), 'r') as handle: - sequence = handle.readline().strip() - return GOT10kSequence.check(os.path.join(path, sequence)) + if not (os.path.isfile(os.path.join(path, 'groundtruth.txt')) and os.path.isfile(os.path.join(path, 'meta_info.ini'))): + return None - @property - def path(self): - return self._path + metadata = dict(fps=30, format="default") - @property - def length(self): - return len(self._sequences) + if os.path.isfile(os.path.join(path, 'meta_info.ini')): + config = configparser.ConfigParser() + config.read(os.path.join(path, 'meta_info.ini')) + metadata.update(config["METAINFO"]) + metadata["fps"] = int(metadata["anno_fps"][:-3]) - def __getitem__(self, key): - return self._sequences[key] + metadata["root"] = path + metadata["name"] = os.path.basename(path) + metadata["channel.default"] = "color" - def __contains__(self, key): - return key in self._sequences + return BasedSequence(metadata["name"], _read_data, metadata) - def __iter__(self): - return self._sequences.values().__iter__() - def list(self): - return list(self._sequences.keys()) diff --git a/vot/dataset/otb.py b/vot/dataset/otb.py index b4c8afa..bc956e3 100644 --- a/vot/dataset/otb.py +++ b/vot/dataset/otb.py @@ -1,15 +1,16 @@ +""" OTB dataset adapter module. OTB is one of the earliest tracking benchmarks. It is a collection of 50/100 sequences +with ground truth annotations. The dataset is available at http://cvlab.hanyang.ac.kr/tracker_benchmark/datasets.html. +""" -from collections import OrderedDict import os -import logging import six -from vot.dataset import BaseSequence, Dataset, DatasetException, PatternFileListChannel +from vot import get_logger +from vot.dataset import BasedSequence, DatasetException, PatternFileListChannel, SequenceData from vot.utilities import Progress -from vot.region import parse - -logger = logging.getLogger("vot") +from vot.region.io import parse_region +logger = get_logger() _BASE_URL = "http://cvlab.hanyang.ac.kr/tracker_benchmark/seq/" @@ -17,8 +18,8 @@ "Car1", "Car4", "CarDark", "CarScale", "ClifBar", "Couple", "Crowds", "David", "Deer", "Diving", "DragonBaby", "Dudek", "Football", "Freeman4", "Girl", "Human3", "Human4", "Human6", "Human9", "Ironman", "Jump", "Jumping", "Liquor", "Matrix", "MotorRolling", "Panda", "RedTeam", "Shaking", - "Singer2", "Skating1", "Skating2", "Skiing", "Soccer", "Surfer", "Sylvester", "Tiger2", - "Trellis", "Walking", "Walking2", "Woman" ] + "Singer2", "Skating1", "Skating2_1", "Skating2_2", "Skiing", "Soccer", "Surfer", "Sylvester", "Tiger2", + "Trellis", "Walking", "Walking2", "Woman"] _SEQUENCES = { "Basketball": {"attributes": ["IV", "OCC", "DEF", "OPR", "BC"]}, @@ -39,7 +40,7 @@ "Crowds": {"attributes": ["IV", "DEF", "BC"]}, "David": {"attributes": ["IV", "SV", "OCC", "DEF", "MB", "IPR", "OPR"], "start": 300, "stop": 770}, "Deer": {"attributes": ["MB", "FM", "IPR", "BC", "LR"]}, - "Diving": {"attributes": ["SV", "DEF", "IPR"]}, + "Diving": {"attributes": ["SV", "DEF", "IPR"], "stop": 215}, "DragonBaby": {"attributes": ["SV", "OCC", "MB", "FM", "IPR", "OPR", "OV"]}, "Dudek": {"attributes": ["SV", "OCC", "DEF", "FM", "IPR", "OPR", "OV", "BC"]}, "Football": {"attributes": ["OCC", "IPR", "OPR", "BC"]}, @@ -60,7 +61,8 @@ "Shaking": {"attributes": ["IV", "SV", "IPR", "OPR", "BC"]}, "Singer2": {"attributes": ["IV", "DEF", "IPR", "OPR", "BC"]}, "Skating1": {"attributes": ["IV", "SV", "OCC", "DEF", "OPR", "BC"]}, - "Skating2": {"objects": 2, "attributes": ["SV", "OCC", "DEF", "FM", "OPR"]}, + "Skating2_1": {"attributes": ["SV", "OCC", "DEF", "FM", "OPR"], "base": "Skating2", "groundtruth" : "groundtruth_rect.1.txt"}, + "Skating2_2": {"attributes": ["SV", "OCC", "DEF", "FM", "OPR"], "base": "Skating2", "groundtruth" : "groundtruth_rect.2.txt"}, "Skiing": {"attributes": ["IV", "SV", "DEF", "IPR", "OPR"]}, "Soccer": {"attributes": ["IV", "SV", "OCC", "MB", "FM", "IPR", "OPR", "BC"]}, "Surfer": {"attributes": ["SV", "FM", "IPR", "OPR", "LR"]}, @@ -72,10 +74,10 @@ "Woman": {"attributes": ["IV", "SV", "OCC", "DEF", "MB", "FM", "OPR"]}, # OTB-100 sequences "Bird2": {"attributes": ["OCC", "DEF", "FM", "IPR", "OPR"]}, - "BlurCar1": {"attributes": ["MB", "FM"]}, - "BlurCar3": {"attributes": ["MB", "FM"]}, - "BlurCar4": {"attributes": ["MB", "FM"]}, - "Board": {"attributes": ["SV", "MB", "FM", "OPR", "OV", "BC"]}, + "BlurCar1": {"attributes": ["MB", "FM"], "start": 247}, + "BlurCar3": {"attributes": ["MB", "FM"], "start": 3}, + "BlurCar4": {"attributes": ["MB", "FM"], "start": 18}, + "Board": {"attributes": ["SV", "MB", "FM", "OPR", "OV", "BC"], "zeros": 5}, "Bolt2": {"attributes": ["DEF", "BC"]}, "Boy": {"attributes": ["SV", "MB", "FM", "IPR", "OPR"]}, "Car2": {"attributes": ["IV", "SV", "MB", "FM", "BC"]}, @@ -103,8 +105,8 @@ "Human5": {"attributes": ["SV", "OCC", "DEF"]}, "Human7": {"attributes": ["IV", "SV", "OCC", "DEF", "MB", "FM"]}, "Human8": {"attributes": ["IV", "SV", "DEF"]}, - "Jogging1": {"attributes": ["OCC", "DEF", "OPR"], "base": "Jogging"}, - "Jogging2": {"attributes": ["OCC", "DEF", "OPR"], "base": "Jogging"}, + "Jogging1": {"attributes": ["OCC", "DEF", "OPR"], "base": "Jogging", "groundtruth" : "groundtruth_rect.1.txt"}, + "Jogging2": {"attributes": ["OCC", "DEF", "OPR"], "base": "Jogging", "groundtruth" : "groundtruth_rect.2.txt"}, "KiteSurf": {"attributes": ["IV", "OCC", "IPR", "OPR"]}, "Lemming": {"attributes": ["IV", "SV", "OCC", "FM", "OPR", "OV"]}, "Man": {"attributes": ["IV"]}, @@ -123,107 +125,96 @@ "Vase": {"attributes": ["SV", "FM", "IPR"]}, } -class OTBSequence(BaseSequence): - - def __init__(self, root, name=None, dataset=None): - - metadata = _SEQUENCES[self.name] - self._base = os.path.join(root, metadata.get("base", name)) - - super().__init__(name, dataset) - - @staticmethod - def check(path: str): - return os.path.isfile(os.path.join(path, 'groundtruth_rect.txt')) - - def _read_metadata(self): - - metadata = _SEQUENCES[self.name] - - return {"attributes": metadata["attributes"]} - - def _read(self): - - channels = {} - groundtruth = [] - - metadata = _SEQUENCES[self.name] - - channels["color"] = PatternFileListChannel(os.path.join(self._base, "img", "%04d.jpg"), - start=metadata.get("start", 1), end=metadata.get("end", None)) - - self._metadata["channel.default"] = "color" - self._metadata["width"], self._metadata["height"] = six.next(six.itervalues(channels)).size +def _load_sequence(metadata): + """Load a sequence from the OTB dataset. + + Args: + metadata (dict): Sequence metadata. + """ - groundtruth_file = os.path.join(self._base, "groundtruth_rect.txt") + channels = {} + groundtruth = [] - with open(groundtruth_file, 'r') as filehandle: - for region in filehandle.readlines(): - groundtruth.append(parse(region)) + attributes = metadata.get("attributes", {}) - self._metadata["length"] = len(groundtruth) + channels["color"] = PatternFileListChannel(os.path.join(metadata["path"], "img", "%%0%dd.jpg" % attributes.get("zeros", 4)), + start=attributes.get("start", 1), end=attributes.get("stop", None)) - if not channels["color"].length == len(groundtruth): - raise DatasetException("Length mismatch between groundtruth and images") + metadata["channel.default"] = "color" + metadata["width"], metadata["height"] = six.next(six.itervalues(channels)).size - return channels, groundtruth, {}, {} + groundtruth_file = os.path.join(metadata["path"], attributes.get("groundtruth", "groundtruth_rect.txt")) -class GOT10kDataset(Dataset): + with open(groundtruth_file, 'r') as filehandle: + for region in filehandle.readlines(): + groundtruth.append(parse_region(region.replace("\t", ",").replace(" ", ","))) - def __init__(self, path, otb50: bool = False): - super().__init__(path) + metadata["length"] = len(groundtruth) - dataset = _SEQUENCES + if not len(channels["color"]) == len(groundtruth): + raise DatasetException("Length mismatch between groundtruth and images %d != %d" % (len(channels["color"]), len(groundtruth))) + + objects = {"object" : groundtruth} - if otb50: - dataset = {k: v for k, v in dataset.items() if k in _OTB50_SUBSET} + return SequenceData(channels, objects, {}, {}, len(groundtruth)) - self._sequences = OrderedDict() +from vot.dataset import sequence_reader - with Progress("Loading dataset", len(dataset)) as progress: +@sequence_reader.register("otb") +def read_sequence(path: str): + """Reads a sequence from OTB dataset. The sequence is identified by the name of the folder and the + groundtruth_rect.txt file is expected to be present in the folder. + + Args: + path (str): Path to the sequence folder. + + Returns: + Sequence: The sequence object. + """ - for name in sorted(list(dataset.keys())): - self._sequences[name.strip()] = OTBSequence(path, name, dataset=self) - progress.relative(1) + if not os.path.isfile(os.path.join(path, 'groundtruth_rect.txt')): + return None - @staticmethod - def check(path: str): - if os.path.isfile(os.path.join(path, 'list.txt')): - return False + name = os.path.basename(path) - for sequence in _OTB50_SUBSET: - return OTBSequence.check(os.path.join(path, sequence)) + if name not in _SEQUENCES: + return None - @property - def path(self): - return self._path + metadata = {"attributes": _SEQUENCES[name], "path": path} + return BasedSequence(name.strip(), _load_sequence, metadata) - @property - def length(self): - return len(self._sequences) - - def __getitem__(self, key): - return self._sequences[key] - - def __contains__(self, key): - return key in self._sequences - - def __iter__(self): - return self._sequences.values().__iter__() - - def list(self): - return list(self._sequences.keys()) +from vot.dataset import dataset_downloader +@dataset_downloader.register("otb50") +def download_otb50(path: str): + """Downloads OTB50 dataset to the given path. + + Args: + path (str): Path to the dataset folder. + """ + dataset = _SEQUENCES + dataset = {k: v for k, v in dataset.items() if k in _OTB50_SUBSET} + _download_dataset(path, dataset) + +@dataset_downloader.register("otb100") +def download_otb100(path: str): + """Downloads OTB100 dataset to the given path. + + Args: + path (str): Path to the dataset folder. + """ + dataset = _SEQUENCES + _download_dataset(path, dataset) -def download_dataset(path: str, otb50: bool = False): +def _download_dataset(path: str, dataset: dict): + """Downloads the given dataset to the given path. + + Args: + path (str): Path to the dataset folder. + """ from vot.utilities.net import download_uncompress, join_url, NetworkException - dataset = _SEQUENCES - - if otb50: - dataset = {k: v for k, v in dataset.items() if k in _OTB50_SUBSET} - with Progress("Downloading", len(dataset)) as progress: for name, metadata in dataset.items(): name = metadata.get("base", name) @@ -237,6 +228,7 @@ def download_dataset(path: str, otb50: bool = False): progress.relative(1) - -if __name__ == "__main__": - download_dataset("") \ No newline at end of file + # Write sequence list to a list.txt file + with open(os.path.join(path, "list.txt"), 'w') as filehandle: + for name in dataset.keys(): + filehandle.write("%s\n" % name) \ No newline at end of file diff --git a/vot/dataset/proxy.py b/vot/dataset/proxy.py index 0dde363..c1bc1d1 100644 --- a/vot/dataset/proxy.py +++ b/vot/dataset/proxy.py @@ -1,46 +1,209 @@ +""" Proxy sequence classes that allow to modify the behaviour of a sequence without changing the underlying data.""" +from typing import List, Set, Tuple -from typing import List +from vot.region import Region from vot.dataset import Channel, Sequence, Frame +class ProxySequence(Sequence): + """A proxy sequence base that forwards requests to undelying source sequence. Meant as a base class. + """ + + def __init__(self, source: Sequence, name: str = None): + """Creates a proxy sequence. + + Args: + source (Sequence): Source sequence object + """ + if name is None: + name = source.name + super().__init__(name) + self._source = source + + def __len__(self): + """Returns the length of the sequence. Forwards the request to the source sequence. + + Returns: + int: Length of the sequence. + """ + return len(self) + + def frame(self, index: int) -> Frame: + """Returns a frame object for the given index. Forwards the request to the source sequence. + + Args: + index (int): Index of the frame. + + Returns: + Frame: Frame object.""" + return Frame(self, index) + + def metadata(self, name, default=None): + """Returns a metadata value for the given name. Forwards the request to the source sequence. + + Args: + name (str): Name of the metadata. + default (object, optional): Default value to return if the metadata is not found. Defaults to None. + + Returns: + object: Metadata value. + """ + return self._source.metadata(name, default) + + def channel(self, channel=None): + """Returns a channel object for the given name. Forwards the request to the source sequence. + + Args: + channel (str, optional): Name of the channel. Defaults to None. + + Returns: + Channel: Channel object. + """ + return self._source.channel(channel) + + def channels(self): + """Returns a list of channel names. Forwards the request to the source sequence. + + Returns: + list: List of channel names. + """ + return self._source.channels() + + def objects(self): + """Returns a list of object ids. Forwards the request to the source sequence. + + Returns: + list: List of object ids. + """ + return self._source.objects() + + def object(self, id, index=None): + """Returns an object for the given id. Forwards the request to the source sequence. + + Args: + id (str): Id of the object. + index (int, optional): Index of the frame. Defaults to None. + + Returns: + Object: Object object. + """ + return self._source.object(id, index) + + def groundtruth(self, index: int = None) -> List[Region]: + """Returns a list of groundtruth regions for the given index. Forwards the request to the source sequence. + + Args: + index (int, optional): Index of the frame. Defaults to None. + + Returns: + list: List of groundtruth regions. + """ + return self._source.groundtruth(index) + + def tags(self, index=None): + """Returns a list of tags for the given index. Forwards the request to the source sequence. + + Args: + index (int, optional): Index of the frame. Defaults to None. + + Returns: + list: List of tags. + """ + return self._source.tags(index) + + def values(self, index=None): + """Returns a list of values for the given index. Forwards the request to the source sequence. + + Args: + index (int, optional): Index of the frame. Defaults to None. + """ + return self._source.values(index) + + @property + def size(self) -> Tuple[int, int]: + """Returns the size of the sequence. Forwards the request to the source sequence. + + Returns: + Tuple[int, int]: Size of the sequence. + """ + return self._source.size + + class FrameMapChannel(Channel): + """A proxy channel that maps frames from a source channel in another order.""" def __init__(self, source: Channel, frame_map: List[int]): + """Creates a frame mapping proxy channel. + + Args: + source (Channel): Source channel object + frame_map (List[int]): A list of frame indices in the source channel that will form the proxy. The list is filtered + so that all indices that are out of bounds are removed. + + """ super().__init__() self._source = source self._map = frame_map - @property - def length(self): + def __len__(self): + """Returns the length of the channel.""" return len(self._map) def frame(self, index): + """Returns a frame object for the given index. + + Args: + index (int): Index of the frame. + + Returns: + Frame: Frame object. + """ return self._source.frame(self._map[index]) def filename(self, index): + """Returns the filename of the frame for the given index. Index is mapped according to the frame map before the request is forwarded to the source channel. + + Args: + index (int): Index of the frame. + + Returns: + str: Filename of the frame. + """ return self._source.filename(self._map[index]) @property def size(self): + """Returns the size of the channel. + + Returns: + Tuple[int, int]: Size of the channel. + """ return self._source.size -class FrameMapSequence(Sequence): +class FrameMapSequence(ProxySequence): + """A proxy sequence that maps frames from a source sequence in another order. + """ def __init__(self, source: Sequence, frame_map: List[int]): - super().__init__(source.name, source.dataset) - self._source = source - self._map = [i for i in frame_map if i >= 0 and i < source.length] + """Creates a frame mapping proxy sequence. - def __len__(self): - return self.length - - def frame(self, index): - return Frame(self, index) - - def metadata(self, name, default=None): - return self._source.metadata(name, default) + Args: + source (Sequence): Source sequence object + frame_map (List[int]): A list of frame indices in the source sequence that will form the proxy. The list is filtered + so that all indices that are out of bounds are removed. + """ + super().__init__(source) + self._map = [i for i in frame_map if i >= 0 and i < len(source)] def channel(self, channel=None): + """Returns a channel object for the given channel name. + + Args: + channel (str): Name of the channel. + + Returns: + Channel: Channel object. + """ sourcechannel = self._source.channel(channel) if sourcechannel is None: @@ -49,9 +212,32 @@ def channel(self, channel=None): return FrameMapChannel(sourcechannel, self._map) def channels(self): + """Returns a list of channel names. + + Returns: + list: List of channel names. + """ return self._source.channels() - def groundtruth(self, index=None): + def frame(self, index: int) -> Frame: + """Returns a frame object for the given index. Forwards the request to the source sequence with the mapped index. + + Args: + index (int): Index of the frame. + + Returns: + Frame: Frame object. + """ + return self._source.frame(self._map[index]) + + def groundtruth(self, index: int = None) -> List[Region]: + """Returns a list of groundtruth regions for the given index. Forwards the request to the source sequence with the mapped index. + + Args: + index (int, optional): Index of the frame. Defaults to None. + + Returns: + list: List of groundtruth regions.""" if index is None: groundtruth = [None] * len(self) for i, m in enumerate(self._map): @@ -60,21 +246,142 @@ def groundtruth(self, index=None): else: return self._source.groundtruth(self._map[index]) + def object(self, id, index=None): + """Returns an object for the given id. Forwards the request to the source sequence with the mapped index. + + Args: + id (str): Id of the object. + index (int, optional): Index of the frame. Defaults to None. + + Returns: + Region: Object region or a list of object regions. + """ + if index is None: + groundtruth = [None] * len(self) + for i, m in enumerate(self._map): + groundtruth[i] = self._source.object(id, m) + return groundtruth + else: + return super().object(id, self._map[index]) + def tags(self, index=None): + """Returns a list of tags for the given index. Forwards the request to the source sequence with the mapped index. + + Args: + index (int, optional): Index of the frame. Defaults to None. + + Returns: + list: List of tags. + """ if index is None: + # TODO: this is probably not correct return self._source.tags() else: return self._source.tags(self._map[index]) def values(self, index=None): + """Returns a list of values for the given index. Forwards the request to the source sequence with the mapped index. + + Args: + index (int, optional): Index of the frame. Defaults to None. + + Returns: + list: List of values. + """ if index is None: + # TODO: this is probably not correct return self._source.values() return self._source.values(self._map[index]) - @property - def size(self): - return self._source.size + def __len__(self) -> int: + """Returns the length of the sequence. The length is the same as the length of the frame map. - @property - def length(self): + Returns: + int: Length of the sequence. + """ return len(self._map) + +class ChannelFilterSequence(ProxySequence): + """A proxy sequence that only makes specific channels visible. + """ + + def __init__(self, source: Sequence, channels: Set[str]): + """Creates a channel filter proxy sequence. + + Args: + source (Sequence): Source sequence object + channels (Set[str]): A set of channel names that will be visible in the proxy sequence. The set is filtered + so that all channel names that are not in the source sequence are removed. + """ + super().__init__(source) + self._filter = [i for i in channels if i in source.channels()] + + def channel(self, channel=None): + """Returns a channel object for the given channel name. If the channel is not in the filter, None is returned. + + Args: + channel (str): Name of the channel. + + Returns: + Channel: Channel object. + """ + if channel not in self._filter: + return None + return self._source.channel(channel) + + def channels(self): + """Returns a list of channel names. + + Returns: + list: List of channel names. + """ + return set(self._filter) + +class ObjectFilterSequence(ProxySequence): + """A proxy sequence that only makes specific object visible. + """ + + def __init__(self, source: Sequence, id: str, trim: bool=False): + """Creates an object filter proxy sequence. + + Args: + source (Sequence): Source sequence object + id (str): ID of the object that will be visible in the proxy sequence. + + Keyword Args: + trim (bool): If true, the sequence will be trimmed to the first and last frame where the object is visible. + """ + super().__init__(source, "%s_%s" % (source.name, id)) + self._id = id + # TODO: implement trim + self._trim = trim + + def objects(self): + """Returns a dictionary of all objects in the sequence. + + Returns: + Dict[str, Object]: Dictionary of all objects in the sequence. + """ + objects = self._source.objects() + return {self._id: objects[id]} + + def object(self, id, index=None): + """Returns an object for the given id. + + Args: + id (str): ID of the object. + + Returns: + Region: Object object. + """ + if id != self._id: + return None + return self._source.object(id, index) + + def groundtruth(self, index: int = None) -> List[Region]: + """Returns the groundtruth for the given index. + + Args: + index (int): Index of the frame. + """ + return self._source.object(self._id, index) \ No newline at end of file diff --git a/vot/dataset/trackingnet.py b/vot/dataset/trackingnet.py new file mode 100644 index 0000000..25d597c --- /dev/null +++ b/vot/dataset/trackingnet.py @@ -0,0 +1,129 @@ +""" Dataset adapter for the TrackingNet dataset. Note that the dataset is organized a different way than the VOT datasets, +annotated frames are stored in a separate directory. The dataset also contains train and test splits. The loader +assumes that only one of the splits is used at a time and that the path is given to this part of the dataset. """ + +import os +import glob +import logging +from collections import OrderedDict + +import six + +from vot.dataset import Dataset, DatasetException, \ + BasedSequence, PatternFileListChannel, SequenceData, \ + Sequence +from vot.region import Special +from vot.region.io import read_trajectory +from vot.utilities import Progress + +logger = logging.getLogger("vot") + +def load_channel(source): + """ Load channel from the given source. + + Args: + source (str): Path to the source. If the source is a directory, it is + assumed to be a pattern file list. If the source is a file, it is + assumed to be a video file. + + Returns: + Channel: Channel object. + """ + + extension = os.path.splitext(source)[1] + + if extension == '': + source = os.path.join(source, '%d.jpg') + return PatternFileListChannel(source) + + +def _read_data(metadata): + """Internal function for reading data from the given metadata for a TrackingNet sequence. + + Args: + metadata (dict): Metadata dictionary. + + Returns: + SequenceData: Sequence data object. + """ + + channels = {} + tags = {} + values = {} + groundtruth = [] + + name = metadata["name"] + root = metadata["root"] + + channels["color"] = load_channel(os.path.join(root, 'frames', name)) + metadata["channel.default"] = "color" + metadata["width"], metadata["height"] = six.next(six.itervalues(channels)).size + + groundtruth = read_trajectory(root) + + if len(groundtruth) == 1 and channels["color"].length > 1: + # We are dealing with testing dataset, only first frame is available, so we pad the + # groundtruth with unknowns. Only unsupervised experiment will work, but it is ok + groundtruth.extend([Special(Sequence.UNKNOWN)] * (channels["color"].length - 1)) + + metadata["length"] = len(groundtruth) + + objects = {"object" : groundtruth} + + return SequenceData(channels, objects, tags, values, len(groundtruth)) + +from vot.dataset import sequence_reader + +sequence_reader.register("trackingnet") +def read_sequence(path): + """ Read sequence from the given path. Different to VOT datasets, the sequence is not + a directory, but a file. From the file name the sequence name is extracted and the + path to image frames is inferred based on standard TrackingNet directory structure. + + Args: + path (str): Path to the sequence groundtruth. + + Returns: + Sequence: Sequence object. + """ + if not os.path.isfile(path): + return None + + name, ext = os.path.splitext(os.path.basename(path)) + + if ext != '.txt': + return None + + root = os.path.dirname(os.path.dirname(os.path.dirname(path))) + + if not os.path.isfile(path) and os.path.isdir(os.path.join(root, 'frames', name)): + return None + + metadata = dict(fps=30) + metadata["channel.default"] = "color" + metadata["name"] = name + metadata["root"] = root + + return BasedSequence(name, _read_data, metadata) + +from vot.dataset import sequence_indexer + +sequence_indexer.register("trackingnet") +def list_sequences(path): + """ List sequences in the given path. The path is expected to be the root of the TrackingNet dataset split. + + Args: + path (str): Path to the dataset root. + + Returns: + list: List of sequences. + """ + for dirname in ["anno", "frames"]: + if not os.path.isdir(os.path.join(path, dirname)): + return None + + sequences = list(glob.glob(os.path.join(path, "anno", "*.txt"))) + + return sequences + + diff --git a/vot/dataset/vot.py b/vot/dataset/vot.py deleted file mode 100644 index efdc768..0000000 --- a/vot/dataset/vot.py +++ /dev/null @@ -1,278 +0,0 @@ - -import os -import glob -import logging -from collections import OrderedDict - -import six - -import cv2 - -from vot.dataset import Dataset, DatasetException, Sequence, BaseSequence, PatternFileListChannel -from vot.region import parse, write_file -from vot.utilities import Progress, localize_path, read_properties, write_properties - -logger = logging.getLogger("vot") - -def load_channel(source): - - extension = os.path.splitext(source)[1] - - if extension == '': - source = os.path.join(source, '%08d.jpg') - return PatternFileListChannel(source) - -class VOTSequence(BaseSequence): - - def __init__(self, base, name=None, dataset=None): - self._base = base - if name is None: - name = os.path.basename(base) - super().__init__(name, dataset) - - @staticmethod - def check(path): - return os.path.isfile(os.path.join(path, 'sequence')) - - def _read_metadata(self): - metadata = dict(fps=30, format="default") - metadata["channel.default"] = "color" - - metadata_file = os.path.join(self._base, 'sequence') - metadata.update(read_properties(metadata_file)) - - return metadata - - def _read(self): - - channels = {} - tags = {} - values = {} - groundtruth = [] - - for c in ["color", "depth", "ir"]: - channel_path = self.metadata("channels.%s" % c, None) - if not channel_path is None: - channels[c] = load_channel(os.path.join(self._base, localize_path(channel_path))) - - # Load default channel if no explicit channel data available - if len(channels) == 0: - channels["color"] = load_channel(os.path.join(self._base, "color", "%08d.jpg")) - else: - self._metadata["channel.default"] = next(iter(channels.keys())) - - self._metadata["width"], self._metadata["height"] = six.next(six.itervalues(channels)).size - - groundtruth_file = os.path.join(self._base, self.metadata("groundtruth", "groundtruth.txt")) - - with open(groundtruth_file, 'r') as filehandle: - for region in filehandle.readlines(): - groundtruth.append(parse(region)) - - self._metadata["length"] = len(groundtruth) - - tagfiles = glob.glob(os.path.join(self._base, '*.tag')) + glob.glob(os.path.join(self._base, '*.label')) - - for tagfile in tagfiles: - with open(tagfile, 'r') as filehandle: - tagname = os.path.splitext(os.path.basename(tagfile))[0] - tag = [line.strip() == "1" for line in filehandle.readlines()] - while not len(tag) >= len(groundtruth): - tag.append(False) - tags[tagname] = tag - - valuefiles = glob.glob(os.path.join(self._base, '*.value')) - - for valuefile in valuefiles: - with open(valuefile, 'r') as filehandle: - valuename = os.path.splitext(os.path.basename(valuefile))[0] - value = [float(line.strip()) for line in filehandle.readlines()] - while not len(value) >= len(groundtruth): - value.append(0.0) - values[valuename] = value - - for name, channel in channels.items(): - if not channel.length == len(groundtruth): - raise DatasetException("Length mismatch for channel %s (%d != %d)" % (name, channel.length, len(groundtruth))) - - for name, tag in tags.items(): - if not len(tag) == len(groundtruth): - tag_tmp = len(groundtruth) * [False] - tag_tmp[:len(tag)] = tag - tag = tag_tmp - - for name, value in values.items(): - if not len(value) == len(groundtruth): - raise DatasetException("Length mismatch for value %s" % name) - - return channels, groundtruth, tags, values - -class VOTDataset(Dataset): - - def __init__(self, path): - super().__init__(path) - - if not os.path.isfile(os.path.join(path, "list.txt")): - raise DatasetException("Dataset not available locally") - - with open(os.path.join(path, "list.txt"), 'r') as fd: - names = fd.readlines() - - self._sequences = OrderedDict() - - with Progress("Loading dataset", len(names)) as progress: - - for name in names: - self._sequences[name.strip()] = VOTSequence(os.path.join(path, name.strip()), dataset=self) - progress.relative(1) - - @staticmethod - def check(path: str): - if not os.path.isfile(os.path.join(path, 'list.txt')): - return False - - with open(os.path.join(path, 'list.txt'), 'r') as handle: - sequence = handle.readline().strip() - return VOTSequence.check(os.path.join(path, sequence)) - - @property - def path(self): - return self._path - - @property - def length(self): - return len(self._sequences) - - def __getitem__(self, key): - return self._sequences[key] - - def __contains__(self, key): - return key in self._sequences - - def __iter__(self): - return self._sequences.values().__iter__() - - def list(self): - return list(self._sequences.keys()) - - @classmethod - def download(self, url, path="."): - from vot.utilities.net import download_uncompress, download_json, get_base_url, join_url, NetworkException - - if os.path.splitext(url)[1] == '.zip': - logger.info('Downloading sequence bundle from "%s". This may take a while ...', url) - - try: - download_uncompress(url, path) - except NetworkException as e: - raise DatasetException("Unable do download dataset bundle, Please try to download the bundle manually from {} and uncompress it to {}'".format(url, path)) - except IOError as e: - raise DatasetException("Unable to extract dataset bundle, is the target directory writable and do you have enough space?") - - else: - - meta = download_json(url) - - logger.info('Downloading sequence dataset "%s" with %s sequences.', meta["name"], len(meta["sequences"])) - - base_url = get_base_url(url) + "/" - - with Progress("Downloading", len(meta["sequences"])) as progress: - for sequence in meta["sequences"]: - sequence_directory = os.path.join(path, sequence["name"]) - os.makedirs(sequence_directory, exist_ok=True) - - data = {'name': sequence["name"], 'fps': sequence["fps"], 'format': 'default'} - - annotations_url = join_url(base_url, sequence["annotations"]["url"]) - - try: - download_uncompress(annotations_url, sequence_directory) - except NetworkException as e: - raise DatasetException("Unable do download annotations bundle") - except IOError as e: - raise DatasetException("Unable to extract annotations bundle, is the target directory writable and do you have enough space?") - - for cname, channel in sequence["channels"].items(): - channel_directory = os.path.join(sequence_directory, cname) - os.makedirs(channel_directory, exist_ok=True) - - channel_url = join_url(base_url, channel["url"]) - - try: - download_uncompress(channel_url, channel_directory) - except NetworkException as e: - raise DatasetException("Unable do download channel bundle") - except IOError as e: - raise DatasetException("Unable to extract channel bundle, is the target directory writable and do you have enough space?") - - if "pattern" in channel: - data["channels." + cname] = cname + os.path.sep + channel["pattern"] - else: - data["channels." + cname] = cname + os.path.sep - - write_properties(os.path.join(sequence_directory, 'sequence'), data) - - progress.relative(1) - - with open(os.path.join(path, "list.txt"), "w") as fp: - for sequence in meta["sequences"]: - fp.write('{}\n'.format(sequence["name"])) - -def write_sequence(directory: str, sequence: Sequence): - - channels = sequence.channels() - - metadata = dict() - metadata["channel.default"] = sequence.metadata("channel.default", "color") - metadata["fps"] = sequence.metadata("fps", "30") - - for channel in channels: - cdir = os.path.join(directory, channel) - os.makedirs(cdir, exist_ok=True) - - metadata["channels.%s" % channel] = os.path.join(channel, "%08d.jpg") - - for i in range(sequence.length): - frame = sequence.frame(i).channel(channel) - cv2.imwrite(os.path.join(cdir, "%08d.jpg" % (i + 1)), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) - - for tag in sequence.tags(): - data = "\n".join(["1" if tag in sequence.tags(i) else "0" for i in range(sequence.length)]) - with open(os.path.join(directory, "%s.tag" % tag), "w") as fp: - fp.write(data) - - for value in sequence.values(): - data = "\n".join([ str(sequence.values(i).get(value, "")) for i in range(sequence.length)]) - with open(os.path.join(directory, "%s.value" % value), "w") as fp: - fp.write(data) - - write_file(os.path.join(directory, "groundtruth.txt"), [f.groundtruth() for f in sequence]) - write_properties(os.path.join(directory, "sequence"), metadata) - - -VOT_DATASETS = { - "vot2013" : "http://data.votchallenge.net/vot2013/dataset/description.json", - "vot2014" : "http://data.votchallenge.net/vot2014/dataset/description.json", - "vot2015" : "http://data.votchallenge.net/vot2015/dataset/description.json", - "vot-tir2015" : "http://www.cvl.isy.liu.se/research/datasets/ltir/version1.0/ltir_v1_0_8bit.zip", - "vot2016" : "http://data.votchallenge.net/vot2016/main/description.json", - "vot-tir2016" : "http://data.votchallenge.net/vot2016/vot-tir2016.zip", - "vot2017" : "http://data.votchallenge.net/vot2017/main/description.json", - "vot-st2018" : "http://data.votchallenge.net/vot2018/main/description.json", - "vot-lt2018" : "http://data.votchallenge.net/vot2018/longterm/description.json", - "vot-st2019" : "http://data.votchallenge.net/vot2019/main/description.json", - "vot-lt2019" : "http://data.votchallenge.net/vot2019/longterm/description.json", - "vot-rgbd2019" : "http://data.votchallenge.net/vot2019/rgbd/description.json", - "vot-rgbt2019" : "http://data.votchallenge.net/vot2019/rgbtir/meta/description.json", - "vot-st2020" : "https://data.votchallenge.net/vot2020/shortterm/description.json", - "vot-rgbt2020" : "http://data.votchallenge.net/vot2020/rgbtir/meta/description.json", - "vot-st2021": "https://data.votchallenge.net/vot2021/shortterm/description.json", - "test" : "http://data.votchallenge.net/toolkit/test.zip", - "segmentation" : "http://box.vicos.si/tracking/vot20_test_dataset.zip" -} - -def download_dataset(name, path="."): - if not name in VOT_DATASETS: - raise ValueError("Unknown dataset") - VOTDataset.download(VOT_DATASETS[name], path) \ No newline at end of file diff --git a/vot/document/__init__.py b/vot/document/__init__.py index 3ddbd52..893976f 100644 --- a/vot/document/__init__.py +++ b/vot/document/__init__.py @@ -1,15 +1,14 @@ +""" This module contains classes for generating reports and visualizations. """ -import os import typing from abc import ABC, abstractmethod import json -import math import inspect import threading -import logging -import tempfile import datetime import collections +import collections.abc +import sys from asyncio import wait from asyncio.futures import wrap_future @@ -24,18 +23,28 @@ from attributee import Attributee, Object, Nested, String, Callable, Integer, List from vot import __version__ as version -from vot import check_debug +from vot import get_logger from vot.dataset import Sequence from vot.tracker import Tracker from vot.analysis import Axes -from vot.experiment import Experiment, analysis_resolver from vot.utilities import class_fullname from vot.utilities.data import Grid class Plot(object): + """ Base class for all plots. """ def __init__(self, identifier: str, xlabel: str, ylabel: str, xlimits: typing.Tuple[float, float], ylimits: typing.Tuple[float, float], trait = None): + """ Initializes the plot. + + Args: + identifier (str): The identifier of the plot. + xlabel (str): The label of the x axis. + ylabel (str): The label of the y axis. + xlimits (tuple): The limits of the x axis. + ylimits (tuple): The limits of the y axis. + trait (str): The trait of the plot. + """ self._identifier = identifier @@ -54,25 +63,32 @@ def __init__(self, identifier: str, xlabel: str, ylabel: str, self._axes.autoscale(False, axis="y") def __call__(self, key, data): + """ Draws the data on the plot.""" self.draw(key, data) def draw(self, key, data): + """ Draws the data on the plot.""" raise NotImplementedError @property def axes(self) -> Axes: + """ Returns the axes of the plot.""" return self._axes def save(self, output, fmt): + """ Saves the plot to a file.""" self._figure.savefig(output, format=fmt, bbox_inches='tight', transparent=True) @property def identifier(self): + """ Returns the identifier of the plot.""" return self._identifier class ScatterPlot(Plot): + """ A scatter plot.""" def draw(self, key, data): + """ Draws the data on the plot. """ if data is None or len(data) != 2: return @@ -81,8 +97,10 @@ def draw(self, key, data): #handle.set_gid("report_%s_%d" % (self._identifier, style["number"])) class LinePlot(Plot): + """ A line plot.""" def draw(self, key, data): + """ Draws the data on the plot.""" if data is None or len(data) < 1: return @@ -101,8 +119,10 @@ def draw(self, key, data): # handle[0].set_gid("report_%s_%d" % (self._identifier, style["number"])) class ResultsJSONEncoder(json.JSONEncoder): + """ JSON encoder for results. """ def default(self, o): + """ Default encoder. """ if isinstance(o, Grid): return list(o) elif isinstance(o, datetime.date): @@ -113,12 +133,15 @@ def default(self, o): return super().default(o) class ResultsYAMLEncoder(yaml.Dumper): + """ YAML encoder for results.""" def represent_tuple(self, data): + """ Represents a tuple. """ return self.represent_list(list(data)) def represent_object(self, o): + """ Represents an object. """ if isinstance(o, Grid): return self.represent_list(list(o)) elif isinstance(o, datetime.date): @@ -136,6 +159,7 @@ def represent_object(self, o): ResultsYAMLEncoder.add_multi_representer(np.inexact, ResultsYAMLEncoder.represent_float) def generate_serialized(trackers: typing.List[Tracker], sequences: typing.List[Sequence], results, storage: "Storage", serializer: str): + """ Generates a serialized report of the results. """ doc = dict() doc["toolkit"] = version @@ -162,6 +186,7 @@ def generate_serialized(trackers: typing.List[Tracker], sequences: typing.List[S raise RuntimeError("Unknown serializer") def configure_axes(figure, rect=None, _=None): + """ Configures the axes of the plot. """ axes = PlotAxes(figure, rect or [0, 0, 1, 1]) @@ -170,6 +195,7 @@ def configure_axes(figure, rect=None, _=None): return axes def configure_figure(traits=None): + """ Configures the figure of the plot. """ args = {} if traits == "ar": @@ -177,63 +203,94 @@ def configure_figure(traits=None): elif traits == "eao": args["figsize"] = (7, 5) elif traits == "attributes": - args["figsize"] = (15, 5) + args["figsize"] = (10, 5) return Figure(**args) class PlotStyle(object): + """ A style for a plot.""" def line_style(self, opacity=1): + """ Returns the style for a line.""" raise NotImplementedError def point_style(self): + """ Returns the style for a point.""" raise NotImplementedError class DefaultStyle(PlotStyle): + """ The default style for a plot.""" colormap = get_cmap("tab20b") colorcount = 20 markers = ["o", "v", "<", ">", "^", "8", "*"] def __init__(self, number): + """ Initializes the style. + + Args: + number (int): The number of the style. + """ super().__init__() self._number = number def line_style(self, opacity=1): + """ Returns the style for a line. + + Args: + opacity (float): The opacity of the line. + """ color = DefaultStyle.colormap((self._number % DefaultStyle.colorcount + 1) / DefaultStyle.colorcount) if opacity < 1: color = colors.to_rgba(color, opacity) return dict(linewidth=1, c=color) def point_style(self): + """ Returns the style for a point. + + Args: + color (str): The color of the point. + opacity (float): The opacity of the line. + """ color = DefaultStyle.colormap((self._number % DefaultStyle.colorcount + 1) / DefaultStyle.colorcount) marker = DefaultStyle.markers[self._number % len(DefaultStyle.markers)] return dict(marker=marker, c=[color]) class Legend(object): + """ A legend for a plot.""" def __init__(self, style_factory=DefaultStyle): + """ Initializes the legend. + + Args: + style_factory (PlotStyleFactory): The style factory. + """ self._mapping = collections.OrderedDict() self._counter = 0 self._style_factory = style_factory def _number(self, key): + """ Returns the number for a key.""" if not key in self._mapping: self._mapping[key] = self._counter self._counter += 1 return self._mapping[key] def __getitem__(self, key) -> PlotStyle: + """ Returns the style for a key.""" number = self._number(key) return self._style_factory(number) def _style(self, number): + """ Returns the style for a number.""" raise NotImplementedError def keys(self): + """ Returns the keys of the legend.""" return self._mapping.keys() def figure(self, key): + """ Returns a figure for a key.""" style = self[key] figure = Figure(figsize=(0.1, 0.1)) # TODO: hardcoded axes = PlotAxes(figure, [0, 0, 1, 1], yticks=[], xticks=[], frame_on=False) @@ -245,6 +302,7 @@ def figure(self, key): return figure class StyleManager(Attributee): + """ A manager for styles. """ plots = Callable(default=DefaultStyle) axes = Callable(default=configure_axes) @@ -253,13 +311,16 @@ class StyleManager(Attributee): _context = threading.local() def __init__(self, **kwargs): + """ Initializes a new instance of the StyleManager class.""" super().__init__(**kwargs) self._legends = dict() def __getitem__(self, key) -> PlotStyle: + """ Gets the style for the given key.""" return self.plot_style(key) def legend(self, key) -> Legend: + """ Gets the legend for the given key.""" if inspect.isclass(key): klass = key else: @@ -271,18 +332,29 @@ def legend(self, key) -> Legend: return self._legends[klass] def plot_style(self, key) -> PlotStyle: + """ Gets the plot style for the given key.""" return self.legend(key)[key] def make_axes(self, figure, rect=None, trait=None) -> Axes: + """ Makes the axes for the given figure.""" return self.axes(figure, rect, trait) def make_figure(self, trait=None) -> typing.Tuple[Figure, Axes]: + """ Makes the figure for the given trait. + + Args: + trait: The trait for which to make the figure. + + Returns: + A tuple containing the figure and the axes. + """ figure = self.figure(trait) axes = self.make_axes(figure, trait=trait) return figure, axes def __enter__(self): + """Enters the context of the style manager.""" manager = getattr(StyleManager._context, 'style_manager', None) @@ -294,6 +366,7 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): + """Exits the context of the style manager.""" manager = getattr(StyleManager._context, 'style_manager', None) if manager == self: @@ -301,21 +374,36 @@ def __exit__(self, exc_type, exc_value, traceback): @staticmethod def default() -> "StyleManager": + """ Gets the default style manager.""" manager = getattr(StyleManager._context, 'style_manager', None) if manager is None: + get_logger().info("Creating new style manager") manager = StyleManager() StyleManager._context.style_manager = manager return manager class TrackerSorter(Attributee): + """ A sorter for trackers. """ experiment = String(default=None) analysis = String(default=None) result = Integer(val_min=0, default=0) def __call__(self, experiments, trackers, sequences): + """ Sorts the trackers. + + Arguments: + experiments (list of Experiment): The experiments. + trackers (list of Tracker): The trackers. + sequences (list of Sequence): The sequences. + + Returns: + A list of indices of the trackers in the sorted order. + """ + from vot.analysis import AnalysisError + if self.experiment is None or self.analysis is None: return range(len(trackers)) @@ -329,8 +417,12 @@ def __call__(self, experiments, trackers, sequences): if analysis is None: raise RuntimeError("Analysis not found") - future = analysis.commit(experiment, trackers, sequences) - result = future.result() + try: + + future = analysis.commit(experiment, trackers, sequences) + result = future.result() + except AnalysisError as e: + raise RuntimeError("Unable to sort trackers", e) scores = [x[self.result] for x in result] indices = [i[0] for i in sorted(enumerate(scores), reverse=True, key=lambda x: x[1])] @@ -338,12 +430,17 @@ def __call__(self, experiments, trackers, sequences): return indices class Generator(Attributee): + """ A generator for reports.""" async def generate(self, experiments, trackers, sequences): raise NotImplementedError async def process(self, analyses, experiment, trackers, sequences): - if not isinstance(analyses, collections.Iterable): + if sys.version_info >= (3, 3): + _Iterable = collections.abc.Iterable + else: + _Iterable = collections.Iterable + if not isinstance(analyses, _Iterable): analyses = [analyses] futures = [] @@ -359,6 +456,7 @@ async def process(self, analyses, experiment, trackers, sequences): return (future.result() for future in futures) class ReportConfiguration(Attributee): + """ A configuration for reports.""" style = Nested(StyleManager) sort = Nested(TrackerSorter) @@ -366,6 +464,17 @@ class ReportConfiguration(Attributee): # TODO: replace this with report generator and separate json/yaml dump def generate_document(format: str, config: ReportConfiguration, trackers: typing.List[Tracker], sequences: typing.List[Sequence], results, storage: "Storage"): + """ Generates a report document. + + Args: + format: The format of the report. + config: The configuration of the report. + trackers: The trackers to include in the report. + sequences: The sequences to include in the report. + results: The results to include in the report. + storage: The storage to use for the report. + + """ from .html import generate_html_document from .latex import generate_latex_document @@ -376,13 +485,12 @@ def generate_document(format: str, config: ReportConfiguration, trackers: typing generate_serialized(trackers, sequences, results, storage, "yaml") else: order = config.sort(results.keys(), trackers, sequences) - trackers = [trackers[i] for i in order] with config.style: if format == "html": generate_html_document(trackers, sequences, results, storage) elif format == "latex": - generate_latex_document(trackers, sequences, results, storage, False) + generate_latex_document(trackers, sequences, results, storage, False, order=order) elif format == "pdf": - generate_latex_document(trackers, sequences, results, storage, True) + generate_latex_document(trackers, sequences, results, storage, True, order=order) diff --git a/vot/document/common.py b/vot/document/common.py index 3b748a9..8a645e0 100644 --- a/vot/document/common.py +++ b/vot/document/common.py @@ -1,3 +1,4 @@ +"""Common functions for document generation.""" import os import math @@ -5,11 +6,23 @@ from vot.analysis import Measure, Point, Plot, Curve, Sorting, Axes def read_resource(name): + """Reads a resource file from the package directory. The file is read as a string.""" path = os.path.join(os.path.dirname(__file__), name) with open(path, "r") as filehandle: return filehandle.read() +def per_tracker(a): + """Returns true if the analysis is per-tracker.""" + return a.axes == Axes.TRACKERS + def extract_measures_table(trackers, results): + """Extracts a table of measures from the results. The table is a list of lists, where each list is a column. + The first column is the tracker name, the second column is the measure name, and the rest of the columns are the values for each tracker. + + Args: + trackers (list): List of trackers. + results (dict): Dictionary of results. + """ table_header = [[], [], []] table_data = dict() column_order = [] @@ -19,7 +32,7 @@ def extract_measures_table(trackers, results): descriptions = analysis.describe() # Ignore all non per-tracker results - if analysis.axes != Axes.TRACKERS: + if not per_tracker(analysis): continue for i, description in enumerate(descriptions): @@ -31,6 +44,9 @@ def extract_measures_table(trackers, results): table_header[2].append(description) column_order.append(description.direction) + if aresults is None: + continue + for tracker, values in zip(trackers, aresults): if not tracker in table_data: table_data[tracker] = list() @@ -63,9 +79,8 @@ def extract_measures_table(trackers, results): return table_header, table_data, table_order - - -def extract_plots(trackers, results): +def extract_plots(trackers, results, order=None): + """Extracts a list of plots from the results. The list is a list of tuples, where each tuple is a pair of strings and a plot.""" plots = dict() j = 0 @@ -75,7 +90,7 @@ def extract_plots(trackers, results): descriptions = analysis.describe() # Ignore all non per-tracker results - if analysis.axes != Axes.TRACKERS: + if not per_tracker(analysis): continue for i, description in enumerate(descriptions): @@ -103,17 +118,20 @@ def extract_plots(trackers, results): else: continue - for tracker, values in zip(trackers, aresults): + for t in order if order is not None else range(len(trackers)): + tracker = trackers[t] + values = aresults[t, 0] data = values[i] if not values is None else None plot(tracker, data) - experiment_plots.append((description.name, plot)) + experiment_plots.append((analysis.title + " - " + description.name, plot)) plots[experiment] = experiment_plots return plots def format_value(data): + """Formats a value for display.""" if data is None: return "N/A" if isinstance(data, str): @@ -125,6 +143,7 @@ def format_value(data): return str(data) def merge_repeats(objects): + """Merges repeated objects in a list into a list of tuples (object, count).""" if not objects: return [] diff --git a/vot/document/html.py b/vot/document/html.py index c27953c..21ef299 100644 --- a/vot/document/html.py +++ b/vot/document/html.py @@ -1,4 +1,4 @@ - +"""HTML report generation. This module is used to generate HTML reports from the results of the experiments.""" import os import io import logging @@ -20,12 +20,14 @@ ORDER_CLASSES = {1: "first", 2: "second", 3: "third"} def insert_cell(value, order): + """Inserts a cell into the data table.""" attrs = dict(data_sort_value=order, data_value=value) if order in ORDER_CLASSES: attrs["cls"] = ORDER_CLASSES[order] td(format_value(value), **attrs) def table_cell(value): + """Returns a cell for the data table.""" if isinstance(value, str): return value elif isinstance(value, Tracker): @@ -35,6 +37,7 @@ def table_cell(value): return format_value(value) def grid_table(data: Grid, rows: List[str], columns: List[str]): + """Generates a table from a grid object.""" assert data.dimensions == 2 assert data.size(0) == len(rows) and data.size(1) == len(columns) @@ -57,24 +60,36 @@ def grid_table(data: Grid, rows: List[str], columns: List[str]): return element def generate_html_document(trackers: List[Tracker], sequences: List[Sequence], results, storage: Storage): + """Generates an HTML document from the results of the experiments. + + Args: + trackers (list): List of trackers. + sequences (list): List of sequences. + results (dict): Dictionary of results. + storage (Storage): Storage object. + """ def insert_figure(figure): + """Inserts a matplotlib figure into the document.""" buffer = io.StringIO() figure.save(buffer, "SVG") raw(buffer.getvalue()) def insert_mplfigure(figure): + """Inserts a matplotlib figure into the document.""" buffer = io.StringIO() figure.savefig(buffer, format="SVG", bbox_inches='tight', pad_inches=0.01, dpi=200) raw(buffer.getvalue()) def add_style(name, linked=False): + """Adds a style to the document.""" if linked: link(rel='stylesheet', href='file://' + os.path.join(os.path.dirname(__file__), name)) else: style(read_resource(name)) def add_script(name, linked=False): + """Adds a script to the document.""" if linked: script(type='text/javascript', src='file://' + os.path.join(os.path.dirname(__file__), name)) else: diff --git a/vot/document/latex.py b/vot/document/latex.py index 83c8fb4..e768031 100644 --- a/vot/document/latex.py +++ b/vot/document/latex.py @@ -1,16 +1,15 @@ - +"""This module contains functions for generating LaTeX documents with results.""" import io -import logging import tempfile import datetime from typing import List from pylatex.base_classes import Container from pylatex.package import Package -from pylatex import Document, Section, Command, LongTable, MultiColumn, MultiRow, Figure, UnsafeCommand +from pylatex import Document, Section, Command, LongTable, MultiColumn, Figure, UnsafeCommand from pylatex.utils import NoEscape -from vot import toolkit_version +from vot import toolkit_version, get_logger from vot.tracker import Tracker from vot.dataset import Sequence from vot.workspace import Storage @@ -20,25 +19,31 @@ TRACKER_GROUP = "default" class Chunk(Container): + """A container that does not add a newline after the content.""" def dumps(self): + """Returns the LaTeX representation of the container.""" return self.dumps_content() def strip_comments(src, wrapper=True): + """Strips comments from a LaTeX source file.""" return "\n".join([line for line in src.split("\n") if not line.startswith("%") and (wrapper or not line.startswith(r"\makeat"))]) def insert_figure(figure): + """Inserts a figure into a LaTeX document.""" buffer = io.StringIO() figure.save(buffer, "PGF") return NoEscape(strip_comments(buffer.getvalue())) def insert_mplfigure(figure, wrapper=True): + """Inserts a matplotlib figure into a LaTeX document.""" buffer = io.StringIO() figure.savefig(buffer, format="PGF", bbox_inches='tight', pad_inches=0.01) return NoEscape(strip_comments(buffer.getvalue(), wrapper)) def generate_symbols(container, trackers): + """Generates a LaTeX command for each tracker. The command is named after the tracker reference and contains the tracker symbol.""" legend = StyleManager.default().legend(Tracker) @@ -50,20 +55,37 @@ def generate_symbols(container, trackers): container.append(Command("makeatother")) -def generate_latex_document(trackers: List[Tracker], sequences: List[Sequence], results, storage: Storage, build=False, multipart=True): +def generate_latex_document(trackers: List[Tracker], sequences: List[Sequence], results, storage: Storage, build=False, multipart=True, order=None) -> str: + """Generates a LaTeX document with the results. The document is returned as a string. If build is True, the document is compiled and the PDF is returned. + + Args: + + trackers (list): List of trackers. + sequences (list): List of sequences. + results (dict): Dictionary of results. + storage (Storage): Storage object. + build (bool): If True, the document is compiled and the PDF is returned. + multipart (bool): If True, the document is split into multiple files. + order (list): List of tracker indices to use for ordering. + """ order_marks = {1: "first", 2: "second", 3: "third"} def format_cell(value, order): + """Formats a cell in the data table.""" cell = format_value(value) if order in order_marks: cell = Command(order_marks[order], cell) return cell - logger = logging.getLogger("vot") + logger = get_logger() table_header, table_data, table_order = extract_measures_table(trackers, results) - plots = extract_plots(trackers, results) + + if order is not None: + ordered_trackers = [trackers[i] for i in order] + else: + ordered_trackers = trackers doc = Document(page_numbers=True) @@ -79,18 +101,19 @@ def format_cell(value, order): if multipart: container = Chunk() - generate_symbols(container, trackers) + generate_symbols(container, ordered_trackers) with storage.write("symbols.tex") as out: container.dump(out) doc.preamble.append(Command("input", "symbols.tex")) else: - generate_symbols(doc.preamble, trackers) + generate_symbols(doc.preamble, ordered_trackers) doc.preamble.append(Command('title', 'VOT report')) doc.preamble.append(Command('author', 'Toolkit version ' + toolkit_version())) doc.preamble.append(Command('date', datetime.datetime.now().isoformat())) doc.append(NoEscape(r'\maketitle')) + if len(table_header[2]) == 0: logger.debug("No measures found, skipping table") else: @@ -107,10 +130,20 @@ def format_cell(value, order): data_table.end_table_header() data_table.add_hline() - for tracker, data in table_data.items(): + for tracker in ordered_trackers: + data = table_data[tracker] data_table.add_row([UnsafeCommand("Tracker", [tracker.reference, TRACKER_GROUP])] + [format_cell(x, order[tracker] if not order is None else None) for x, order in zip(data, table_order)]) + if order is not None: + z_order = [0] * len(order) + for i, j in enumerate(order): + z_order[max(order) - i] = j + else: + z_order = list(range(len(trackers))) + + plots = extract_plots(trackers, results, z_order) + for experiment, experiment_plots in plots.items(): if len(experiment_plots) == 0: continue @@ -131,7 +164,7 @@ def format_cell(value, order): if build: temp = tempfile.mktemp() - logger.debug("Generating to tempourary output %s", temp) + logger.debug("Generating to temporary output %s", temp) doc.generate_pdf(temp, clean_tex=True) storage.copy(temp + ".pdf", "report.pdf") else: diff --git a/vot/experiment/__init__.py b/vot/experiment/__init__.py index 8f4b55d..bb1a770 100644 --- a/vot/experiment/__init__.py +++ b/vot/experiment/__init__.py @@ -1,4 +1,5 @@ - +""" Experiments are the main building blocks of the toolkit. They are used to evaluate trackers on sequences in +various ways.""" import logging import typing @@ -9,33 +10,44 @@ from attributee import Attributee, Object, Integer, Float, Nested, List -from vot.tracker import RealtimeTrackerRuntime, TrackerException +from vot import get_logger +from vot.tracker import TrackerException from vot.utilities import Progress, to_number, import_class experiment_registry = ClassRegistry("vot_experiment") - transformer_registry = ClassRegistry("vot_transformer") class RealtimeConfig(Attributee): + """Config proxy for real-time experiment. + """ grace = Integer(val_min=0, default=0) fps = Float(val_min=0, default=20) class NoiseConfig(Attributee): + """Config proxy for noise modifiers in experiments.""" # Not implemented yet placeholder = Integer(default=1) class InjectConfig(Attributee): + """Config proxy for parameter injection in experiments.""" # Not implemented yet placeholder = Integer(default=1) def transformer_resolver(typename, context, **kwargs): + """Resolve a transformer from a string. If the transformer is not registered, it is imported as a class and + instantiated with the provided arguments. + + Args: + typename (str): Name of the transformer + context (Attributee): Context of the resolver + + Returns: + Transformer: Resolved transformer + """ from vot.experiment.transformer import Transformer - if "parent" in context: - storage = context["parent"].storage.substorage("cache").substorage("transformer") - else: - storage = None + storage = context.parent.storage.substorage("cache").substorage("transformer") if typename in transformer_registry: transformer = transformer_registry.get(typename, cache=storage, **kwargs) @@ -47,6 +59,17 @@ def transformer_resolver(typename, context, **kwargs): return transformer_class(cache=storage, **kwargs) def analysis_resolver(typename, context, **kwargs): + """Resolve an analysis from a string. If the analysis is not registered, it is imported as a class and + instantiated with the provided arguments. + + Args: + typename (str): Name of the analysis + context (Attributee): Context of the resolver + + Returns: + Analysis: Resolved analysis + """ + from vot.analysis import Analysis, analysis_registry if typename in analysis_registry: @@ -57,20 +80,37 @@ def analysis_resolver(typename, context, **kwargs): assert issubclass(analysis_class, Analysis) analysis = analysis_class(**kwargs) - if "parent" in context: - assert analysis.compatible(context["parent"]) + assert analysis.compatible(context.parent) return analysis class Experiment(Attributee): + """Experiment abstract base class. Each experiment is responsible for running a tracker on a sequence and + storing results into dedicated storage. + """ - realtime = Nested(RealtimeConfig, default=None) + UNKNOWN = 0 + INITIALIZATION = 1 + + realtime = Nested(RealtimeConfig, default=None, description="Realtime modifier config") noise = Nested(NoiseConfig, default=None) inject = Nested(InjectConfig, default=None) transformers = List(Object(transformer_resolver), default=[]) analyses = List(Object(analysis_resolver), default=[]) - def __init__(self, _identifier: str, _storage: "LocalStorage", **kwargs): + def __init__(self, _identifier: str, _storage: "Storage", **kwargs): + """Initialize an experiment. + + Args: + _identifier (str): Identifier of the experiment + _storage (Storage): Storage to use for storing results + + Keyword Args: + **kwargs: Additional arguments + + Raises: + ValueError: If the identifier is not valid + """ self._identifier = _identifier self._storage = _storage super().__init__(**kwargs) @@ -78,71 +118,229 @@ def __init__(self, _identifier: str, _storage: "LocalStorage", **kwargs): @property def identifier(self) -> str: + """Identifier of the experiment. + + Returns: + str: Identifier of the experiment + """ return self._identifier + @property + def _multiobject(self) -> bool: + """Whether the experiment is multi-object or not. + + Returns: + bool: Whether the experiment is multi-object or not + """ + # TODO: at some point this may be a property for all experiments + return False + @property def storage(self) -> "Storage": + """Storage used by the experiment. + + Returns: + Storage: Storage used by the experiment + """ return self._storage - def _get_initialization(self, sequence: "Sequence", index: int): - return sequence.groundtruth(index) + def _get_initialization(self, sequence: "Sequence", index: int, id: str = None): + """Get initialization for a given sequence, index and object id. + + Args: + sequence (Sequence): Sequence to get initialization for + index (int): Index of the frame to get initialization for + id (str): Object id to get initialization for + + Returns: + Initialization: Initialization for the given sequence, index and object id + + Raises: + ValueError: If the sequence does not contain the given index or object id + """ + if not self._multiobject and id is None: + return sequence.groundtruth(index) + else: + return sequence.frame(index).object(id) + + def _get_runtime(self, tracker: "Tracker", sequence: "Sequence", multiobject=False): + """Get runtime for a given tracker and sequence. Can convert single-object runtimes to multi-object runtimes. + + Args: + tracker (Tracker): Tracker to get runtime for + sequence (Sequence): Sequence to get runtime for + multiobject (bool): Whether the runtime should be multi-object or not + + Returns: + TrackerRuntime: Runtime for the given tracker and sequence + + Raises: + TrackerException: If the tracker does not support multi-object experiments + """ + from ..tracker import SingleObjectTrackerRuntime, RealtimeTrackerRuntime, MultiObjectTrackerRuntime + + runtime = tracker.runtime() + + if multiobject: + if not runtime.multiobject: + raise TrackerException("Tracker {} does not support multi-object experiments".format(tracker.identifier)) + #runtime = MultiObjectTrackerRuntime(runtime) + else: + runtime = SingleObjectTrackerRuntime(runtime) - def _get_runtime(self, tracker: "Tracker", sequence: "Sequence"): if not self.realtime is None: grace = to_number(self.realtime.grace, min_n=0) fps = to_number(self.realtime.fps, min_n=0, conversion=float) interval = 1 / float(sequence.metadata("fps", fps)) - runtime = RealtimeTrackerRuntime(tracker.runtime(), grace, interval) - else: - runtime = tracker.runtime() + runtime = RealtimeTrackerRuntime(runtime, grace, interval) + return runtime @abstractmethod def execute(self, tracker: "Tracker", sequence: "Sequence", force: bool = False, callback: typing.Callable = None): + """Execute the experiment for a given tracker and sequence. + + Args: + tracker (Tracker): Tracker to execute + sequence (Sequence): Sequence to execute + force (bool): Whether to force execution even if the results are already present + callback (typing.Callable): Callback to call after each frame + + Returns: + Results: Results for the tracker and sequence + """ raise NotImplementedError @abstractmethod def scan(self, tracker: "Tracker", sequence: "Sequence"): + """ Scan results for a given tracker and sequence. + + Args: + tracker (Tracker): Tracker to scan results for + sequence (Sequence): Sequence to scan results for + + Returns: + Results: Results for the tracker and sequence + """ raise NotImplementedError def results(self, tracker: "Tracker", sequence: "Sequence") -> "Results": - from vot.tracker import Results - from vot.workspace import LocalStorage + """Get results for a given tracker and sequence. + + Args: + tracker (Tracker): Tracker to get results for + sequence (Sequence): Sequence to get results for + + Returns: + Results: Results for the tracker and sequence + """ if tracker.storage is not None: return tracker.storage.results(tracker, self, sequence) return self._storage.results(tracker, self, sequence) def log(self, identifier: str): + """Get a log file for the experiment. + + Args: + identifier (str): Identifier of the log + + Returns: + str: Path to the log file + """ return self._storage.substorage("logs").write("{}_{:%Y-%m-%dT%H-%M-%S.%f%z}.log".format(identifier, datetime.now())) - def transform(self, sequence: "Sequence"): - for transformer in self.transformers: - sequence = transformer(sequence) - return sequence + def transform(self, sequences): + """Transform a list of sequences using the experiment transformers. + + Args: + sequences (typing.List[Sequence]): List of sequences to transform + + Returns: + typing.List[Sequence]: List of transformed sequences. The number of sequences may be larger than the input as some transformers may split sequences. + """ + from vot.dataset import Sequence + from vot.experiment.transformer import SingleObject + if isinstance(sequences, Sequence): + sequences = [sequences] + + transformers = list(self.transformers) + + if not self._multiobject: + get_logger().debug("Adding single object transformer since experiment is not multi-object") + transformers.insert(0, SingleObject(cache=None)) + + # Process sequences one transformer at the time. The number of sequences may grow + for transformer in transformers: + transformed = [] + for sequence in sequences: + get_logger().debug("Transforming sequence {} with transformer {}.{}".format(sequence.identifier, transformer.__class__.__module__, transformer.__class__.__name__)) + transformed.extend(transformer(sequence)) + sequences = transformed + + return sequences from .multirun import UnsupervisedExperiment, SupervisedExperiment from .multistart import MultiStartExperiment def run_experiment(experiment: Experiment, tracker: "Tracker", sequences: typing.List["Sequence"], force: bool = False, persist: bool = False): + """A helper function that performs a given experiment with a given tracker on a list of sequences. + + Args: + experiment (Experiment): The experiment object + tracker (Tracker): The tracker object + sequences (typing.List[Sequence]): List of sequences. + force (bool, optional): Ignore the cached results, rerun all the experiments. Defaults to False. + persist (bool, optional): Continue runing even if exceptions were raised. Defaults to False. + + Raises: + TrackerException: If the experiment is interrupted + """ class EvaluationProgress(object): + """A helper class that wraps a progress bar and updates it based on the number of finished sequences.""" def __init__(self, description, total): + """Initialize the progress bar. + + Args: + description (str): Description of the progress bar + total (int): Total number of sequences + + Raises: + ValueError: If the total number of sequences is not positive + """ self.bar = Progress(description, total) self._finished = 0 def __call__(self, progress): + """Update the progress bar. The progress is a number between 0 and 1. + + Args: + progress (float): Progress of the current sequence + + Raises: + ValueError: If the progress is not between 0 and 1 + """ self.bar.absolute(self._finished + min(1, max(0, progress))) def push(self): + """Push the progress bar.""" self._finished = self._finished + 1 self.bar.absolute(self._finished) + def close(self): + """Close the progress bar.""" + self.bar.close() + logger = logging.getLogger("vot") + transformed = [] + for sequence in sequences: + transformed.extend(experiment.transform(sequence)) + sequences = transformed + progress = EvaluationProgress("{}/{}".format(tracker.identifier, experiment.identifier), len(sequences)) for sequence in sequences: - sequence = experiment.transform(sequence) try: experiment.execute(tracker, sequence, force=force, callback=progress) except TrackerException as te: @@ -153,6 +351,7 @@ def push(self): flog.write(te.log) logger.error("Tracker output written to file: %s", flog.name) if not persist: - raise te + raise TrackerException("Experiment interrupted", te, tracker=tracker) progress.push() + progress.close() \ No newline at end of file diff --git a/vot/experiment/helpers.py b/vot/experiment/helpers.py new file mode 100644 index 0000000..8488c79 --- /dev/null +++ b/vot/experiment/helpers.py @@ -0,0 +1,53 @@ +""" Helper classes for experiments.""" + +from vot.dataset import Sequence +from vot.region import RegionType + +def _objectstart(sequence: Sequence, id: str): + """Returns the first frame where the object appears in the sequence.""" + trajectory = sequence.object(id) + return [x is None or x.type == RegionType.SPECIAL for x in trajectory].index(False) + +class MultiObjectHelper(object): + """Helper class for multi-object sequences. It provides methods for querying active objects at a given frame.""" + + def __init__(self, sequence: Sequence): + """Initialize the helper class. + + Args: + sequence (Sequence): The sequence to be used. + """ + self._sequence = sequence + self._ids = list(sequence.objects()) + start = [_objectstart(sequence, id) for id in self._ids] + self._ids = sorted(zip(start, self._ids), key=lambda x: x[0]) + + def new(self, position: int): + """Returns a list of objects that appear at the given frame. + + Args: + position (int): The frame number. + + Returns: + [list]: A list of object ids. + """ + return [x[1] for x in self._ids if x[0] == position] + + def objects(self, position: int): + """Returns a list of objects that are active at the given frame. + + Args: + position (int): The frame number. + + Returns: + [list]: A list of object ids. + """ + return [x[1] for x in self._ids if x[0] <= position] + + def all(self): + """Returns a list of all objects in the sequence. + + Returns: + [list]: A list of object ids. + """ + return [x[1] for x in self._ids] \ No newline at end of file diff --git a/vot/experiment/multirun.py b/vot/experiment/multirun.py index e9186d4..422eaf0 100644 --- a/vot/experiment/multirun.py +++ b/vot/experiment/multirun.py @@ -1,5 +1,5 @@ -#pylint: disable=W0223 - +"""Multi-run experiments. This module contains the implementation of multi-run experiments. + Multi-run experiments are used to run a tracker multiple times on the same sequence. """ from typing import Callable from vot.dataset import Sequence @@ -8,100 +8,196 @@ from attributee import Boolean, Integer, Float, List, String from vot.experiment import Experiment, experiment_registry -from vot.tracker import Tracker, Trajectory +from vot.tracker import Tracker, Trajectory, ObjectStatus class MultiRunExperiment(Experiment): + """Base class for multi-run experiments. Multi-run experiments are used to run a tracker multiple times on the same sequence.""" repetitions = Integer(val_min=1, default=1) early_stop = Boolean(default=True) def _can_stop(self, tracker: Tracker, sequence: Sequence): + """Check whether the experiment can be stopped early. + + Args: + tracker (Tracker): The tracker to be checked. + sequence (Sequence): The sequence to be checked. + + Returns: + bool: True if the experiment can be stopped early, False otherwise. + """ if not self.early_stop: return False - trajectories = self.gather(tracker, sequence) - if len(trajectories) < 3: - return False + + for o in sequence.objects(): - for trajectory in trajectories[1:]: - if not trajectory.equals(trajectories[0]): + trajectories = self.gather(tracker, sequence, objects=[o]) + if len(trajectories) < 3: return False + + for trajectory in trajectories[1:]: + if not trajectory.equals(trajectories[0]): + return False return True def scan(self, tracker: Tracker, sequence: Sequence): + """Scan the results of the experiment for the given tracker and sequence. + + Args: + tracker (Tracker): The tracker to be scanned. + sequence (Sequence): The sequence to be scanned. + + Returns: + [tuple]: A tuple containing three elements. The first element is a boolean indicating whether the experiment is complete. The second element is a list of files that are present. The third element is the results object. + """ results = self.results(tracker, sequence) files = [] complete = True + multiobject = len(sequence.objects()) > 1 + assert self._multiobject or not multiobject - for i in range(1, self.repetitions+1): - name = "%s_%03d" % (sequence.name, i) - if Trajectory.exists(results, name): - files.extend(Trajectory.gather(results, name)) - elif self._can_stop(tracker, sequence): - break - else: - complete = False - break + for o in sequence.objects(): + prefix = sequence.name if not multiobject else "%s_%s" % (sequence.name, o) + for i in range(1, self.repetitions+1): + name = "%s_%03d" % (prefix, i) + if Trajectory.exists(results, name): + files.extend(Trajectory.gather(results, name)) + elif self._can_stop(tracker, sequence): + break + else: + complete = False + break return complete, files, results - def gather(self, tracker: Tracker, sequence: Sequence): + def gather(self, tracker: Tracker, sequence: Sequence, objects = None, pad = False): + """Gather trajectories for the given tracker and sequence. + + Args: + tracker (Tracker): The tracker to be used. + sequence (Sequence): The sequence to be used. + objects (list, optional): The list of objects to be gathered. Defaults to None. + pad (bool, optional): Whether to pad the list of trajectories with None values. Defaults to False. + + Returns: + list: The list of trajectories. + """ trajectories = list() + + multiobject = len(sequence.objects()) > 1 + + assert self._multiobject or not multiobject results = self.results(tracker, sequence) - for i in range(1, self.repetitions+1): - name = "%s_%03d" % (sequence.name, i) - if Trajectory.exists(results, name): - trajectories.append(Trajectory.read(results, name)) + + if objects is None: + objects = list(sequence.objects()) + + for o in objects: + prefix = sequence.name if not multiobject else "%s_%s" % (sequence.name, o) + for i in range(1, self.repetitions+1): + name = "%s_%03d" % (prefix, i) + if Trajectory.exists(results, name): + trajectories.append(Trajectory.read(results, name)) + elif pad: + trajectories.append(None) return trajectories @experiment_registry.register("unsupervised") class UnsupervisedExperiment(MultiRunExperiment): + """Unsupervised experiment. This experiment is used to run a tracker multiple times on the same sequence without any supervision.""" + + multiobject = Boolean(default=False) + + @property + def _multiobject(self) -> bool: + """Whether the experiment is multi-object or not. + + Returns: + bool: True if the experiment is multi-object, False otherwise. + """ + return self.multiobject def execute(self, tracker: Tracker, sequence: Sequence, force: bool = False, callback: Callable = None): + """Execute the experiment for the given tracker and sequence. + + Args: + tracker (Tracker): The tracker to be used. + sequence (Sequence): The sequence to be used. + force (bool, optional): Whether to force the execution. Defaults to False. + callback (Callable, optional): The callback to be used. Defaults to None. + """ + + from .helpers import MultiObjectHelper results = self.results(tracker, sequence) - with self._get_runtime(tracker, sequence) as runtime: + multiobject = len(sequence.objects()) > 1 + assert self._multiobject or not multiobject + + helper = MultiObjectHelper(sequence) + + def result_name(sequence, o, i): + """Get the name of the result file.""" + return "%s_%s_%03d" % (sequence.name, o, i) if multiobject else "%s_%03d" % (sequence.name, i) + + with self._get_runtime(tracker, sequence, self._multiobject) as runtime: for i in range(1, self.repetitions+1): - name = "%s_%03d" % (sequence.name, i) - if Trajectory.exists(results, name) and not force: + trajectories = {} + + for o in helper.all(): trajectories[o] = Trajectory(len(sequence)) + + if all([Trajectory.exists(results, result_name(sequence, o, i)) for o in trajectories.keys()]) and not force: continue if self._can_stop(tracker, sequence): return - trajectory = Trajectory(sequence.length) + _, elapsed = runtime.initialize(sequence.frame(0), [ObjectStatus(self._get_initialization(sequence, 0, x), {}) for x in helper.new(0)]) - _, properties, elapsed = runtime.initialize(sequence.frame(0), self._get_initialization(sequence, 0)) + for x in helper.new(0): + trajectories[x].set(0, Special(Trajectory.INITIALIZATION), {"time": elapsed}) - properties["time"] = elapsed + for frame in range(1, len(sequence)): + state, elapsed = runtime.update(sequence.frame(frame), [ObjectStatus(self._get_initialization(sequence, 0, x), {}) for x in helper.new(frame)]) - trajectory.set(0, Special(Special.INITIALIZATION), properties) + if not isinstance(state, list): + state = [state] - for frame in range(1, sequence.length): - region, properties, elapsed = runtime.update(sequence.frame(frame)) + for x, object in zip(helper.objects(frame), state): + object.properties["time"] = elapsed # TODO: what to do with time stats? + trajectories[x].set(frame, object.region, object.properties) - properties["time"] = elapsed + if callback: + callback(float(i-1) / self.repetitions + (float(frame) / (self.repetitions * len(sequence)))) - trajectory.set(frame, region, properties) - - trajectory.write(results, name) + for o, trajectory in trajectories.items(): + trajectory.write(results, result_name(sequence, o, i)) - if callback: - callback(i / self.repetitions) @experiment_registry.register("supervised") class SupervisedExperiment(MultiRunExperiment): + """Supervised experiment. This experiment is used to run a tracker multiple times on the same sequence with supervision (reinitialization in case of failure).""" + + FAILURE = 2 skip_initialize = Integer(val_min=1, default=1) skip_tags = List(String(), default=[]) failure_overlap = Float(val_min=0, val_max=1, default=0) def execute(self, tracker: Tracker, sequence: Sequence, force: bool = False, callback: Callable = None): + """Execute the experiment for the given tracker and sequence. + + Args: + tracker (Tracker): The tracker to be used. + sequence (Sequence): The sequence to be used. + force (bool, optional): Whether to force the execution. Defaults to False. + callback (Callable, optional): The callback to be used. Defaults to None. + """ results = self.results(tracker, sequence) @@ -116,37 +212,35 @@ def execute(self, tracker: Tracker, sequence: Sequence, force: bool = False, cal if self._can_stop(tracker, sequence): return - trajectory = Trajectory(sequence.length) + trajectory = Trajectory(len(sequence)) frame = 0 - while frame < sequence.length: - - _, properties, elapsed = runtime.initialize(sequence.frame(frame), self._get_initialization(sequence, frame)) + while frame < len(sequence): - properties["time"] = elapsed + _, elapsed = runtime.initialize(sequence.frame(frame), self._get_initialization(sequence, frame)) - trajectory.set(frame, Special(Special.INITIALIZATION), properties) + trajectory.set(frame, Special(Trajectory.INITIALIZATION), {"time" : elapsed}) frame = frame + 1 - while frame < sequence.length: + while frame < len(sequence): - region, properties, elapsed = runtime.update(sequence.frame(frame)) + object, elapsed = runtime.update(sequence.frame(frame)) - properties["time"] = elapsed + object.properties["time"] = elapsed - if calculate_overlap(region, sequence.groundtruth(frame), sequence.size) <= self.failure_overlap: - trajectory.set(frame, Special(Special.FAILURE), properties) + if calculate_overlap(object.region, sequence.groundtruth(frame), sequence.size) <= self.failure_overlap: + trajectory.set(frame, Special(SupervisedExperiment.FAILURE), object.properties) frame = frame + self.skip_initialize if self.skip_tags: - while frame < sequence.length: + while frame < len(sequence): if not [t for t in sequence.tags(frame) if t in self.skip_tags]: break frame = frame + 1 break else: - trajectory.set(frame, region, properties) + trajectory.set(frame, object.region, object.properties) frame = frame + 1 if callback: diff --git a/vot/experiment/multistart.py b/vot/experiment/multistart.py index 04f0e4f..dd1bccd 100644 --- a/vot/experiment/multistart.py +++ b/vot/experiment/multistart.py @@ -1,4 +1,6 @@ +""" This module implements the multistart experiment. """ + from typing import Callable from vot.dataset import Sequence @@ -11,9 +13,18 @@ from vot.tracker import Tracker, Trajectory def find_anchors(sequence: Sequence, anchor="anchor"): + """Find anchor frames in the sequence. Anchor frames are frames where the given object is visible and can be used for initialization. + + Args: + sequence (Sequence): The sequence to be scanned. + anchor (str, optional): The name of the object to be used as an anchor. Defaults to "anchor". + + Returns: + [tuple]: A tuple containing two lists of frames. The first list contains forward anchors, the second list contains backward anchors. + """ forward = [] backward = [] - for frame in range(sequence.length): + for frame in range(len(sequence)): values = sequence.values(frame) if anchor in values: if values[anchor] > 0: @@ -24,10 +35,22 @@ def find_anchors(sequence: Sequence, anchor="anchor"): @experiment_registry.register("multistart") class MultiStartExperiment(Experiment): + """The multistart experiment. The experiment works by utilizing anchor frames in the sequence. + Anchor frames are frames where the given object is visible and can be used for initialization. + The tracker is then initialized in each anchor frame and run until the end of the sequence either forward or backward. + """ anchor = String(default="anchor") - def scan(self, tracker: Tracker, sequence: Sequence): + def scan(self, tracker: Tracker, sequence: Sequence) -> tuple: + """Scan the results of the experiment for the given tracker and sequence. + + Args: + tracker (Tracker): The tracker to be scanned. + sequence (Sequence): The sequence to be scanned. + + Returns: + [tuple]: A tuple containing three elements. The first element is a boolean indicating whether the experiment is complete. The second element is a list of files that are present. The third element is the results object.""" files = [] complete = True @@ -48,7 +71,18 @@ def scan(self, tracker: Tracker, sequence: Sequence): return complete, files, results - def execute(self, tracker: Tracker, sequence: Sequence, force: bool = False, callback: Callable = None): + def execute(self, tracker: Tracker, sequence: Sequence, force: bool = False, callback: Callable = None) -> None: + """Execute the experiment for the given tracker and sequence. + + Args: + tracker (Tracker): The tracker to be executed. + sequence (Sequence): The sequence to be executed. + force (bool, optional): Force re-execution of the experiment. Defaults to False. + callback (Callable, optional): A callback function that is called after each frame. Defaults to None. + + Raises: + RuntimeError: If the sequence does not contain any anchors. + """ results = self.results(tracker, sequence) @@ -71,22 +105,20 @@ def execute(self, tracker: Tracker, sequence: Sequence, force: bool = False, cal if reverse: proxy = FrameMapSequence(sequence, list(reversed(range(0, i + 1)))) else: - proxy = FrameMapSequence(sequence, list(range(i, sequence.length))) - - trajectory = Trajectory(proxy.length) + proxy = FrameMapSequence(sequence, list(range(i, len(sequence)))) - _, properties, elapsed = runtime.initialize(proxy.frame(0), self._get_initialization(proxy, 0)) + trajectory = Trajectory(len(proxy)) - properties["time"] = elapsed + _, elapsed = runtime.initialize(proxy.frame(0), self._get_initialization(proxy, 0)) - trajectory.set(0, Special(Special.INITIALIZATION), properties) + trajectory.set(0, Special(Trajectory.INITIALIZATION), {"time": elapsed}) - for frame in range(1, proxy.length): - region, properties, elapsed = runtime.update(proxy.frame(frame)) + for frame in range(1, len(proxy)): + object, elapsed = runtime.update(proxy.frame(frame)) - properties["time"] = elapsed + object.properties["time"] = elapsed - trajectory.set(frame, region, properties) + trajectory.set(frame, object.region, object.properties) trajectory.write(results, name) diff --git a/vot/experiment/transformer.py b/vot/experiment/transformer.py index 0fe856d..d8a932e 100644 --- a/vot/experiment/transformer.py +++ b/vot/experiment/transformer.py @@ -1,36 +1,82 @@ +""" Transformer module for experiments.""" + import os from abc import abstractmethod +from typing import List from PIL import Image -from attributee import Attributee, Integer, Float +from attributee import Attributee, Integer, Float, Boolean -from vot.dataset import Sequence, VOTSequence, InMemorySequence +from vot.dataset import Sequence, InMemorySequence from vot.dataset.proxy import FrameMapSequence -from vot.dataset.vot import write_sequence +from vot.dataset.common import write_sequence, read_sequence from vot.region import RegionType from vot.utilities import arg_hash from vot.experiment import transformer_registry class Transformer(Attributee): + """Base class for transformers. Transformers are used to generate new modified sequences from existing ones.""" def __init__(self, cache: "LocalStorage", **kwargs): + """Initialize the transformer. + + Args: + cache (LocalStorage): The cache to be used for storing generated sequences. + """ super().__init__(**kwargs) self._cache = cache @abstractmethod - def __call__(self, sequence: Sequence) -> Sequence: + def __call__(self, sequence: Sequence) -> List[Sequence]: + """Generate a list of sequences from the given sequence. The generated sequences are stored in the cache if needed. + + Args: + sequence (Sequence): The sequence to be transformed. + + Returns: + [list]: A list of generated sequences. + """ raise NotImplementedError +@transformer_registry.register("singleobject") +class SingleObject(Transformer): + """Transformer that generates a sequence for each object in the given sequence.""" + + trim = Boolean(default=False, description="Trim each generated sequence to a visible subsection for the selected object") + + def __call__(self, sequence: Sequence) -> List[Sequence]: + """Generate a list of sequences from the given sequence. + + Args: + sequence (Sequence): The sequence to be transformed. + """ + from vot.dataset.proxy import ObjectFilterSequence + + if len(sequence.objects()) == 1: + return [sequence] + + return [ObjectFilterSequence(sequence, id, self.trim) for id in sequence.objects()] + @transformer_registry.register("redetection") class Redetection(Transformer): + """Transformer that test redetection of the object in the sequence. The object is shown in several frames and then moved to a different location. + + This tranformer can only be used with single-object sequences.""" length = Integer(default=100, val_min=1) initialization = Integer(default=5, val_min=1) padding = Float(default=2, val_min=0) scaling = Float(default=1, val_min=0.1, val_max=10) - def __call__(self, sequence: Sequence) -> Sequence: + def __call__(self, sequence: Sequence) -> List[Sequence]: + """Generate a list of sequences from the given sequence. + + Args: + sequence (Sequence): The sequence to be transformed. + """ + + assert len(sequence.objects()) == 1, "Redetection transformer can only be used with single-object sequences." chache_dir = self._cache.directory(self, arg_hash(sequence.name, **self.dump())) @@ -61,6 +107,6 @@ def __call__(self, sequence: Sequence) -> Sequence: write_sequence(chache_dir, generated) - source = VOTSequence(chache_dir, name=sequence.name) - mapping = [0] * self.initialization + [1] * (self.length - self.initialization) - return FrameMapSequence(source, mapping) + source = read_sequence(chache_dir) + mapping = [0] * self.initialization + [1] * (len(self) - self.initialization) + return [FrameMapSequence(source, mapping)] diff --git a/vot/region/__init__.py b/vot/region/__init__.py index 32bb161..58524ac 100644 --- a/vot/region/__init__.py +++ b/vot/region/__init__.py @@ -1,17 +1,29 @@ +""" This module contains classes for region representation and manipulation. Regions are also used to represent results + of trackers as well as groundtruth trajectories. The module also contains functions for calculating overlaps between + regions and for converting between different region types.""" + from abc import abstractmethod, ABC -from typing import Tuple from enum import Enum from vot import ToolkitException from vot.utilities.draw import DrawHandle -class RegionException(Exception): +class RegionException(ToolkitException): """General region exception""" class ConversionException(RegionException): """Region conversion exception, the conversion cannot be performed """ def __init__(self, *args, source=None): + """Constructor + + Args: + *args: Arguments for the base exception + + Keyword Arguments: + source (Region): Source region (default: {None}) + + """ super().__init__(*args) self._source = source @@ -25,29 +37,37 @@ class RegionType(Enum): class Region(ABC): """ - Base class for all region containers - - :var type: type of the region + Base class for all region containers. """ def __init__(self): + """Base constructor""" pass @property @abstractmethod def type(self): + """Return type of the region + + Returns: + RegionType -- Type of the region + """ pass @abstractmethod def copy(self): """Copy region to another object + + Returns: + Region -- Copy of the region """ @abstractmethod def convert(self, rtype: RegionType): """Convert region to another type. Note that some conversions degrade information. - Arguments: - rtype {RegionType} -- Desired type. + + Args: + rtype (RegionType): Target region type to convert to. """ @abstractmethod @@ -57,19 +77,16 @@ def is_empty(self): class Special(Region): """ - Special region + Special region, meaning of the code can change depending on the context :var code: Code value """ - UNKNOWN = 0 - INITIALIZATION = 1 - FAILURE = 2 - def __init__(self, code): """ Constructor - :param code: Special code + Args: + code (int): Code value """ super().__init__() self._code = int(code) @@ -80,12 +97,26 @@ def __str__(self): @property def type(self): + """Return type of the region""" return RegionType.SPECIAL def copy(self): + """Copy region to another object""" return Special(self._code) def convert(self, rtype: RegionType): + """Convert region to another type. Note that some conversions degrade information. + + Args: + rtype (RegionType): Target region type to convert to. + + Raises: + ConversionException: Unable to convert special region to another type + + Returns: + Region -- Converted region + """ + if rtype == RegionType.SPECIAL: return self.copy() else: @@ -93,19 +124,23 @@ def convert(self, rtype: RegionType): @property def code(self): - """Retiurns special code for this region + """Retiurns special code for this region. Returns: int -- Type code """ return self._code def draw(self, handle: DrawHandle): + """Draw region to the image using the provided handle. + + Args: + handle (DrawHandle): Draw handle + """ pass def is_empty(self): - return False + """ Check if region is empty. Special regions are always empty by definition.""" + return True -from vot.region.io import read_file, write_file -from .shapes import Rectangle, Polygon, Mask -from .io import read_file, write_file, parse from .raster import calculate_overlap, calculate_overlaps +from .shapes import Rectangle, Polygon, Mask \ No newline at end of file diff --git a/vot/region/io.py b/vot/region/io.py index 8e5371d..5847da8 100644 --- a/vot/region/io.py +++ b/vot/region/io.py @@ -1,16 +1,22 @@ +""" Utilities for reading and writing regions from and to files. """ + import math -from typing import Union, TextIO +from typing import List, Union, TextIO +import io import numpy as np -from numba import jit +import numba -from vot.region import Special +@numba.njit(cache=True) +def mask_to_rle(m, maxstride=100000000): + """ Converts a binary mask to RLE encoding. This is a Numba decorated function that is compiled just-in-time for faster execution. -@jit(nopython=True) -def mask_to_rle(m): - """ - # Input: 2-D numpy array - # Output: list of numbers (1st number = #0s, 2nd number = #1s, 3rd number = #0s, ...) + Args: + m (np.ndarray): 2-D binary mask + maxstride (int): Maximum number of consecutive 0s or 1s in the RLE encoding. If the number of consecutive 0s or 1s is larger than maxstride, it is split into multiple elements. + + Returns: + List[int]: RLE encoding of the mask """ # reshape mask to vector v = m.reshape((m.shape[0] * m.shape[1])) @@ -29,31 +35,49 @@ def mask_to_rle(m): # go over all elements and check if two consecutive are the same for i in range(1, v.size): if v[i] != v[i - 1]: - rle.append(i - last_idx) + length = i - last_idx + # if length is larger than maxstride, split it into multiple elements + while length > maxstride: + rle.append(maxstride) + rle.append(0) + length -= maxstride + # add remaining length + if length > 0: + rle.append(length) last_idx = i if v.size > 0: # handle last element of rle if last_idx < v.size - 1: # last element is the same as one element before it - add number of these last elements - rle.append(v.size - last_idx) + length = v.size - last_idx + while length > maxstride: + rle.append(maxstride) + rle.append(0) + length -= maxstride + if length > 0: + rle.append(length) else: # last element is different than one element before - add 1 rle.append(1) return rle -@jit(nopython=True) +@numba.njit(cache=True) def rle_to_mask(rle, width, height): + """ Converts RLE encoding to a binary mask. This is a Numba decorated function that is compiled just-in-time for faster execution. + + Args: + rle (List[int]): RLE encoding of the mask + width (int): Width of the mask + height (int): Height of the mask + + Returns: + np.ndarray: 2-D binary mask """ - rle: input rle mask encoding - each evenly-indexed element represents number of consecutive 0s - each oddly indexed element represents number of consecutive 1s - width and height are dimensions of the mask - output: 2-D binary mask - """ + # allocate list of zeros - v = [0] * (width * height) + v = np.zeros(width * height, dtype=np.uint8) # set id of the last different element to the beginning of the vector idx_ = 0 @@ -66,7 +90,8 @@ def rle_to_mask(rle, width, height): # reshape vector into 2-D mask # return np.reshape(np.array(v, dtype=np.uint8), (height, width)) # numba bug / not supporting np.reshape - return np.array(v, dtype=np.uint8).reshape((height, width)) + #return np.array(v, dtype=np.uint8).reshape((height, width)) + return v.reshape((height, width)) def create_mask_from_string(mask_encoding): """ @@ -87,12 +112,13 @@ def create_mask_from_string(mask_encoding): from vot.region.raster import mask_bounds def encode_mask(mask): - """ - mask: input binary mask, type: uint8 - output: full RLE encoding in the format: (x0, y0, w, h), RLE - first get minimal axis-aligned region which contains all positive pixels - extract this region from mask and calculate mask RLE within the region - output position and size of the region, dimensions of the full mask and RLE encoding + """ Encode a binary mask to a string in the following format: x0, y0, w, h, RLE. + + Args: + mask (np.ndarray): 2-D binary mask + + Returns: + str: Encoded mask """ # calculate coordinates of the top-left corner and region width and height (minimal region containing all 1s) x_min, y_min, x_max, y_max = mask_bounds(mask) @@ -113,17 +139,23 @@ def encode_mask(mask): return (tl_x, tl_y, region_w, region_h), rle -def parse(string): - """ - parse string to the appropriate region format and return region object +def parse_region(string: str) -> "Region": + """Parse input string to the appropriate region format and return Region object + + Args: + string (str): comma separated list of values + + Returns: + Region: resulting region """ + from vot import config + from vot.region import Special from vot.region.shapes import Rectangle, Polygon, Mask - if string[0] == 'm': # input is a mask - decode it m_, offset_ = create_mask_from_string(string[1:].split(',')) - return Mask(m_, offset=offset_) + return Mask(m_, offset=offset_, optimize=config.mask_optimize_read) else: # input is not a mask - check if special, rectangle or polygon tokens = [float(t) for t in string.split(',')] @@ -139,30 +171,141 @@ def parse(string): return Special(0) else: return Polygon([(x_, y_) for x_, y_ in zip(tokens[::2], tokens[1::2])]) - print('Unknown region format.') return None -def read_file(fp: Union[str, TextIO]): +def read_trajectory_binary(fp: io.RawIOBase): + """Reads a trajectory from a binary file and returns a list of regions. + + Args: + fp (io.RawIOBase): File pointer to the binary file + + Returns: + list: List of regions + """ + import struct + from cachetools import LRUCache, cached + from vot.region import Special + from vot.region.shapes import Rectangle, Polygon, Mask + + buffer = dict(data=fp.read(), offset = 0) + + @cached(cache=LRUCache(maxsize=32)) + def calcsize(format): + """Calculate size of the struct format""" + return struct.calcsize(format) + + def read(format: str): + """Read struct from the buffer and update offset""" + unpacked = struct.unpack_from(format, buffer["data"], buffer["offset"]) + buffer["offset"] += calcsize(format) + return unpacked + + _, length = read("= union[2] or union[1] >= union[3]: + # Two empty regons are considered to be identical + return float(1) + if not bounds is None: raster_bounds = (max(0, union[0]), max(0, union[1]), min(bounds[0] - 1, union[2]), min(bounds[1] - 1, union[3])) else: raster_bounds = union if raster_bounds[0] >= raster_bounds[2] or raster_bounds[1] >= raster_bounds[3]: + # Regions are not identical, but are outside rasterization bounds. return float(0) m1 = _region_raster(a, raster_bounds, at, ao) @@ -228,15 +333,20 @@ def _calculate_overlap(a: np.ndarray, b: np.ndarray, at: int, bt: int, ao: Optio from vot.region import Region, RegionException from vot.region.shapes import Shape, Rectangle, Polygon, Mask -def calculate_overlap(reg1: Shape, reg2: Shape, bounds: Optional[Tuple[int, int]] = None): - """ - Inputs: reg1 and reg2 are Region objects (Rectangle, Polygon or Mask) - bounds: size of the image, format: [width, height] - function first rasterizes both regions to 2-D binary masks and calculates overlap between them - """ +Bounds = Tuple[int, int] - if not isinstance(reg1, Shape) or not isinstance(reg2, Shape): - return float(0) +def calculate_overlap(reg1: Shape, reg2: Shape, bounds: Optional[Bounds] = None): + """ Calculate the overlap between two regions. The function first rasterizes both regions to 2-D binary masks and calculates overlap between them + + Args: + reg1: first region + reg2: second region + bounds: 2-tuple with the bounds of the image (width, height) + + Returns: + float with the overlap between the two regions. Note that overlap is one by definition if both regions are empty. + + """ if isinstance(reg1, Rectangle): data1 = np.round(reg1._data) @@ -250,6 +360,10 @@ def calculate_overlap(reg1: Shape, reg2: Shape, bounds: Optional[Tuple[int, int] data1 = reg1.mask offset1 = reg1.offset type1 = _TYPE_MASK + else: + data1 = np.zeros((1, 1)) + offset1 = (0, 0) + type1 = _TYPE_EMPTY if isinstance(reg2, Rectangle): data2 = np.round(reg2._data) @@ -263,14 +377,26 @@ def calculate_overlap(reg1: Shape, reg2: Shape, bounds: Optional[Tuple[int, int] data2 = reg2.mask offset2 = reg2.offset type2 = _TYPE_MASK + else: + data2 = np.zeros((1, 1)) + offset2 = (0, 0) + type2 = _TYPE_EMPTY return _calculate_overlap(data1, data2, type1, type2, offset1, offset2, bounds) -def calculate_overlaps(first: List[Region], second: List[Region], bounds: Optional[Tuple[int, int]]): - """ - first and second are lists containing objects of type Region - bounds is in the format [width, height] - output: list of per-frame overlaps (floats) +def calculate_overlaps(first: List[Region], second: List[Region], bounds: Optional[Bounds] = None): + """ Calculate the overlap between two lists of regions. The function first rasterizes both regions to 2-D binary masks and calculates overlap between them + + Args: + first: first list of regions + second: second list of regions + bounds: 2-tuple with the bounds of the image (width, height) + + Returns: + list of floats with the overlap between the two regions. Note that overlap is one by definition if both regions are empty. + + Raises: + RegionException: if the lists are not of the same size """ if not len(first) == len(second): raise RegionException("List not of the same size {} != {}".format(len(first), len(second))) diff --git a/vot/region/shapes.py b/vot/region/shapes.py index b49edf3..d067fad 100644 --- a/vot/region/shapes.py +++ b/vot/region/shapes.py @@ -1,4 +1,4 @@ -import sys +""" Module for region shapes. """ from copy import copy from functools import reduce @@ -13,43 +13,67 @@ from vot.utilities.draw import DrawHandle class Shape(Region, ABC): + """ Base class for all shape regions. """ @abstractmethod def draw(self, handle: DrawHandle) -> None: + """ Draw the region to the given handle. + + """ pass @abstractmethod def resize(self, factor=1) -> "Shape": + """ Resize the region by the given factor. """ pass @abstractmethod def move(self, dx=0, dy=0) -> "Shape": + """ Move the region by the given offset. + + Args: + dx (float, optional): X offset. Defaults to 0. + dy (float, optional): Y offset. Defaults to 0. + + Returns: + Shape: Moved region. + """ pass @abstractmethod def rasterize(self, bounds: Tuple[int, int, int, int]) -> np.ndarray: + """ Rasterize the region to a binary mask. + + Args: + bounds (Tuple[int, int, int, int]): Bounds of the mask. + + Returns: + np.ndarray: Binary mask. + """ pass @abstractmethod def bounds(self) -> Tuple[int, int, int, int]: + """ Get the bounding box of the region. + + Returns: + Tuple[int, int, int, int]: Bounding box (x, y, width, height). + """ + pass class Rectangle(Shape): """ - Rectangle region - - :var x: top left x coord of the rectangle region - :var float y: top left y coord of the rectangle region - :var float w: width of the rectangle region - :var float h: height of the rectangle region + Rectangle region class for representing rectangular regions. """ def __init__(self, x=0, y=0, width=0, height=0): - """ Constructor + """ Constructor for rectangle region. - :param float x: top left x coord of the rectangle region - :param float y: top left y coord of the rectangle region - :param float w: width of the rectangle region - :param float h: height of the rectangle region + Args: + x (float, optional): X coordinate of the top left corner. Defaults to 0. + y (float, optional): Y coordinate of the top left corner. Defaults to 0. + width (float, optional): Width of the rectangle. Defaults to 0. + height (float, optional): Height of the rectangle. Defaults to 0. """ super().__init__() self._data = np.array([[x], [y], [width], [height]], dtype=np.float32) @@ -60,28 +84,42 @@ def __str__(self): @property def x(self): - return self._data[0, 0] + """ X coordinate of the top left corner. """ + return float(self._data[0, 0]) @property def y(self): - return self._data[1, 0] + """ Y coordinate of the top left corner. """ + return float(self._data[1, 0]) @property def width(self): - return self._data[2, 0] + """ Width of the rectangle.""" + return float(self._data[2, 0]) @property def height(self): - return self._data[3, 0] + """ Height of the rectangle.""" + return float(self._data[3, 0]) @property def type(self): + """ Type of the region.""" return RegionType.RECTANGLE def copy(self): + """ Copy region to another object. """ return copy(self) def convert(self, rtype: RegionType): + """ Convert region to another type. Note that some conversions degrade information. + + Args: + rtype (RegionType): Desired type. + + Raises: + ConversionException: Unable to convert rectangle region to {rtype} + """ if rtype == RegionType.RECTANGLE: return self.copy() elif rtype == RegionType.POLYGON: @@ -97,46 +135,86 @@ def convert(self, rtype: RegionType): raise ConversionException("Unable to convert rectangle region to {}".format(rtype), source=self) def is_empty(self): + """ Check if the region is empty. + + Returns: + bool: True if the region is empty, False otherwise. + """ if self.width > 0 and self.height > 0: return False else: return True def draw(self, handle: DrawHandle): + """ Draw the region to the given handle. + + Args: + handle (DrawHandle): Handle to draw to. + """ polygon = [(self.x, self.y), (self.x + self.width, self.y), \ (self.x + self.width, self.y + self.height), \ (self.x, self.y + self.height)] handle.polygon(polygon) def resize(self, factor=1): + """ Resize the region by the given factor. + + Args: + factor (float, optional): Resize factor. Defaults to 1. + + Returns: + Rectangle: Resized region. + """ return Rectangle(self.x * factor, self.y * factor, self.width * factor, self.height * factor) def center(self): + """ Get the center of the region. + + Returns: + tuple: Center coordinates (x,y). + """ return (self.x + self.width / 2, self.y + self.height / 2) def move(self, dx=0, dy=0): + """ Move the region by the given offset. + + Args: + dx (float, optional): X offset. Defaults to 0. + dy (float, optional): Y offset. Defaults to 0. + + Returns: + Rectangle: Moved region. + """ return Rectangle(self.x + dx, self.y + dy, self.width, self.height) def rasterize(self, bounds: Tuple[int, int, int, int]): + """ Rasterize the region to a binary mask. + + Args: + bounds (tuple): Bounds of the mask (x1,y1,x2,y2). + """ from vot.region.raster import rasterize_rectangle return rasterize_rectangle(self._data, np.array(bounds)) def bounds(self): + """ Get the bounding box of the region. + + Returns: + tuple: Bounding box (x1,y1,x2,y2). + """ return int(round(self.x)), int(round(self.y)), int(round(self.width + self.x)), int(round(self.height + self.y)) class Polygon(Shape): """ - Polygon region - - :var list points: List of points as tuples [(x1,y1), (x2,y2),...,(xN,yN)] - :var int count: number of points + Polygon region defined by a list of points. The polygon is closed, i.e. the first and last point are connected. """ def __init__(self, points): """ Constructor - :param list points: List of points as tuples [(x1,y1), (x2,y2),...,(xN,yN)] + Args: + points (list): List of points as tuples [(x1,y1), (x2,y2),...,(xN,yN)] """ super().__init__() assert(points) @@ -150,22 +228,39 @@ def __str__(self): @property def type(self): + """ Get the region type. """ return RegionType.POLYGON @property def size(self): + """ Get the number of points. """ return self._points.shape[0] # pylint: disable=E1136 def __getitem__(self, i): + """ Get the i-th point.""" return self._points[i, 0], self._points[i, 1] def points(self): + """ Get the list of points. + + Returns: + list: List of points as tuples [(x1,y1), (x2,y2),...,(xN,yN)] + """ return [self[i] for i in range(self.size)] def copy(self): + """ Create a copy of the polygon. """ return copy(self) def convert(self, rtype: RegionType): + """ Convert the polygon to another region type. + + Args: + rtype (RegionType): Target region type. + + Returns: + Region: Converted region. + """ if rtype == RegionType.POLYGON: return self.copy() elif rtype == RegionType.RECTANGLE: @@ -183,15 +278,43 @@ def convert(self, rtype: RegionType): raise ConversionException("Unable to convert polygon region to {}".format(rtype), source=self) def draw(self, handle: DrawHandle): + """ Draw the polygon on the given handle. + + Args: + handle (DrawHandle): Handle to draw on. + """ handle.polygon([(p[0], p[1]) for p in self._points]) def resize(self, factor=1): + """ Resize the polygon by a factor. + + Args: + factor (float): Resize factor. + + Returns: + Polygon: Resized polygon. + """ return Polygon([(p[0] * factor, p[1] * factor) for p in self._points]) def move(self, dx=0, dy=0): + """ Move the polygon by a given offset. + + Args: + dx (float): X offset. + dy (float): Y offset. + + Returns: + Polygon: Moved polygon. + """ return Polygon([(p[0] + dx, p[1] + dy) for p in self._points]) def is_empty(self): + """ Check if the polygon is empty. + + Returns: + bool: True if the polygon is empty, False otherwise. + + """ top = np.min(self._points[:, 1]) bottom = np.max(self._points[:, 1]) left = np.min(self._points[:, 0]) @@ -199,10 +322,23 @@ def is_empty(self): return top == bottom or left == right def rasterize(self, bounds: Tuple[int, int, int, int]): + """ Rasterize the polygon into a binary mask. + + Args: + bounds (tuple): Bounding box of the mask as (left, top, right, bottom). + + Returns: + numpy.ndarray: Binary mask. + """ from vot.region.raster import rasterize_polygon return rasterize_polygon(self._points, bounds) def bounds(self): + """ Get the bounding box of the polygon. + + Returns: + tuple: Bounding box as (left, top, right, bottom). + """ top = np.min(self._points[:, 1]) bottom = np.max(self._points[:, 1]) left = np.min(self._points[:, 0]) @@ -213,49 +349,74 @@ def bounds(self): from vot.region.io import mask_to_rle class Mask(Shape): - """Mask region + """Mask region defined by a binary mask. The mask is defined by a binary image and an offset. """ def __init__(self, mask: np.array, offset: Tuple[int, int] = (0, 0), optimize=False): + """ Constructor + + Args: + mask (numpy.ndarray): Binary mask. + offset (tuple): Offset of the mask as (x, y). + optimize (bool): Optimize the mask by removing empty rows and columns. + + """ super().__init__() self._mask = mask.astype(np.uint8) self._mask[self._mask > 0] = 1 self._offset = offset if optimize: # optimize is used when mask without an offset is given (e.g. full-image mask) self._optimize() - + def __str__(self): + """ Create string from class """ offset_str = '%d,%d' % self.offset region_sz_str = '%d,%d' % (self.mask.shape[1], self.mask.shape[0]) rle_str = ','.join([str(el) for el in mask_to_rle(self.mask)]) return 'm%s,%s,%s' % (offset_str, region_sz_str, rle_str) def _optimize(self): + """ Optimize the mask by removing empty rows and columns. If the mask is empty, the mask is set to zero size. + Do not call this method directly, it is called from the constructor. """ bounds = mask_bounds(self.mask) - if bounds[0] is None: + if bounds[2] == 0: # mask is empty self._mask = np.zeros((0, 0), dtype=np.uint8) self._offset = (0, 0) else: - self._mask = np.copy(self.mask[bounds[1]:bounds[3], bounds[0]:bounds[2]]) + + self._mask = np.copy(self.mask[bounds[1]:bounds[3]+1, bounds[0]:bounds[2]+1]) self._offset = (bounds[0] + self.offset[0], bounds[1] + self.offset[1]) @property def mask(self): + """ Get the mask. Note that you should not modify the mask directly. Also make sure to + take into account the offset when using the mask.""" return self._mask @property def offset(self): + """ Get the offset of the mask in pixels.""" return self._offset @property def type(self): + """ Get the region type.""" return RegionType.MASK def copy(self): + """ Create a copy of the mask.""" return copy(self) def convert(self, rtype: RegionType): + """ Convert the mask to another region type. The mask is converted to a rectangle or polygon by approximating bounding box of the mask. + + Args: + rtype (RegionType): Target region type. + + Returns: + Shape: Converted region. + """ if rtype == RegionType.MASK: return self.copy() elif rtype == RegionType.RECTANGLE: @@ -275,19 +436,43 @@ def convert(self, rtype: RegionType): raise ConversionException("Unable to convert mask region to {}".format(rtype), source=self) def draw(self, handle: DrawHandle): + """ Draw the mask into an image. + + Args: + handle (DrawHandle): Handle to the image. + """ handle.mask(self._mask, self.offset) def rasterize(self, bounds: Tuple[int, int, int, int]): + """ Rasterize the mask into a binary mask. The mask is cropped to the given bounds. + + Args: + bounds (tuple): Bounding box of the mask as (left, top, right, bottom). + + Returns: + numpy.ndarray: Binary mask. The mask is a copy of the original mask. + """ from vot.region.raster import copy_mask return copy_mask(self._mask, self._offset, np.array(bounds)) def is_empty(self): - if self.mask.shape[1] > 0 and self.mask.shape[0] > 0: - return False - else: - return True + """ Check if the mask is empty. + + Returns: + bool: True if the mask is empty, False otherwise. + """ + bounds = mask_bounds(self.mask) + return bounds[2] == 0 or bounds[3] == 0 def resize(self, factor=1): + """ Resize the mask by a given factor. The mask is resized using nearest neighbor interpolation. + + Args: + factor (float): Resize factor. + + Returns: + Mask: Resized mask. + """ offset = (int(self.offset[0] * factor), int(self.offset[1] * factor)) height = max(1, int(self.mask.shape[0] * factor)) @@ -301,8 +486,22 @@ def resize(self, factor=1): return Mask(mask, offset, False) def move(self, dx=0, dy=0): + """ Move the mask by a given offset. + + Args: + dx (int): Horizontal offset. + dy (int): Vertical offset. + + Returns: + Mask: Moved mask. + """ return Mask(self._mask, (self.offset[0] + dx, self.offset[1] + dy)) def bounds(self): + """ Get the bounding box of the mask. + + Returns: + tuple: Bounding box of the mask as (left, top, right, bottom). + """ bounds = mask_bounds(self.mask) return bounds[0] + self.offset[0], bounds[1] + self.offset[1], bounds[2] + self.offset[0], bounds[3] + self.offset[1] diff --git a/vot/region/tests.py b/vot/region/tests.py index 561032a..b1a2ce9 100644 --- a/vot/region/tests.py +++ b/vot/region/tests.py @@ -1,4 +1,6 @@ +"""Tests for the region module. """ + import unittest import numpy as np @@ -6,21 +8,85 @@ from vot.region.raster import rasterize_polygon, rasterize_rectangle, copy_mask, calculate_overlap class TestRasterMethods(unittest.TestCase): + """Tests for the raster module.""" def test_rasterize_polygon(self): + """Tests if the polygon rasterization works correctly. """ points = np.array([[0, 0], [0, 100], [100, 100], [100, 0]], dtype=np.float32) np.testing.assert_array_equal(rasterize_polygon(points, (0, 0, 99, 99)), np.ones((100, 100), dtype=np.uint8)) def test_rasterize_rectangle(self): + """Tests if the rectangle rasterization works correctly.""" np.testing.assert_array_equal(rasterize_rectangle(np.array([[0], [0], [100], [100]], dtype=np.float32), (0, 0, 99, 99)), np.ones((100, 100), dtype=np.uint8)) def test_copy_mask(self): + """Tests if the mask copy works correctly.""" mask = np.ones((100, 100), dtype=np.uint8) np.testing.assert_array_equal(copy_mask(mask, (0, 0), (0, 0, 99, 99)), np.ones((100, 100), dtype=np.uint8)) def test_calculate_overlap(self): + """Tests if the overlap calculation works correctly.""" from vot.region import Rectangle r1 = Rectangle(0, 0, 100, 100) self.assertEqual(calculate_overlap(r1, r1), 1) + r1 = Rectangle(0, 0, 0, 0) + self.assertEqual(calculate_overlap(r1, r1), 1) + + def test_empty_mask(self): + """Tests if the empty mask is correctly detected.""" + from vot.region import Mask + + mask = Mask(np.zeros((100, 100), dtype=np.uint8)) + self.assertTrue(mask.is_empty()) + + mask = Mask(np.ones((100, 100), dtype=np.uint8)) + self.assertFalse(mask.is_empty()) + + def test_binary_format(self): + """ Tests if the binary format of a region matched the plain-text one.""" + import io + + from vot.region import Rectangle, Polygon, Mask + from vot.region.io import read_trajectory, write_trajectory + from vot.region.raster import calculate_overlaps + + trajectory = [ + Rectangle(0, 0, 100, 100), + Rectangle(0, 10, 100, 100), + Rectangle(0, 0, 200, 100), + Polygon([[0, 0], [0, 100], [100, 100], [100, 0]]), + Mask(np.ones((100, 100), dtype=np.uint8)), + Mask(np.zeros((100, 100), dtype=np.uint8)), + ] + + binf = io.BytesIO() + txtf = io.StringIO() + + write_trajectory(binf, trajectory) + write_trajectory(txtf, trajectory) + + binf.seek(0) + txtf.seek(0) + + bint = read_trajectory(binf) + txtt = read_trajectory(txtf) + + o1 = calculate_overlaps(bint, txtt, None) + o2 = calculate_overlaps(bint, trajectory, None) + + self.assertTrue(np.all(np.array(o1) == 1)) + self.assertTrue(np.all(np.array(o2) == 1)) + + def test_rle(self): + """ Test if RLE encoding works for limited stride representation.""" + from vot.region.io import rle_to_mask, mask_to_rle + rle = [0, 2, 122103, 9, 260, 19, 256, 21, 256, 22, 254, 24, 252, 26, 251, 27, 250, 28, 249, 28, 250, 28, 249, 28, 249, 29, 249, 30, 247, 33, 245, 33, 244, 34, 244, 37, 241, 39, 239, 41, 237, 41, 236, 43, 235, 45, 234, 47, 233, 47, 231, 48, 230, 48, 230, 11, 7, 29, 231, 9, 9, 29, 230, 8, 11, 28, 230, 7, 12, 28, 230, 7, 13, 27, 231, 5, 14, 27, 233, 2, 16, 26, 253, 23, 255, 22, 256, 20, 258, 19, 259, 17, 3] + rle = np.array(rle) + m1 = rle_to_mask(np.array(rle, dtype=np.int32), 277, 478) + + r2 = mask_to_rle(m1, maxstride=255) + m2 = rle_to_mask(np.array(r2, dtype=np.int32), 277, 478) + + np.testing.assert_array_equal(m1, m2) \ No newline at end of file diff --git a/vot/stack/__init__.py b/vot/stack/__init__.py index 44da321..8aead08 100644 --- a/vot/stack/__init__.py +++ b/vot/stack/__init__.py @@ -1,29 +1,34 @@ +"""Stacks are collections of experiments that are grouped together for convenience. Stacks are used to organize experiments and to run them in +batch mode. +""" import os -import json -import glob -import collections -from typing import List +from typing import List, Mapping import yaml from attributee import Attributee, String, Boolean, Map, Object from vot.experiment import Experiment, experiment_registry -from vot.experiment.transformer import Transformer from vot.utilities import import_class -from vot.analysis import Analysis def experiment_resolver(typename, context, **kwargs): + """Resolves experiment objects from stack definitions. This function is used by the stack module to resolve experiment objects from stack + definitions. It is not intended to be used directly. - if "key" in context: - identifier = context["key"] - else: - identifier = None + Args: + typename (str): Name of the experiment class + context (Attributee): Context of the experiment + kwargs (dict): Additional arguments + + Returns: + Experiment: Experiment object + """ + identifier = context.key storage = None - if "parent" in context: - if getattr(context["parent"], "workspace", None) is not None: - storage = context["parent"].workspace.storage + + if getattr(context.parent, "workspace", None) is not None: + storage = context.parent.workspace.storage if typename in experiment_registry: experiment = experiment_registry.get(typename, _identifier=identifier, _storage=storage, **kwargs) @@ -35,14 +40,22 @@ def experiment_resolver(typename, context, **kwargs): return experiment_class(_identifier=identifier, _storage=storage, **kwargs) class Stack(Attributee): + """Stack class represents a collection of experiments. Stacks are used to organize experiments and to run them in batch mode. + """ - title = String() - dataset = String(default="") + title = String(default="Stack") + dataset = String(default=None) url = String(default="") deprecated = Boolean(default=False) experiments = Map(Object(experiment_resolver)) def __init__(self, name: str, workspace: "Workspace", **kwargs): + """Creates a new stack object. + + Args: + name (str): Name of the stack + workspace (Workspace): Workspace object + """ self._workspace = workspace self._name = name @@ -50,22 +63,45 @@ def __init__(self, name: str, workspace: "Workspace", **kwargs): @property def workspace(self): + """Returns the workspace object for the stack.""" return self._workspace @property def name(self): + """Returns the name of the stack.""" return self._name def __iter__(self): + """Iterates over experiments in the stack.""" return iter(self.experiments.values()) def __len__(self): + """Returns the number of experiments in the stack.""" return len(self.experiments) def __getitem__(self, identifier): + """Returns the experiment with the given identifier. + + Args: + identifier (str): Identifier of the experiment + + Returns: + Experiment: Experiment object + + """ return self.experiments[identifier] -def resolve_stack(name, *directories): +def resolve_stack(name: str, *directories: List[str]) -> str: + """Searches for stack file in the given directories and returns its absolute path. If given an absolute path as input + it simply returns it. + + Args: + name (str): Name of the stack + directories (List[str]): Directories that will be used + + Returns: + str: Absolute path to stack file + """ if os.path.isabs(name): return name if os.path.isfile(name) else None for directory in directories: @@ -77,11 +113,24 @@ def resolve_stack(name, *directories): return full return None -def list_integrated_stacks(): +def list_integrated_stacks() -> Mapping[str, str]: + """List stacks that come with the toolkit + + Returns: + Map[str, str]: A mapping of stack ids and stack title pairs + """ + + from pathlib import Path + stacks = {} - for stack_file in glob.glob(os.path.join(os.path.dirname(__file__), "*.yaml")): - with open(stack_file, 'r') as fp: + root = Path(os.path.join(os.path.dirname(__file__))) + + for stack_path in root.rglob("*.yaml"): + with open(stack_path, 'r') as fp: stack_metadata = yaml.load(fp, Loader=yaml.BaseLoader) - stacks[os.path.splitext(os.path.basename(stack_file))[0]] = stack_metadata.get("title", "") + if stack_metadata is None: + continue + key = str(stack_path.relative_to(root).with_name(os.path.splitext(stack_path.name)[0])) + stacks[key] = stack_metadata.get("title", "") return stacks \ No newline at end of file diff --git a/vot/stack/otb100.yaml b/vot/stack/otb100.yaml new file mode 100644 index 0000000..ff604f5 --- /dev/null +++ b/vot/stack/otb100.yaml @@ -0,0 +1,10 @@ +title: OTB100 dataset experiment stack +url: http://cvlab.hanyang.ac.kr/tracker_benchmark/index.html +dataset: otb100 +experiments: + baseline: + type: unsupervised + analyses: + - type: average_accuracy + name: accuracy + burnin: 1 \ No newline at end of file diff --git a/vot/stack/otb50.yaml b/vot/stack/otb50.yaml new file mode 100644 index 0000000..1794bdd --- /dev/null +++ b/vot/stack/otb50.yaml @@ -0,0 +1,10 @@ +title: OTB50 dataset experiment stack +url: http://cvlab.hanyang.ac.kr/tracker_benchmark/index.html +dataset: otb50 +experiments: + baseline: + type: unsupervised + analyses: + - type: average_accuracy + name: accuracy + burnin: 1 \ No newline at end of file diff --git a/vot/stack/tests.py b/vot/stack/tests.py index 083f5c5..0b5c68f 100644 --- a/vot/stack/tests.py +++ b/vot/stack/tests.py @@ -1,19 +1,27 @@ -import os +"""Tests for the experiment stack module.""" + import unittest import yaml -from vot.workspace import Workspace, NullStorage +from vot.workspace import NullStorage from vot.stack import Stack, list_integrated_stacks, resolve_stack class NoWorkspace: + """Empty workspace, does not save anything + """ @property def storage(self): + """Returns the storage object for the workspace. """ return NullStorage() class TestStacks(unittest.TestCase): + """Tests for the experiment stack utilities + """ def test_stacks(self): + """Test loading integrated stacks + """ stacks = list_integrated_stacks() for stack_name in stacks: diff --git a/vot/stack/tests/basic.yaml b/vot/stack/tests/basic.yaml new file mode 100644 index 0000000..368bce4 --- /dev/null +++ b/vot/stack/tests/basic.yaml @@ -0,0 +1,7 @@ +title: VOT Basic Test Stack +url: http://www.votchallenge.net/ +dataset: https://data.votchallenge.net/toolkit/test.zip +experiments: + baseline: + type: unsupervised + repetitions: 1 \ No newline at end of file diff --git a/vot/stack/tests/multiobject.yaml b/vot/stack/tests/multiobject.yaml new file mode 100644 index 0000000..6b6d6ba --- /dev/null +++ b/vot/stack/tests/multiobject.yaml @@ -0,0 +1,21 @@ +title: VOTS2023 Test Stack +dataset: https://data.votchallenge.net/vots2023/test/description.json +experiments: + baseline: + type: unsupervised + repetitions: 1 + multiobject: True + analyses: + - type: average_accuracy + name: Quality + burnin: 0 + ignore_unknown: False + weighted: False + - type: average_success_plot + name: Quality plot + burnin: 0 + ignore_unknown: False + - type: longterm_ar + name: AR + - type: average_quality_auxiliary + name: Auxiliary \ No newline at end of file diff --git a/vot/stack/testing.yaml b/vot/stack/tests/segmentation.yaml similarity index 85% rename from vot/stack/testing.yaml rename to vot/stack/tests/segmentation.yaml index 04c1459..43e5288 100644 --- a/vot/stack/testing.yaml +++ b/vot/stack/tests/segmentation.yaml @@ -1,6 +1,6 @@ -title: VOT testing +title: VOT Segmentation testing url: http://www.votchallenge.net/ -dataset: vot:segmentation +dataset: http://box.vicos.si/tracking/vot20_test_dataset.zip experiments: baseline: type: multistart diff --git a/vot/stack/vot2013.yaml b/vot/stack/vot2013.yaml index 85ad419..3006908 100644 --- a/vot/stack/vot2013.yaml +++ b/vot/stack/vot2013.yaml @@ -1,6 +1,6 @@ title: VOT2013 challenge url: http://www.votchallenge.net/vot2013/ -dataset: vot:vot2013 +dataset: http://data.votchallenge.net/vot2013/dataset/description.json deprecated: True experiments: baseline: diff --git a/vot/stack/vot2014.yaml b/vot/stack/vot2014.yaml index 171df7e..3713448 100644 --- a/vot/stack/vot2014.yaml +++ b/vot/stack/vot2014.yaml @@ -1,5 +1,5 @@ title: VOT2014 challenge -dataset: vot:vot2014 +dataset: http://data.votchallenge.net/vot2014/dataset/description.json url: http://www.votchallenge.net/vot2014/ deprecated: True experiments: diff --git a/vot/stack/vot2015.yaml b/vot/stack/vot2015/rgb.yaml similarity index 82% rename from vot/stack/vot2015.yaml rename to vot/stack/vot2015/rgb.yaml index eb30eec..8e3c2ab 100644 --- a/vot/stack/vot2015.yaml +++ b/vot/stack/vot2015/rgb.yaml @@ -1,5 +1,5 @@ title: VOT2015 challenge -dataset: vot:vot2015 +dataset: http://data.votchallenge.net/vot2015/dataset/description.json url: http://www.votchallenge.net/vot2015/ experiments: baseline: diff --git a/vot/stack/vottir2015.yaml b/vot/stack/vot2015/tir.yaml similarity index 72% rename from vot/stack/vottir2015.yaml rename to vot/stack/vot2015/tir.yaml index 5ab84de..e2d09cf 100644 --- a/vot/stack/vottir2015.yaml +++ b/vot/stack/vot2015/tir.yaml @@ -1,5 +1,5 @@ title: VOT-TIR2015 challenge -dataset: vot:vot-tir2015 +dataset: http://www.cvl.isy.liu.se/research/datasets/ltir/version1.0/ltir_v1_0_8bit.zip url: http://www.votchallenge.net/vot2015/ experiments: baseline: diff --git a/vot/stack/vot2016.yaml b/vot/stack/vot2016/rgb.yaml similarity index 87% rename from vot/stack/vot2016.yaml rename to vot/stack/vot2016/rgb.yaml index 7dd3de6..8a616ea 100644 --- a/vot/stack/vot2016.yaml +++ b/vot/stack/vot2016/rgb.yaml @@ -1,5 +1,5 @@ title: VOT2016 challenge -dataset: vot:vot2016 +dataset: http://data.votchallenge.net/vot2016/main/description.json url: http://www.votchallenge.net/vot2016/ experiments: baseline: diff --git a/vot/stack/vottir2016.yaml b/vot/stack/vot2016/tir.yaml similarity index 79% rename from vot/stack/vottir2016.yaml rename to vot/stack/vot2016/tir.yaml index a5b09a3..6d4366b 100644 --- a/vot/stack/vottir2016.yaml +++ b/vot/stack/vot2016/tir.yaml @@ -1,5 +1,5 @@ title: VOT-TIR2016 challenge -dataset: vot:vot-tir2016 +dataset: http://data.votchallenge.net/vot2016/vot-tir2016.zip url: http://www.votchallenge.net/vot2016/ experiments: baseline: diff --git a/vot/stack/vot2017.yaml b/vot/stack/vot2017.yaml index 3a362ff..bc7a24b 100644 --- a/vot/stack/vot2017.yaml +++ b/vot/stack/vot2017.yaml @@ -1,5 +1,5 @@ title: VOT2017 challenge -dataset: vot:vot2017 +dataset: http://data.votchallenge.net/vot2017/main/description.json url: http://www.votchallenge.net/vot2017/ experiments: baseline: diff --git a/vot/stack/vot2018/longterm.yaml b/vot/stack/vot2018/longterm.yaml new file mode 100644 index 0000000..b31e677 --- /dev/null +++ b/vot/stack/vot2018/longterm.yaml @@ -0,0 +1,20 @@ +title: VOT-LT2018 challenge +dataset: http://data.votchallenge.net/vot2018/longterm/description.json +url: http://www.votchallenge.net/vot2018/ +experiments: + longterm: + type: unsupervised + repetitions: 1 + analyses: + - type: average_tpr + name: average_tpr + - type: pr_curve + - type: f_curve + redetection: + type: unsupervised + transformers: + - type: redetection + length: 200 + initialization: 5 + padding: 2 + scaling: 3 diff --git a/vot/stack/vot2018.yaml b/vot/stack/vot2018/shortterm.yaml similarity index 91% rename from vot/stack/vot2018.yaml rename to vot/stack/vot2018/shortterm.yaml index fb7942a..60a97af 100644 --- a/vot/stack/vot2018.yaml +++ b/vot/stack/vot2018/shortterm.yaml @@ -1,5 +1,5 @@ title: VOT-ST2018 challenge -dataset: vot:vot-st2018 +dataset: http://data.votchallenge.net/vot2018/main/description.json url: http://www.votchallenge.net/vot2018/ experiments: baseline: diff --git a/vot/stack/votlt2019.yaml b/vot/stack/vot2019/longterm.yaml similarity index 85% rename from vot/stack/votlt2019.yaml rename to vot/stack/vot2019/longterm.yaml index 65dba8a..0bc2805 100644 --- a/vot/stack/votlt2019.yaml +++ b/vot/stack/vot2019/longterm.yaml @@ -1,5 +1,5 @@ title: VOT-LT2019 challenge -dataset: vot:vot-lt2019 +dataset: http://data.votchallenge.net/vot2019/longterm/description.json url: http://www.votchallenge.net/vot2019/ experiments: longterm: diff --git a/vot/stack/votrgbd2019.yaml b/vot/stack/vot2019/rgbd.yaml similarity index 77% rename from vot/stack/votrgbd2019.yaml rename to vot/stack/vot2019/rgbd.yaml index f54726d..481fcf7 100644 --- a/vot/stack/votrgbd2019.yaml +++ b/vot/stack/vot2019/rgbd.yaml @@ -1,5 +1,5 @@ title: VOT-RGBD2019 challenge -dataset: vot:vot-rgbd2019 +dataset: http://data.votchallenge.net/vot2019/rgbd/description.json url: http://www.votchallenge.net/vot2019/ experiments: rgbd-unsupervised: diff --git a/vot/stack/vot2019/rgbtir.yaml b/vot/stack/vot2019/rgbtir.yaml new file mode 100644 index 0000000..15c60cf --- /dev/null +++ b/vot/stack/vot2019/rgbtir.yaml @@ -0,0 +1,15 @@ +title: VOT-RGBTIR2019 challenge +dataset: http://data.votchallenge.net/vot2019/rgbtir/meta/description.json +url: http://www.votchallenge.net/vot2019/ +experiments: + baseline: + type: multistart + realtime: + grace: 3 + analyses: + - type: multistart_average_ar + - type: multistart_eao_score + low: 115 + high: 755 + - type: multistart_eao_curve + high: 755 \ No newline at end of file diff --git a/vot/stack/vot2019.yaml b/vot/stack/vot2019/shortterm.yaml similarity index 91% rename from vot/stack/vot2019.yaml rename to vot/stack/vot2019/shortterm.yaml index 387605b..1ee18e7 100644 --- a/vot/stack/vot2019.yaml +++ b/vot/stack/vot2019/shortterm.yaml @@ -1,5 +1,5 @@ title: VOT-ST2019 challenge -dataset: vot:vot-st2019 +dataset: http://data.votchallenge.net/vot2019/main/description.json url: http://www.votchallenge.net/vot2019/ experiments: baseline: diff --git a/vot/stack/votlt2020.yaml b/vot/stack/vot2020/longterm.yaml similarity index 85% rename from vot/stack/votlt2020.yaml rename to vot/stack/vot2020/longterm.yaml index d8418fd..e698f87 100644 --- a/vot/stack/votlt2020.yaml +++ b/vot/stack/vot2020/longterm.yaml @@ -1,5 +1,5 @@ title: VOT-LT2020 challenge -dataset: vot:vot-lt2019 +dataset: http://data.votchallenge.net/vot2019/longterm/description.json url: http://www.votchallenge.net/vot2020/ experiments: longterm: diff --git a/vot/stack/votrgbd2020.yaml b/vot/stack/vot2020/rgbd.yaml similarity index 77% rename from vot/stack/votrgbd2020.yaml rename to vot/stack/vot2020/rgbd.yaml index f3265e8..75a8061 100644 --- a/vot/stack/votrgbd2020.yaml +++ b/vot/stack/vot2020/rgbd.yaml @@ -1,5 +1,5 @@ title: VOT-RGBD2020 challenge -dataset: vot:vot-rgbd2019 +dataset: http://data.votchallenge.net/vot2019/rgbd/description.json url: http://www.votchallenge.net/vot2020/ experiments: rgbd-unsupervised: diff --git a/vot/stack/votrgbtir2020.yaml b/vot/stack/vot2020/rgbtir.yaml similarity index 81% rename from vot/stack/votrgbtir2020.yaml rename to vot/stack/vot2020/rgbtir.yaml index c656ec3..98945c3 100644 --- a/vot/stack/votrgbtir2020.yaml +++ b/vot/stack/vot2020/rgbtir.yaml @@ -1,5 +1,5 @@ title: VOT-RGBTIR2020 challenge -dataset: vot:vot-rgbt2020 +dataset: http://data.votchallenge.net/vot2020/rgbtir/meta/description.json url: http://www.votchallenge.net/vot2020/ experiments: baseline: diff --git a/vot/stack/vot2020.yaml b/vot/stack/vot2020/shortterm.yaml similarity index 91% rename from vot/stack/vot2020.yaml rename to vot/stack/vot2020/shortterm.yaml index b736d6b..0452531 100644 --- a/vot/stack/vot2020.yaml +++ b/vot/stack/vot2020/shortterm.yaml @@ -1,5 +1,5 @@ title: VOT-ST2020 challenge -dataset: vot:vot-st2020 +dataset: https://data.votchallenge.net/vot2020/shortterm/description.json url: http://www.votchallenge.net/vot2020/ experiments: baseline: diff --git a/vot/stack/votlt2021.yaml b/vot/stack/vot2021/lt.yaml similarity index 85% rename from vot/stack/votlt2021.yaml rename to vot/stack/vot2021/lt.yaml index 842820b..d31c092 100644 --- a/vot/stack/votlt2021.yaml +++ b/vot/stack/vot2021/lt.yaml @@ -1,5 +1,5 @@ title: VOT-LT2021 challenge -dataset: vot:vot-lt2019 +dataset: http://data.votchallenge.net/vot2019/longterm/description.json url: http://www.votchallenge.net/vot2021/ experiments: longterm: diff --git a/vot/stack/votrgbd2021.yaml b/vot/stack/vot2021/rgbd.yaml similarity index 77% rename from vot/stack/votrgbd2021.yaml rename to vot/stack/vot2021/rgbd.yaml index 44eaa3c..cfffa5e 100644 --- a/vot/stack/votrgbd2021.yaml +++ b/vot/stack/vot2021/rgbd.yaml @@ -1,5 +1,5 @@ title: VOT-RGBD2021 challenge -dataset: vot:vot-rgbd2019 +dataset: http://data.votchallenge.net/vot2019/rgbd/description.json url: http://www.votchallenge.net/vot2021/ experiments: rgbd-unsupervised: diff --git a/vot/stack/vot2021.yaml b/vot/stack/vot2021/st.yaml similarity index 91% rename from vot/stack/vot2021.yaml rename to vot/stack/vot2021/st.yaml index 566d68b..b1f8a7b 100644 --- a/vot/stack/vot2021.yaml +++ b/vot/stack/vot2021/st.yaml @@ -1,5 +1,5 @@ title: VOT-ST2021 challenge -dataset: vot:vot-st2021 +dataset: https://data.votchallenge.net/vot2021/shortterm/description.json url: http://www.votchallenge.net/vot2021/ experiments: baseline: diff --git a/vot/stack/vot2022/depth.yaml b/vot/stack/vot2022/depth.yaml new file mode 100644 index 0000000..0b97f8e --- /dev/null +++ b/vot/stack/vot2022/depth.yaml @@ -0,0 +1,16 @@ +title: VOT-D2022 challenge +dataset: https://data.votchallenge.net/vot2022/depth/description.json +url: https://www.votchallenge.net/vot2022/ +experiments: + baseline: + type: multistart + analyses: + - type: multistart_eao_score + name: eaoscore + low: 115 + high: 755 + - type: multistart_eao_curve + name: eaocurve + high: 755 + - type: multistart_average_ar + name: ar diff --git a/vot/stack/vot2022/lt.yaml b/vot/stack/vot2022/lt.yaml new file mode 100644 index 0000000..fb80d22 --- /dev/null +++ b/vot/stack/vot2022/lt.yaml @@ -0,0 +1,20 @@ +title: VOT-LT2022 challenge +dataset: https://data.votchallenge.net/vot2022/lt/description.json +url: https://www.votchallenge.net/vot2022/ +experiments: + longterm: + type: unsupervised + repetitions: 1 + analyses: + - type: average_tpr + name: average_tpr + - type: pr_curve + - type: f_curve + redetection: + type: unsupervised + transformers: + - type: redetection + length: 200 + initialization: 5 + padding: 2 + scaling: 3 diff --git a/vot/stack/vot2022/rgbd.yaml b/vot/stack/vot2022/rgbd.yaml new file mode 100644 index 0000000..c5a8f19 --- /dev/null +++ b/vot/stack/vot2022/rgbd.yaml @@ -0,0 +1,16 @@ +title: VOT-RGBD2022 challenge +dataset: https://data.votchallenge.net/vot2022/rgbd/description.json +url: https://www.votchallenge.net/vot2022/ +experiments: + baseline: + type: multistart + analyses: + - type: multistart_eao_score + name: eaoscore + low: 115 + high: 755 + - type: multistart_eao_curve + name: eaocurve + high: 755 + - type: multistart_average_ar + name: ar diff --git a/vot/stack/vot2022/stb.yaml b/vot/stack/vot2022/stb.yaml new file mode 100644 index 0000000..5378a63 --- /dev/null +++ b/vot/stack/vot2022/stb.yaml @@ -0,0 +1,37 @@ +title: VOT-ST2022 bounding-box challenge +dataset: https://data.votchallenge.net/vot2022/stb/description.json +url: https://www.votchallenge.net/vot2022/ +experiments: + baseline: + type: multistart + analyses: + - type: multistart_eao_score + name: eaoscore + low: 115 + high: 755 + - type: multistart_eao_curve + name: eaocurve + high: 755 + - type: multistart_average_ar + name: ar + realtime: + type: multistart + realtime: + grace: 3 + analyses: + - type: multistart_eao_score + name: eaoscore + low: 115 + high: 755 + - type: multistart_eao_curve + name: eaocurve + high: 755 + - type: multistart_average_ar + name: ar + unsupervised: + type: unsupervised + repetitions: 1 + analyses: + - type: average_accuracy + name: accuracy + burnin: 1 \ No newline at end of file diff --git a/vot/stack/vot2022/sts.yaml b/vot/stack/vot2022/sts.yaml new file mode 100644 index 0000000..99f8de9 --- /dev/null +++ b/vot/stack/vot2022/sts.yaml @@ -0,0 +1,37 @@ +title: VOT-ST2021 segmentation challenge +dataset: https://data.votchallenge.net/vot2022/sts/description.json +url: https://www.votchallenge.net/vot2022/ +experiments: + baseline: + type: multistart + analyses: + - type: multistart_eao_score + name: eaoscore + low: 115 + high: 755 + - type: multistart_eao_curve + name: eaocurve + high: 755 + - type: multistart_average_ar + name: ar + realtime: + type: multistart + realtime: + grace: 3 + analyses: + - type: multistart_eao_score + name: eaoscore + low: 115 + high: 755 + - type: multistart_eao_curve + name: eaocurve + high: 755 + - type: multistart_average_ar + name: ar + unsupervised: + type: unsupervised + repetitions: 1 + analyses: + - type: average_accuracy + name: accuracy + burnin: 1 \ No newline at end of file diff --git a/vot/stack/vots2023.yaml b/vot/stack/vots2023.yaml new file mode 100644 index 0000000..6de44b1 --- /dev/null +++ b/vot/stack/vots2023.yaml @@ -0,0 +1,7 @@ +title: VOTS2023 Challenge Stack +dataset: https://data.votchallenge.net/vots2023/dataset/description.json +experiments: + baseline: + type: unsupervised + repetitions: 1 + multiobject: True \ No newline at end of file diff --git a/vot/tracker/__init__.py b/vot/tracker/__init__.py index b7119cf..eb6a468 100644 --- a/vot/tracker/__init__.py +++ b/vot/tracker/__init__.py @@ -1,11 +1,12 @@ +""" This module contains the base classes for trackers and the registry of known trackers. """ import os import re import configparser import logging import copy -from typing import Tuple -from collections import OrderedDict +from typing import Tuple, List, Union +from collections import OrderedDict, namedtuple from abc import abstractmethod, ABC import yaml @@ -18,20 +19,35 @@ logger = logging.getLogger("vot") class TrackerException(ToolkitException): + """ Base class for all tracker related exceptions.""" + def __init__(self, *args, tracker, tracker_log=None): + """ Initialize the exception. + + Args: + tracker (Tracker): Tracker that caused the exception. + tracker_log (str, optional): Optional log message. Defaults to None. + """ super().__init__(*args) self._tracker_log = tracker_log self._tracker = tracker @property - def log(self): + def log(self) -> str: + """ Returns the log message of the tracker. + + Returns: + sts: Log message of the tracker. + """ return self._tracker_log @property def tracker(self): + """ Returns the tracker that caused the exception.""" return self._tracker class TrackerTimeoutException(TrackerException): + """ Exception raised when the tracker communication times out.""" pass VALID_IDENTIFIER = re.compile("^[a-zA-Z0-9-_]+$") @@ -39,12 +55,39 @@ class TrackerTimeoutException(TrackerException): VALID_REFERENCE = re.compile("^([a-zA-Z0-9-_]+)(@[a-zA-Z0-9-_]*)?$") def is_valid_identifier(identifier): + """Checks if the identifier is valid. + + Args: + identifier (str): The identifier to check. + + Returns: + bool: True if the identifier is valid, False otherwise. + """ return not VALID_IDENTIFIER.match(identifier) is None def is_valid_reference(reference): + """Checks if the reference is valid. + + Args: + reference (str): The reference to check. + + Returns: + bool: True if the reference is valid, False otherwise. + """ return not VALID_REFERENCE.match(reference) is None def parse_reference(reference): + """Parses the reference into identifier and version. + + Args: + reference (str): The reference to parse. + + Returns: + tuple: A tuple containing the identifier and the version. + + Raises: + ValueError: If the reference is not valid. + """ matches = VALID_REFERENCE.match(reference) if not matches: return None, None @@ -53,8 +96,15 @@ def parse_reference(reference): _runtime_protocols = {} class Registry(object): + """ Repository of known trackers. Trackers are loaded from a manifest files in one or more directories. """ def __init__(self, directories, root=os.getcwd()): + """ Initialize the registry. + + Args: + directories (list): List of directories to scan for trackers. + root (str, optional): The root directory of the workspace. Defaults to os.getcwd(). + """ trackers = dict() registries = [] @@ -107,19 +157,37 @@ def __init__(self, directories, root=os.getcwd()): logger.debug("Found %d trackers", len(self._trackers)) def __getitem__(self, reference) -> "Tracker": + """ Returns the tracker for the given reference. """ + return self.resolve(reference, skip_unknown=False, resolve_plural=False)[0] def __contains__(self, reference) -> bool: + """ Checks if the tracker is registered. """ identifier, _ = parse_reference(reference) return identifier in self._trackers def __iter__(self): + """ Returns an iterator over the trackers.""" return iter(self._trackers.values()) def __len__(self): + """ Returns the number of trackers.""" return len(self._trackers) def resolve(self, *references, storage=None, skip_unknown=True, resolve_plural=True): + """ Resolves the references to trackers. + + Args: + storage (_type_, optional): Storage to use for resolving references. Defaults to None. + skip_unknown (bool, optional): Skip unknown trackers. Defaults to True. + resolve_plural (bool, optional): Resolve plural references. Defaults to True. + + Raises: + ToolkitException: When a reference cannot be resolved. + + Returns: + list: Resolved trackers. + """ trackers = [] @@ -151,7 +219,16 @@ def resolve(self, *references, storage=None, skip_unknown=True, resolve_plural=T return trackers - def _find_versions(self, identifier, storage): + def _find_versions(self, identifier: str, storage: "Storage"): + """ Finds all versions of the tracker in the storage. + + Args: + identifier (str): The identifier of the tracker. + storage (Storage): The storage to use for finding the versions. + + Returns: + list: List of trackers. + """ trackers = [] @@ -167,15 +244,33 @@ def _find_versions(self, identifier, storage): return trackers def references(self): + """ Returns a list of all tracker references. + + Returns: + list: List of tracker references. + """ return [t.reference for t in self._trackers.values()] def identifiers(self): + """ Returns a list of all tracker identifiers. + + Returns: + list: List of tracker identifiers. + """ return [t.identifier for t in self._trackers.values()] class Tracker(object): + """ Tracker definition class. """ @staticmethod def _collect_envvars(**kwargs): + """ Collects environment variables from the keyword arguments. + + Args: + **kwargs: Keyword arguments. + + Returns: + tuple: Tuple of environment variables and other keyword arguments. """ envvars = dict() other = dict() @@ -194,6 +289,14 @@ def _collect_envvars(**kwargs): @staticmethod def _collect_arguments(**kwargs): + """ Collects arguments from the keyword arguments. + + Args: + **kwargs: Keyword arguments. + + Returns: + tuple: Tuple of arguments and other keyword arguments. + """ arguments = dict() other = dict() @@ -212,6 +315,18 @@ def _collect_arguments(**kwargs): @staticmethod def _collect_metadata(**kwargs): + """ Collects metadata from the keyword arguments. + + Args: + **kwargs: Keyword arguments. + + Returns: + tuple: Tuple of metadata and other keyword arguments. + + Examples: + >>> Tracker._collect_metadata(meta_author="John Doe", meta_year=2018) + ({'author': 'John Doe', 'year': 2018}, {}) + """ metadata = dict() other = dict() @@ -229,6 +344,23 @@ def _collect_metadata(**kwargs): return metadata, other def __init__(self, _identifier, _source, command, protocol=None, label=None, version=None, tags=None, storage=None, **kwargs): + """ Initializes the tracker definition. + + Args: + _identifier (str): The identifier of the tracker. + _source (str): The source of the tracker. + command (str): The command to execute. + protocol (str, optional): The protocol of the tracker. Defaults to None. + label (str, optional): The label of the tracker. Defaults to None. + version (str, optional): The version of the tracker. Defaults to None. + tags (str, optional): The tags of the tracker. Defaults to None. + storage (str, optional): The storage of the tracker. Defaults to None. + **kwargs: Additional keyword arguments. + + Raises: + ValueError: When the identifier is not valid. + + """ from vot.workspace import LocalStorage self._identifier = _identifier self._source = _source @@ -267,6 +399,7 @@ def reversion(self, version=None) -> "Tracker": return tracker def runtime(self, log=False) -> "TrackerRuntime": + """Creates a new runtime instance for this tracker instance.""" if not self._command: raise TrackerException("Tracker does not have an attached executable", tracker=self) @@ -276,31 +409,49 @@ def runtime(self, log=False) -> "TrackerRuntime": return _runtime_protocols[self._protocol](self, self._command, log=log, envvars=self._envvars, arguments=self._arguments, **self._args) def __eq__(self, other): + """ Checks if two trackers are equal. + + Args: + other (Tracker): The other tracker. + + Returns: + bool: True if the trackers are equal, False otherwise. + """ if other is None or not isinstance(other, Tracker): return False return self.reference == other.identifier def __hash__(self): + """ Returns the hash of the tracker. """ return hash(self.reference) def __repr__(self): + """ Returns the string representation of the tracker. """ return self.reference @property def source(self): + """Returns the source of the tracker.""" return self._source @property def storage(self) -> "Storage": + """Returns the storage of the tracker results.""" return self._storage @property def identifier(self) -> str: + """Returns the identifier of the tracker.""" return self._identifier @property def label(self): + """Returns the label of the tracker. If the version is specified, the label will contain the version as well. + + Returns: + str: Label of the tracker. + """ if self._version is None: return self._label else: @@ -308,10 +459,20 @@ def label(self): @property def version(self) -> str: + """Returns the version of the tracker. If the version is not specified, None is returned. + + Returns: + str: Version of the tracker. + """ return self._version @property def reference(self) -> str: + """Returns the reference of the tracker. If the version is specified, the reference will contain the version as well. + + Returns: + str: Reference of the tracker. + """ if self._version is None: return self._identifier else: @@ -319,55 +480,124 @@ def reference(self) -> str: @property def protocol(self) -> str: + """Returns the communication protocol used by this tracker. + + Returns: + str: Communication protocol + """ return self._protocol def describe(self): + """Returns a dictionary containing the tracker description. + + Returns: + dict: Dictionary containing the tracker description. + """ data = dict(command=self._command, label=self.label, protocol=self.protocol, arguments=self._arguments, env=self._envvars) data.update(self._args) return data def metadata(self, key): + """Returns the metadata value for specified key.""" if not key in self._metadata: return None return self._metadata[key] def tagged(self, tag): + """Returns true if the tracker is tagged with specified tag. + + Args: + tag (str): The tag to check. + + Returns: + bool: True if the tracker is tagged with specified tag, False otherwise. + """ + return tag in self._tags +ObjectStatus = namedtuple("ObjectStatus", ["region", "properties"]) + +Objects = Union[List[ObjectStatus], ObjectStatus] class TrackerRuntime(ABC): + """Base class for tracker runtime implementations. Tracker runtime is responsible for running the tracker executable and communicating with it.""" def __init__(self, tracker: Tracker): + """Creates a new tracker runtime instance. + + Args: + tracker (Tracker): The tracker instance. + """ self._tracker = tracker @property def tracker(self) -> Tracker: + """Returns the tracker instance associated with this runtime.""" return self._tracker def __enter__(self): + """Starts the tracker runtime.""" return self def __exit__(self, exc_type, exc_val, exc_tb): + """Stops the tracker runtime.""" self.stop() + @property + def multiobject(self): + """Returns True if the tracker supports multiple objects, False otherwise.""" + return False + @abstractmethod def stop(self): + """Stops the tracker runtime.""" pass @abstractmethod def restart(self): + """Restarts the tracker runtime, usually stars a new process.""" pass @abstractmethod - def initialize(self, frame: Frame, region: Region, properties: dict = None) -> Tuple[Region, dict, float]: + def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """Initializes the tracker runtime with specified frame and objects. Returns the initial objects and the time it took to initialize the tracker. + + Arguments: + frame {Frame} -- The frame to initialize the tracker with. + new {Objects} -- The objects to initialize the tracker with. + properties {dict} -- The properties to initialize the tracker with. + + Returns: + Tuple[Objects, float] -- The initial objects and the time it took to initialize the tracker. + """ pass @abstractmethod - def update(self, frame: Frame, properties: dict = None) -> Tuple[Region, dict, float]: + def update(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """Updates the tracker runtime with specified frame and objects. Returns the updated objects and the time it took to update the tracker. + + Arguments: + frame {Frame} -- The frame to update the tracker with. + new {Objects} -- The objects to update the tracker with. + properties {dict} -- The properties to update the tracker with. + + Returns: + Tuple[Objects, float] -- The updated objects and the time it took to update the tracker. + """ pass class RealtimeTrackerRuntime(TrackerRuntime): + """Base class for realtime tracker runtime implementations. + Realtime tracker runtime is responsible for running the tracker executable and communicating with it while simulating given real-time constraints.""" def __init__(self, runtime: TrackerRuntime, grace: int = 1, interval: float = 0.1): + """Initializes the realtime tracker runtime with specified tracker runtime, grace period and update interval. + + Arguments: + runtime {TrackerRuntime} -- The tracker runtime to wrap. + grace {int} -- The grace period in seconds. The tracker will be updated at least once during the grace period. (default: {1}) + interval {float} -- The update interval in seconds. (default: {0.1}) + + """ super().__init__(runtime.tracker) self._runtime = runtime self._grace = grace @@ -376,27 +606,38 @@ def __init__(self, runtime: TrackerRuntime, grace: int = 1, interval: float = 0. self._time = 0 self._out = None - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() + @property + def multiobject(self): + """Returns True if the tracker supports multiple objects, False otherwise.""" + return self._runtime.multiobject def stop(self): + """Stops the tracker runtime.""" self._runtime.stop() self._time = 0 self._out = None def restart(self): + """Restarts the tracker runtime, usually stars a new process.""" self._runtime.restart() self._time = 0 self._out = None - def initialize(self, frame: Frame, region: Region, properties: dict = None) -> Tuple[Region, dict, float]: + def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """Initializes the tracker runtime with specified frame and objects. Returns the initial objects and the time it took to initialize the tracker. + + Arguments: + frame {Frame} -- The frame to initialize the tracker with. + new {Objects} -- The objects to initialize the tracker with. + properties {dict} -- The properties to initialize the tracker with. + + Returns: + Tuple[Objects, float] -- The initial objects and the time it took to initialize the tracker. + """ self._countdown = self._grace self._out = None - out, prop, time = self._runtime.initialize(frame, region, properties) + out, prop, time = self._runtime.initialize(frame, new, properties) if time > self._interval: if self._countdown > 0: @@ -411,7 +652,17 @@ def initialize(self, frame: Frame, region: Region, properties: dict = None) -> T return out, prop, time - def update(self, frame: Frame, properties: dict = None) -> Tuple[Region, dict, float]: + def update(self, frame: Frame, _: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """Updates the tracker runtime with specified frame and objects. Returns the updated objects and the time it took to update the tracker. + + Arguments: + frame {Frame} -- The frame to update the tracker with. + new {Objects} -- The objects to update the tracker with. + properties {dict} -- The properties to update the tracker with. + + Returns: + Tuple[Objects, float] -- The updated objects and the time it took to update the tracker. + """ if self._time > self._interval: self._time = self._time - self._interval @@ -434,25 +685,43 @@ def update(self, frame: Frame, properties: dict = None) -> Tuple[Region, dict, f class PropertyInjectorTrackerRuntime(TrackerRuntime): + """Base class for tracker runtime implementations that inject properties into the tracker runtime.""" def __init__(self, runtime: TrackerRuntime, **kwargs): + """Initializes the property injector tracker runtime with specified tracker runtime and properties. + + Arguments: + runtime {TrackerRuntime} -- The tracker runtime to wrap. + **kwargs -- The properties to inject into the tracker runtime. + """ super().__init__(runtime.tracker) self._runtime = runtime self._properties = {k : str(v) for k, v in kwargs.items()} - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() + @property + def multiobject(self): + """Returns True if the tracker supports multiple objects, False otherwise.""" + return self._runtime.multiobject def stop(self): + """Stops the tracker runtime.""" self._runtime.stop() def restart(self): + """Restarts the tracker runtime, usually stars a new process.""" self._runtime.restart() - def initialize(self, frame: Frame, region: Region, properties: dict = None) -> Tuple[Region, dict, float]: + def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """Initializes the tracker runtime with specified frame and objects. Returns the initial objects and the time it took to initialize the tracker. + This method injects the properties into the tracker runtime. + + Arguments: + frame {Frame} -- The frame to initialize the tracker with. + new {Objects} -- The objects to initialize the tracker with. + properties {dict} -- The properties to initialize the tracker with. + + Returns: + Tuple[Objects, float] -- The initial objects and the time it took to initialize the tracker.""" if not properties is None: tproperties = dict(properties) @@ -461,12 +730,174 @@ def initialize(self, frame: Frame, region: Region, properties: dict = None) -> T tproperties.update(self._properties) - return self._runtime.initialize(frame, region, tproperties) + return self._runtime.initialize(frame, new, tproperties) + + + def update(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """Updates the tracker runtime with specified frame and objects. Returns the updated objects and the time it took to update the tracker. + + Arguments: + frame {Frame} -- The frame to update the tracker with. + new {Objects} -- The objects to update the tracker with. + properties {dict} -- The properties to update the tracker with. + + Returns: + Tuple[Objects, float] -- The updated objects and the time it took to update the tracker. + + """ + return self._runtime.update(frame, new, properties) + + +class SingleObjectTrackerRuntime(TrackerRuntime): + """Wrapper for tracker runtime that only support single object tracking. Used to enforce single object tracking even for multi object trackers.""" + + def __init__(self, runtime: TrackerRuntime): + """Initializes the single object tracker runtime with specified tracker runtime. + + Arguments: + runtime {TrackerRuntime} -- The tracker runtime to wrap. + """ + super().__init__(runtime.tracker) + self._runtime = runtime + + @property + def multiobject(self): + """Returns False, since the tracker runtime only supports single object tracking.""" + return False + + def stop(self): + """Stops the tracker runtime. + + Raises: + TrackerException -- If the tracker runtime does not support stopping. + """ + self._runtime.stop() + + def restart(self): + """Restarts the tracker runtime, usually stars a new process. + + Raises: + TrackerException -- If the tracker runtime does not support restarting. + """ + self._runtime.restart() + + def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """Initializes the tracker runtime with specified frame and objects. Returns the initial objects and the time it took to initialize the tracker. + + Arguments: + frame {Frame} -- The frame to initialize the tracker with. + new {Objects} -- The objects to initialize the tracker with. + properties {dict} -- The properties to initialize the tracker with. + + Returns: + Tuple[Objects, float] -- The initial objects and the time it took to initialize the tracker. + """ + + if isinstance(new, list) and len(new) != 1: raise TrackerException("Only supports single object tracking", tracker=self.tracker) + status, time = self._runtime.initialize(frame, new, properties) + if isinstance(status, list): status = status[0] + return status, time + + def update(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """Updates the tracker runtime with specified frame and objects. Returns the updated objects and the time it took to update the tracker. + + Arguments: + frame {Frame} -- The frame to update the tracker with. + new {Objects} -- The objects to update the tracker with. + properties {dict} -- The properties to update the tracker with. + + Returns: + Tuple[Objects, float] -- The updated objects and the time it took to update the tracker. + """ + + if not new is None: raise TrackerException("Only supports single object tracking", tracker=self.tracker) + status, time = self._runtime.update(frame, new, properties) + if isinstance(status, list): status = status[0] + return status, time + +class MultiObjectTrackerRuntime(TrackerRuntime): + """ This is a wrapper for tracker runtimes that do not support multi object tracking. STILL IN DEVELOPMENT!""" + + def __init__(self, runtime: TrackerRuntime): + """Initializes the multi object tracker runtime with specified tracker runtime. + + Arguments: + runtime {TrackerRuntime} -- The tracker runtime to wrap. + """ + + super().__init__(runtime.tracker) + if runtime.multiobject: + self._runtime = runtime + else: + self._runtime = [runtime] + self._used = 0 + + @property + def multiobject(self): + """Always returns True, since the tracker runtime supports multi object tracking.""" + return True + + def stop(self): + """Stops the tracker runtime.""" + if isinstance(self._runtime, TrackerRuntime): + self._runtime.stop() + else: + for r in self._runtime: + r.stop() + + def restart(self): + """Restarts the tracker runtime, usually stars a new process.""" + if isinstance(self._runtime, TrackerRuntime): + self._runtime.restart() + else: + for r in self._runtime: + r.restart() + + def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """Initializes the tracker runtime with specified frame and objects. Returns the initial objects and the time it took to initialize the tracker. + Internally this method initializes the tracker runtime for each object in the objects list. + + Arguments: + frame {Frame} -- The frame to initialize the tracker with. + new {Objects} -- The objects to initialize the tracker with. + properties {dict} -- The properties to initialize the tracker with. + + Returns: + Tuple[Objects, float] -- The initial objects and the time it took to initialize the tracker. + """ + + if isinstance(self._runtime, TrackerRuntime): + return self._runtime.initialize(frame, new, properties) + if isinstance(new, ObjectStatus): + new = [new] + + self._used = 0 + status = [] + for i, o in enumerate(new): + if i >= len(self._runtime): + self._runtime.append(self._tracker.runtime()) + self._runtime.initialize(frame, new, properties) + if isinstance(status, list): status = status[0] + return status - def update(self, frame: Frame, properties: dict = None) -> Tuple[Region, dict, float]: - return self._runtime.update(frame, properties) + def update(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """Updates the tracker runtime with specified frame and objects. Returns the updated objects and the time it took to update the tracker. + Internally this method updates the tracker runtime for each object in the new objects list. + + Arguments: + frame {Frame} -- The frame to update the tracker with. + new {Objects} -- The objects to update the tracker with. + properties {dict} -- The properties to update the tracker with. + + Returns: + Tuple[Objects, float] -- The updated objects and the time it took to update the tracker. + """ + if not new is None: raise TrackerException("Only supports single object tracking") + status = self._runtime.update(frame, new, properties) + if isinstance(status, list): status = status[0] + return status try: diff --git a/vot/tracker/dummy.py b/vot/tracker/dummy.py index 2e06666..b80f31d 100644 --- a/vot/tracker/dummy.py +++ b/vot/tracker/dummy.py @@ -1,20 +1,23 @@ +""" Dummy tracker for testing purposes. """ + from __future__ import absolute_import import os from sys import path import time -from trax import Image, Region, Server, TraxStatus - def _main(): - region = None + """ Dummy tracker main function for testing purposes.""" + from trax import Image, Region, Server, TraxStatus + + objects = None with Server([Region.RECTANGLE], [Image.PATH]) as server: while True: request = server.wait() if request.type in [TraxStatus.QUIT, TraxStatus.ERROR]: break if request.type == TraxStatus.INITIALIZE: - region = request.region - server.status(region) + objects = request.objects + server.status(objects) time.sleep(0.1) if __name__ == '__main__': diff --git a/vot/tracker/results.py b/vot/tracker/results.py index 90a59bf..564d180 100644 --- a/vot/tracker/results.py +++ b/vot/tracker/results.py @@ -1,42 +1,113 @@ +"""Results module for storing and retrieving tracker results.""" + import os import fnmatch from typing import List from copy import copy -from vot.region import Region, RegionType, Special, write_file, read_file, calculate_overlap +from vot.region import Region, RegionType, Special, calculate_overlap +from vot.region.io import write_trajectory, read_trajectory from vot.utilities import to_string class Results(object): + """Generic results interface for storing and retrieving results.""" def __init__(self, storage: "Storage"): + """Creates a new results interface. + + Args: + storage (Storage): Storage interface + """ self._storage = storage def exists(self, name): + """Returns true if the given file exists in the results storage. + + Args: + name (str): File name + + Returns: + bool: True if the file exists + """ return self._storage.isdocument(name) def read(self, name): + """Returns a file handle for reading the given file from the results storage. + + Args: + name (str): File name + + Returns: + file: File handle + """ + if name.endswith(".bin"): + return self._storage.read(name, binary=True) return self._storage.read(name) - def write(self, name): + def write(self, name: str): + """Returns a file handle for writing the given file to the results storage. + + Args: + name (str): File name + + Returns: + file: File handle + """ + if name.endswith(".bin"): + return self._storage.write(name, binary=True) return self._storage.write(name) def find(self, pattern): - return fnmatch.filter(self._storage.documents(), pattern) + """Returns a list of files matching the given pattern in the results storage. + + Args: + pattern (str): Pattern + Returns: + list: List of files + """ + + return fnmatch.filter(self._storage.documents(), pattern) + class Trajectory(object): + """Trajectory class for storing and retrieving tracker trajectories.""" + + UNKNOWN = 0 + INITIALIZATION = 1 + FAILURE = 2 @classmethod def exists(cls, results: Results, name: str) -> bool: - return results.exists(name + ".txt") + """Returns true if the trajectory exists in the results storage. + + Args: + results (Results): Results storage + name (str): Trajectory name (without extension) + + Returns: + bool: True if the trajectory exists + """ + return results.exists(name + ".bin") or results.exists(name + ".txt") @classmethod def gather(cls, results: Results, name: str) -> list: - - if not Trajectory.exists(results, name): + """Returns a list of files that are part of the trajectory. + + Args: + results (Results): Results storage + name (str): Trajectory name (without extension) + + Returns: + list: List of files + """ + + if results.exists(name + ".bin"): + files = [name + ".bin"] + elif results.exists(name + ".txt"): + files = [name + ".txt"] + else: return [] - files = [name + ".txt"] - for propertyfile in results.find(name + "_*.value"): files.append(propertyfile) @@ -44,18 +115,38 @@ def gather(cls, results: Results, name: str) -> list: @classmethod def read(cls, results: Results, name: str) -> 'Trajectory': + """Reads a trajectory from the results storage. + + Args: + results (Results): Results storage + name (str): Trajectory name (without extension) + + Returns: + Trajectory: Trajectory + """ def parse_float(line): + """Parses a float from a line. + + Args: + line (str): Line + + Returns: + float: Float value + """ if not line.strip(): return None return float(line.strip()) - if not results.exists(name + ".txt"): + if results.exists(name + ".txt"): + with results.read(name + ".txt") as fp: + regions = read_trajectory(fp) + elif results.exists(name + ".bin"): + with results.read(name + ".bin") as fp: + regions = read_trajectory(fp) + else: raise FileNotFoundError("Trajectory data not found: {}".format(name)) - with results.read(name + ".txt") as fp: - regions = read_file(fp) - trajectory = Trajectory(len(regions)) trajectory._regions = regions @@ -67,14 +158,29 @@ def parse_float(line): trajectory._properties[propertyname] = [parse_float(line) for line in lines] except ValueError: trajectory._properties[propertyname] = [line.strip() for line in lines] - + return trajectory - def __init__(self, length:int): - self._regions = [Special(Special.UNKNOWN)] * length + def __init__(self, length: int): + """Creates a new trajectory of the given length. + + Args: + length (int): Trajectory length + """ + self._regions = [Special(Trajectory.UNKNOWN)] * length self._properties = dict() def set(self, frame: int, region: Region, properties: dict = None): + """Sets the region for the given frame. + + Args: + frame (int): Frame index + region (Region): Region + properties (dict, optional): Frame properties. Defaults to None. + + Raises: + IndexError: Frame index out of bounds + """ if frame < 0 or frame >= len(self._regions): raise IndexError("Frame index out of bounds") @@ -89,14 +195,41 @@ def set(self, frame: int, region: Region, properties: dict = None): self._properties[k][frame] = v def region(self, frame: int) -> Region: + """Returns the region for the given frame. + + Args: + frame (int): Frame index + + Raises: + IndexError: Frame index out of bounds + + Returns: + Region: Region + """ if frame < 0 or frame >= len(self._regions): raise IndexError("Frame index out of bounds") return self._regions[frame] def regions(self) -> List[Region]: + """ Returns the list of regions. + + Returns: + List[Region]: List of regions + """ return copy(self._regions) def properties(self, frame: int = None) -> dict: + """Returns the properties for the given frame or all properties if frame is None. + + Args: + frame (int, optional): Frame index. Defaults to None. + + Raises: + IndexError: Frame index out of bounds + + Returns: + dict: Properties + """ if frame is None: return tuple(self._properties.keys()) @@ -107,12 +240,29 @@ def properties(self, frame: int = None) -> dict: return {k : v[frame] for k, v in self._properties.items() if not v[frame] is None} def __len__(self): + """Returns the length of the trajectory. + + Returns: + int: Length + """ return len(self._regions) def write(self, results: Results, name: str): - - with results.write(name + ".txt") as fp: - write_file(fp, self._regions) + """Writes the trajectory to the results storage. + + Args: + results (Results): Results storage + name (str): Trajectory name (without extension) + """ + from vot import config + + if config.results_binary: + with results.write(name + ".bin") as fp: + write_trajectory(fp, self._regions) + else: + with results.write(name + ".txt") as fp: + # write_trajectory_file(fp, self._regions) + write_trajectory(fp, self._regions) for k, v in self._properties.items(): with results.write(name + "_" + k + ".value") as fp: @@ -120,6 +270,16 @@ def write(self, results: Results, name: str): def equals(self, trajectory: 'Trajectory', check_properties: bool = False, overlap_threshold: float = 0.99999): + """Returns true if the trajectories are equal. + + Args: + trajectory (Trajectory): _description_ + check_properties (bool, optional): _description_. Defaults to False. + overlap_threshold (float, optional): _description_. Defaults to 0.99999. + + Returns: + _type_: _description_ + """ if not len(self) == len(trajectory): return False diff --git a/vot/tracker/tests.py b/vot/tracker/tests.py index 9126bd9..1334c44 100644 --- a/vot/tracker/tests.py +++ b/vot/tracker/tests.py @@ -1,18 +1,20 @@ - +""" Unit tests for the tracker module. """ import unittest -from ..dataset.dummy import DummySequence +from ..dataset.dummy import generate_dummy from ..tracker.dummy import DummyTracker class TestStacks(unittest.TestCase): + """Tests for the stacks module.""" def test_tracker_test(self): + """Test tracker runtime with dummy sequence and dummy tracker.""" tracker = DummyTracker - sequence = DummySequence(10) + sequence = generate_dummy(10) with tracker.runtime(log=False) as runtime: runtime.initialize(sequence.frame(0), sequence.groundtruth(0)) - for i in range(1, sequence.length): + for i in range(1, len(sequence)): runtime.update(sequence.frame(i)) diff --git a/vot/tracker/trax.py b/vot/tracker/trax.py index 055e828..8fb0618 100644 --- a/vot/tracker/trax.py +++ b/vot/tracker/trax.py @@ -1,4 +1,8 @@ +""" TraX protocol implementation for the toolkit. TraX is a communication protocol for visual object tracking. + It enables communication between a tracker and a client. The protocol was originally developed for the VOT challenge to address + the need for a unified communication interface between trackers and benchmarking tools. +""" import sys import os import time @@ -27,7 +31,7 @@ from vot.dataset import Frame, DatasetException from vot.region import Region, Polygon, Rectangle, Mask -from vot.tracker import Tracker, TrackerRuntime, TrackerException +from vot.tracker import Tracker, TrackerRuntime, TrackerException, Objects, ObjectStatus from vot.utilities import to_logical, to_number, normalize_path PORT_POOL_MIN = 9090 @@ -36,27 +40,40 @@ logger = logging.getLogger("vot") class LogAggregator(object): + """ Aggregates log messages from the tracker. """ def __init__(self): + """ Initializes the aggregator.""" self._fragments = [] def __call__(self, fragment): + """ Appends a new fragment to the log.""" self._fragments.append(fragment) def __str__(self): + """ Returns the aggregated log.""" return "".join(self._fragments) class ColorizedOutput(object): + """ Colorized output for the tracker.""" def __init__(self): + """ Initializes the colorized output.""" colorama.init() def __call__(self, fragment): + """ Prints a new fragment to the output. + + Args: + fragment: The fragment to be printed. + """ print(colorama.Fore.CYAN + fragment + colorama.Fore.RESET, end="") class PythonCrashHelper(object): + """ Helper class for detecting Python crashes in the tracker.""" def __init__(self): + """ Initializes the crash helper.""" self._matcher = re.compile(r''' ^Traceback [\s\S]+? @@ -64,12 +81,27 @@ def __init__(self): ''', re.M | re.X) def __call__(self, log, directory): + """ Detects Python crashes in the log. + + Args: + log: The log to be checked. + directory: The directory where the log is stored. + """ matches = self._matcher.findall(log) if len(matches) > 0: return matches[-1].group(0) return None def convert_frame(frame: Frame, channels: list) -> dict: + """ Converts a frame to a dictionary of Trax images. + + Args: + frame: The frame to be converted. + channels: The list of channels to be converted. + + Returns: + A dictionary of Trax images. + """ tlist = dict() for channel in channels: @@ -82,16 +114,31 @@ def convert_frame(frame: Frame, channels: list) -> dict: return tlist def convert_region(region: Region) -> TraxRegion: + """ Converts a region to a Trax region. + + Args: + region: The region to be converted. + + Returns: + A Trax region. + """ if isinstance(region, Rectangle): return TraxRectangle.create(region.x, region.y, region.width, region.height) elif isinstance(region, Polygon): return TraxPolygon.create([region[i] for i in range(region.size)]) elif isinstance(region, Mask): return TraxMask.create(region.mask, x=region.offset[0], y=region.offset[1]) - return None def convert_traxregion(region: TraxRegion) -> Region: + """ Converts a Trax region to a region. + + Args: + region: The Trax region to be converted. + + Returns: + A region. + """ if region.type == TraxRegion.RECTANGLE: x, y, width, height = region.bounds() return Rectangle(x, y, width, height) @@ -99,22 +146,61 @@ def convert_traxregion(region: TraxRegion) -> Region: return Polygon(list(region)) elif region.type == TraxRegion.MASK: return Mask(region.array(), region.offset(), optimize=True) + return None + +def convert_objects(objects: Objects) -> TraxRegion: + """ Converts a list of objects to a Trax region. + + Args: + objects: The list of objects to be converted. + + Returns: + A Trax region. + """ + if objects is None: return [] + if isinstance(objects, (list, )): + return [(convert_region(o.region), dict(o.properties)) for o in objects] + if isinstance(objects, (ObjectStatus, )): + return [(convert_region(objects.region), dict(objects.properties))] + else: + return [(convert_region(objects), dict())] + +def convert_traxobjects(region: TraxRegion) -> Region: + """ Converts a Trax region to a region. + Args: + region: The Trax region to be converted. + + Returns: + A region. + + """ + if region.type == TraxRegion.RECTANGLE: + x, y, width, height = region.bounds() + return Rectangle(x, y, width, height) + elif region.type == TraxRegion.POLYGON: + return Polygon(list(region)) + elif region.type == TraxRegion.MASK: + return Mask(region.array(), region.offset(), optimize=True) return None class TestRasterMethods(unittest.TestCase): + """ Tests for the raster methods. """ def test_convert_traxregion(self): + """ Tests the conversion of Trax regions.""" convert_traxregion(TraxRectangle.create(0, 0, 10, 10)) convert_traxregion(TraxPolygon.create([(0, 0), (10, 0), (10, 10), (0, 10)])) convert_traxregion(TraxMask.create(np.ones((100, 100), dtype=np.uint8))) def test_convert_region(self): + """ Tests the conversion of regions.""" convert_region(Rectangle(0, 0, 10, 10)) convert_region(Polygon([(0, 0), (10, 0), (10, 10), (0, 10)])) convert_region(Mask(np.ones((100, 100), dtype=np.uint8))) def open_local_port(port: int): + """ Opens a local port for listening.""" socket = socketio.socket(socketio.AF_INET, socketio.SOCK_STREAM) try: socket.setsockopt(socketio.SOL_SOCKET, socketio.SO_REUSEADDR, 1) @@ -129,12 +215,25 @@ def open_local_port(port: int): return None def normalize_paths(paths, tracker): + """ Normalizes a list of paths relative to the tracker source.""" root = os.path.dirname(tracker.source) return [normalize_path(path, root) for path in paths] class TrackerProcess(object): - + """ A tracker process. This class is used to run trackers in a separate process and handles + starting, stopping and communication with the process. """ + def __init__(self, command: str, envvars=dict(), timeout=30, log=False, socket=False): + """ Initializes a new tracker process. + + Args: + command: The command to run the tracker. + envvars: A dictionary of environment variables to be set for the tracker process. + timeout: The timeout for the tracker process. + log: Whether to log the tracker output. + socket: Whether to use a socket for communication. + + """ environment = dict(os.environ) environment.update(envvars) @@ -188,6 +287,7 @@ def __init__(self, command: str, envvars=dict(), timeout=30, log=False, socket=F self._client = Client( stream=(self._process.stdin.fileno(), self._process.stdout.fileno()), log=log ) + except TraxException as e: self.terminate() self._watchdog_reset(False) @@ -195,8 +295,15 @@ def __init__(self, command: str, envvars=dict(), timeout=30, log=False, socket=F self._watchdog_reset(False) self._has_vot_wrapper = not self._client.get("vot") is None + self._multiobject = self._client.get("multiobject") def _watchdog_reset(self, enable=True): + """ Resets the watchdog. + + Args: + enable: Whether to enable the watchdog. + + """ if self._watchdog_counter == 0: return @@ -206,6 +313,7 @@ def _watchdog_reset(self, enable=True): self._watchdog_counter = -1 def _watchdog_loop(self): + """ The watchdog loop. This loop is used to monitor the tracker process and terminate it if it does not respond anymore.""" while self.alive: time.sleep(0.1) @@ -219,28 +327,46 @@ def _watchdog_loop(self): @property def has_vot_wrapper(self): + """ Whether the tracker has a VOT wrapper. VOT wrapper limits TraX functionality and injects a property at handshake to let the client know this.""" return self._has_vot_wrapper @property def returncode(self): + """ The return code of the tracker process.""" return self._returncode @property def workdir(self): + """ The working directory of the tracker process.""" return self._workdir @property def interrupted(self): + """ Whether the tracker process was interrupted.""" return self._watchdog_counter == 0 @property def alive(self): + """ Whether the tracker process is alive.""" if self._process is None: return False self._returncode = self._process.returncode return self._returncode is None - def initialize(self, frame: Frame, region: Region, properties: dict = None) -> Tuple[Region, dict, float]: + def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """ Initializes the tracker. This method is used to initialize the tracker with the first frame. It returns the initial state of the tracker. + + Args: + frame: The first frame. + new: The initial state of the tracker. + properties: The properties to be set for the tracker. + + Returns: + The initial state of the tracker. + + Raises: + TraxException: If the tracker is not alive. + """ if not self.alive: raise TraxException("Tracker not alive") @@ -249,33 +375,55 @@ def initialize(self, frame: Frame, region: Region, properties: dict = None) -> T properties = dict() tlist = convert_frame(frame, self._client.channels) - tregion = convert_region(region) + tobjects = convert_objects(new) self._watchdog_reset(True) - region, properties, elapsed = self._client.initialize(tlist, tregion, properties) + status, elapsed = self._client.initialize(tlist, tobjects, properties) self._watchdog_reset(False) - return convert_traxregion(region), properties.dict(), elapsed + status = [ObjectStatus(convert_traxregion(region), properties) for region, properties in status] + + return status, elapsed + + + def update(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """ Updates the tracker with a new frame. This method is used to update the tracker with a new frame. It returns the new state of the tracker. + Args: + frame: The new frame. + new: The new state of the tracker. + properties: The properties to be set for the tracker. - def frame(self, frame: Frame, properties: dict = dict()) -> Tuple[Region, dict, float]: + Returns: + The new state of the tracker. + + Raises: + TraxException: If the tracker is not alive. + + """ if not self.alive: raise TraxException("Tracker not alive") tlist = convert_frame(frame, self._client.channels) + tobjects = convert_objects(new) + self._watchdog_reset(True) - region, properties, elapsed = self._client.frame(tlist, properties) + status, elapsed = self._client.frame(tlist, properties, tobjects) self._watchdog_reset(False) - return convert_traxregion(region), properties.dict(), elapsed + status = [ObjectStatus(convert_traxregion(region), properties) for region, properties in status] + + return status, elapsed def terminate(self): + """ Terminates the tracker. This method is used to terminate the tracker. It closes the connection to the tracker and terminates the tracker process. + """ with self._watchdog_lock: if not self.alive: @@ -314,10 +462,12 @@ def terminate(self): self._process = None def __del__(self): + """ Destructor. This method is used to terminate the tracker process if it is still alive.""" if hasattr(self, "_workdir"): shutil.rmtree(self._workdir, ignore_errors=True) def wait(self): + """ Waits for the tracker to terminate. This method is used to wait for the tracker to terminate. It waits until the tracker process terminates.""" self._watchdog_reset(True) @@ -331,8 +481,23 @@ def wait(self): class TraxTrackerRuntime(TrackerRuntime): + """ The TraX tracker runtime. This class is used to run a tracker using the TraX protocol.""" def __init__(self, tracker: Tracker, command: str, log: bool = False, timeout: int = 30, linkpaths=None, envvars=None, arguments=None, socket=False, restart=False, onerror=None): + """ Initializes the TraX tracker runtime. + + Args: + tracker: The tracker to be run. + command: The command to run the tracker. + log: Whether to log the output of the tracker. + timeout: The timeout in seconds for the tracker to respond. + linkpaths: The paths to be added to the PATH environment variable. + envvars: The environment variables to be set for the tracker. + arguments: The arguments to be passed to the tracker. + socket: Whether to use a socket to communicate with the tracker. + restart: Whether to restart the tracker if it crashes. + onerror: The error handler to be called if the tracker crashes. + """ super().__init__(tracker) self._command = command self._process = None @@ -365,9 +530,17 @@ def __init__(self, tracker: Tracker, command: str, log: bool = False, timeout: i @property def tracker(self) -> Tracker: + """ The associated tracker object. """ return self._tracker + @property + def multiobject(self): + """ Whether the tracker supports multiple objects.""" + self._connect() + return self._process._multiobject + def _connect(self): + """ Connects to the tracker. This method is used to connect to the tracker. It starts the tracker process if it is not running yet.""" if not self._process: if not self._output is None: log = self._output @@ -378,6 +551,7 @@ def _connect(self): self._restart = True def _error(self, exception): + """ Handles an error. This method is used to handle an error. It calls the error handler if it is set.""" workdir = None timeout = False if not self._output is None: @@ -409,13 +583,24 @@ def _error(self, exception): tracker_log=log if not self._output is None else None) def restart(self): + """ Restarts the tracker. This method is used to restart the tracker. It stops the tracker process and starts it again.""" try: self.stop() self._connect() except TraxException as e: self._error(e) - def initialize(self, frame: Frame, region: Region, properties: dict = None) -> Tuple[Region, dict, float]: + def initialize(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """ Initializes the tracker. This method is used to initialize the tracker. It starts the tracker process if it is not running yet. + + Args: + frame: The initial frame. + new: The initial objects. + properties: The initial properties. + + Returns: + A tuple containing the initial objects and the initial score. + """ try: if self._restart: self.stop() @@ -426,35 +611,72 @@ def initialize(self, frame: Frame, region: Region, properties: dict = None) -> T if not properties is None: tproperties.update(properties) - return self._process.initialize(frame, region, tproperties) + return self._process.initialize(frame, new, tproperties) except TraxException as e: self._error(e) - def update(self, frame: Frame, properties: dict = None) -> Tuple[Region, dict, float]: + def update(self, frame: Frame, new: Objects = None, properties: dict = None) -> Tuple[Objects, float]: + """ Updates the tracker. This method is used to update the tracker state with a new frame. + + Args: + frame: The current frame. + new: The current objects. + properties: The current properties. + + Returns: + A tuple containing the updated objects and the updated score. + """ try: if properties is None: properties = dict() - return self._process.frame(frame, properties) + return self._process.update(frame, new, properties) except TraxException as e: self._error(e) def stop(self): + """ Stops the tracker. This method is used to stop the tracker. It stops the tracker process.""" if not self._process is None: self._process.terminate() self._process = None def __del__(self): - if not self._process is None: - self._process.terminate() - self._process = None + """ Destructor. This method is used to stop the tracker process when the object is deleted.""" + self.stop() def escape_path(path): + """ Escapes a path. This method is used to escape a path. + + Args: + path: The path to escape. + + Returns: + The escaped path. + """ if sys.platform.startswith("win"): return path.replace("\\\\", "\\").replace("\\", "\\\\") else: return path def trax_python_adapter(tracker, command, envvars, paths="", log: bool = False, timeout: int = 30, linkpaths=None, arguments=None, python=None, socket=False, restart=False, **kwargs): + """ Creates a Python adapter for a tracker. This method is used to create a Python adapter for a tracker. + + Args: + tracker: The tracker to create the adapter for. + command: The command to run the tracker. + envvars: The environment variables to set. + paths: The paths to add to the Python path. + log: Whether to log the tracker output. + timeout: The timeout in seconds. + linkpaths: The paths to link. + arguments: The arguments to pass to the tracker. + python: The Python interpreter to use. + socket: Whether to use a socket to communicate with the tracker. + restart: Whether to restart the tracker after each frame. + kwargs: Additional keyword arguments. + + Returns: + The Python TraX runtime object. + """ if not isinstance(paths, list): paths = paths.split(os.pathsep) @@ -475,6 +697,25 @@ def trax_python_adapter(tracker, command, envvars, paths="", log: bool = False, return TraxTrackerRuntime(tracker, command, log=log, timeout=timeout, linkpaths=linkpaths, envvars=envvars, arguments=arguments, socket=socket, restart=restart) def trax_matlab_adapter(tracker, command, envvars, paths="", log: bool = False, timeout: int = 30, linkpaths=None, arguments=None, matlab=None, socket=False, restart=False, **kwargs): + """ Creates a Matlab adapter for a tracker. This method is used to create a Matlab adapter for a tracker. + + Args: + tracker: The tracker to create the adapter for. + command: The command to run the tracker. + envvars: The environment variables to set. + paths: The paths to add to the Matlab path. + log: Whether to log the tracker output. + timeout: The timeout in seconds. + linkpaths: The paths to link. + arguments: The arguments to pass to the tracker. + matlab: The Matlab executable to use. + socket: Whether to use a socket to communicate with the tracker. + restart: Whether to restart the tracker after each frame. + kwargs: Additional keyword arguments. + + Returns: + The Matlab TraX runtime object. + """ if not isinstance(paths, list): paths = paths.split(os.pathsep) @@ -514,6 +755,25 @@ def trax_matlab_adapter(tracker, command, envvars, paths="", log: bool = False, return TraxTrackerRuntime(tracker, command, log=log, timeout=timeout, linkpaths=linkpaths, envvars=envvars, arguments=arguments, socket=socket, restart=restart) def trax_octave_adapter(tracker, command, envvars, paths="", log: bool = False, timeout: int = 30, linkpaths=None, arguments=None, socket=False, restart=False, **kwargs): + """ Creates an Octave adapter for a tracker. This method is used to create an Octave adapter for a tracker. + + Args: + tracker: The tracker to create the adapter for. + command: The command to run the tracker. + envvars: The environment variables to set. + paths: The paths to add to the Octave path. + log: Whether to log the tracker output. + timeout: The timeout in seconds. + linkpaths: The paths to link. + arguments: The arguments to pass to the tracker. + socket: Whether to use a socket to communicate with the tracker. + restart: Whether to restart the tracker after each frame. + kwargs: Additional keyword arguments. + + Returns: + The Octave TraX runtime object. + """ + if not isinstance(paths, list): paths = paths.split(os.pathsep) diff --git a/vot/utilities/__init__.py b/vot/utilities/__init__.py index aebccaa..db90717 100644 --- a/vot/utilities/__init__.py +++ b/vot/utilities/__init__.py @@ -1,16 +1,18 @@ +""" This module contains various utility functions and classes used throughout the toolkit. """ + import os import sys import csv import re import hashlib -import errno import logging import inspect +import time import concurrent.futures as futures from logging import Formatter, LogRecord from numbers import Number -from typing import Tuple +from typing import Any, Mapping, Tuple import typing from vot import get_logger @@ -19,8 +21,18 @@ __ALIASES = dict() +def import_class(classpath: str) -> typing.Type: + """Import a class from a string by importing parent packages. + + Args: + classpath (str): String representing a canonical class name with all parent packages. + + Raises: + ImportError: Raised when -def import_class(classpath): + Returns: + [type]: [description] + """ delimiter = classpath.rfind(".") if delimiter == -1: if classpath in __ALIASES: @@ -33,7 +45,14 @@ def import_class(classpath): return getattr(module, classname) def alias(*args): - def register(cls): + """ Decorator for registering class aliases. Aliases are used to refer to classes by a short name. + + Args: + *args: A list of strings representing aliases for the class. + """ + + def register(cls: typing.Type): + """ Register the class with the given aliases. """ assert cls is not None for name in args: if name in __ALIASES: @@ -45,9 +64,25 @@ def register(cls): return register def class_fullname(o): + """Returns the full name of the class of the given object. + + Args: + o: The object to get the class name from. + + Returns: + The full name of the class of the given object. + """ return class_string(o.__class__) def class_string(kls): + """Returns the full name of the given class. + + Args: + kls: The class to get the name from. + + Returns: + The full name of the given class. + """ assert inspect.isclass(kls) module = kls.__module__ if module is None or module == str.__class__.__module__: @@ -56,8 +91,27 @@ def class_string(kls): return module + '.' + kls.__name__ def flip(size: Tuple[Number, Number]) -> Tuple[Number, Number]: + """Flips the given size tuple. + + Args: + size: The size tuple to flip. + + Returns: + The flipped size tuple. + """ return (size[1], size[0]) +def flatten(nested_list): + """Flattens a nested list. + + Args: + nested_list: The nested list to flatten. + + Returns: + The flattened list. + """ + return [item for sublist in nested_list for item in sublist] + from vot.utilities.notebook import is_notebook if is_notebook(): @@ -70,22 +124,36 @@ def flip(size: Tuple[Number, Number]) -> Tuple[Number, Number]: from tqdm import tqdm class Progress(object): + """Wrapper around tqdm progress bar, enables silecing the progress output and some more + costumizations. + """ class StreamProxy(object): + """Proxy class for tqdm to enable silent mode.""" def write(self, x): + """Write function used by tqdm.""" # Avoid print() second call (useless \n) if len(x.rstrip()) > 0: tqdm.write(x) + def flush(self): + """Flush function used by tqdm.""" #return getattr(self.file, "flush", lambda: None)() pass @staticmethod def logstream(): + """Returns a stream proxy that can be used to redirect output to the progress bar.""" return Progress.StreamProxy() def __init__(self, description="Processing", total=100): + """Creates a new progress bar. + + Args: + description: The description of the progress bar. + total: The total number of steps. + """ silent = get_logger().level > logging.INFO if not silent: @@ -99,9 +167,22 @@ def __init__(self, description="Processing", total=100): self._total = total if not silent else 0 def _percent(self, n): + """Returns the percentage of the given value. + + Args: + n: The value to compute the percentage of. + + Returns: + The percentage of the given value. + """ return int((n * 100) / self._total) def absolute(self, value): + """Sets the progress to the given value. + + Args: + value: The value to set the progress to. + """ if self._tqdm is None: if self._total == 0: return @@ -113,6 +194,11 @@ def absolute(self, value): self._tqdm.update(value - self._tqdm.n) # will also set self.n = b * bsize def relative(self, n): + """Increments the progress by the given value. + + Args: + n: The value to increment the progress by. + """ if self._tqdm is None: if self._total == 0: return @@ -124,6 +210,11 @@ def relative(self, n): self._tqdm.update(n) # will also set self.n = b * bsize def total(self, t): + """Sets the total number of steps. + + Args: + t: The total number of steps. + """ if self._tqdm is None: if self._total == 0: return @@ -135,16 +226,26 @@ def total(self, t): self._tqdm.refresh() def __enter__(self): + """Enters the context manager.""" return self def __exit__(self, exc_type, exc_value, traceback): + """Exits the context manager.""" self.close() def close(self): + """Closes the progress bar.""" if self._tqdm: self._tqdm.close() def extract_files(archive, destination, callback = None): + """Extracts all files from the given archive to the given destination. + + Args: + archive: The archive to extract the files from. + destination: The destination to extract the files to. + callback: An optional callback function that is called after each file is extracted. + """ from zipfile import ZipFile with ZipFile(file=archive) as zip_file: @@ -181,21 +282,32 @@ def read_properties(filename: str, delimiter: str = '=') -> typing.Dict[str, str properties[groups.group(1)] = groups.group(2) return properties -def write_properties(filename, dictionary, delimiter='='): - ''' Writes the provided dictionary in key sorted order to a properties +def write_properties(filename: str, dictionary: Mapping[str, Any], delimiter: str = '='): + """Writes the provided dictionary in key sorted order to a properties file with each line in the format: keyvalue - filename -- the name of the file to be written - dictionary -- a dictionary containing the key/value pairs. - ''' + + Args: + filename (str): the name of the file to be written + dictionary (Mapping[str, str]): a dictionary containing the key/value pairs. + delimiter (str, optional): _description_. Defaults to '='. + """ + open_kwargs = {'mode': 'w', 'newline': ''} if six.PY3 else {'mode': 'wb'} with open(filename, **open_kwargs) as csvfile: writer = csv.writer(csvfile, delimiter=delimiter, escapechar='\\', quoting=csv.QUOTE_NONE) writer.writerows(sorted(dictionary.items())) -def file_hash(filename): +def file_hash(filename: str) -> Tuple[str, str]: + """Calculates MD5 and SHA1 hashes based on file content - # BUF_SIZE is totally arbitrary, change for your app! + Args: + filename (str): Filename of the file to open and analyze + + Returns: + Tuple[str, str]: MD5 and SHA1 hashes as hexadecimal strings. + """ + bufsize = 65536 # lets read stuff in 64kb chunks! md5 = hashlib.md5() @@ -211,7 +323,16 @@ def file_hash(filename): return md5.hexdigest(), sha1.hexdigest() -def arg_hash(*args, **kwargs): +def arg_hash(*args, **kwargs) -> str: + """Computes hash based on input positional and keyword arguments. + + The algorithm tries to convert all arguments to string, then enclose them with delimiters. The + positonal arguments are listed as is, keyword arguments are sorted and encoded with their keys as + well as values. + + Returns: + str: SHA1 hash as hexadecimal string + """ sha1 = hashlib.sha1() for arg in args: @@ -222,9 +343,25 @@ def arg_hash(*args, **kwargs): return sha1.hexdigest() -def which(program): +def which(program: str) -> str: + """Locates an executable in system PATH list by its name. + + Args: + program (str): Name of the executable + + Returns: + str: Full path or None if not found + """ def is_exe(fpath): + """Checks if the given path is an executable file. + + Args: + fpath (str): Path to check + + Returns: + bool: True if the path is an executable file + """ return os.path.isfile(fpath) and os.access(fpath, os.X_OK) fpath, _ = os.path.split(program) @@ -240,6 +377,15 @@ def is_exe(fpath): return None def normalize_path(path, root=None): + """Normalizes the given path by making it absolute and removing redundant parts. + + Args: + path (str): Path to normalize + root (str, optional): Root path to use if the given path is relative. Defaults to None. + + Returns: + str: Normalized path + """ if os.path.isabs(path): return path if not root: @@ -247,18 +393,47 @@ def normalize_path(path, root=None): return os.path.normpath(os.path.join(root, path)) def localize_path(path): + """Converts path to local format (backslashes on Windows, slashes on Linux) + + Args: + path (str): Path to convert + + Returns: + str: Converted path + """ if sys.platform.startswith("win"): return path.replace("/", "\\") else: return path.replace("\\", "/") -def to_string(n): +def to_string(n: Any) -> str: + """Converts object to string, returs empty string if object is None (so a bit different behaviour than + the original string conversion). + + Args: + n (Any): Object of any kind + + Returns: + str: String representation (using built-in conversion) + """ if n is None: return "" else: return str(n) def to_number(val, max_n = None, min_n = None, conversion=int): + """Converts the given value to a number and checks if it is within the given range. If the value is not a number, + a RuntimeError is raised. + + Args: + val (Any): Value to convert + max_n (int, optional): Maximum allowed value. Defaults to None. + min_n (int, optional): Minimum allowed value. Defaults to None. + conversion (function, optional): Conversion function. Defaults to int. + + Returns: + int: Converted value + """ try: n = conversion(val) @@ -274,6 +449,15 @@ def to_number(val, max_n = None, min_n = None, conversion=int): raise RuntimeError("Number conversion error") def to_logical(val): + """Converts the given value to a logical value (True/False). If the value is not a logical value, + a RuntimeError is raised. + + Args: + val (Any): Value to convert + + Returns: + bool: Converted value + """ try: if isinstance(val, str): return val.lower() in ['true', '1', 't', 'y', 'yes'] @@ -283,24 +467,63 @@ def to_logical(val): except ValueError: raise RuntimeError("Logical value conversion error") +def format_size(num, suffix="B"): + """Formats the given number as a human-readable size string. + + Args: + num (int): Number to format + suffix (str, optional): Suffix to use. Defaults to "B". + + Returns: + str: Formatted string + """ + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: + if abs(num) < 1024.0: + return f"{num:3.1f}{unit}{suffix}" + num /= 1024.0 + return f"{num:.1f}Yi{suffix}" + def singleton(class_): + """Singleton decorator for classes. + + Args: + class_ (class): Class to decorate + + Returns: + class: Decorated class + + Example: + @singleton + class MyClass: + pass + + a = MyClass() + """ instances = {} def getinstance(*args, **kwargs): + """Returns the singleton instance of the class. If the instance does not exist, it is created.""" if class_ not in instances: instances[class_] = class_(*args, **kwargs) return instances[class_] return getinstance class ColoredFormatter(Formatter): + """Colored log formatter using colorama package. + """ class Empty(object): """An empty class used to copy :class:`~logging.LogRecord` objects without reinitializing them.""" def __init__(self, **kwargs): + """Initializes the formatter. + + Args: + **kwargs: Keyword arguments passed to the base class + + """ super().__init__(**kwargs) colorama.init() - self._styles = dict( debug=colorama.Fore.GREEN, verbose=colorama.Fore.BLACK, @@ -311,8 +534,15 @@ def __init__(self, **kwargs): critical=colorama.Fore.RED + colorama.Style.BRIGHT, ) + def format(self, record: LogRecord) -> str: + """Formats message by injecting colorama terminal codes for text coloring. + + Args: + record (LogRecord): Input log record - def format(self, record: LogRecord): + Returns: + str: Formatted string + """ style = self._styles[record.levelname.lower()] copy = ColoredFormatter.Empty() @@ -326,11 +556,20 @@ def format(self, record: LogRecord): class ThreadPoolExecutor(futures.ThreadPoolExecutor): + """Thread pool executor with a shutdown method that waits for all threads to finish. + """ + def __init__(self, *args, **kwargs): + """Initializes the thread pool executor.""" super().__init__(*args, **kwargs) #self._work_queue = Queue.Queue(maxsize=maxsize) def shutdown(self, wait=True): + """Shuts down the thread pool executor. If wait is True, waits for all threads to finish. + + Args: + wait (bool, optional): Wait for all threads to finish. Defaults to True. + """ import queue with self._shutdown_lock: self._shutdown = True @@ -344,3 +583,26 @@ def shutdown(self, wait=True): if wait: for t in self._threads: t.join() + +class Timer(object): + """Simple timer class for measuring elapsed time.""" + + def __init__(self, name=None): + """Initializes the timer. + + Args: + name (str, optional): Name of the timer. Defaults to None. + """ + self.name = name + + def __enter__(self): + """Starts the timer.""" + self._tstart = time.time() + + def __exit__(self, type, value, traceback): + """Stops the timer and prints the elapsed time.""" + elapsed = time.time() - self._tstart + if self.name: + print('[%s]: %.4fs' % (self.name, elapsed)) + else: + print('Elapsed: %.4fs' % elapsed) \ No newline at end of file diff --git a/vot/utilities/cli.py b/vot/utilities/cli.py index a229868..972122b 100644 --- a/vot/utilities/cli.py +++ b/vot/utilities/cli.py @@ -1,3 +1,5 @@ +"""Command line interface for the toolkit. This module provides a command line interface for the toolkit. It is used to run experiments, manage trackers and datasets, and to perform other tasks.""" + import os import sys import argparse @@ -18,6 +20,7 @@ class EnvDefault(argparse.Action): """Argparse action that resorts to a value in a specified envvar if no value is provided via program arguments. """ def __init__(self, envvar, required=True, default=None, separator=None, **kwargs): + """Initialize the action""" if not default and envvar: if envvar in os.environ: default = os.environ[envvar] @@ -30,6 +33,7 @@ def __init__(self, envvar, required=True, default=None, separator=None, **kwargs **kwargs) def __call__(self, parser, namespace, values, option_string=None): + """Call the action""" if self.separator: values = values.split(self.separator) setattr(namespace, self.dest, values) @@ -40,8 +44,11 @@ def do_test(config: argparse.Namespace): Args: config (argparse.Namespace): Configuration """ - from vot.dataset.dummy import DummySequence - from vot.dataset import load_sequence + from vot.dataset.dummy import generate_dummy + from vot.dataset import load_sequence, Frame + from vot.tracker import ObjectStatus + from vot.experiment.helpers import MultiObjectHelper + trackers = Registry(config.registry) if not config.tracker: @@ -57,57 +64,86 @@ def do_test(config: argparse.Namespace): tracker = trackers[config.tracker] - logger.info("Generating dummy sequence") - - if config.sequence is None: - sequence = DummySequence(50) - else: - sequence = load_sequence(normalize_path(config.sequence)) - - logger.info("Obtaining runtime for tracker %s", tracker.identifier) - - if config.visualize: - import matplotlib.pylab as plt - from vot.utilities.draw import MatplotlibDrawHandle - figure = plt.figure() - figure.canvas.set_window_title('VOT Test') - axes = figure.add_subplot(1, 1, 1) - axes.set_aspect("equal") - handle = MatplotlibDrawHandle(axes, size=sequence.size) - handle.style(fill=False) - figure.show() - - runtime = None - + def visualize(axes, frame: Frame, reference, state): + """Visualize the frame and the state of the tracker. + + Args: + axes (matplotlib.axes.Axes): The axes to draw on. + frame (Frame): The frame to draw. + reference (list): List of references. + state (ObjectStatus): The state of the tracker. + + """ + axes.clear() + handle.image(frame.channel()) + if not isinstance(state, list): + state = [state] + for gt, st in zip(reference, state): + handle.style(color="green").region(gt) + handle.style(color="red").region(st.region) + try: runtime = tracker.runtime(log=True) - for repeat in range(1, 4): - - logger.info("Initializing tracker ({}/{})".format(repeat, 3)) + logger.info("Generating dummy sequence") - region, _, _ = runtime.initialize(sequence.frame(0), sequence.groundtruth(0)) + if config.sequence is None: + sequence = generate_dummy(50, objects=3 if runtime.multiobject else 1) + else: + sequence = load_sequence(normalize_path(config.sequence)) + + logger.info("Obtaining runtime for tracker %s", tracker.identifier) + + context = {"continue" : True} + + def on_press(event): + """Callback for key press event. + + Args: + event (matplotlib.backend_bases.Event): The event. + """ + if event.key == 'q': + context["continue"] = False + + if config.visualize: + import matplotlib.pylab as plt + from vot.utilities.draw import MatplotlibDrawHandle + figure = plt.figure() + figure.canvas.set_window_title('VOT Test') + axes = figure.add_subplot(1, 1, 1) + axes.set_aspect("equal") + handle = MatplotlibDrawHandle(axes, size=sequence.size) + context["click"] = figure.canvas.mpl_connect('key_press_event', on_press) + handle.style(fill=False) + figure.show() + + helper = MultiObjectHelper(sequence) + + logger.info("Initializing tracker") + + frame = sequence.frame(0) + state, _ = runtime.initialize(frame, [ObjectStatus(frame.object(x), {}) for x in helper.new(0)]) + + if config.visualize: + visualize(axes, frame, [frame.object(x) for x in helper.objects(0)], state) + figure.canvas.draw() + + for i in range(1, len(sequence)): + + logger.info("Processing frame %d/%d", i, len(sequence)-1) + frame = sequence.frame(i) + state, _ = runtime.update(frame, [ObjectStatus(frame.object(x), {}) for x in helper.new(i)]) if config.visualize: - axes.clear() - handle.image(sequence.frame(0).channel()) - handle.style(color="green").region(sequence.frame(0).groundtruth()) - handle.style(color="red").region(region) + visualize(axes, frame, [frame.object(x) for x in helper.objects(i)], state) figure.canvas.draw() + figure.canvas.flush_events() - for i in range(1, sequence.length): - logger.info("Updating on frame %d/%d", i, sequence.length-1) - region, _, _ = runtime.update(sequence.frame(i)) - - if config.visualize: - axes.clear() - handle.image(sequence.frame(i).channel()) - handle.style(color="green").region(sequence.frame(i).groundtruth()) - handle.style(color="red").region(region) - figure.canvas.draw() + if not context["continue"]: + break - logger.info("Stopping tracker") + logger.info("Stopping tracker") runtime.stop() @@ -122,6 +158,11 @@ def do_test(config: argparse.Namespace): runtime.stop() def do_workspace(config: argparse.Namespace): + """Initialize / manage a workspace. + + Args: + config (argparse.Namespace): Configuration + """ from vot.workspace import WorkspaceException @@ -153,18 +194,23 @@ def do_workspace(config: argparse.Namespace): logger.error("Error during workspace initialization: %s", we) def do_evaluate(config: argparse.Namespace): + """Run an evaluation for a tracker on an experiment stack and a set of sequences. + + Args: + config (argparse.Namespace): Configuration + """ from vot.experiment import run_experiment workspace = Workspace.load(config.workspace) - logger.info("Loaded workspace in '%s'", config.workspace) + logger.debug("Loaded workspace in '%s'", config.workspace) global_registry = [os.path.abspath(x) for x in config.registry] - registry = Registry(workspace.registry + global_registry, root=config.workspace) + registry = Registry(list(workspace.registry) + global_registry, root=config.workspace) - logger.info("Found data for %d trackers", len(registry)) + logger.debug("Found data for %d trackers", len(registry)) trackers = registry.resolve(*config.trackers, storage=workspace.storage.substorage("results"), skip_unknown=False) @@ -177,7 +223,7 @@ def do_evaluate(config: argparse.Namespace): try: for tracker in trackers: - logger.info("Evaluating tracker %s", tracker.identifier) + logger.debug("Evaluating tracker %s", tracker.identifier) for experiment in workspace.stack: run_experiment(experiment, tracker, workspace.dataset, config.force, config.persist) @@ -189,19 +235,24 @@ def do_evaluate(config: argparse.Namespace): logger.error("Evaluation interrupted by tracker error: {}".format(te)) def do_analysis(config: argparse.Namespace): + """Run an analysis for a tracker on an experiment stack and a set of sequences. + + Args: + config (argparse.Namespace): Configuration + """ from vot.analysis import AnalysisProcessor, process_stack_analyses from vot.document import generate_document workspace = Workspace.load(config.workspace) - logger.info("Loaded workspace in '%s'", config.workspace) + logger.debug("Loaded workspace in '%s'", config.workspace) global_registry = [os.path.abspath(x) for x in config.registry] - registry = Registry(workspace.registry + global_registry, root=config.workspace) + registry = Registry(list(workspace.registry) + global_registry, root=config.workspace) - logger.info("Found data for %d trackers", len(registry)) + logger.debug("Found data for %d trackers", len(registry)) if not config.trackers: trackers = workspace.list_results(registry) @@ -217,7 +268,7 @@ def do_analysis(config: argparse.Namespace): if config.workers == 1: if config.debug: - from vot.analysis._processor import DebugExecutor + from vot.analysis.processor import DebugExecutor logging.getLogger("concurrent.futures").setLevel(logging.DEBUG) executor = DebugExecutor() else: @@ -263,7 +314,7 @@ def do_pack(config: argparse.Namespace): """Package results to a ZIP file so that they can be submitted to a challenge. Args: - config ([type]): [description] + config (argparse.Namespace): Configuration """ import zipfile, io @@ -271,11 +322,9 @@ def do_pack(config: argparse.Namespace): workspace = Workspace.load(config.workspace) - logger.info("Loaded workspace in '%s'", config.workspace) - - registry = Registry(workspace.registry + config.registry, root=config.workspace) + logger.debug("Loaded workspace in '%s'", config.workspace) - logger.info("Found data for %d trackers", len(registry)) + registry = Registry(list(workspace.registry) + config.registry, root=config.workspace) tracker = registry[config.tracker] @@ -287,8 +336,8 @@ def do_pack(config: argparse.Namespace): with Progress("Scanning", len(workspace.dataset) * len(workspace.stack)) as progress: for experiment in workspace.stack: - for sequence in workspace.dataset: - sequence = experiment.transform(sequence) + sequences = experiment.transform(workspace.dataset) + for sequence in sequences: complete, files, results = experiment.scan(tracker, sequence) all_files.extend([(f, experiment.identifier, sequence.name, results) for f in files]) if not complete: @@ -300,7 +349,7 @@ def do_pack(config: argparse.Namespace): logger.error("Unable to continue, experiments not complete") return - logger.info("Collected %d files, compressing to archive ...", len(all_files)) + logger.debug("Collected %d files, compressing to archive ...", len(all_files)) timestamp = datetime.now() @@ -315,8 +364,11 @@ def do_pack(config: argparse.Namespace): with zipfile.ZipFile(workspace.storage.write(archive_name, binary=True), mode="w") as archive: for f in all_files: info = zipfile.ZipInfo(filename=os.path.join(f[1], f[2], f[0]), date_time=timestamp.timetuple()) - with io.TextIOWrapper(archive.open(info, mode="w")) as fout, f[3].read(f[0]) as fin: - copyfileobj(fin, fout) + with archive.open(info, mode="w") as fout, f[3].read(f[0]) as fin: + if isinstance(fin, io.TextIOBase): + copyfileobj(fin, io.TextIOWrapper(fout)) + else: + copyfileobj(fin, fout) progress.relative(1) info = zipfile.ZipInfo(filename="manifest.yml", date_time=timestamp.timetuple()) @@ -326,13 +378,13 @@ def do_pack(config: argparse.Namespace): logger.info("Result packaging successful, archive available in %s", archive_name) def main(): - """Entrypoint to the VOT Command Line Interface utility, should be executed as a program and provided with arguments. + """Entrypoint to the toolkit Command Line Interface utility, should be executed as a program and provided with arguments. """ stream = logging.StreamHandler() stream.setFormatter(ColoredFormatter()) logger.addHandler(stream) - parser = argparse.ArgumentParser(description='VOT Toolkit Command Line Utility', prog="vot") + parser = argparse.ArgumentParser(description='VOT Toolkit Command Line Interface', prog="vot") parser.add_argument("--debug", "-d", default=False, help="Backup backend", required=False, action='store_true') parser.add_argument("--registry", default=".", help='Tracker registry paths', required=False, action=EnvDefault, \ separator=os.path.pathsep, envvar='VOT_REGISTRY') @@ -373,22 +425,29 @@ def main(): logger.setLevel(logging.INFO) - if args.debug or check_debug: + if args.debug or check_debug(): logger.setLevel(logging.DEBUG) - update, version = check_updates() - if update: - logger.warning("A newer version of the VOT toolkit is available (%s), please update.", version) + def check_version(): + """Check if a newer version of the toolkit is available.""" + update, version = check_updates() + if update: + logger.warning("A newer version of the VOT toolkit is available (%s), please update.", version) if args.action == "test": + check_version() do_test(args) elif args.action == "initialize": + check_version() do_workspace(args) elif args.action == "evaluate": + check_version() do_evaluate(args) elif args.action == "analysis": + check_version() do_analysis(args) elif args.action == "pack": + check_version() do_pack(args) else: parser.print_help() diff --git a/vot/utilities/data.py b/vot/utilities/data.py index 2c3c207..1b1aa02 100644 --- a/vot/utilities/data.py +++ b/vot/utilities/data.py @@ -1,23 +1,41 @@ +""" Data structures for storing data in a grid.""" + import functools -from typing import Tuple, Union, Iterable import unittest -import numpy as np - class Grid(object): + """ A grid is a multidimensional array with named dimensions. """ @staticmethod def scalar(obj): + """ Creates a grid with a single cell containing the given object. + + Args: + obj (object): The object to store in the grid. + """ grid = Grid(1,1) grid[0, 0] = obj return grid def __init__(self, *size): + """ Creates a grid with the given dimensions. + + Args: + size (int): The size of each dimension. + """ assert len(size) > 0 self._size = size self._data = [None] * functools.reduce(lambda x, y: x * y, size) def _ravel(self, pos): + """ Converts a multidimensional index to a single index. + + Args: + pos (tuple): The multidimensional index. + + Returns: + int: The single index. + """ if not isinstance(pos, tuple): pos = (pos, ) assert len(pos) == len(self._size) @@ -32,6 +50,14 @@ def _ravel(self, pos): return raveled def _unravel(self, index): + """ Converts a single index to a multidimensional index. + + Args: + index (int): The single index. + + Returns: + tuple: The multidimensional index. + """ unraveled = [] for n in reversed(self._size): unraveled.append(index % n) @@ -39,34 +65,76 @@ def _unravel(self, index): return tuple(reversed(unraveled)) def __str__(self): + """ Returns a string representation of the grid.""" return str(self._data) @property def dimensions(self): + """ Returns the number of dimensions of the grid. """ return len(self._size) def size(self, i: int = None): + """ Returns the size of the grid or the size of a specific dimension. + + Args: + i (int): The dimension to query. If None, the size of the grid is returned. + + Returns: + int: The size of the grid or the size of the given dimension. + """ if i is None: return tuple(self._size) assert i >= 0 and i < len(self._size) return self._size[i] def __len__(self): + """ Returns the number of elements in the grid. """ return len(self._data) def __getitem__(self, i): + """ Returns the element at the given index. + + Args: + i (tuple): The index of the element. If the grid is one-dimensional, the index can be an integer. + + Returns: + object: The element at the given index. + """ return self._data[self._ravel(i)] def __setitem__(self, i, data): + """ Sets the element at the given index. + + Args: + i (tuple): The index of the element. If the grid is one-dimensional, the index can be an integer. + data (object): The data to store at the given index. + """ self._data[self._ravel(i)] = data def __iter__(self): + """ Returns an iterator over the elements of the grid. """ return iter(self._data) def cell(self, *i): + """ Returns the element at the given index packed in a scalar grid. + + Args: + i (int): The index of the element. If the grid is one-dimensional, the index can be an integer. + + Returns: + object: The element at the given index packed in a scalar grid. + """ return Grid.scalar(self[i]) def column(self, i): + """ Returns the column at the given index. + + Args: + i (int): The index of the column. + + Returns: + Grid: The column at the given index. + """ assert self.dimensions == 2 column = Grid(1, self.size()[0]) for j in range(self.size()[0]): @@ -74,6 +142,14 @@ def column(self, i): return column def row(self, i): + """ Returns the row at the given index. + + Args: + i (int): The index of the row. + + Returns: + Grid: The row at the given index. + """ assert self.dimensions == 2 row = Grid(self.size()[1], 1) for j in range(self.size()[1]): @@ -81,6 +157,15 @@ def row(self, i): return row def foreach(self, cb) -> "Grid": + """ Applies a function to each element of the grid. + + Args: + cb (function): The function to apply to each element. The first argument is the element, the following + arguments are the indices of the element. + + Returns: + Grid: A grid containing the results of the function. + """ result = Grid(*self._size) for i, x in enumerate(self._data): @@ -90,8 +175,10 @@ def foreach(self, cb) -> "Grid": return result class TestGrid(unittest.TestCase): + """ Unit tests for the Grid class. """ def test_foreach1(self): + """ Tests the foreach method. """ a = Grid(5, 3) @@ -100,6 +187,7 @@ def test_foreach1(self): self.assertTrue(all([x == 5 for x in b]), "Output incorrect") def test_foreach2(self): + """ Tests the foreach method. """ a = Grid(5, 6, 3) diff --git a/vot/utilities/draw.py b/vot/utilities/draw.py index da9215a..cb5fcc6 100644 --- a/vot/utilities/draw.py +++ b/vot/utilities/draw.py @@ -1,4 +1,4 @@ - +""" Drawing utilities for visualizing results.""" from typing import Tuple, List, Union @@ -13,6 +13,7 @@ from io import BytesIO def show_image(a): + """Shows an image in the IPython notebook.""" try: import IPython.display except ImportError: @@ -36,19 +37,37 @@ def show_image(a): } def resolve_color(color: Union[Tuple[float, float, float], str]): + """Resolves a color to a tuple of floats. If the color is a string, it is resolved from an internal palette.""" if isinstance(color, str): return _PALETTE.get(color, (0, 0, 0, 1)) return (np.clip(color[0], 0, 1), np.clip(color[1], 0, 1), np.clip(color[2], 0, 1)) class DrawHandle(object): + """Base class for drawing handles.""" def __init__(self, color: Union[Tuple[float, float, float], str] = (1, 0, 0), width: int = 1, fill: bool = False): + """Initializes the drawing handle. + + Args: + color (tuple or str): Color of the drawing handle. + width (int): Width of the drawing handle. + fill (bool): Whether to fill the drawing handle. + """ self._color = resolve_color(color) self._width = width self._fill = fill - def style(self, color: Union[Tuple[float, float, float], str] = (1, 0, 0), width: int = 1, fill: bool = False): + def style(self, color: Union[Tuple[float, float, float], str] = (1, 0, 0), width: int = 1, fill: bool = False) -> 'DrawHandle': + """Sets the style of the drawing handle. Returns self for chaining. + + Args: + color (tuple or str): Color of the drawing handle. + width (int): Width of the drawing handle. + fill (bool): Whether to fill the drawing handle. + + Returns: + self""" color = resolve_color(color) self._color = (color[0], color[1], color[2], 1) self._width = width @@ -59,34 +78,47 @@ def style(self, color: Union[Tuple[float, float, float], str] = (1, 0, 0), width return self def region(self, region): + """Draws a region.""" region.draw(self) return self def image(self, image: Union[np.ndarray, Image.Image], offset: Tuple[int, int] = None): + """Draws an image at the given offset.""" return self def line(self, p1: Tuple[float, float], p2: Tuple[float, float]): + """Draws a line between two points.""" return self def lines(self, points: List[Tuple[float, float]]): + """Draws a line between multiple points.""" return self def polygon(self, points: List[Tuple[float, float]]): + """Draws a polygon.""" return self def points(self, points: List[Tuple[float, float]]): + """Draws points.""" return self def rectangle(self, left: float, top: float, right: float, bottom: float): + """Draws a rectangle. + + The rectangle is defined by the top-left and bottom-right corners. + """ self.polygon([(left, top), (right, top), (right, bottom), (left, bottom)]) return self def mask(self, mask: np.array, offset: Tuple[int, int] = (0, 0)): + """Draws a mask.""" return self class MatplotlibDrawHandle(DrawHandle): + """Draw handle for Matplotlib. This handle is used for drawing to a Matplotlib axis.""" def __init__(self, axis, color: Tuple[float, float, float] = (1, 0, 0), width: int = 1, fill: bool = False, size: Tuple[int, int] = None): + """Initializes a new instance of the MatplotlibDrawHandle class.""" super().__init__(color, width, fill) self._axis = axis self._size = size @@ -96,6 +128,7 @@ def __init__(self, axis, color: Tuple[float, float, float] = (1, 0, 0), width: i def image(self, image: Union[np.ndarray, Image.Image], offset: Tuple[int, int] = None): + """Draws an image at the given offset.""" if offset is None: offset = (0, 0) @@ -113,16 +146,19 @@ def image(self, image: Union[np.ndarray, Image.Image], offset: Tuple[int, int] = return self def line(self, p1: Tuple[float, float], p2: Tuple[float, float]): + """Draws a line between two points.""" self._axis.plot((p1[0], p2[0]), (p1[1], p2[1]), linewidth=self._width, color=self._color) return self def lines(self, points: List[Tuple[float, float]]): + """Draws a line between multiple points.""" x = [x for x, _ in points] y = [y for _, y in points] self._axis.plot(x, y, linewidth=self._width, color=self._color) return self def polygon(self, points: List[Tuple[float, float]]): + """Draws a polygon.""" if self._fill: poly = Polygon(points, edgecolor=self._color, linewidth=self._width, fill=True, color=self._fill) else: @@ -131,11 +167,13 @@ def polygon(self, points: List[Tuple[float, float]]): return self def points(self, points: List[Tuple[float, float]]): + """Draws points.""" x, y = zip(*points) self._axis.plot(x, y, markeredgecolor=self._color, markeredgewidth=self._width, linewidth=0) return self def mask(self, mask: np.array, offset: Tuple[int, int] = (0, 0)): + """Draws a mask.""" # TODO: segmentation should also have option of non-filled mask[mask != 0] = 1 if self._fill: @@ -156,12 +194,15 @@ def mask(self, mask: np.array, offset: Tuple[int, int] = (0, 0)): class ImageDrawHandle(DrawHandle): + """Draw handle for Pillow. This handle is used for drawing to a Pillow image.""" @staticmethod def _convert_color(c, alpha=255): + """Converts a color from a float tuple to an integer tuple.""" return (int(c[0] * 255), int(c[1] * 255), int(c[2] * 255), alpha) def __init__(self, image: Union[np.ndarray, Image.Image], color: Tuple[float, float, float] = (1, 0, 0), width: int = 1, fill: bool = False): + """Initializes a new instance of the ImageDrawHandle class.""" super().__init__(color, width, fill) if isinstance(image, np.ndarray): image = Image.fromarray(image) @@ -171,13 +212,16 @@ def __init__(self, image: Union[np.ndarray, Image.Image], color: Tuple[float, fl @property def array(self) -> np.ndarray: + """Returns the image as a numpy array.""" return np.asarray(self._image) @property def snapshot(self) -> Image.Image: + """Returns a snapshot of the current image.""" return self._image.copy() def image(self, image: Union[np.ndarray, Image.Image], offset: Tuple[int, int] = None): + """Draws an image at the given offset.""" if isinstance(image, np.ndarray): if image.dtype == np.float32 or image.dtype == np.float64: image = (image * 255).astype(np.uint8) @@ -189,11 +233,13 @@ def image(self, image: Union[np.ndarray, Image.Image], offset: Tuple[int, int] = return self def line(self, p1, p2): + """Draws a line between two points.""" color = ImageDrawHandle._convert_color(self._color) self._handle.line([p1, p2], fill=color, width=self._width) return self def lines(self, points: List[Tuple[float, float]]): + """Draws a line between multiple points.""" if len(points) == 0: return color = ImageDrawHandle._convert_color(self._color) @@ -201,6 +247,7 @@ def lines(self, points: List[Tuple[float, float]]): return self def polygon(self, points: List[Tuple[float, float]]): + """Draws a polygon.""" if len(points) == 0: return self @@ -213,12 +260,14 @@ def polygon(self, points: List[Tuple[float, float]]): return self def points(self, points: List[Tuple[float, float]]): + """Draws points.""" color = ImageDrawHandle._convert_color(self._color) for (x, y) in points: self._handle.ellipse((x - 2, y - 2, x + 2, y + 2), outline=color, width=self._width) return self def mask(self, mask: np.array, offset: Tuple[int, int] = (0, 0)): + """Draws a mask.""" if mask.size == 0: return self diff --git a/vot/utilities/migration.py b/vot/utilities/migration.py index 9e975cd..8dc1ff8 100644 --- a/vot/utilities/migration.py +++ b/vot/utilities/migration.py @@ -1,4 +1,4 @@ - +""" Migration utilities for old workspaces (legacy Matlab toolkit)""" import os import re @@ -11,13 +11,32 @@ from vot.stack import resolve_stack from vot.workspace import WorkspaceException -def migrate_matlab_workspace(directory): +def migrate_matlab_workspace(directory: str): + """ Migrates a legacy matlab workspace to the new format. + + Args: + directory (str): The directory of the workspace. + + Raises: + WorkspaceException: If the workspace is already initialized. + WorkspaceException: If the workspace is not a legacy workspace. + """ logger = logging.getLogger("vot") logger.info("Attempting to migrate workspace in %s", directory) def scan_text(pattern, content, default=None): + """ Scans the text for a pattern and returns the first match. + + Args: + pattern (str): The pattern to search for. + content (str): The content to search in. + default (str): The default value if no match is found. + + Returns: + str: The first match or the default value. + """ matches = re.findall(pattern, content) if not len(matches) == 1: return default diff --git a/vot/utilities/net.py b/vot/utilities/net.py index 488485b..0da9394 100644 --- a/vot/utilities/net.py +++ b/vot/utilities/net.py @@ -1,3 +1,4 @@ +""" Network utilities for the toolkit. """ import os import re @@ -7,23 +8,57 @@ import requests -from vot import ToolkitException +from vot import ToolkitException, get_logger class NetworkException(ToolkitException): + """ Exception raised when a network error occurs. """ pass def get_base_url(url): + """ Returns the base url of a given url. + + Args: + url (str): The url to parse. + + Returns: + str: The base url.""" return url.rsplit('/', 1)[0] def is_absolute_url(url): + """ Returns True if the given url is absolute. + + Args: + url (str): The url to parse. + + Returns: + bool: True if the url is absolute, False otherwise. + """ + return bool(urlparse(url).netloc) def join_url(url_base, url_path): + """ Joins a base url with a path. + + Args: + url_base (str): The base url. + url_path (str): The path to join. + + Returns: + str: The joined url. + """ if is_absolute_url(url_path): return url_path return urljoin(url_base, url_path) def get_url_from_gdrive_confirmation(contents): + """ Returns the url of a google drive file from the confirmation page. + + Args: + contents (str): The contents of the confirmation page. + + Returns: + str: The url of the file. + """ url = '' for line in contents.splitlines(): m = re.search(r'href="(\/uc\?export=download[^"]+)', line) @@ -45,17 +80,49 @@ def get_url_from_gdrive_confirmation(contents): def is_google_drive_url(url): + """ Returns True if the given url is a google drive url. + + Args: + url (str): The url to parse. + + Returns: + bool: True if the url is a google drive url, False otherwise. + """ m = re.match(r'^https?://drive.google.com/uc\?id=.*$', url) return m is not None def download_json(url): + """ Downloads a JSON file from the given url. + + Args: + url (str): The url to parse. + + Returns: + dict: The JSON content. + """ try: return requests.get(url).json() except requests.exceptions.RequestException as e: raise NetworkException("Unable to read JSON file {}".format(e)) -def download(url, output, callback=None, chunk_size=1024*32): +def download(url, output, callback=None, chunk_size=1024*32, retry=10): + """ Downloads a file from the given url. Supports google drive urls. + callback for progress report, automatically resumes download if connection is closed. + + Args: + url (str): The url to parse. + output (str): The output file path or file handle. + callback (function): The callback function for progress report. + chunk_size (int): The chunk size for download. + retry (int): The number of retries. + + Raises: + NetworkException: If the file is not available. + """ + + logger = get_logger() + with requests.session() as sess: is_gdrive = is_google_drive_url(url) @@ -79,6 +146,7 @@ def download(url, output, callback=None, chunk_size=1024*32): raise NetworkException("Permission denied for {}".format(gurl)) url = gurl + if output is None: if is_gdrive: m = re.search('filename="(.*)"', @@ -96,21 +164,49 @@ def download(url, output, callback=None, chunk_size=1024*32): tmp_file = None filehandle = output + position = 0 + progress = False + try: total = res.headers.get('Content-Length') - if total is not None: total = int(total) - - for chunk in res.iter_content(chunk_size=chunk_size): - filehandle.write(chunk) - if callback: - callback(len(chunk), total) - if tmp_file: - filehandle.close() - shutil.copy(tmp_file, output) - except IOError: - raise NetworkException("Error when downloading file") + while True: + try: + for chunk in res.iter_content(chunk_size=chunk_size): + filehandle.write(chunk) + position += len(chunk) + progress = True + if callback: + callback(position, total) + + if position < total: + raise requests.exceptions.RequestException("Connection closed") + + if tmp_file: + filehandle.close() + shutil.copy(tmp_file, output) + break + + except requests.exceptions.RequestException as e: + if not progress: + logger.warning("Error when downloading file, retrying") + retry-=1 + if retry < 1: + raise NetworkException("Unable to download file {}".format(e)) + res = sess.get(url, stream=True) + filehandle.seek(0) + position = 0 + else: + logger.warning("Error when downloading file, trying to resume download") + res = sess.get(url, stream=True, headers=({'Range': f'bytes={position}-'} if position > 0 else None)) + progress = False + + if position < total: + raise NetworkException("Unable to download file") + + except IOError as e: + raise NetworkException("Local I/O Error when downloading file: %s" % e) finally: try: if tmp_file: @@ -120,7 +216,17 @@ def download(url, output, callback=None, chunk_size=1024*32): return output + def download_uncompress(url, path): + """ Downloads a file from the given url and uncompress it to the given path. + + Args: + url (str): The url to parse. + path (str): The path to uncompress the file. + + Raises: + NetworkException: If the file is not available. + """ from vot.utilities import extract_files _, ext = os.path.splitext(urlparse(url).path) tmp_file = tempfile.mktemp(suffix=ext) diff --git a/vot/utilities/notebook.py b/vot/utilities/notebook.py index aedb7b3..f292fc0 100644 --- a/vot/utilities/notebook.py +++ b/vot/utilities/notebook.py @@ -1,9 +1,15 @@ +""" This module contains functions for visualization in Jupyter notebooks. """ import os import io from threading import Thread, Condition def is_notebook(): + """ Returns True if the current environment is a Jupyter notebook. + + Returns: + bool: True if the current environment is a Jupyter notebook. + """ try: from IPython import get_ipython if get_ipython() is None: @@ -17,93 +23,259 @@ def is_notebook(): else: return True - -def run_tracker(tracker: "Tracker", sequence: "Sequence"): +if is_notebook(): + from IPython.display import display from ipywidgets import widgets from vot.utilities.draw import ImageDrawHandle - def encode_image(handle): - with io.BytesIO() as output: - handle.snapshot.save(output, format="PNG") - return output.getvalue() + class SequenceView(object): + """ A widget for visualizing a sequence. """ - handle = ImageDrawHandle(sequence.frame(0).image()) + def __init__(self): + """ Initializes a new instance of the SequenceView class. + + Args: + sequence (Sequence): The sequence to visualize. + """ - button_restart = widgets.Button(description='Restart') - button_next = widgets.Button(description='Next') - button_play = widgets.Button(description='Run') - frame = widgets.Label(value="") - frame.layout.display = "none" - frame2 = widgets.Label(value="") - image = widgets.Image(value=encode_image(handle), format="png", width=sequence.size[0] * 2, height=sequence.size[1] * 2) + self._handle = ImageDrawHandle(sequence.frame(0).image()) - state = dict(frame=0, auto=False, alive=True, region=None) - condition = Condition() + self._button_restart = widgets.Button(description='Restart') + self._button_next = widgets.Button(description='Next') + self._button_play = widgets.Button(description='Run') + self._frame = widgets.Label(value="") + self._frame.layout.display = "none" + self._frame_feedback = widgets.Label(value="") + self._image = widgets.Image(value="", format="png", width=sequence.size[0] * 2, height=sequence.size[1] * 2) - buttons = widgets.HBox(children=(frame, button_restart, button_next, button_play, frame2)) + state = dict(frame=0, auto=False, alive=True, region=None) + condition = Condition() - image.value = encode_image(handle) + self._buttons = widgets.HBox(children=(frame, self._button_restart, self._button_next, button_play, frame2)) - def run(): + def _push_image(handle): + """ Pushes an image to the widget. - runtime = tracker.runtime() + Args: + handle (ImageDrawHandle): The image handle. + """ + with io.BytesIO() as output: + handle.snapshot.save(output, format="PNG") + return output.getvalue() - while state["alive"]: + def visualize_tracker(tracker: "Tracker", sequence: "Sequence"): + """ Visualizes a tracker in a Jupyter notebook. - if state["frame"] == 0: - state["region"], _, _ = runtime.initialize(sequence.frame(0), sequence.groundtruth(0)) - else: - state["region"], _, _ = runtime.update(sequence.frame(state["frame"])) + Args: + tracker (Tracker): The tracker to visualize. + sequence (Sequence): The sequence to visualize. + """ + from IPython.display import display + from ipywidgets import widgets + from vot.utilities.draw import ImageDrawHandle - update_image() + def encode_image(handle): + """ Encodes an image so that it can be displayed in a Jupyter notebook. + + Args: + handle (ImageDrawHandle): The image handle. + + Returns: + bytes: The encoded image.""" + with io.BytesIO() as output: + handle.snapshot.save(output, format="PNG") + return output.getvalue() - with condition: - condition.wait() + handle = ImageDrawHandle(sequence.frame(0).image()) - if state["frame"] == sequence.length: - state["alive"] = False - continue + button_restart = widgets.Button(description='Restart') + button_next = widgets.Button(description='Next') + button_play = widgets.Button(description='Run') + frame = widgets.Label(value="") + frame.layout.display = "none" + frame2 = widgets.Label(value="") + image = widgets.Image(value=encode_image(handle), format="png", width=sequence.size[0] * 2, height=sequence.size[1] * 2) - state["frame"] = state["frame"] + 1 + state = dict(frame=0, auto=False, alive=True, region=None) + condition = Condition() + buttons = widgets.HBox(children=(frame, button_restart, button_next, button_play, frame2)) - def update_image(): - handle.image(sequence.frame(state["frame"]).image()) - handle.style(color="green").region(sequence.frame(state["frame"]).groundtruth()) - if state["region"]: - handle.style(color="red").region(state["region"]) image.value = encode_image(handle) - frame.value = "Frame: " + str(state["frame"] - 1) - def on_click(button): - if button == button_next: - with condition: - state["auto"] = False - condition.notify() - if button == button_restart: - with condition: - state["frame"] = 0 - condition.notify() - if button == button_play: + def run(): + """ Runs the tracker. """ + + runtime = tracker.runtime() + + while state["alive"]: + + if state["frame"] == 0: + state["region"], _, _ = runtime.initialize(sequence.frame(0), sequence.groundtruth(0)) + else: + state["region"], _, _ = runtime.update(sequence.frame(state["frame"])) + + update_image() + + with condition: + condition.wait() + + if state["frame"] == len(sequence): + state["alive"] = False + continue + + state["frame"] = state["frame"] + 1 + + + def update_image(): + """ Updates the image. """ + handle.image(sequence.frame(state["frame"]).image()) + handle.style(color="green").region(sequence.frame(state["frame"]).groundtruth()) + if state["region"]: + handle.style(color="red").region(state["region"]) + image.value = encode_image(handle) + frame.value = "Frame: " + str(state["frame"] - 1) + + def on_click(button): + """ Handles a button click. """ + if button == button_next: + with condition: + state["auto"] = False + condition.notify() + if button == button_restart: + with condition: + state["frame"] = 0 + condition.notify() + if button == button_play: + with condition: + state["auto"] = not state["auto"] + button.description = "Stop" if state["auto"] else "Run" + condition.notify() + + button_next.on_click(on_click) + button_restart.on_click(on_click) + button_play.on_click(on_click) + widgets.jslink((frame, "value"), (frame2, "value")) + + def on_update(_): + """ Handles a widget update.""" with condition: - state["auto"] = not state["auto"] - button.description = "Stop" if state["auto"] else "Run" - condition.notify() + if state["auto"]: + condition.notify() + + frame2.observe(on_update, names=("value", )) + + thread = Thread(target=run) + display(widgets.Box([widgets.VBox(children=(image, buttons))])) + thread.start() + + def visualize_results(experiment: "Experiment", sequence: "Sequence"): + """ Visualizes the results of an experiment in a Jupyter notebook. + + Args: + experiment (Experiment): The experiment to visualize. + sequence (Sequence): The sequence to visualize. + + """ + + from IPython.display import display + from ipywidgets import widgets + from vot.utilities.draw import ImageDrawHandle - button_next.on_click(on_click) - button_restart.on_click(on_click) - button_play.on_click(on_click) - widgets.jslink((frame, "value"), (frame2, "value")) + def encode_image(handle): + """ Encodes an image so that it can be displayed in a Jupyter notebook. + + Args: + handle (ImageDrawHandle): The image handle. + + Returns: + bytes: The encoded image. + """ - def on_update(_): - with condition: - if state["auto"]: - condition.notify() + with io.BytesIO() as output: + handle.snapshot.save(output, format="PNG") + return output.getvalue() - frame2.observe(on_update, names=("value", )) + handle = ImageDrawHandle(sequence.frame(0).image()) + + button_restart = widgets.Button(description='Restart') + button_next = widgets.Button(description='Next') + button_play = widgets.Button(description='Run') + frame = widgets.Label(value="") + frame.layout.display = "none" + frame2 = widgets.Label(value="") + image = widgets.Image(value=encode_image(handle), format="png", width=sequence.size[0] * 2, height=sequence.size[1] * 2) + + state = dict(frame=0, auto=False, alive=True, region=None) + condition = Condition() + + buttons = widgets.HBox(children=(frame, button_restart, button_next, button_play, frame2)) + + image.value = encode_image(handle) + + def run(): + """ Runs the tracker. """ + + runtime = tracker.runtime() + + while state["alive"]: + + if state["frame"] == 0: + state["region"], _, _ = runtime.initialize(sequence.frame(0), sequence.groundtruth(0)) + else: + state["region"], _, _ = runtime.update(sequence.frame(state["frame"])) + + update_image() + + with condition: + condition.wait() + + if state["frame"] == len(sequence): + state["alive"] = False + continue + + state["frame"] = state["frame"] + 1 + + + def update_image(): + """ Updates the image. """ + handle.image(sequence.frame(state["frame"]).image()) + handle.style(color="green").region(sequence.frame(state["frame"]).groundtruth()) + if state["region"]: + handle.style(color="red").region(state["region"]) + image.value = encode_image(handle) + frame.value = "Frame: " + str(state["frame"] - 1) + + def on_click(button): + """ Handles a button click. """ + if button == button_next: + with condition: + state["auto"] = False + condition.notify() + if button == button_restart: + with condition: + state["frame"] = 0 + condition.notify() + if button == button_play: + with condition: + state["auto"] = not state["auto"] + button.description = "Stop" if state["auto"] else "Run" + condition.notify() + + button_next.on_click(on_click) + button_restart.on_click(on_click) + button_play.on_click(on_click) + widgets.jslink((frame, "value"), (frame2, "value")) + + def on_update(_): + """ Handles a widget update.""" + with condition: + if state["auto"]: + condition.notify() - thread = Thread(target=run) - display(widgets.Box([widgets.VBox(children=(image, buttons))])) - thread.start() + frame2.observe(on_update, names=("value", )) + thread = Thread(target=run) + display(widgets.Box([widgets.VBox(children=(image, buttons))])) + thread.start() \ No newline at end of file diff --git a/vot/version.py b/vot/version.py index 08d79c0..6dbe830 100644 --- a/vot/version.py +++ b/vot/version.py @@ -1 +1,4 @@ -__version__ = '0.5.1' \ No newline at end of file +""" +Toolkit version +""" +__version__ = '0.6.4' \ No newline at end of file diff --git a/vot/workspace/__init__.py b/vot/workspace/__init__.py index 7fc0626..08600d7 100644 --- a/vot/workspace/__init__.py +++ b/vot/workspace/__init__.py @@ -1,46 +1,68 @@ +"""This module contains the Workspace class that represents the main junction of trackers, datasets and experiments.""" import os import typing import importlib import yaml +from lazy_object_proxy import Proxy -from attributee import Attribute, Attributee, Nested, List, String +from attributee import Attribute, Attributee, Nested, List, String, CoerceContext from .. import ToolkitException, get_logger from ..dataset import Dataset, load_dataset from ..tracker import Registry, Tracker from ..stack import Stack, resolve_stack -from ..utilities import normalize_path, class_fullname +from ..utilities import normalize_path from ..document import ReportConfiguration from .storage import LocalStorage, Storage, NullStorage + _logger = get_logger() class WorkspaceException(ToolkitException): + """Errors related to workspace raise this exception + """ pass class StackLoader(Attribute): """Special attribute that converts a string or a dictionary input to a Stack object. """ - def coerce(self, value, ctx): + def coerce(self, value, context: typing.Optional[CoerceContext]): + """Coerce a value to a Stack object + + Args: + value (typing.Any): Value to coerce + context (typing.Optional[CoerceContext]): Coercion context + + Returns: + Stack: Coerced value + """ importlib.import_module("vot.analysis") importlib.import_module("vot.experiment") if isinstance(value, str): - stack_file = resolve_stack(value, ctx["parent"].directory) + stack_file = resolve_stack(value, context.parent.directory) if stack_file is None: raise WorkspaceException("Experiment stack does not exist") with open(stack_file, 'r') as fp: stack_metadata = yaml.load(fp, Loader=yaml.BaseLoader) - return Stack(value, ctx["parent"], **stack_metadata) + return Stack(value, context.parent, **stack_metadata) else: - return Stack(None, ctx["parent"], **value) + return Stack(None, context.parent, **value) - def dump(self, value): + def dump(self, value: "Stack") -> str: + """Dump a Stack object to a string or a dictionary + + Args: + value (Stack): Value to dump + + Returns: + str: Dumped value + """ if value.name is None: return value.dump() else: @@ -51,14 +73,14 @@ class Workspace(Attributee): given experiments on a provided dataset. """ - registry = List(String(transformer=lambda x, ctx: normalize_path(x, ctx["parent"].directory))) + registry = List(String(transformer=lambda x, ctx: normalize_path(x, ctx.parent.directory))) stack = StackLoader() sequences = String(default="sequences") report = Nested(ReportConfiguration) @staticmethod def initialize(directory: str, config: typing.Optional[typing.Dict] = None, download: bool = True) -> None: - """[summary] + """Initialize a new workspace in a given directory with the given config Args: directory (str): Root for workspace storage @@ -145,9 +167,12 @@ def __init__(self, directory: str, **kwargs): directory ([type]): [description] """ self._directory = directory - self._storage = LocalStorage(directory) if directory is not None else NullStorage() + self._storage = Proxy(lambda: LocalStorage(directory) if directory is not None else NullStorage()) + super().__init__(**kwargs) + + dataset_directory = normalize_path(self.sequences, directory) if not self.stack.dataset is None: diff --git a/vot/workspace/storage.py b/vot/workspace/storage.py index c1c5e41..df8e7f9 100644 --- a/vot/workspace/storage.py +++ b/vot/workspace/storage.py @@ -1,3 +1,4 @@ +"""Storage abstraction for the workspace.""" import os import pickle @@ -13,6 +14,8 @@ from ..dataset import Sequence from ..tracker import Tracker, Results +from attributee import Attributee, Boolean + class Storage(ABC): """Abstract superclass for workspace storage abstraction """ @@ -120,71 +123,133 @@ def copy(self, localfile: str, destination: str): pass class NullStorage(Storage): - """An implementation of dummy storage that does not save anything - """ + """An implementation of dummy storage that does not save anything.""" def results(self, tracker: Tracker, experiment: Experiment, sequence: Sequence): + """Returns results object for the given tracker, experiment, sequence combination.""" return Results(self) def __repr__(self) -> str: + """Returns a string representation of the storage object.""" return "".format(self._root) def write(self, name, binary=False): + """Opens the given file entry for writing, returns opened handle.""" if binary: return open(os.devnull, "wb") else: return open(os.devnull, "w") def documents(self): + """Lists documents in the storage.""" return [] def folders(self): + """Lists folders in the storage. Reuturns an empty list. + + Returns: + list: Empty list""" return [] def read(self, name, binary=False): + """Opens the given file entry for reading, returns opened handle. + + Returns: + None: Returns None. + """ return None def isdocument(self, name): + """Checks if given name is a document/file in this storage. + + Returns: + bool: Returns False.""" return False def isfolder(self, name): + """Checks if given name is a folder in this storage. + + Returns: + bool: Returns False. + """ return False def delete(self, name) -> bool: + """Deletes a given document. + + Returns: + bool: Returns False since nothing is deleted.""" return False def substorage(self, name): + """Returns a substorage, storage object with root in a subfolder.""" return NullStorage() def copy(self, localfile, destination): + """Copy a document to another location. Does nothing.""" return class LocalStorage(Storage): - """Storage backed by the local filesystem. - """ + """Storage backed by the local filesystem. This is the default real storage implementation.""" def __init__(self, root: str): + """Creates a new local storage object. + + Args: + root (str): Root path of the storage. + """ self._root = root self._results = os.path.join(root, "results") def __repr__(self) -> str: + """Returns a string representation of the storage object.""" return "".format(self._root) @property def base(self) -> str: + """Returns the base path of the storage.""" return self._root def results(self, tracker: Tracker, experiment: Experiment, sequence: Sequence): + """Returns results object for the given tracker, experiment, sequence combination. + + Args: + tracker (Tracker): Selected tracker + experiment (Experiment): Selected experiment + sequence (Sequence): Selected sequence + + Returns: + Results: Results object + """ storage = LocalStorage(os.path.join(self._results, tracker.reference, experiment.identifier, sequence.name)) return Results(storage) def documents(self): + """Lists documents in the storage. + + Returns: + list: List of document names. + """ return [name for name in os.listdir(self._root) if os.path.isfile(os.path.join(self._root, name))] def folders(self): + """Lists folders in the storage. + + Returns: + list: List of folder names. + """ return [name for name in os.listdir(self._root) if os.path.isdir(os.path.join(self._root, name))] def write(self, name: str, binary: bool = False): + """Opens the given file entry for writing, returns opened handle. + + Args: + name (str): File name. + binary (bool, optional): Open file in binary mode. Defaults to False. + + Returns: + file: Opened file handle. + """ full = os.path.join(self.base, name) os.makedirs(os.path.dirname(full), exist_ok=True) @@ -194,6 +259,15 @@ def write(self, name: str, binary: bool = False): return open(full, mode="w", newline="") def read(self, name, binary=False): + """Opens the given file entry for reading, returns opened handle. + + Args: + name (str): File name. + binary (bool, optional): Open file in binary mode. Defaults to False. + + Returns: + file: Opened file handle. + """ full = os.path.join(self.base, name) if binary: @@ -202,6 +276,14 @@ def read(self, name, binary=False): return open(full, mode="r", newline="") def delete(self, name) -> bool: + """Deletes a given document. Returns True if successful, False otherwise. + + Args: + name (str): File name. + + Returns: + bool: Returns True if successful, False otherwise. + """ full = os.path.join(self.base, name) if os.path.isfile(full): os.unlink(full) @@ -209,15 +291,49 @@ def delete(self, name) -> bool: return False def isdocument(self, name): + """Checks if given name is a document/file in this storage. + + Args: + name (str): Name of the entry to check + + Returns: + bool: Returns True if entry is a document, False otherwise. + """ return os.path.isfile(os.path.join(self._root, name)) def isfolder(self, name): + """Checks if given name is a folder in this storage. + + Args: + name (str): Name of the entry to check + + Returns: + bool: Returns True if entry is a folder, False otherwise. + """ return os.path.isdir(os.path.join(self._root, name)) def substorage(self, name): + """Returns a substorage, storage object with root in a subfolder. + + Args: + name (str): Name of the entry, must be a folder + + Returns: + Storage: Storage object + """ return LocalStorage(os.path.join(self.base, name)) def copy(self, localfile, destination): + """Copy a document to another location in the storage. + + Args: + localfile (str): Original location + destination (str): New location + + Raises: + IOError: If the destination is an absolute path. + + """ import shutil if os.path.isabs(destination): raise IOError("Only relative paths allowed") @@ -228,6 +344,17 @@ def copy(self, localfile, destination): shutil.move(localfile, os.path.join(self.base, full)) def directory(self, *args): + """Returns a path to a directory in the storage. + + Args: + *args: Path segments. + + Returns: + str: Path to the directory. + + Raises: + ValueError: If the path is not a directory. + """ segments = [] for arg in args: if arg is None: @@ -278,7 +405,19 @@ def _filename(self, key: typing.Union[typing.Tuple, str]) -> str: directory = "" return os.path.join(directory, filename) - def __getitem__(self, key: str): + def __getitem__(self, key: str) -> typing.Any: + """Retrieves an image from cache. If it does not exist, a KeyError is raised + + Args: + key (str): Key of the item + + Raises: + KeyError: Entry does not exist or cannot be retrieved + PickleError: Unable to + + Returns: + typing.Any: item value + """ try: return super().__getitem__(key) except KeyError as e: @@ -290,20 +429,32 @@ def __getitem__(self, key: str): data = pickle.load(filehandle) super().__setitem__(key, data) return data - except pickle.PickleError: - raise e + except pickle.PickleError as e: + raise KeyError(e) - def __setitem__(self, key: str, value: typing.Any): + def __setitem__(self, key: str, value: typing.Any) -> None: + """Sets an item for given key + + Args: + key (str): Item key + value (typing.Any): Item value + + """ super().__setitem__(key, value) filename = self._filename(key) try: with self._storage.write(filename, binary=True) as filehandle: - return pickle.dump(value, filehandle) + pickle.dump(value, filehandle) except pickle.PickleError: pass - def __delitem__(self, key: str): + def __delitem__(self, key: str) -> None: + """Operator for item deletion. + + Args: + key (str): Key of object to remove + """ try: super().__delitem__(key) filename = self._filename(key) @@ -314,6 +465,14 @@ def __delitem__(self, key: str): except KeyError: pass - def __contains__(self, key: str): + def __contains__(self, key: str) -> bool: + """Magic method, does the cache include an item for a given key. + + Args: + key (str): Item key + + Returns: + bool: True if object exists for a given key + """ filename = self._filename(key) return self._storage.isdocument(filename) diff --git a/vot/workspace/tests.py b/vot/workspace/tests.py index bfd1b64..4ead470 100644 --- a/vot/workspace/tests.py +++ b/vot/workspace/tests.py @@ -1,4 +1,4 @@ - +"""Tests for workspace related methods and classes.""" import logging import tempfile @@ -9,8 +9,12 @@ from vot.workspace import Workspace, NullStorage class TestStacks(unittest.TestCase): + """Tests for workspace related methods + """ def test_void_storage(self): + """Test if void storage works + """ storage = NullStorage() @@ -20,6 +24,8 @@ def test_void_storage(self): self.assertIsNone(storage.read("test.data")) def test_local_storage(self): + """Test if local storage works + """ with tempfile.TemporaryDirectory() as testdir: storage = LocalStorage(testdir) @@ -32,16 +38,21 @@ def test_local_storage(self): # TODO: more tests def test_workspace_create(self): + """Test if workspace creation works + """ get_logger().setLevel(logging.WARN) # Disable progress bar - default_config = dict(stack="testing", registry=["./trackers.ini"]) + default_config = dict(stack="tests/basic", registry=["./trackers.ini"]) with tempfile.TemporaryDirectory() as testdir: Workspace.initialize(testdir, default_config, download=True) Workspace.load(testdir) def test_cache(self): + """Test if local storage cache works + """ + with tempfile.TemporaryDirectory() as testdir: cache = Cache(LocalStorage(testdir))