From 78dfef04d39d6e07c384a125c82b4f6e8cdd6bcc Mon Sep 17 00:00:00 2001 From: mrepetto Date: Wed, 18 Sep 2019 12:39:59 -0300 Subject: [PATCH] migrated code to python 3.6.6 and refactored some code to improve it. --- dejavu.cnf.SAMPLE | 6 +- dejavu.py | 6 +- dejavu/__init__.py | 149 +++++---- dejavu/config/config.py | 74 +++++ dejavu/database.py | 38 +-- dejavu/database_handler/__init__.py | 0 dejavu/database_handler/mysql_database.py | 235 ++++++++++++++ dejavu/database_handler/mysql_queries.py | 126 ++++++++ dejavu/database_sql.py | 373 ---------------------- dejavu/decoder.py | 15 +- dejavu/fingerprint.py | 109 ++----- dejavu/recognize.py | 30 +- dejavu/testing.py | 109 ++++--- dejavu/wavio.py | 1 + example.py | 42 +-- requirements.txt | 14 +- run_tests.py | 8 +- setup.py | 6 +- 18 files changed, 681 insertions(+), 660 deletions(-) create mode 100644 dejavu/config/config.py create mode 100644 dejavu/database_handler/__init__.py create mode 100755 dejavu/database_handler/mysql_database.py create mode 100644 dejavu/database_handler/mysql_queries.py delete mode 100755 dejavu/database_sql.py diff --git a/dejavu.cnf.SAMPLE b/dejavu.cnf.SAMPLE index cd677b0d..9a89e256 100755 --- a/dejavu.cnf.SAMPLE +++ b/dejavu.cnf.SAMPLE @@ -2,7 +2,7 @@ "database": { "host": "127.0.0.1", "user": "root", - "passwd": "12345678", - "db": "dejavu" + "password": "rootpass", + "database": "dejavu" } -} \ No newline at end of file +} diff --git a/dejavu.py b/dejavu.py index a7f74a16..db0c73c2 100755 --- a/dejavu.py +++ b/dejavu.py @@ -24,7 +24,7 @@ def init(configpath): with open(configpath) as f: config = json.load(f) except IOError as err: - print("Cannot open configuration: %s. Exiting" % (str(err))) + print(("Cannot open configuration: %s. Exiting" % (str(err)))) sys.exit(1) # create a Dejavu instance @@ -67,8 +67,8 @@ def init(configpath): if len(args.fingerprint) == 2: directory = args.fingerprint[0] extension = args.fingerprint[1] - print("Fingerprinting all .%s files in the %s directory" - % (extension, directory)) + print(("Fingerprinting all .%s files in the %s directory" + % (extension, directory))) djv.fingerprint_directory(directory, ["." + extension], 4) elif len(args.fingerprint) == 1: diff --git a/dejavu/__init__.py b/dejavu/__init__.py index 7cc3f3b6..bf89d3a7 100755 --- a/dejavu/__init__.py +++ b/dejavu/__init__.py @@ -1,28 +1,23 @@ -from dejavu.database import get_database, Database -import dejavu.decoder as decoder -import fingerprint import multiprocessing import os -import traceback import sys +import traceback +import dejavu.decoder as decoder +from dejavu.config.config import (CONFIDENCE, DEFAULT_FS, + DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE, + FIELD_FILE_SHA1, OFFSET, OFFSET_SECS, + SONG_ID, SONG_NAME, TOPN) +from dejavu.database import get_database +from dejavu.fingerprint import fingerprint -class Dejavu(object): - - SONG_ID = "song_id" - SONG_NAME = 'song_name' - CONFIDENCE = 'confidence' - MATCH_TIME = 'match_time' - OFFSET = 'offset' - OFFSET_SECS = 'offset_seconds' +class Dejavu: def __init__(self, config): - super(Dejavu, self).__init__() - self.config = config # initialize db - db_cls = get_database(config.get("database_type", None)) + db_cls = get_database(config.get("database_type", "mysql").lower()) self.db = db_cls(**config.get("database", {})) self.db.setup() @@ -39,7 +34,7 @@ def get_fingerprinted_songs(self): self.songs = self.db.get_songs() self.songhashes_set = set() # to know which ones we've computed before for song in self.songs: - song_hash = song[Database.FIELD_FILE_SHA1] + song_hash = song[FIELD_FILE_SHA1] self.songhashes_set.add(song_hash) def fingerprint_directory(self, path, extensions, nprocesses=None): @@ -55,26 +50,23 @@ def fingerprint_directory(self, path, extensions, nprocesses=None): filenames_to_fingerprint = [] for filename, _ in decoder.find_files(path, extensions): - # don't refingerprint already fingerprinted files if decoder.unique_hash(filename) in self.songhashes_set: - print "%s already fingerprinted, continuing..." % filename + print(f"{filename} already fingerprinted, continuing...") continue filenames_to_fingerprint.append(filename) # Prepare _fingerprint_worker input - worker_input = zip(filenames_to_fingerprint, - [self.limit] * len(filenames_to_fingerprint)) + worker_input = list(zip(filenames_to_fingerprint, [self.limit] * len(filenames_to_fingerprint))) # Send off our tasks - iterator = pool.imap_unordered(_fingerprint_worker, - worker_input) + iterator = pool.imap_unordered(_fingerprint_worker, worker_input) # Loop till we have all of them while True: try: - song_name, hashes, file_hash = iterator.next() + song_name, hashes, file_hash = next(iterator) except multiprocessing.TimeoutError: continue except StopIteration: @@ -99,7 +91,7 @@ def fingerprint_file(self, filepath, song_name=None): song_name = song_name or songname # don't refingerprint already fingerprinted files if song_hash in self.songhashes_set: - print "%s already fingerprinted, continuing..." % song_name + print(f"{song_name} already fingerprinted, continuing...") else: song_name, hashes, file_hash = _fingerprint_worker( filepath, @@ -112,22 +104,21 @@ def fingerprint_file(self, filepath, song_name=None): self.db.set_song_fingerprinted(sid) self.get_fingerprinted_songs() - def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS): - hashes = fingerprint.fingerprint(samples, Fs=Fs) + def find_matches(self, samples, Fs=DEFAULT_FS): + hashes = fingerprint(samples, Fs=Fs) return self.db.return_matches(hashes) - def align_matches(self, matches): + def align_matches(self, matches, topn=TOPN): """ Finds hash matches that align in time with other matches and finds consensus about which hashes are "true" signal from the audio. - Returns a dictionary with match information. + Returns a list of dictionaries (based on topn) with match information. """ # align by diffs diff_counter = {} - largest = 0 largest_count = 0 - song_id = -1 + for tup in matches: sid, diff = tup if diff not in diff_counter: @@ -137,30 +128,65 @@ def align_matches(self, matches): diff_counter[diff][sid] += 1 if diff_counter[diff][sid] > largest_count: - largest = diff largest_count = diff_counter[diff][sid] - song_id = sid - # extract idenfication - song = self.db.get_song_by_id(song_id) - if song: - # TODO: Clarify what `get_song_by_id` should return. - songname = song.get(Dejavu.SONG_NAME, None) - else: - return None - - # return match info - nseconds = round(float(largest) / fingerprint.DEFAULT_FS * - fingerprint.DEFAULT_WINDOW_SIZE * - fingerprint.DEFAULT_OVERLAP_RATIO, 5) - song = { - Dejavu.SONG_ID : song_id, - Dejavu.SONG_NAME : songname.encode("utf8"), - Dejavu.CONFIDENCE : largest_count, - Dejavu.OFFSET : int(largest), - Dejavu.OFFSET_SECS : nseconds, - Database.FIELD_FILE_SHA1 : song.get(Database.FIELD_FILE_SHA1, None).encode("utf8"),} - return song + # create dic where key are songs ids + songs_num_matches = {} + for dc in diff_counter: + for sid in diff_counter[dc]: + match_val = diff_counter[dc][sid] + if (sid not in songs_num_matches) or (match_val > songs_num_matches[sid]['value']): + songs_num_matches[sid] = { + 'sid': sid, + 'value': match_val, + 'largest': dc + } + + # use dicc of songs to create an ordered (descending) list using the match value property assigned to each song + songs_num_matches_list = [] + for s in songs_num_matches: + songs_num_matches_list.append({ + 'sid': s, + 'object': songs_num_matches[s] + }) + + songs_num_matches_list_ordered = sorted(songs_num_matches_list, key=lambda x: x['object']['value'], + reverse=True) + + # iterate the ordered list and fill results + songs_result = [] + for s in songs_num_matches_list_ordered: + + # get expected variable by the original code + song_id = s['object']['sid'] + largest = s['object']['largest'] + largest_count = s['object']['value'] + + # extract identification + song = self.db.get_song_by_id(song_id) + if song: + # TODO: Clarify what `get_song_by_id` should return. + songname = song.get(SONG_NAME, None) + + # return match info + nseconds = round(float(largest) / DEFAULT_FS * + DEFAULT_WINDOW_SIZE * + DEFAULT_OVERLAP_RATIO, 5) + song = { + SONG_ID: song_id, + SONG_NAME: songname.encode("utf8"), + CONFIDENCE: largest_count, + OFFSET: int(largest), + OFFSET_SECS: nseconds, + FIELD_FILE_SHA1: song.get(FIELD_FILE_SHA1, None).encode("utf8") + } + + songs_result.append(song) + + # only consider up to topn elements in the result + if len(songs_result) > topn: + break + return songs_result def recognize(self, recognizer, *options, **kwoptions): r = recognizer(self) @@ -177,26 +203,15 @@ def _fingerprint_worker(filename, limit=None, song_name=None): songname, extension = os.path.splitext(os.path.basename(filename)) song_name = song_name or songname - channels, Fs, file_hash = decoder.read(filename, limit) + channels, fs, file_hash = decoder.read(filename, limit) result = set() channel_amount = len(channels) for channeln, channel in enumerate(channels): # TODO: Remove prints or change them into optional logging. - print("Fingerprinting channel %d/%d for %s" % (channeln + 1, - channel_amount, - filename)) - hashes = fingerprint.fingerprint(channel, Fs=Fs) - print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount, - filename)) + print(f"Fingerprinting channel {channeln + 1}/{channel_amount} for {filename}") + hashes = fingerprint(channel, Fs=fs) + print(f"Finished channel {channeln + 1}/{channel_amount} for {filename}") result |= set(hashes) return song_name, result, file_hash - - -def chunkify(lst, n): - """ - Splits a list into roughly n equal parts. - http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts - """ - return [lst[i::n] for i in xrange(n)] diff --git a/dejavu/config/config.py b/dejavu/config/config.py new file mode 100644 index 00000000..4e28bcc6 --- /dev/null +++ b/dejavu/config/config.py @@ -0,0 +1,74 @@ +# Dejavu +SONG_ID = "song_id" +SONG_NAME = 'song_name' +CONFIDENCE = 'confidence' +MATCH_TIME = 'match_time' +OFFSET = 'offset' +OFFSET_SECS = 'offset_seconds' + +# DATABASE CLASS INSTANCES: +DATABASES = { + 'mysql': ("dejavu.database_handler.mysql_database", "MySQLDatabase") +} + +# TABLE SONGS +SONGS_TABLENAME = "songs" + +# SONGS FIELDS +FIELD_SONG_ID = 'song_id' +FIELD_SONGNAME = 'song_name' +FIELD_FINGERPRINTED = "fingerprinted" +FIELD_FILE_SHA1 = 'file_sha1' + +# TABLE FINGERPRINTS +FINGERPRINTS_TABLENAME = "fingerprints" + +# FINGERPRINTS FIELDS +FIELD_HASH = 'hash' +FIELD_OFFSET = 'offset' + +# FINGERPRINTS CONFIG: +# Sampling rate, related to the Nyquist conditions, which affects +# the range frequencies we can detect. +DEFAULT_FS = 44100 + +# Size of the FFT window, affects frequency granularity +DEFAULT_WINDOW_SIZE = 4096 + +# Ratio by which each sequential window overlaps the last and the +# next window. Higher overlap will allow a higher granularity of offset +# matching, but potentially more fingerprints. +DEFAULT_OVERLAP_RATIO = 0.5 + +# Degree to which a fingerprint can be paired with its neighbors -- +# higher will cause more fingerprints, but potentially better accuracy. +DEFAULT_FAN_VALUE = 15 + +# Minimum amplitude in spectrogram in order to be considered a peak. +# This can be raised to reduce number of fingerprints, but can negatively +# affect accuracy. +DEFAULT_AMP_MIN = 10 + +# Number of cells around an amplitude peak in the spectrogram in order +# for Dejavu to consider it a spectral peak. Higher values mean less +# fingerprints and faster matching, but can potentially affect accuracy. +PEAK_NEIGHBORHOOD_SIZE = 20 + +# Thresholds on how close or far fingerprints can be in time in order +# to be paired as a fingerprint. If your max is too low, higher values of +# DEFAULT_FAN_VALUE may not perform as expected. +MIN_HASH_TIME_DELTA = 0 +MAX_HASH_TIME_DELTA = 200 + +# If True, will sort peaks temporally for fingerprinting; +# not sorting will cut down number of fingerprints, but potentially +# affect performance. +PEAK_SORT = True + +# Number of bits to grab from the front of the SHA1 hash in the +# fingerprint calculation. The more you grab, the more memory storage, +# with potentially lesser collisions of matches. +FINGERPRINT_REDUCTION = 20 + +# Number of results being returned for file recognition +TOPN = 2 \ No newline at end of file diff --git a/dejavu/database.py b/dejavu/database.py index e5732ff0..bc7154cd 100755 --- a/dejavu/database.py +++ b/dejavu/database.py @@ -1,22 +1,15 @@ -from __future__ import absolute_import import abc +import importlib +from dejavu.config.config import DATABASES -class Database(object): - __metaclass__ = abc.ABCMeta - - FIELD_FILE_SHA1 = 'file_sha1' - FIELD_SONG_ID = 'song_id' - FIELD_SONGNAME = 'song_name' - FIELD_OFFSET = 'offset' - FIELD_HASH = 'hash' - +class Database(object, metaclass=abc.ABCMeta): # Name of your Database subclass, this is used in configuration # to refer to your class type = None def __init__(self): - super(Database, self).__init__() + super().__init__() def before_fork(self): """ @@ -159,18 +152,11 @@ def return_matches(self, hashes): pass -def get_database(database_type=None): - # Default to using the mysql database - database_type = database_type or "mysql" - # Lower all the input. - database_type = database_type.lower() - - for db_cls in Database.__subclasses__(): - if db_cls.type == database_type: - return db_cls - - raise TypeError("Unsupported database type supplied.") - - -# Import our default database handler -import dejavu.database_sql +def get_database(database_type="mysql"): + path, db_class_name = DATABASES[database_type] + try: + db_module = importlib.import_module(path) + db_class = getattr(db_module, db_class_name) + return db_class + except ImportError: + raise TypeError("Unsupported database type supplied.") diff --git a/dejavu/database_handler/__init__.py b/dejavu/database_handler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dejavu/database_handler/mysql_database.py b/dejavu/database_handler/mysql_database.py new file mode 100755 index 00000000..4fb9bea7 --- /dev/null +++ b/dejavu/database_handler/mysql_database.py @@ -0,0 +1,235 @@ +import queue + +import mysql.connector +from mysql.connector.errors import DatabaseError + +import dejavu.database_handler.mysql_queries as queries +from dejavu.database import Database + + +class MySQLDatabase(Database): + type = "mysql" + + def __init__(self, **options): + super().__init__() + self.cursor = cursor_factory(**options) + self._options = options + + def after_fork(self): + # Clear the cursor cache, we don't want any stale connections from + # the previous process. + Cursor.clear_cache() + + def setup(self): + """ + Creates any non-existing tables required for dejavu to function. + + This also removes all songs that have been added but have no + fingerprints associated with them. + """ + with self.cursor() as cur: + cur.execute(queries.CREATE_SONGS_TABLE) + cur.execute(queries.CREATE_FINGERPRINTS_TABLE) + cur.execute(queries.DELETE_UNFINGERPRINTED) + + def empty(self): + """ + Drops tables created by dejavu and then creates them again + by calling `SQLDatabase.setup`. + + .. warning: + This will result in a loss of data + """ + with self.cursor() as cur: + cur.execute(queries.DROP_FINGERPRINTS) + cur.execute(queries.DROP_SONGS) + + self.setup() + + def delete_unfingerprinted_songs(self): + """ + Removes all songs that have no fingerprints associated with them. + """ + with self.cursor() as cur: + cur.execute(queries.DELETE_UNFINGERPRINTED) + + def get_num_songs(self): + """ + Returns number of songs the database has fingerprinted. + """ + with self.cursor() as cur: + cur.execute(queries.SELECT_UNIQUE_SONG_IDS) + count = cur.fetchone()[0] if cur.rowcount != 0 else 0 + + return count + + def get_num_fingerprints(self): + """ + Returns number of fingerprints the database has fingerprinted. + """ + with self.cursor() as cur: + cur.execute(queries.SELECT_NUM_FINGERPRINTS) + count = cur.fetchone()[0] if cur.rowcount != 0 else 0 + cur.close() + + return count + + def set_song_fingerprinted(self, sid): + """ + Set the fingerprinted flag to TRUE (1) once a song has been completely + fingerprinted in the database. + """ + with self.cursor() as cur: + cur.execute(queries.UPDATE_SONG_FINGERPRINTED, (sid,)) + + def get_songs(self): + """ + Return songs that have the fingerprinted flag set TRUE (1). + """ + with self.cursor(dictionary=True) as cur: + cur.execute(queries.SELECT_SONGS) + for row in cur: + yield row + + def get_song_by_id(self, sid): + """ + Returns song by its ID. + """ + with self.cursor(dictionary=True) as cur: + cur.execute(queries.SELECT_SONG, (sid,)) + return cur.fetchone() + + def insert(self, hash, sid, offset): + """ + Insert a (sha1, song_id, offset) row into database. + """ + with self.cursor() as cur: + cur.execute(queries.INSERT_FINGERPRINT, (hash, sid, offset)) + + def insert_song(self, song_name, file_hash): + """ + Inserts song in the database and returns the ID of the inserted record. + """ + with self.cursor() as cur: + cur.execute(queries.INSERT_SONG, (song_name, file_hash)) + return cur.lastrowid + + def query(self, hash): + """ + Return all tuples associated with hash. + + If hash is None, returns all entries in the + database (be careful with that one!). + """ + if hash: + with self.cursor() as cur: + cur.execute(queries.SELECT, (hash,)) + for sid, offset in cur: + yield (sid, offset) + else: # select all if no key + with self.cursor() as cur: + cur.execute(queries.SELECT_ALL) + for sid, offset in cur: + yield (sid, offset) + + def get_iterable_kv_pairs(self): + """ + Returns all tuples in database. + """ + return self.query(None) + + def insert_hashes(self, sid, hashes, batch=1000): + """ + Insert series of hash => song_id, offset + values into the database. + """ + values = [(sid, hash, int(offset)) for hash, offset in hashes] + + with self.cursor() as cur: + for index in range(0, len(hashes), batch): + cur.executemany(queries.INSERT_FINGERPRINT, values[index: index + batch]) + + def return_matches(self, hashes, batch=1000): + """ + Return the (song_id, offset_diff) tuples associated with + a list of (sha1, sample_offset) values. + """ + # Create a dictionary of hash => offset pairs for later lookups + mapper = {} + for hash, offset in hashes: + mapper[hash.upper()] = offset + + # Get an iterable of all the hashes we need + values = list(mapper.keys()) + + with self.cursor() as cur: + for index in range(0, len(values), batch): + # Create our IN part of the query + query = queries.SELECT_MULTIPLE + query = query % ', '.join(['UNHEX(%s)'] * len(values[index: index + batch])) + + cur.execute(query, values[index: index + batch]) + + for hash, sid, offset in cur: + # (sid, db_offset - song_sampled_offset) + yield (sid, offset - mapper[hash]) + + def __getstate__(self): + return self._options, + + def __setstate__(self, state): + self._options, = state + self.cursor = cursor_factory(**self._options) + + +def cursor_factory(**factory_options): + def cursor(**options): + options.update(factory_options) + return Cursor(**options) + return cursor + + +class Cursor(object): + """ + Establishes a connection to the database and returns an open cursor. + # Use as context manager + with Cursor() as cur: + cur.execute(query) + ... + """ + def __init__(self, dictionary=False, **options): + super().__init__() + + self._cache = queue.Queue(maxsize=5) + + try: + conn = self._cache.get_nowait() + # Ping the connection before using it from the cache. + conn.ping(True) + except queue.Empty: + conn = mysql.connector.connect(**options) + + self.conn = conn + self.dictionary = dictionary + + @classmethod + def clear_cache(cls): + cls._cache = queue.Queue(maxsize=5) + + def __enter__(self): + self.cursor = self.conn.cursor(dictionary=self.dictionary) + return self.cursor + + def __exit__(self, extype, exvalue, traceback): + # if we had a MySQL related error we try to rollback the cursor. + if extype is DatabaseError: + self.cursor.rollback() + + self.cursor.close() + self.conn.commit() + + # Put it back on the queue + try: + self._cache.put_nowait(self.conn) + except queue.Full: + self.conn.close() diff --git a/dejavu/database_handler/mysql_queries.py b/dejavu/database_handler/mysql_queries.py new file mode 100644 index 00000000..d3f78a29 --- /dev/null +++ b/dejavu/database_handler/mysql_queries.py @@ -0,0 +1,126 @@ +from dejavu.config.config import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED, + FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID, + FIELD_SONGNAME, FINGERPRINTS_TABLENAME, + SONGS_TABLENAME) + +""" +Queries: + +1) Find duplicates (shouldn't be any, though): + + select `hash`, `song_id`, `offset`, count(*) cnt + from fingerprints + group by `hash`, `song_id`, `offset` + having cnt > 1 + order by cnt asc; + +2) Get number of hashes by song: + + select song_id, song_name, count(song_id) as num + from fingerprints + natural join songs + group by song_id + order by count(song_id) desc; + +3) get hashes with highest number of collisions + + select + hash, + count(distinct song_id) as n + from fingerprints + group by `hash` + order by n DESC; + +=> 26 different songs with same fingerprint (392 times): + + select songs.song_name, fingerprints.offset + from fingerprints natural join songs + where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73"; +""" + +# creates +CREATE_SONGS_TABLE = f""" + CREATE TABLE IF NOT EXISTS `{SONGS_TABLENAME}` ( + `{FIELD_SONG_ID}` mediumint unsigned not null auto_increment, + `{FIELD_SONGNAME}` varchar(250) not null, + `{FIELD_FINGERPRINTED}` tinyint default 0, + `{FIELD_FILE_SHA1}` binary(20) not null, + PRIMARY KEY (`{FIELD_SONG_ID}`), + UNIQUE KEY `{FIELD_SONG_ID}` (`{FIELD_SONG_ID}`) +) ENGINE=INNODB;""" + +CREATE_FINGERPRINTS_TABLE = f""" + CREATE TABLE IF NOT EXISTS `{FINGERPRINTS_TABLENAME}` ( + `{FIELD_HASH}` binary(10) not null, + `{FIELD_SONG_ID}` mediumint unsigned not null, + `{FIELD_OFFSET}` int unsigned not null, + INDEX ({FIELD_HASH}), + UNIQUE KEY `unique_constraint` ({FIELD_SONG_ID}, {FIELD_OFFSET}, {FIELD_HASH}), + FOREIGN KEY ({FIELD_SONG_ID}) REFERENCES {SONGS_TABLENAME}({FIELD_SONG_ID}) ON DELETE CASCADE +) ENGINE=INNODB;""" + +# inserts (ignores duplicates) +INSERT_FINGERPRINT = f""" + INSERT IGNORE INTO `{FINGERPRINTS_TABLENAME}` ( + `{FIELD_SONG_ID}` + , `{FIELD_HASH}` + , `{FIELD_OFFSET}`) + VALUES (%s, UNHEX(%s), %s); +""" + +INSERT_SONG = f""" + INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`) + VALUES (%s, UNHEX(%s)); +""" + +# selects +SELECT = f""" + SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` + FROM `{FINGERPRINTS_TABLENAME}` + WHERE `{FIELD_HASH}` = UNHEX(%s); +""" + +SELECT_MULTIPLE = f""" + SELECT HEX(`{FIELD_HASH}`), `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` + FROM `{FINGERPRINTS_TABLENAME}` + WHERE `{FIELD_HASH}` IN (%s); +""" + +SELECT_ALL = f"SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` FROM `{FINGERPRINTS_TABLENAME}`;" + +SELECT_SONG = f""" + SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}` + FROM `{SONGS_TABLENAME}` + WHERE `{FIELD_SONG_ID}` = %s; +""" + +SELECT_NUM_FINGERPRINTS = f"SELECT COUNT(*) AS n FROM `{FINGERPRINTS_TABLENAME}`;" + +SELECT_UNIQUE_SONG_IDS = f""" + SELECT COUNT(`{FIELD_SONG_ID}`) AS n + FROM `{SONGS_TABLENAME}` + WHERE `{FIELD_FINGERPRINTED}` = 1; +""" + +SELECT_SONGS = f""" + SELECT + `{FIELD_SONG_ID}` + , `{FIELD_SONGNAME}` + , HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}` + FROM `{SONGS_TABLENAME}` + WHERE `{FIELD_FINGERPRINTED}` = 1; +""" + +# drops +DROP_FINGERPRINTS = f"DROP TABLE IF EXISTS `{FINGERPRINTS_TABLENAME}`;" +DROP_SONGS = f"DROP TABLE IF EXISTS `{SONGS_TABLENAME}`;" + +# update +UPDATE_SONG_FINGERPRINTED = f""" + UPDATE `{SONGS_TABLENAME}` SET `{FIELD_FINGERPRINTED}` = 1 WHERE `{FIELD_SONG_ID}` = %s; +""" + +# delete +DELETE_UNFINGERPRINTED = f""" + DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 0; +""" diff --git a/dejavu/database_sql.py b/dejavu/database_sql.py deleted file mode 100755 index 0fe2e68d..00000000 --- a/dejavu/database_sql.py +++ /dev/null @@ -1,373 +0,0 @@ -from __future__ import absolute_import -from itertools import izip_longest -import Queue - -import MySQLdb as mysql -from MySQLdb.cursors import DictCursor - -from dejavu.database import Database - - -class SQLDatabase(Database): - """ - Queries: - - 1) Find duplicates (shouldn't be any, though): - - select `hash`, `song_id`, `offset`, count(*) cnt - from fingerprints - group by `hash`, `song_id`, `offset` - having cnt > 1 - order by cnt asc; - - 2) Get number of hashes by song: - - select song_id, song_name, count(song_id) as num - from fingerprints - natural join songs - group by song_id - order by count(song_id) desc; - - 3) get hashes with highest number of collisions - - select - hash, - count(distinct song_id) as n - from fingerprints - group by `hash` - order by n DESC; - - => 26 different songs with same fingerprint (392 times): - - select songs.song_name, fingerprints.offset - from fingerprints natural join songs - where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73"; - """ - - type = "mysql" - - # tables - FINGERPRINTS_TABLENAME = "fingerprints" - SONGS_TABLENAME = "songs" - - # fields - FIELD_FINGERPRINTED = "fingerprinted" - - # creates - CREATE_FINGERPRINTS_TABLE = """ - CREATE TABLE IF NOT EXISTS `%s` ( - `%s` binary(10) not null, - `%s` mediumint unsigned not null, - `%s` int unsigned not null, - INDEX (%s), - UNIQUE KEY `unique_constraint` (%s, %s, %s), - FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE - ) ENGINE=INNODB;""" % ( - FINGERPRINTS_TABLENAME, Database.FIELD_HASH, - Database.FIELD_SONG_ID, Database.FIELD_OFFSET, Database.FIELD_HASH, - Database.FIELD_SONG_ID, Database.FIELD_OFFSET, Database.FIELD_HASH, - Database.FIELD_SONG_ID, SONGS_TABLENAME, Database.FIELD_SONG_ID - ) - - CREATE_SONGS_TABLE = """ - CREATE TABLE IF NOT EXISTS `%s` ( - `%s` mediumint unsigned not null auto_increment, - `%s` varchar(250) not null, - `%s` tinyint default 0, - `%s` binary(20) not null, - PRIMARY KEY (`%s`), - UNIQUE KEY `%s` (`%s`) - ) ENGINE=INNODB;""" % ( - SONGS_TABLENAME, Database.FIELD_SONG_ID, Database.FIELD_SONGNAME, FIELD_FINGERPRINTED, - Database.FIELD_FILE_SHA1, - Database.FIELD_SONG_ID, Database.FIELD_SONG_ID, Database.FIELD_SONG_ID, - ) - - # inserts (ignores duplicates) - INSERT_FINGERPRINT = """ - INSERT IGNORE INTO %s (%s, %s, %s) values - (UNHEX(%%s), %%s, %%s); - """ % (FINGERPRINTS_TABLENAME, Database.FIELD_HASH, Database.FIELD_SONG_ID, Database.FIELD_OFFSET) - - INSERT_SONG = "INSERT INTO %s (%s, %s) values (%%s, UNHEX(%%s));" % ( - SONGS_TABLENAME, Database.FIELD_SONGNAME, Database.FIELD_FILE_SHA1) - - # selects - SELECT = """ - SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s); - """ % (Database.FIELD_SONG_ID, Database.FIELD_OFFSET, FINGERPRINTS_TABLENAME, Database.FIELD_HASH) - - SELECT_MULTIPLE = """ - SELECT HEX(%s), %s, %s FROM %s WHERE %s IN (%%s); - """ % (Database.FIELD_HASH, Database.FIELD_SONG_ID, Database.FIELD_OFFSET, - FINGERPRINTS_TABLENAME, Database.FIELD_HASH) - - SELECT_ALL = """ - SELECT %s, %s FROM %s; - """ % (Database.FIELD_SONG_ID, Database.FIELD_OFFSET, FINGERPRINTS_TABLENAME) - - SELECT_SONG = """ - SELECT %s, HEX(%s) as %s FROM %s WHERE %s = %%s; - """ % (Database.FIELD_SONGNAME, Database.FIELD_FILE_SHA1, Database.FIELD_FILE_SHA1, SONGS_TABLENAME, Database.FIELD_SONG_ID) - - SELECT_NUM_FINGERPRINTS = """ - SELECT COUNT(*) as n FROM %s - """ % (FINGERPRINTS_TABLENAME) - - SELECT_UNIQUE_SONG_IDS = """ - SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1; - """ % (Database.FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED) - - SELECT_SONGS = """ - SELECT %s, %s, HEX(%s) as %s FROM %s WHERE %s = 1; - """ % (Database.FIELD_SONG_ID, Database.FIELD_SONGNAME, Database.FIELD_FILE_SHA1, Database.FIELD_FILE_SHA1, - SONGS_TABLENAME, FIELD_FINGERPRINTED) - - # drops - DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME - DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME - - # update - UPDATE_SONG_FINGERPRINTED = """ - UPDATE %s SET %s = 1 WHERE %s = %%s - """ % (SONGS_TABLENAME, FIELD_FINGERPRINTED, Database.FIELD_SONG_ID) - - # delete - DELETE_UNFINGERPRINTED = """ - DELETE FROM %s WHERE %s = 0; - """ % (SONGS_TABLENAME, FIELD_FINGERPRINTED) - - def __init__(self, **options): - super(SQLDatabase, self).__init__() - self.cursor = cursor_factory(**options) - self._options = options - - def after_fork(self): - # Clear the cursor cache, we don't want any stale connections from - # the previous process. - Cursor.clear_cache() - - def setup(self): - """ - Creates any non-existing tables required for dejavu to function. - - This also removes all songs that have been added but have no - fingerprints associated with them. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.CREATE_SONGS_TABLE) - cur.execute(self.CREATE_FINGERPRINTS_TABLE) - cur.execute(self.DELETE_UNFINGERPRINTED) - - def empty(self): - """ - Drops tables created by dejavu and then creates them again - by calling `SQLDatabase.setup`. - - .. warning: - This will result in a loss of data - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.DROP_FINGERPRINTS) - cur.execute(self.DROP_SONGS) - - self.setup() - - def delete_unfingerprinted_songs(self): - """ - Removes all songs that have no fingerprints associated with them. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.DELETE_UNFINGERPRINTED) - - def get_num_songs(self): - """ - Returns number of songs the database has fingerprinted. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.SELECT_UNIQUE_SONG_IDS) - - for count, in cur: - return count - return 0 - - def get_num_fingerprints(self): - """ - Returns number of fingerprints the database has fingerprinted. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.SELECT_NUM_FINGERPRINTS) - - for count, in cur: - return count - return 0 - - def set_song_fingerprinted(self, sid): - """ - Set the fingerprinted flag to TRUE (1) once a song has been completely - fingerprinted in the database. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,)) - - def get_songs(self): - """ - Return songs that have the fingerprinted flag set TRUE (1). - """ - with self.cursor(cursor_type=DictCursor, charset="utf8") as cur: - cur.execute(self.SELECT_SONGS) - for row in cur: - yield row - - def get_song_by_id(self, sid): - """ - Returns song by its ID. - """ - with self.cursor(cursor_type=DictCursor, charset="utf8") as cur: - cur.execute(self.SELECT_SONG, (sid,)) - return cur.fetchone() - - def insert(self, hash, sid, offset): - """ - Insert a (sha1, song_id, offset) row into database. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset)) - - def insert_song(self, songname, file_hash): - """ - Inserts song in the database and returns the ID of the inserted record. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.INSERT_SONG, (songname, file_hash)) - return cur.lastrowid - - def query(self, hash): - """ - Return all tuples associated with hash. - - If hash is None, returns all entries in the - database (be careful with that one!). - """ - # select all if no key - query = self.SELECT_ALL if hash is None else self.SELECT - - with self.cursor(charset="utf8") as cur: - cur.execute(query) - for sid, offset in cur: - yield (sid, offset) - - def get_iterable_kv_pairs(self): - """ - Returns all tuples in database. - """ - return self.query(None) - - def insert_hashes(self, sid, hashes): - """ - Insert series of hash => song_id, offset - values into the database. - """ - values = [] - for hash, offset in hashes: - values.append((hash, sid, offset)) - - with self.cursor(charset="utf8") as cur: - for split_values in grouper(values, 1000): - cur.executemany(self.INSERT_FINGERPRINT, split_values) - - def return_matches(self, hashes): - """ - Return the (song_id, offset_diff) tuples associated with - a list of (sha1, sample_offset) values. - """ - # Create a dictionary of hash => offset pairs for later lookups - mapper = {} - for hash, offset in hashes: - mapper[hash.upper()] = offset - - # Get an iteratable of all the hashes we need - values = mapper.keys() - - with self.cursor(charset="utf8") as cur: - for split_values in grouper(values, 1000): - # Create our IN part of the query - query = self.SELECT_MULTIPLE - query = query % ', '.join(['UNHEX(%s)'] * len(split_values)) - - cur.execute(query, split_values) - - for hash, sid, offset in cur: - # (sid, db_offset - song_sampled_offset) - yield (sid, offset - mapper[hash]) - - def __getstate__(self): - return (self._options,) - - def __setstate__(self, state): - self._options, = state - self.cursor = cursor_factory(**self._options) - - -def grouper(iterable, n, fillvalue=None): - args = [iter(iterable)] * n - return (filter(None, values) for values - in izip_longest(fillvalue=fillvalue, *args)) - - -def cursor_factory(**factory_options): - def cursor(**options): - options.update(factory_options) - return Cursor(**options) - return cursor - - -class Cursor(object): - """ - Establishes a connection to the database and returns an open cursor. - - - ```python - # Use as context manager - with Cursor() as cur: - cur.execute(query) - ``` - """ - - def __init__(self, cursor_type=mysql.cursors.Cursor, **options): - super(Cursor, self).__init__() - - self._cache = Queue.Queue(maxsize=5) - try: - conn = self._cache.get_nowait() - except Queue.Empty: - conn = mysql.connect(**options) - else: - # Ping the connection before using it from the cache. - conn.ping(True) - - self.conn = conn - self.conn.autocommit(False) - self.cursor_type = cursor_type - - @classmethod - def clear_cache(cls): - cls._cache = Queue.Queue(maxsize=5) - - def __enter__(self): - self.cursor = self.conn.cursor(self.cursor_type) - return self.cursor - - def __exit__(self, extype, exvalue, traceback): - # if we had a MySQL related error we try to rollback the cursor. - if extype is mysql.MySQLError: - self.cursor.rollback() - - self.cursor.close() - self.conn.commit() - - # Put it back on the queue - try: - self._cache.put_nowait(self.conn) - except Queue.Full: - self.conn.close() diff --git a/dejavu/decoder.py b/dejavu/decoder.py index 04aa39f4..92990685 100755 --- a/dejavu/decoder.py +++ b/dejavu/decoder.py @@ -3,9 +3,10 @@ import numpy as np from pydub import AudioSegment from pydub.utils import audioop -import wavio +from . import wavio from hashlib import sha1 + def unique_hash(filepath, blocksize=2**20): """ Small function to generate a hash to uniquely generate a file. Inspired by MD5 version here: @@ -14,7 +15,7 @@ def unique_hash(filepath, blocksize=2**20): Works with large files. """ s = sha1() - with open(filepath , "rb") as f: + with open(filepath, "rb") as f: while True: buf = f.read(blocksize) if not buf: @@ -29,7 +30,7 @@ def find_files(path, extensions): for dirpath, dirnames, files in os.walk(path): for extension in extensions: - for f in fnmatch.filter(files, "*.%s" % extension): + for f in fnmatch.filter(files, f"*.{extension}"): p = os.path.join(dirpath, f) yield (p, extension) @@ -53,15 +54,15 @@ def read(filename, limit=None): if limit: audiofile = audiofile[:limit * 1000] - data = np.fromstring(audiofile._data, np.int16) + data = np.fromstring(audiofile.raw_data, np.int16) channels = [] - for chn in xrange(audiofile.channels): + for chn in range(audiofile.channels): channels.append(data[chn::audiofile.channels]) - fs = audiofile.frame_rate + audiofile.frame_rate except audioop.error: - fs, _, audiofile = wavio.readwav(filename) + _, _, audiofile = wavio.readwav(filename) if limit: audiofile = audiofile[:limit * 1000] diff --git a/dejavu/fingerprint.py b/dejavu/fingerprint.py index f56118ac..ce8d8dba 100755 --- a/dejavu/fingerprint.py +++ b/dejavu/fingerprint.py @@ -1,74 +1,32 @@ -import numpy as np +import hashlib +from operator import itemgetter + import matplotlib.mlab as mlab import matplotlib.pyplot as plt +import numpy as np from scipy.ndimage.filters import maximum_filter -from scipy.ndimage.morphology import (generate_binary_structure, - iterate_structure, binary_erosion) -import hashlib -from operator import itemgetter +from scipy.ndimage.morphology import (binary_erosion, + generate_binary_structure, + iterate_structure) + +from dejavu.config.config import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE, + DEFAULT_FS, DEFAULT_OVERLAP_RATIO, + DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION, + MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA, + PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT) IDX_FREQ_I = 0 IDX_TIME_J = 1 -###################################################################### -# Sampling rate, related to the Nyquist conditions, which affects -# the range frequencies we can detect. -DEFAULT_FS = 44100 - -###################################################################### -# Size of the FFT window, affects frequency granularity -DEFAULT_WINDOW_SIZE = 4096 - -###################################################################### -# Ratio by which each sequential window overlaps the last and the -# next window. Higher overlap will allow a higher granularity of offset -# matching, but potentially more fingerprints. -DEFAULT_OVERLAP_RATIO = 0.5 - -###################################################################### -# Degree to which a fingerprint can be paired with its neighbors -- -# higher will cause more fingerprints, but potentially better accuracy. -DEFAULT_FAN_VALUE = 15 - -###################################################################### -# Minimum amplitude in spectrogram in order to be considered a peak. -# This can be raised to reduce number of fingerprints, but can negatively -# affect accuracy. -DEFAULT_AMP_MIN = 10 - -###################################################################### -# Number of cells around an amplitude peak in the spectrogram in order -# for Dejavu to consider it a spectral peak. Higher values mean less -# fingerprints and faster matching, but can potentially affect accuracy. -PEAK_NEIGHBORHOOD_SIZE = 20 - -###################################################################### -# Thresholds on how close or far fingerprints can be in time in order -# to be paired as a fingerprint. If your max is too low, higher values of -# DEFAULT_FAN_VALUE may not perform as expected. -MIN_HASH_TIME_DELTA = 0 -MAX_HASH_TIME_DELTA = 200 - -###################################################################### -# If True, will sort peaks temporally for fingerprinting; -# not sorting will cut down number of fingerprints, but potentially -# affect performance. -PEAK_SORT = True - -###################################################################### -# Number of bits to grab from the front of the SHA1 hash in the -# fingerprint calculation. The more you grab, the more memory storage, -# with potentially lesser collisions of matches. -FINGERPRINT_REDUCTION = 20 - -def fingerprint(channel_samples, Fs=DEFAULT_FS, + +def fingerprint(channel_samples, + Fs=DEFAULT_FS, wsize=DEFAULT_WINDOW_SIZE, wratio=DEFAULT_OVERLAP_RATIO, fan_value=DEFAULT_FAN_VALUE, amp_min=DEFAULT_AMP_MIN): """ - FFT the channel, log transform output, find local maxima, then return - locally sensitive hashes. + FFT the channel, log transform output, find local maxima, then return locally sensitive hashes. """ # FFT the signal and extract frequency components arr2D = mlab.specgram( @@ -78,11 +36,9 @@ def fingerprint(channel_samples, Fs=DEFAULT_FS, window=mlab.window_hanning, noverlap=int(wsize * wratio))[0] - # apply log transform since specgram() returns linear array - arr2D = 10 * np.log10(arr2D) - arr2D[arr2D == -np.inf] = 0 # replace infs with zeros + # Apply log transform since specgram() returns linear array. 0s are excluded to avoid np warning. + arr2D = 10 * np.log10(arr2D, out=np.zeros_like(arr2D), where=(arr2D != 0)) - # find local maxima local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min) # return hashes @@ -97,39 +53,35 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN): # find local maxima using our filter shape local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D background = (arr2D == 0) - eroded_background = binary_erosion(background, structure=neighborhood, - border_value=1) + eroded_background = binary_erosion(background, structure=neighborhood, border_value=1) # Boolean mask of arr2D with True at peaks (Fixed deprecated boolean operator by changing '-' to '^') detected_peaks = local_max ^ eroded_background # extract peaks amps = arr2D[detected_peaks] - j, i = np.where(detected_peaks) + freqs, times = np.where(detected_peaks) # filter peaks amps = amps.flatten() - peaks = zip(i, j, amps) - peaks_filtered = filter(lambda x: x[2]>amp_min, peaks) # freq, time, amp # get indices for frequency and time - frequency_idx = [] - time_idx = [] - for x in peaks_filtered: - frequency_idx.append(x[1]) - time_idx.append(x[0]) - + filter_idxs = np.where(amps > amp_min) + + freqs_filter = freqs[filter_idxs] + times_filter = times[filter_idxs] + if plot: # scatter of the peaks fig, ax = plt.subplots() ax.imshow(arr2D) - ax.scatter(time_idx, frequency_idx) + ax.scatter(times_filter, freqs_filter) ax.set_xlabel('Time') ax.set_ylabel('Frequency') ax.set_title("Spectrogram") plt.gca().invert_yaxis() plt.show() - return zip(frequency_idx, time_idx) + return list(zip(freqs_filter, times_filter)) def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE): @@ -151,7 +103,6 @@ def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE): t2 = peaks[i + j][IDX_TIME_J] t_delta = t2 - t1 - if t_delta >= MIN_HASH_TIME_DELTA and t_delta <= MAX_HASH_TIME_DELTA: - h = hashlib.sha1( - "%s|%s|%s" % (str(freq1), str(freq2), str(t_delta))) + if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA: + h = hashlib.sha1(f"{str(freq1)}|{str(freq2)}|{str(t_delta)}".encode('utf-8')) yield (h.hexdigest()[0:FINGERPRINT_REDUCTION], t1) diff --git a/dejavu/recognize.py b/dejavu/recognize.py index 269a82af..3d6c6226 100755 --- a/dejavu/recognize.py +++ b/dejavu/recognize.py @@ -1,16 +1,16 @@ -# encoding: utf-8 -import dejavu.fingerprint as fingerprint -import dejavu.decoder as decoder +import time + import numpy as np import pyaudio -import time +import dejavu.decoder as decoder +from dejavu.config.config import DEFAULT_FS -class BaseRecognizer(object): +class BaseRecognizer(object): def __init__(self, dejavu): self.dejavu = dejavu - self.Fs = fingerprint.DEFAULT_FS + self.Fs = DEFAULT_FS def _recognize(self, *data): matches = [] @@ -24,32 +24,32 @@ def recognize(self): class FileRecognizer(BaseRecognizer): def __init__(self, dejavu): - super(FileRecognizer, self).__init__(dejavu) + super().__init__(dejavu) def recognize_file(self, filename): frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit) t = time.time() - match = self._recognize(*frames) + matches = self._recognize(*frames) t = time.time() - t - if match: + for match in matches: match['match_time'] = t - return match + return matches def recognize(self, filename): return self.recognize_file(filename) class MicrophoneRecognizer(BaseRecognizer): - default_chunksize = 8192 - default_format = pyaudio.paInt16 - default_channels = 2 - default_samplerate = 44100 + default_chunksize = 8192 + default_format = pyaudio.paInt16 + default_channels = 2 + default_samplerate = 44100 def __init__(self, dejavu): - super(MicrophoneRecognizer, self).__init__(dejavu) + super().__init__(dejavu) self.audio = pyaudio.PyAudio() self.stream = None self.data = [] diff --git a/dejavu/testing.py b/dejavu/testing.py index d2a3b484..eb785780 100644 --- a/dejavu/testing.py +++ b/dejavu/testing.py @@ -1,14 +1,19 @@ -from __future__ import division + +import ast +import fnmatch +import logging +import os +import random +import re +import subprocess +import traceback + from pydub import AudioSegment -from dejavu.decoder import path_to_songname + from dejavu import Dejavu +from dejavu.decoder import path_to_songname from dejavu.fingerprint import * -import traceback -import fnmatch -import os, re, ast -import subprocess -import random -import logging + def set_seed(seed=None): """ @@ -20,6 +25,7 @@ def set_seed(seed=None): if seed != None: random.seed(seed) + def get_files_recursive(src, fmt): """ `src` is the source directory. @@ -29,6 +35,7 @@ def get_files_recursive(src, fmt): for filename in fnmatch.filter(filenames, '*' + fmt): yield os.path.join(root, filename) + def get_length_audio(audiopath, extension): """ Returns length of audio in seconds. @@ -37,10 +44,11 @@ def get_length_audio(audiopath, extension): try: audio = AudioSegment.from_file(audiopath, extension.replace(".", "")) except: - print "Error in get_length_audio(): %s" % traceback.format_exc() + print(f"Error in get_length_audio(): {traceback.format_exc()}") return None return int(len(audio) / 1000.0) + def get_starttime(length, nseconds, padding): """ `length` is total audio length in seconds @@ -52,6 +60,7 @@ def get_starttime(length, nseconds, padding): return 0 return random.randint(padding, maximum) + def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10): """ Generates a test file for each file recursively in `src` directory @@ -75,42 +84,43 @@ def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10): testsources = get_files_recursive(src, fmt) for audiosource in testsources: - print "audiosource:", audiosource + print("audiosource:", audiosource) filename, extension = os.path.splitext(os.path.basename(audiosource)) length = get_length_audio(audiosource, extension) starttime = get_starttime(length, nseconds, padding) - test_file_name = "%s_%s_%ssec.%s" % ( - os.path.join(dest, filename), starttime, - nseconds, extension.replace(".", "")) + test_file_name = f"{os.path.join(dest, filename)}_{starttime}_{nseconds}sec.{extension.replace('.', '')}" subprocess.check_output([ "ffmpeg", "-y", - "-ss", "%d" % starttime, - '-t' , "%d" % nseconds, + "-ss", f"{starttime}", + '-t', f"{nseconds}", "-i", audiosource, test_file_name]) + def log_msg(msg, log=True, silent=False): if log: logging.debug(msg) if not silent: - print msg + print(msg) + def autolabel(rects, ax): # attach some text labels for rect in rects: height = rect.get_height() - ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, - '%d' % int(height), ha='center', va='bottom') + ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{int(height)}', ha='center', va='bottom') + def autolabeldoubles(rects, ax): # attach some text labels for rect in rects: height = rect.get_height() - ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, - '%s' % round(float(height), 3), ha='center', va='bottom') + ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{round(float(height), 3)}', + ha='center', va='bottom') + class DejavuTest(object): def __init__(self, folder, seconds): @@ -120,35 +130,35 @@ def __init__(self, folder, seconds): self.test_seconds = seconds self.test_songs = [] - print "test_seconds", self.test_seconds + print("test_seconds", self.test_seconds) self.test_files = [ f for f in os.listdir(self.test_folder) if os.path.isfile(os.path.join(self.test_folder, f)) and re.findall("[0-9]*sec", f)[0] in self.test_seconds] - print "test_files", self.test_files + print("test_files", self.test_files) self.n_columns = len(self.test_seconds) self.n_lines = int(len(self.test_files) / self.n_columns) - print "columns:", self.n_columns - print "length of test files:", len(self.test_files) - print "lines:", self.n_lines + print("columns:", self.n_columns) + print("length of test files:", len(self.test_files)) + print("lines:", self.n_lines) # variable match results (yes, no, invalid) - self.result_match = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)] + self.result_match = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] - print "result_match matrix:", self.result_match + print("result_match matrix:", self.result_match) # variable match precision (if matched in the corrected time) - self.result_matching_times = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)] + self.result_matching_times = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] # variable mahing time (query time) - self.result_query_duration = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)] + self.result_query_duration = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] # variable confidence - self.result_match_confidence = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)] + self.result_match_confidence = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] self.begin() @@ -178,19 +188,17 @@ def create_plots(self, name, results, results_folder): # add some ax.set_ylabel(name) - ax.set_title("%s %s Results" % (self.test_seconds[sec], name)) + ax.set_title(f"{self.test_seconds[sec]} {name} Results") ax.set_xticks(ind + width) labels = [0 for x in range(0, self.n_lines)] for x in range(0, self.n_lines): - labels[x] = "song %s" % (x+1) + labels[x] = f"song {x+1}" ax.set_xticklabels(labels) box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) - #ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5)) - if name == 'Confidence': autolabel(rects1, ax) else: @@ -198,13 +206,13 @@ def create_plots(self, name, results, results_folder): plt.grid() - fig_name = os.path.join(results_folder, "%s_%s.png" % (name, self.test_seconds[sec])) + fig_name = os.path.join(results_folder, f"{name}_{self.test_seconds[sec]}.png") fig.savefig(fig_name) def begin(self): for f in self.test_files: log_msg('--------------------------------------------------') - log_msg('file: %s' % f) + log_msg(f'file: {f}') # get column col = self.get_column_id(re.findall("[0-9]*sec", f)[0]) @@ -235,8 +243,8 @@ def begin(self): # which song did we predict? result = ast.literal_eval(result) song_result = result["song_name"] - log_msg('song: %s' % song) - log_msg('song_result: %s' % song_result) + log_msg(f'song: {song}') + log_msg(f'song_result: {song_result}') if song_result != song: log_msg('invalid match') @@ -246,31 +254,28 @@ def begin(self): self.result_match_confidence[line][col] = 0 else: log_msg('correct match') - print self.result_match + print(self.result_match) self.result_match[line][col] = 'yes' self.result_query_duration[line][col] = round(result[Dejavu.MATCH_TIME],3) self.result_match_confidence[line][col] = result[Dejavu.CONFIDENCE] - song_start_time = re.findall("\_[^\_]+",f) + song_start_time = re.findall("_[^_]+", f) song_start_time = song_start_time[0].lstrip("_ ") - result_start_time = round((result[Dejavu.OFFSET] * DEFAULT_WINDOW_SIZE * - DEFAULT_OVERLAP_RATIO) / (DEFAULT_FS), 0) + result_start_time = round((result[Dejavu.OFFSET] * DEFAULT_WINDOW_SIZE * + DEFAULT_OVERLAP_RATIO) / DEFAULT_FS, 0) self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time) - if (abs(self.result_matching_times[line][col]) == 1): + if abs(self.result_matching_times[line][col]) == 1: self.result_matching_times[line][col] = 0 - log_msg('query duration: %s' % round(result[Dejavu.MATCH_TIME],3)) - log_msg('confidence: %s' % result[Dejavu.CONFIDENCE]) - log_msg('song start_time: %s' % song_start_time) - log_msg('result start time: %s' % result_start_time) - if (self.result_matching_times[line][col] == 0): + log_msg(f'query duration: {round(result[Dejavu.MATCH_TIME], 3)}') + log_msg(f'confidence: {result[Dejavu.CONFIDENCE]}') + log_msg(f'song start_time: {song_start_time}') + log_msg(f'result start time: {result_start_time}') + + if self.result_matching_times[line][col] == 0: log_msg('accurate match') else: log_msg('inaccurate match') log_msg('--------------------------------------------------\n') - - - - diff --git a/dejavu/wavio.py b/dejavu/wavio.py index e8d1fc33..70450570 100644 --- a/dejavu/wavio.py +++ b/dejavu/wavio.py @@ -5,6 +5,7 @@ # Github: github.com/WarrenWeckesser/wavio import wave as _wave + import numpy as _np diff --git a/example.py b/example.py index 1c99e69c..87aef09f 100755 --- a/example.py +++ b/example.py @@ -1,35 +1,37 @@ -import warnings import json -warnings.filterwarnings("ignore") +import warnings from dejavu import Dejavu from dejavu.recognize import FileRecognizer, MicrophoneRecognizer +warnings.filterwarnings("ignore") + + # load config from a JSON file (or anything outputting a python dictionary) with open("dejavu.cnf.SAMPLE") as f: config = json.load(f) if __name__ == '__main__': - # create a Dejavu instance - djv = Dejavu(config) + # create a Dejavu instance + djv = Dejavu(config) - # Fingerprint all the mp3's in the directory we give it - djv.fingerprint_directory("mp3", [".mp3"]) + # Fingerprint all the mp3's in the directory we give it + djv.fingerprint_directory("mp3", [".mp3"]) - # Recognize audio from a file - song = djv.recognize(FileRecognizer, "mp3/Sean-Fournier--Falling-For-You.mp3") - print "From file we recognized: %s\n" % song + # Recognize audio from a file + song = djv.recognize(FileRecognizer, "mp3/Sean-Fournier--Falling-For-You.mp3") + print(f"From file we recognized: {song}\n") - # Or recognize audio from your microphone for `secs` seconds - secs = 5 - song = djv.recognize(MicrophoneRecognizer, seconds=secs) - if song is None: - print "Nothing recognized -- did you play the song out loud so your mic could hear it? :)" - else: - print "From mic with %d seconds we recognized: %s\n" % (secs, song) + # Or recognize audio from your microphone for `secs` seconds + secs = 5 + song = djv.recognize(MicrophoneRecognizer, seconds=secs) + if song is None: + print("Nothing recognized -- did you play the song out loud so your mic could hear it? :)") + else: + print(f"From mic with %d seconds we recognized: {(secs, song)}\n") - # Or use a recognizer without the shortcut, in anyway you would like - recognizer = FileRecognizer(djv) - song = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3") - print "No shortcut, we recognized: %s\n" % song \ No newline at end of file + # Or use a recognizer without the shortcut, in anyway you would like + recognizer = FileRecognizer(djv) + song = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3") + print(f"No shortcut, we recognized: {song}\n") diff --git a/requirements.txt b/requirements.txt index 9478f734..4954c7c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,7 @@ -# requirements file +pydub==0.23.1 +PyAudio==0.2.11 +numpy==1.17.2 +scipy==1.3.1 +matplotlib==3.1.1 +mysql-connector-python==8.0.17 -### BEGIN ### -pydub>=0.9.4 -PyAudio>=0.2.7 -numpy>=1.8.2 -scipy>=0.12.1 -matplotlib>=1.3.1 -### END ### diff --git a/run_tests.py b/run_tests.py index b0dfde99..7ad387d0 100644 --- a/run_tests.py +++ b/run_tests.py @@ -86,10 +86,10 @@ n_secs = len(test_seconds) # set result variables -> 4d variables -all_match_counter = [[[0 for x in xrange(tests)] for x in xrange(3)] for x in xrange(n_secs)] -all_matching_times_counter = [[[0 for x in xrange(tests)] for x in xrange(2)] for x in xrange(n_secs)] -all_query_duration = [[[0 for x in xrange(tests)] for x in xrange(djv.n_lines)] for x in xrange(n_secs)] -all_match_confidence = [[[0 for x in xrange(tests)] for x in xrange(djv.n_lines)] for x in xrange(n_secs)] +all_match_counter = [[[0 for x in range(tests)] for x in range(3)] for x in range(n_secs)] +all_matching_times_counter = [[[0 for x in range(tests)] for x in range(2)] for x in range(n_secs)] +all_query_duration = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)] +all_match_confidence = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)] # group results by seconds for line in range(0, djv.n_lines): diff --git a/setup.py b/setup.py index 8312d1d5..9484507d 100644 --- a/setup.py +++ b/setup.py @@ -7,11 +7,11 @@ def parse_requirements(requirements): with open(requirements) as f: lines = [l for l in f] # remove spaces - stripped = map((lambda x: x.strip()), lines) + stripped = list(map((lambda x: x.strip()), lines)) # remove comments - nocomments = filter((lambda x: not x.startswith('#')), stripped) + nocomments = list(filter((lambda x: not x.startswith('#')), stripped)) # remove empty lines - reqs = filter((lambda x: x), nocomments) + reqs = list(filter((lambda x: x), nocomments)) return reqs PACKAGE_NAME = "PyDejavu"