Skip to content

Commit

Permalink
Improved performance and added logic to query by Country
Browse files Browse the repository at this point in the history
  • Loading branch information
arturogonzalezm committed Jul 20, 2024
1 parent 888fd2c commit e14ebbb
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 140 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Here's a basic example of how to use the `WorldBankDataDownloader`:
```python
downloader = WorldBankDataDownloader()
all_data = downloader.download_all_data()
downloader.save_data_to_file(all_data, 'data/world_bank_data.json')
downloader.save_data_to_file(all_data, 'data/world_bank_data_optimised.json')
```

## Unit Tests
Expand Down
25 changes: 22 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,38 @@
"""
This is the main file to run the WorldBankDataDownloader.
"""
from src.world_bank_data_downloader import WorldBankDataDownloader
from src.utilts.singleton_logger import SingletonLogger
from src.utilts.world_bank_data_downloader import WorldBankDataDownloader


def main():
"""
Main function to download all data from the World Bank API.
Main function to download data from the World Bank API.
:return: None
:rtype: None
"""
# Set up logging
logger = SingletonLogger().get_logger()

# Initialize the downloader
downloader = WorldBankDataDownloader()
all_data = downloader.download_all_data()

# Get all country codes and indicator codes
country_codes = downloader.country_codes
indicator_codes = downloader.indicator_codes

# Download data for all countries and indicators
all_data = {}
for country_code in country_codes:
logger.info(f"Downloading data for country: {country_code}")
country_data = downloader.fetch_data_concurrently(country_code, indicator_codes)
all_data[country_code] = country_data

# Save the data to a file
downloader.save_data_to_file(all_data)

logger.info("Data download and save completed.")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pytest==8.2.2
coverage==7.6.0
pytest-cov==5.0.0
tenacity==8.5.0
aiohttp
Empty file added src/extract/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions src/extract/world_bank_indicators_by_country.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
This script downloads data from the World Bank API for a specific country (Australia in this case) and saves it to a file.
"""
from src.utilts.singleton_logger import SingletonLogger
from src.utilts.world_bank_data_downloader import WorldBankDataDownloader


def main():
"""
Main function to download data from the World Bank API.
:return: None
:rtype: None
"""
logger = SingletonLogger().get_logger()
country_code = 'AUS'

# Create an instance of the WorldBankDataDownloader
downloader = WorldBankDataDownloader()

# Get the list of all indicators
indicators = downloader.get_indicators()

# Fetch data for Australia (AUS) for all indicators concurrently
australia_data = downloader.fetch_data_concurrently(country_code, indicators)

# Save the data to a file
filename = f'../data/raw/{country_code}_world_bank_data.json'
downloader.save_data_to_file(australia_data, filename=filename)

# Load the data from the file (for verification)
loaded_data = downloader.load_data_from_file(filename=filename)

# Print the loaded data (or a subset of it)
for indicator, data in loaded_data.items():
logger.info(f"Indicator: {indicator}")
for entry in data:
logger.info(entry)
logger.info("\n")


if __name__ == '__main__':
main()
41 changes: 41 additions & 0 deletions src/utilts/singleton_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

"""
SingletonLogger class is a singleton class that provides a single instance of logger object.
"""

import logging
import threading


class SingletonLogger:
_instance = None
_lock = threading.RLock()

def __new__(cls, logger_name=None, log_level=logging.DEBUG, log_format=None):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
cls._instance._initialize_logger(logger_name, log_level, log_format)
return cls._instance

def _initialize_logger(self, logger_name, log_level, log_format):
self._logger = logging.getLogger(logger_name or __name__)
self._logger.setLevel(log_level)

console_handler = logging.StreamHandler()
console_handler.setLevel(log_level)

if log_format:
formatter = logging.Formatter(log_format)
else:
formatter = logging.Formatter("%(asctime)s - %(threadName)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)

self._logger.addHandler(console_handler)

def get_logger(self):
return self._logger


# logger = SingletonLogger().get_logger()
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""
This module contains a class for downloading data from the World Bank API.
This module contains the WorldBankDataDownloader class, which is used to download data from the World Bank API.
"""

import os
import time
import json
import logging
import requests

from concurrent.futures import ThreadPoolExecutor, as_completed
from tenacity import retry, stop_after_attempt, wait_exponential

from src.utilts.singleton_logger import SingletonLogger


class WorldBankDataDownloader:
"""Class for downloading data from the World Bank API."""
Expand All @@ -19,6 +21,7 @@ def __init__(self):
Initialize the WorldBankDataDownloader with the base URL, country codes, and indicator codes.
"""
self.base_url = 'http://api.worldbank.org/v2'
self.logger = SingletonLogger().get_logger()
self.country_codes = self.get_country_codes()
self.indicator_codes = self.get_indicators()
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -86,51 +89,51 @@ def fetch_data(self, country_code, indicator_code):
break
return all_pages_data

def download_all_data(self):
def fetch_data_concurrently(self, country_code, indicator_codes, max_workers=10):
"""
Download data for all indicators and country codes.
:return: A dictionary containing the data for all country codes and indicator codes.
Fetch data concurrently for a given country code and a list of indicator codes.
:param country_code: The country code.
:param indicator_codes: The list of indicator codes.
:param max_workers: The maximum number of threads to use.
:return: A dictionary containing the data for the given country code and indicator codes.
"""
all_data = {}
total_requests = len(self.country_codes) * len(self.indicator_codes)
completed_requests = 0

for country_code in self.country_codes:
for indicator_code in self.indicator_codes:
completed_requests += 1
progress = (completed_requests / total_requests) * 100
logging.info(
f"Progress: {progress:.2f}% - Fetching data for country {country_code} and indicator {indicator_code}")

data = self.fetch_data(country_code, indicator_code)
if data:
all_data[(country_code, indicator_code)] = data

time.sleep(0.5) # Add a 0.5-second delay between requests to avoid rate limiting

return all_data
results = {}
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_indicator = {executor.submit(self.fetch_data, country_code, indicator_code): indicator_code for
indicator_code in indicator_codes}
for future in as_completed(future_to_indicator):
indicator_code = future_to_indicator[future]
try:
data = future.result()
if data:
results[indicator_code] = data
except Exception as e:
logging.error(f"Error fetching data for indicator {indicator_code}: {e}")
return results

@staticmethod
def save_data_to_file(data, filename='../data/raw/world_bank_data.json'):
def save_data_to_file(data, filename='../data/raw/world_bank_data_optimised.json'):
"""
Save the data to a JSON file.
:param data: The data to save.
:param filename: The filename to save the data to.
"""
os.makedirs(os.path.dirname(filename), exist_ok=True)
serializable_data = {str(key): value for key, value in data.items()}
# Use a delimiter that is safe and won't appear in the keys
serializable_data = {"__DELIM__".join(key): value for key, value in data.items()}
with open(filename, 'w', encoding='utf-8') as f:
json.dump(serializable_data, f, ensure_ascii=False, indent=4)
logging.info(f"Data saved to {filename}")

@staticmethod
def load_data_from_file(filename='../data/raw/world_bank_data.json'):
def load_data_from_file(filename='../data/raw/world_bank_data_optimised.json'):
"""
Load the data from a JSON file.
:param filename: The filename to load the data from.
:return: The loaded data.
"""
with open(filename, 'r', encoding='utf-8') as f:
data = json.load(f)
deserialized_data = {eval(key): value for key, value in data.items()}
# Use the same delimiter to split the keys back into tuples
deserialized_data = {tuple(key.split("__DELIM__")): value for key, value in data.items()}
return deserialized_data
Loading

0 comments on commit e14ebbb

Please sign in to comment.