Skip to content

feat: get config from YAML files #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/test_scripts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Test Script PR

on:
pull_request:
types: [edited, opened, synchronize, reopened]
branches: [main]

jobs:
image_build_and_test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0

- uses: actions/setup-python@v5
with:
python-version: 3.11
cache: 'pip' # caching pip dependencies

- name: Cache pip dependencies
id: cached-pip-dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: pip-${{ runner.os }}-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
pip-${{ runner.os }}-

- name: Install dependencies
run: |
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt

- name: Test PR
run: |
source venv/bin/activate
pytest
78 changes: 78 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Config"""

from typing import Dict
from yaml import safe_load
import json

SPARK_CONFIG = "spark-config.yaml"
REGEX_PATTERNS_CONFIG = "regex-patterns.yaml"


def get_config(config_file=None):
"""Get config file from YAML"""
if config_file is None:
raise FileNotFoundError("Config file not found")
try:
yaml_config = safe_load(open(config_file))
return yaml_config
except FileNotFoundError:
raise


def get_spark_job_config(app_config) -> Dict:
"""Get spark job config from YAML"""
spark_app_config = dict()
master = app_config.get("master", "local[*]")
driver = app_config.get("driver", {})
executor = app_config.get("executor", {})
spark_app_config["config"] = {}
for key in driver.keys():
spark_app_config["config"][f"spark.driver.{key}"] = driver[key]
for key in executor.keys():
spark_app_config["config"][f"spark.executor.{key}"] = executor[key]
spark_app_config["master"] = master
spark_app_config["app"] = app_config.get("app", {})
return spark_app_config


def parse_spark_config() -> Dict:
"""Parse Spark config"""
yaml_config = get_config(SPARK_CONFIG)
spark_config = dict()
spark = yaml_config.get("spark", {})
for key in spark.keys():
spark_config[key] = get_spark_job_config(spark[key])
return spark_config


def get_regex_patterns_config() -> Dict:
"""Get regex patterns config"""
yaml_config = get_config(REGEX_PATTERNS_CONFIG)
regex_patterns_config = dict()
regex_patterns = yaml_config.get("regex_patterns", {})
for key in regex_patterns.keys():
regex_patterns_config[key] = regex_patterns[key]
return regex_patterns_config


if __name__ == "__main__":
config = parse_spark_config()
print(json.dumps(config, indent=2))

regex_config = get_regex_patterns_config()
print(json.dumps(regex_config, indent=2))
"""
Generates JSON dict

{
"config": {
"spark.driver.memory": "12g",
"spark.driver.cores": 1,
"spark.executor.memory": "4g",
"spark.executor.instances": 4,
"spark.executor.cores": 1
},
"master": "spark://processing:7077",
"app": "Scan"
}
"""
46 changes: 46 additions & 0 deletions regex-patterns.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
regex_patterns:
house:
- "\\bdeep\\s+house\\b"
- "\\btech\\s+house\\b"
- "\\bprogressive\\s+house\\b"
- "\\btropical\\s+house\\b"
- "\\bfuture\\s+house\\b"
- "\\bacid\\s+house\\b"
- "\\belectro\\s+house\\b"
techno:
- "\\bminimal\\s+techno\\b"
- "\\bdetroit\\s+techno\\b"
- "\\bacid\\s+techno\\b"
- "\\bindustrial\\s+techno\\b"
- "\\bhard\\s+techno\\b"
- "\\btech\\s+techno\\b"
trance:
- "\\bprogressive\\s+trance\\b"
- "\\buplifting\\s+trance\\b"
- "\\bpsychedelic\\s+trance\\b"
- "\\bgoa\\s+trance\\b"
- "\\btech\\s+trance\\b"
- "\\bvocal\\s+trance\\b"
dubstep:
- "\\bbrostep\\b"
- "\\bchillstep\\b"
- "\\briddim\\b"
- "\\bdeep\\s+dubstep\\b"
drum_and_bass:
- "\\bdnb\\b"
- "\\bneurofunk\\b"
- "\\bliquid\\s+funk\\b"
- "\\bjump\\s+up\\b"
- "\\bdarkstep\\b"
- "\\bbreakcore\\b"
electro:
- "\\belectro\\b"
- "\\belectro\\s+funk\\b"
- "\\belectro\\s+clash\\b"
- "\\bfuture\\s+electro\\b"
hardcore:
- "\\bhardcore\\b"
- "\\bgabber\\b"
- "\\bhappy\\s+hardcore\\b"
- "\\bdigital\\s+hardcore\\b"
- "\\bbreakbeat\\s+hardcore\\b"
21 changes: 12 additions & 9 deletions repartition.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from pyspark.sql import SparkSession
from pyspark import StorageLevel
from config import parse_spark_config

import math

cfg = parse_spark_config()
repartition_config = cfg['repartition']

# Initialize Spark session
spark = SparkSession.builder \
.master('spark://processing:7077') \
.config("spark.driver.memory", "16g") \
.config("spark.driver.cores", "1") \
.config("spark.executor.memory", "32g") \
.config("spark.executor.instances", "1") \
.config("spark.executor.cores", "1") \
.appName('RepartitionFile') \
.getOrCreate()
spark_session = SparkSession.builder \
.master(cfg["master"]) \
.appName(cfg["app"]) \

for keys in repartition_config["config"].keys():
spark_session.config(keys, repartition_config["config"][keys])

spark = spark_session.getOrCreate()

# File Size for repartition targets and Input/Output file paths
file_size_gb = 11 # This is the approximate size of the input parquet files, should probably be loaded from disk
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ pandas
fastparquet # used by the download wikipedia script to save the dataset
parquet-tools # this is used for CLI reading/inspection of parquet files

# Config Reading
pyyaml
pytest
21 changes: 12 additions & 9 deletions scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
from pyspark.sql.types import StringType, ArrayType
from pyspark import StorageLevel
from utils import regex_patterns
from config import parse_spark_config

import re

spark = SparkSession.builder \
.master('spark://processing:7077') \
.config("spark.driver.memory", "12g") \
.config("spark.driver.cores", "1") \
.config("spark.executor.memory", "4g") \
.config("spark.executor.instances", "4") \
.config("spark.executor.cores", "1") \
.appName('Scan') \
.getOrCreate()
cfg = parse_spark_config()
scan_config = cfg['scan']

spark_session = SparkSession.builder \
.master(scan_config["master"]) \
.appName(scan_config["app"]) \

for keys in scan_config["config"].keys():
spark_session.config(keys, scan_config["config"][keys])

spark = spark_session.getOrCreate()

# Compile the combined regex patterns, this significantly speeds up the process and (?i) makes it case insensitive
compiled_patterns = {category: re.compile("(?i)" + "|".join(patterns)) for category, patterns in regex_patterns.items()}
Expand Down
21 changes: 21 additions & 0 deletions spark-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
spark:
scan:
master: spark://processing:7077
driver:
memory: 12g
cores: 1
executor:
memory: 4g
instances: 4
cores: 1
app: Scan
repartition:
master: spark://processing:7077
driver:
memory: 16g
cores: 1
executor:
memory: 32g
instances: 1
cores: 1
app: RepartitionFile
79 changes: 79 additions & 0 deletions test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Test config module."""

import pytest
from unittest.mock import patch, mock_open
from config import get_config, get_spark_job_config, parse_spark_config, get_regex_patterns_config

# Sample YAML content for testing
SPARK_CONFIG_CONTENT = """
spark:
app1:
master: "local[*]"
driver:
memory: "4g"
cores: 2
executor:
memory: "2g"
instances: 1
cores: 1
"""

REGEX_PATTERNS_CONFIG_CONTENT = """
regex_patterns:
pattern1:
- '^[a-zA-Z0-9_]+$'
pattern2:
- '^\\d{3}-\\d{2}-\\d{4}$'
"""


def test_get_config_file_not_found():
with pytest.raises(FileNotFoundError):
get_config()


def test_get_config_success():
with patch("builtins.open", mock_open(read_data=SPARK_CONFIG_CONTENT)):
config = get_config("spark-config.yaml")
assert config["spark"]["app1"]["master"] == "local[*]"


def test_get_spark_job_config():
app_config = {
"master": "local[*]",
"driver": {
"memory": "4g",
"cores": 2
},
"executor": {
"memory": "2g",
"instances": 1,
"cores": 1
}
}
expected_config = {
"config": {
"spark.driver.memory": "4g",
"spark.driver.cores": 2,
"spark.executor.memory": "2g",
"spark.executor.instances": 1,
"spark.executor.cores": 1
},
"master": "local[*]",
"app": {}
}
assert get_spark_job_config(app_config) == expected_config


def test_parse_spark_config():
with patch("builtins.open", mock_open(read_data=SPARK_CONFIG_CONTENT)):
with patch("config.SPARK_CONFIG", "spark-config.yaml"):
config = parse_spark_config()
assert config["app1"]["config"]["spark.driver.memory"] == "4g"


def test_get_regex_patterns_config():
with patch("builtins.open", mock_open(read_data=REGEX_PATTERNS_CONFIG_CONTENT)):
with patch("config.REGEX_PATTERNS_CONFIG", "regex-patterns.yaml"):
config = get_regex_patterns_config()
assert config["pattern1"] == ["^[a-zA-Z0-9_]+$"]
26 changes: 26 additions & 0 deletions test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Test Utils."""

import pytest
from unittest.mock import patch, mock_open
from config import get_regex_patterns_config

# Sample YAML content for testing
REGEX_PATTERNS_CONFIG_CONTENT = """
regex_patterns:
pattern1:
- '^[a-zA-Z0-9_]+$'
pattern2:
- '^\\d{3}-\\d{2}-\\d{4}$'
"""


def test_get_regex_patterns_config():
with patch("builtins.open", mock_open(read_data=REGEX_PATTERNS_CONFIG_CONTENT)):
with patch("config.REGEX_PATTERNS_CONFIG", "regex-patterns.yaml"):
regex_patterns = get_regex_patterns_config()
assert regex_patterns["pattern1"] == ["^[a-zA-Z0-9_]+$"]
assert regex_patterns["pattern2"] == ["^\\d{3}-\\d{2}-\\d{4}$"]


if __name__ == "__main__":
pytest.main()
Loading