From e14ebbb6a63fc2d9987d6d1baf0d37a7e6bcf2fc Mon Sep 17 00:00:00 2001 From: "Arturo Gonzalez M." Date: Sat, 20 Jul 2024 17:39:43 +1000 Subject: [PATCH] Improved performance and added logic to query by Country --- README.md | 2 +- main.py | 25 +- requirements.txt | 1 + src/extract/__init__.py | 0 .../world_bank_indicators_by_country.py | 42 ++++ src/utilts/singleton_logger.py | 41 ++++ .../world_bank_data_downloader.py | 57 ++--- tests/test_world_bank_data_downloader.py | 214 +++++++++--------- 8 files changed, 242 insertions(+), 140 deletions(-) create mode 100644 src/extract/__init__.py create mode 100644 src/extract/world_bank_indicators_by_country.py create mode 100644 src/utilts/singleton_logger.py rename src/{ => utilts}/world_bank_data_downloader.py (69%) diff --git a/README.md b/README.md index e05861f..9e5493b 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ Here's a basic example of how to use the `WorldBankDataDownloader`: ```python downloader = WorldBankDataDownloader() all_data = downloader.download_all_data() -downloader.save_data_to_file(all_data, 'data/world_bank_data.json') +downloader.save_data_to_file(all_data, 'data/world_bank_data_optimised.json') ``` ## Unit Tests diff --git a/main.py b/main.py index 19e3a0d..2b9bda9 100644 --- a/main.py +++ b/main.py @@ -1,19 +1,38 @@ """ This is the main file to run the WorldBankDataDownloader. """ -from src.world_bank_data_downloader import WorldBankDataDownloader +from src.utilts.singleton_logger import SingletonLogger +from src.utilts.world_bank_data_downloader import WorldBankDataDownloader def main(): """ - Main function to download all data from the World Bank API. + Main function to download data from the World Bank API. :return: None :rtype: None """ + # Set up logging + logger = SingletonLogger().get_logger() + + # Initialize the downloader downloader = WorldBankDataDownloader() - all_data = downloader.download_all_data() + + # Get all country codes and indicator codes + country_codes = downloader.country_codes + indicator_codes = downloader.indicator_codes + + # Download data for all countries and indicators + all_data = {} + for country_code in country_codes: + logger.info(f"Downloading data for country: {country_code}") + country_data = downloader.fetch_data_concurrently(country_code, indicator_codes) + all_data[country_code] = country_data + + # Save the data to a file downloader.save_data_to_file(all_data) + logger.info("Data download and save completed.") + if __name__ == "__main__": main() diff --git a/requirements.txt b/requirements.txt index a500045..b106015 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ pytest==8.2.2 coverage==7.6.0 pytest-cov==5.0.0 tenacity==8.5.0 +aiohttp diff --git a/src/extract/__init__.py b/src/extract/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/extract/world_bank_indicators_by_country.py b/src/extract/world_bank_indicators_by_country.py new file mode 100644 index 0000000..220a7a9 --- /dev/null +++ b/src/extract/world_bank_indicators_by_country.py @@ -0,0 +1,42 @@ +""" +This script downloads data from the World Bank API for a specific country (Australia in this case) and saves it to a file. +""" +from src.utilts.singleton_logger import SingletonLogger +from src.utilts.world_bank_data_downloader import WorldBankDataDownloader + + +def main(): + """ + Main function to download data from the World Bank API. + :return: None + :rtype: None + """ + logger = SingletonLogger().get_logger() + country_code = 'AUS' + + # Create an instance of the WorldBankDataDownloader + downloader = WorldBankDataDownloader() + + # Get the list of all indicators + indicators = downloader.get_indicators() + + # Fetch data for Australia (AUS) for all indicators concurrently + australia_data = downloader.fetch_data_concurrently(country_code, indicators) + + # Save the data to a file + filename = f'../data/raw/{country_code}_world_bank_data.json' + downloader.save_data_to_file(australia_data, filename=filename) + + # Load the data from the file (for verification) + loaded_data = downloader.load_data_from_file(filename=filename) + + # Print the loaded data (or a subset of it) + for indicator, data in loaded_data.items(): + logger.info(f"Indicator: {indicator}") + for entry in data: + logger.info(entry) + logger.info("\n") + + +if __name__ == '__main__': + main() diff --git a/src/utilts/singleton_logger.py b/src/utilts/singleton_logger.py new file mode 100644 index 0000000..6ec2ea7 --- /dev/null +++ b/src/utilts/singleton_logger.py @@ -0,0 +1,41 @@ + +""" +SingletonLogger class is a singleton class that provides a single instance of logger object. +""" + +import logging +import threading + + +class SingletonLogger: + _instance = None + _lock = threading.RLock() + + def __new__(cls, logger_name=None, log_level=logging.DEBUG, log_format=None): + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super().__new__(cls) + cls._instance._initialize_logger(logger_name, log_level, log_format) + return cls._instance + + def _initialize_logger(self, logger_name, log_level, log_format): + self._logger = logging.getLogger(logger_name or __name__) + self._logger.setLevel(log_level) + + console_handler = logging.StreamHandler() + console_handler.setLevel(log_level) + + if log_format: + formatter = logging.Formatter(log_format) + else: + formatter = logging.Formatter("%(asctime)s - %(threadName)s - %(levelname)s - %(message)s") + console_handler.setFormatter(formatter) + + self._logger.addHandler(console_handler) + + def get_logger(self): + return self._logger + + +# logger = SingletonLogger().get_logger() \ No newline at end of file diff --git a/src/world_bank_data_downloader.py b/src/utilts/world_bank_data_downloader.py similarity index 69% rename from src/world_bank_data_downloader.py rename to src/utilts/world_bank_data_downloader.py index 2cafd3f..34a51af 100644 --- a/src/world_bank_data_downloader.py +++ b/src/utilts/world_bank_data_downloader.py @@ -1,15 +1,17 @@ """ -This module contains a class for downloading data from the World Bank API. +This module contains the WorldBankDataDownloader class, which is used to download data from the World Bank API. """ - import os import time import json import logging import requests +from concurrent.futures import ThreadPoolExecutor, as_completed from tenacity import retry, stop_after_attempt, wait_exponential +from src.utilts.singleton_logger import SingletonLogger + class WorldBankDataDownloader: """Class for downloading data from the World Bank API.""" @@ -19,6 +21,7 @@ def __init__(self): Initialize the WorldBankDataDownloader with the base URL, country codes, and indicator codes. """ self.base_url = 'http://api.worldbank.org/v2' + self.logger = SingletonLogger().get_logger() self.country_codes = self.get_country_codes() self.indicator_codes = self.get_indicators() logging.basicConfig(level=logging.INFO) @@ -86,45 +89,44 @@ def fetch_data(self, country_code, indicator_code): break return all_pages_data - def download_all_data(self): + def fetch_data_concurrently(self, country_code, indicator_codes, max_workers=10): """ - Download data for all indicators and country codes. - :return: A dictionary containing the data for all country codes and indicator codes. + Fetch data concurrently for a given country code and a list of indicator codes. + :param country_code: The country code. + :param indicator_codes: The list of indicator codes. + :param max_workers: The maximum number of threads to use. + :return: A dictionary containing the data for the given country code and indicator codes. """ - all_data = {} - total_requests = len(self.country_codes) * len(self.indicator_codes) - completed_requests = 0 - - for country_code in self.country_codes: - for indicator_code in self.indicator_codes: - completed_requests += 1 - progress = (completed_requests / total_requests) * 100 - logging.info( - f"Progress: {progress:.2f}% - Fetching data for country {country_code} and indicator {indicator_code}") - - data = self.fetch_data(country_code, indicator_code) - if data: - all_data[(country_code, indicator_code)] = data - - time.sleep(0.5) # Add a 0.5-second delay between requests to avoid rate limiting - - return all_data + results = {} + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_indicator = {executor.submit(self.fetch_data, country_code, indicator_code): indicator_code for + indicator_code in indicator_codes} + for future in as_completed(future_to_indicator): + indicator_code = future_to_indicator[future] + try: + data = future.result() + if data: + results[indicator_code] = data + except Exception as e: + logging.error(f"Error fetching data for indicator {indicator_code}: {e}") + return results @staticmethod - def save_data_to_file(data, filename='../data/raw/world_bank_data.json'): + def save_data_to_file(data, filename='../data/raw/world_bank_data_optimised.json'): """ Save the data to a JSON file. :param data: The data to save. :param filename: The filename to save the data to. """ os.makedirs(os.path.dirname(filename), exist_ok=True) - serializable_data = {str(key): value for key, value in data.items()} + # Use a delimiter that is safe and won't appear in the keys + serializable_data = {"__DELIM__".join(key): value for key, value in data.items()} with open(filename, 'w', encoding='utf-8') as f: json.dump(serializable_data, f, ensure_ascii=False, indent=4) logging.info(f"Data saved to {filename}") @staticmethod - def load_data_from_file(filename='../data/raw/world_bank_data.json'): + def load_data_from_file(filename='../data/raw/world_bank_data_optimised.json'): """ Load the data from a JSON file. :param filename: The filename to load the data from. @@ -132,5 +134,6 @@ def load_data_from_file(filename='../data/raw/world_bank_data.json'): """ with open(filename, 'r', encoding='utf-8') as f: data = json.load(f) - deserialized_data = {eval(key): value for key, value in data.items()} + # Use the same delimiter to split the keys back into tuples + deserialized_data = {tuple(key.split("__DELIM__")): value for key, value in data.items()} return deserialized_data diff --git a/tests/test_world_bank_data_downloader.py b/tests/test_world_bank_data_downloader.py index 5c7059e..b061a55 100644 --- a/tests/test_world_bank_data_downloader.py +++ b/tests/test_world_bank_data_downloader.py @@ -1,19 +1,24 @@ """ Unit tests for the WorldBankDataDownloader class. """ +import os +import sys import pytest -import json -import requests -from unittest.mock import patch, MagicMock -from src.world_bank_data_downloader import WorldBankDataDownloader +from unittest.mock import patch, MagicMock, call +from concurrent.futures import Future + +from src.utilts.world_bank_data_downloader import WorldBankDataDownloader + +# Add the project root to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) @pytest.fixture def downloader(): """ - Fixture to create a WorldBankDataDownloader instance. - :return: A WorldBankDataDownloader instance. + Fixture to create an instance of the WorldBankDataDownloader class. + :return: An instance of the WorldBankDataDownloader class. :rtype: WorldBankDataDownloader """ return WorldBankDataDownloader() @@ -34,14 +39,13 @@ def mock_response(): {"id": "GBR", "name": "United Kingdom"} ] ] - mock.raise_for_status.return_value = None return mock def test_init(downloader): """ - Test the initialization of the WorldBankDataDownloader class. - :param downloader: A WorldBankDataDownloader instance. + Test the __init__ method of the WorldBankDataDownloader class. + :param downloader: An instance of the WorldBankDataDownloader class. :return: None :rtype: None """ @@ -50,12 +54,12 @@ def test_init(downloader): assert isinstance(downloader.indicator_codes, list) -@patch('requests.get') +@patch('src.utilts.world_bank_data_downloader.requests.get') def test_get_country_codes(mock_get, downloader, mock_response): """ - Test the get_country_codes method. + Test the get_country_codes method of the WorldBankDataDownloader class. :param mock_get: A mock of the requests.get function. - :param downloader: A WorldBankDataDownloader instance. + :param downloader: An instance of the WorldBankDataDownloader class. :param mock_response: A mock response object. :return: None :rtype: None @@ -63,141 +67,133 @@ def test_get_country_codes(mock_get, downloader, mock_response): mock_get.return_value = mock_response country_codes = downloader.get_country_codes() assert country_codes == ['USA', 'GBR'] - mock_get.assert_called_once_with(f'{downloader.base_url}/country?format=json&per_page=300', timeout=30) + mock_get.assert_called_once_with( + 'http://api.worldbank.org/v2/country?format=json&per_page=300', + timeout=30 + ) -@patch('requests.get') +@patch('src.utilts.world_bank_data_downloader.requests.get') def test_get_indicators(mock_get, downloader, mock_response): """ - Test the get_indicators method. + Test the get_indicators method of the WorldBankDataDownloader class. :param mock_get: A mock of the requests.get function. - :param downloader: A WorldBankDataDownloader instance. + :param downloader: An instance of the WorldBankDataDownloader class. :param mock_response: A mock response object. :return: None :rtype: None """ + mock_response.json.return_value = [ + {"page": 1, "pages": 1, "per_page": 50, "total": 2}, + [ + {"id": "SP.POP.TOTL", "name": "Population, total"}, + {"id": "NY.GDP.MKTP.CD", "name": "GDP (current US$)"} + ] + ] mock_get.return_value = mock_response indicators = downloader.get_indicators() - assert indicators == ['USA', 'GBR'] # Using the same mock response for simplicity - mock_get.assert_called_once_with(f'{downloader.base_url}/indicator?format=json&per_page=1000', timeout=30) - - -@patch('requests.get') -def test_fetch_data(mock_get, downloader, mock_response): - """ - Test the fetch_data method. - :param mock_get: A mock of the requests.get function. - :param downloader: A WorldBankDataDownloader instance. - :param mock_response: A mock response object. - :return: None - :rtype: None - """ - mock_get.return_value = mock_response - data = downloader.fetch_data('USA', 'GDP') - assert data == [{"id": "USA", "name": "United States"}, {"id": "GBR", "name": "United Kingdom"}] + assert indicators == ['SP.POP.TOTL', 'NY.GDP.MKTP.CD'] mock_get.assert_called_once_with( - f'{downloader.base_url}/country/USA/indicator/GDP?format=json&per_page=1000&page=1', + 'http://api.worldbank.org/v2/indicator?format=json&per_page=1000', timeout=30 ) -@patch.object(WorldBankDataDownloader, 'fetch_data') -def test_download_all_data(mock_fetch_data, downloader): +@patch('src.utilts.world_bank_data_downloader.requests.get') +def test_fetch_data(mock_get, downloader): """ - Test the download_all_data method. - :param mock_fetch_data: A mock of the fetch_data method. - :param downloader: A WorldBankDataDownloader instance. + Test the fetch_data method of the WorldBankDataDownloader class. + :param mock_get: A mock of the requests.get function. + :param downloader: An instance of the WorldBankDataDownloader class. :return: None :rtype: None """ - mock_fetch_data.return_value = [{"year": 2020, "value": 100}] - downloader.country_codes = ['USA', 'GBR'] - downloader.indicator_codes = ['GDP', 'POP'] - - all_data = downloader.download_all_data() + mock_responses = [ + MagicMock(json=lambda: [ + {"page": 1, "pages": 2, "per_page": 1, "total": 2}, + [{"year": "2020", "value": "100"}] + ]), + MagicMock(json=lambda: [ + {"page": 2, "pages": 2, "per_page": 1, "total": 2}, + [{"year": "2019", "value": "90"}] + ]) + ] + mock_get.side_effect = mock_responses - assert len(all_data) == 4 # 2 countries * 2 indicators - assert all_data[('USA', 'GDP')] == [{"year": 2020, "value": 100}] - assert mock_fetch_data.call_count == 4 + data = downloader.fetch_data('USA', 'SP.POP.TOTL') + assert data == [ + {"year": "2020", "value": "100"}, + {"year": "2019", "value": "90"} + ] + assert mock_get.call_count == 2 -def test_save_data_to_file(tmp_path, downloader): +@patch('src.utilts.world_bank_data_downloader.ThreadPoolExecutor') +def test_fetch_data_concurrently(mock_executor, downloader): """ - Test the save_data_to_file method. - :param tmp_path: A temporary directory. - :param downloader: A WorldBankDataDownloader instance. + Test the fetch_data_concurrently method of the WorldBankDataDownloader class. + :param mock_executor: A mock of the ThreadPoolExecutor class. + :param downloader: An instance of the WorldBankDataDownloader class. :return: None :rtype: None """ - data = {('USA', 'GDP'): [{"year": 2020, "value": 100}]} - filename = tmp_path / "test_data.json" - - downloader.save_data_to_file(data, filename) + # Set up mock data + mock_data_1 = [{"year": "2020", "value": "100"}] + mock_data_2 = [{"year": "2020", "value": "200"}] + + # Set up mock futures + mock_future_1 = Future() + mock_future_1.set_result(mock_data_1) + mock_future_2 = Future() + mock_future_2.set_result(mock_data_2) + + # Set up mock executor + mock_executor_instance = MagicMock() + mock_executor_instance.submit.side_effect = [mock_future_1, mock_future_2] + mock_executor.return_value.__enter__.return_value = mock_executor_instance + + # Call the method + result = downloader.fetch_data_concurrently('USA', ['SP.POP.TOTL', 'NY.GDP.MKTP.CD']) + + # Print the result for debugging + print(f"Result: {result}") + + # Assert the results + assert result == { + 'SP.POP.TOTL': mock_data_1, + 'NY.GDP.MKTP.CD': mock_data_2 + } + + # Assert that executor.submit was called with correct arguments + calls = [ + call(downloader.fetch_data, 'USA', 'SP.POP.TOTL'), + call(downloader.fetch_data, 'USA', 'NY.GDP.MKTP.CD') + ] + assert mock_executor_instance.submit.call_count == 2 + mock_executor_instance.submit.assert_has_calls(calls, any_order=True) - assert filename.exists() - with open(filename, 'r') as f: - saved_data = json.load(f) - assert saved_data == {"('USA', 'GDP')": [{"year": 2020, "value": 100}]} + # Assert that ThreadPoolExecutor was used + mock_executor.assert_called_once() -def test_load_data_from_file(tmp_path, downloader): +def test_save_and_load_data(downloader, tmp_path): """ - Test the load_data_from_file method. - :param tmp_path: A temporary directory. - :param downloader: A WorldBankDataDownloader instance. + Test the save_data_to_file and load_data_from_file methods of the WorldBankDataDownloader class. + :param downloader: An instance of the WorldBankDataDownloader class. + :param tmp_path: A temporary directory path. :return: None :rtype: None """ - data = {"('USA', 'GDP')": [{"year": 2020, "value": 100}]} + test_data = { + ('USA', 'SP.POP.TOTL'): [{"year": "2020", "value": "100"}], + ('GBR', 'NY.GDP.MKTP.CD'): [{"year": "2020", "value": "1000"}] + } filename = tmp_path / "test_data.json" - with open(filename, 'w') as f: - json.dump(data, f) + downloader.save_data_to_file(test_data, filename) loaded_data = downloader.load_data_from_file(filename) - assert loaded_data == {('USA', 'GDP'): [{"year": 2020, "value": 100}]} - - -@patch('requests.get') -def test_get_country_codes_error(mock_get, downloader): - """ - Test the get_country_codes method when an error occurs. - :param mock_get: A mock of the requests.get function. - :param downloader: A WorldBankDataDownloader instance. - :return: None - :rtype: None - """ - mock_get.side_effect = requests.exceptions.RequestException("API Error") - country_codes = downloader.get_country_codes() - assert country_codes == [] - - -@patch('requests.get') -def test_get_indicators_error(mock_get, downloader): - """ - Test the get_indicators method when an error occurs. - :param mock_get: A mock of the requests.get function. - :param downloader: A WorldBankDataDownloader instance. - :return: None - :rtype: None - """ - mock_get.side_effect = requests.exceptions.RequestException("API Error") - indicators = downloader.get_indicators() - assert indicators == [] - - -@patch('requests.get') -def test_fetch_data_error(mock_get, downloader): - """ - Test the fetch_data method when an error occurs. - :param mock_get: A mock of the requests.get function. - :param downloader: A WorldBankDataDownloader instance. - :return: None - :rtype: None - """ - mock_get.side_effect = requests.exceptions.RequestException("API Error") - data = downloader.fetch_data('USA', 'GDP') - assert data == [] + assert loaded_data == test_data # Run the tests