Skip to content

Commit

Permalink
[structured config] Migrate resources from project-fully-featured to …
Browse files Browse the repository at this point in the history
…struct config resources
  • Loading branch information
benpankow committed Jan 20, 2023
1 parent 20ca644 commit cb3d3ca
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
from dagster_dbt import dbt_cli_resource
from dagster_pyspark import pyspark_resource

from .common_bucket_s3_pickle_io_manager import common_bucket_s3_pickle_io_manager
from .duckdb_parquet_io_manager import duckdb_partitioned_parquet_io_manager
from .common_bucket_s3_pickle_io_manager import CommonBucketS3PickleIOManager
from .duckdb_parquet_io_manager import DuckDBPartitionedParquetIOManager
from .hn_resource import HNAPIClient, HNAPISubsampleClient
from .parquet_io_manager import (
local_partitioned_parquet_io_manager,
s3_partitioned_parquet_io_manager,
)
from .parquet_io_manager import LocalPartitionedParquetIOManager, S3PartitionedParquetIOManager
from .snowflake_io_manager import SnowflakeIOManager

DBT_PROJECT_DIR = file_relative_path(__file__, "../../dbt_project")
Expand Down Expand Up @@ -49,15 +46,19 @@
"account": os.getenv("SNOWFLAKE_ACCOUNT", ""),
"user": os.getenv("SNOWFLAKE_USER", ""),
"password": os.getenv("SNOWFLAKE_PASSWORD", ""),
"warehouse": "TINY_WAREHOUSE",
"warehouse": "ELEMENTL",
}

RESOURCES_PROD = {
"s3_bucket": "hackernews-elementl-prod",
"io_manager": common_bucket_s3_pickle_io_manager,
"io_manager": CommonBucketS3PickleIOManager(
s3=s3_resource, s3_bucket="hackernews-elementl-prod"
),
"s3": s3_resource,
"parquet_io_manager": s3_partitioned_parquet_io_manager,
"warehouse_io_manager": SnowflakeIOManager(dict(database="DEMO_DB", **SHARED_SNOWFLAKE_CONF)),
"parquet_io_manager": S3PartitionedParquetIOManager(
pyspark=configured_pyspark, s3_bucket="hackernews-elementl-dev"
),
"warehouse_io_manager": SnowflakeIOManager(database="DEMO_DB", **SHARED_SNOWFLAKE_CONF),
"pyspark": configured_pyspark,
"hn_client": HNAPISubsampleClient(subsample_rate=10),
"dbt": dbt_prod_resource,
Expand All @@ -66,22 +67,24 @@

RESOURCES_STAGING = {
"s3_bucket": "hackernews-elementl-dev",
"io_manager": common_bucket_s3_pickle_io_manager,
"io_manager": CommonBucketS3PickleIOManager(
s3=s3_resource, s3_bucket="hackernews-elementl-dev"
),
"s3": s3_resource,
"parquet_io_manager": s3_partitioned_parquet_io_manager,
"warehouse_io_manager": SnowflakeIOManager(
dict(database="DEMO_DB_STAGING", **SHARED_SNOWFLAKE_CONF)
"parquet_io_manager": S3PartitionedParquetIOManager(
pyspark=configured_pyspark, s3_bucket="hackernews-elementl-dev"
),
"warehouse_io_manager": SnowflakeIOManager(database="DEMO_DB_STAGING", **SHARED_SNOWFLAKE_CONF),
"pyspark": configured_pyspark,
"hn_client": HNAPISubsampleClient(subsample_rate=10),
"dbt": dbt_staging_resource,
}


RESOURCES_LOCAL = {
"parquet_io_manager": local_partitioned_parquet_io_manager,
"warehouse_io_manager": duckdb_partitioned_parquet_io_manager.configured(
{"duckdb_path": os.path.join(DBT_PROJECT_DIR, "hackernews.duckdb")},
"parquet_io_manager": LocalPartitionedParquetIOManager(pyspark=configured_pyspark),
"warehouse_io_manager": DuckDBPartitionedParquetIOManager(
duckdb_path=os.path.join(DBT_PROJECT_DIR, "hackernews.duckdb")
),
"pyspark": configured_pyspark,
"hn_client": HNAPIClient(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from dagster import build_init_resource_context, io_manager
from typing import Any

from dagster import build_init_resource_context
from dagster._config.structured_config import ResourceDependency, StructuredConfigIOManagerBase
from dagster._core.storage.io_manager import IOManager
from dagster_aws.s3 import s3_pickle_io_manager


@io_manager(required_resource_keys={"s3_bucket", "s3"})
def common_bucket_s3_pickle_io_manager(init_context):
"""
A version of the s3_pickle_io_manager that gets its bucket from another resource.
"""
return s3_pickle_io_manager(
build_init_resource_context(
config={"s3_bucket": init_context.resources.s3_bucket},
resources={"s3": init_context.resources.s3},
class CommonBucketS3PickleIOManager(StructuredConfigIOManagerBase):
s3_bucket: ResourceDependency[str]
s3: ResourceDependency[Any]

def create_io_manager_to_pass_to_user_code(self, context) -> IOManager:
return s3_pickle_io_manager(
build_init_resource_context(
config={"s3_bucket": self.s3_bucket},
resources={"s3": self.s3},
)
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import duckdb
import pandas as pd
from dagster import (
Field,
PartitionKeyRange,
_check as check,
io_manager,
)
from dagster._seven.temp_dir import get_system_temp_directory

Expand All @@ -16,9 +14,12 @@
class DuckDBPartitionedParquetIOManager(PartitionedParquetIOManager):
"""Stores data in parquet files and creates duckdb views over those files."""

def __init__(self, base_path: str, duckdb_path: str):
super().__init__(base_path=base_path)
self._duckdb_path = check.str_param(duckdb_path, "duckdb_path")
duckdb_path: str
base_path: str = get_system_temp_directory()

@property
def _base_path(self):
return self.base_path

def handle_output(self, context, obj):
if obj is not None: # if this is a dbt output, then the value will be None
Expand Down Expand Up @@ -63,15 +64,4 @@ def _schema(self, context) -> str:
return f"{context.asset_key.path[-2]}"

def _connect_duckdb(self):
return duckdb.connect(database=self._duckdb_path, read_only=False)


@io_manager(
config_schema={"base_path": Field(str, is_required=False), "duckdb_path": str},
required_resource_keys={"pyspark"},
)
def duckdb_partitioned_parquet_io_manager(init_context):
return DuckDBPartitionedParquetIOManager(
base_path=init_context.resource_config.get("base_path", get_system_temp_directory()),
duckdb_path=init_context.resource_config["duckdb_path"],
)
return duckdb.connect(database=self.duckdb_path, read_only=False)
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from typing import Any, Dict, Optional

import requests
from dagster._config.structured_config import Resource
from dagster._utils import file_relative_path
from dagster._utils.cached_method import cached_method

HNItemRecord = Dict[str, Any]

HN_BASE_URL = "https://hacker-news.firebaseio.com/v0"


class HNClient(ABC):
class HNClient(Resource, ABC):
@abstractmethod
def fetch_item_by_id(self, item_id: int) -> Optional[HNItemRecord]:
pass
Expand Down Expand Up @@ -39,19 +41,20 @@ def min_item_id(self) -> int:


class HNSnapshotClient(HNClient):
def __init__(self):
@cached_method
def load_items(self) -> Dict[str, HNItemRecord]:
file_path = file_relative_path(__file__, "../utils/snapshot.gzip")
with gzip.open(file_path, "r") as f:
self._items: Dict[str, HNItemRecord] = json.loads(f.read().decode())
return json.loads(f.read().decode())

def fetch_item_by_id(self, item_id: int) -> Optional[HNItemRecord]:
return self._items.get(str(item_id))
return self.load_items().get(str(item_id))

def fetch_max_item_id(self) -> int:
return int(list(self._items.keys())[-1])
return int(list(self.load_items().keys())[-1])

def min_item_id(self) -> int:
return int(list(self._items.keys())[0])
return int(list(self.load_items().keys())[0])


class HNAPISubsampleClient(HNClient):
Expand All @@ -60,9 +63,8 @@ class HNAPISubsampleClient(HNClient):
which is useful for testing / demoing purposes.
"""

def __init__(self, subsample_rate):
self._items = {}
self.subsample_rate = subsample_rate
subsample_rate: int
_items: Dict[int, HNItemRecord] = {}

def fetch_item_by_id(self, item_id: int) -> Optional[HNItemRecord]:
# map self.subsample_rate items to the same item_id, caching it for faster performance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
from typing import Union

import pandas
import pyspark
from dagster import (
Field,
InputContext,
IOManager,
OutputContext,
_check as check,
io_manager,
)
from dagster._config.structured_config import ResourceDependency, StructuredConfigIOManager
from dagster._seven.temp_dir import get_system_temp_directory
from dagster_pyspark.resources import PySparkResource
from pyspark.sql import DataFrame as PySparkDataFrame


class PartitionedParquetIOManager(IOManager):
class PartitionedParquetIOManager(StructuredConfigIOManager):
"""
This IOManager will take in a pandas or pyspark dataframe and store it in parquet at the
specified path.
Expand All @@ -25,12 +24,13 @@ class PartitionedParquetIOManager(IOManager):
to where the data is stored.
"""

def __init__(self, base_path):
self._base_path = base_path
pyspark: ResourceDependency[PySparkResource]

def handle_output(
self, context: OutputContext, obj: Union[pandas.DataFrame, pyspark.sql.DataFrame]
):
@property
def _base_path(self):
raise NotImplementedError()

def handle_output(self, context: OutputContext, obj: Union[pandas.DataFrame, PySparkDataFrame]):
path = self._get_path(context)
if "://" not in self._base_path:
os.makedirs(os.path.dirname(path), exist_ok=True)
Expand All @@ -39,19 +39,19 @@ def handle_output(
row_count = len(obj)
context.log.info(f"Row count: {row_count}")
obj.to_parquet(path=path, index=False)
elif isinstance(obj, pyspark.sql.DataFrame):
elif isinstance(obj, PySparkDataFrame):
row_count = obj.count()
obj.write.parquet(path=path, mode="overwrite")
else:
raise Exception(f"Outputs of type {type(obj)} not supported.")

context.add_output_metadata({"row_count": row_count, "path": path})

def load_input(self, context) -> Union[pyspark.sql.DataFrame, str]:
def load_input(self, context) -> Union[PySparkDataFrame, str]:
path = self._get_path(context)
if context.dagster_type.typing_type == pyspark.sql.DataFrame:
if context.dagster_type.typing_type == PySparkDataFrame:
# return pyspark dataframe
return context.resources.pyspark.spark_session.read.parquet(path)
return self.pyspark.spark_session.read.parquet(path)

return check.failed(
f"Inputs of type {context.dagster_type} not supported. Please specify a valid type "
Expand All @@ -70,16 +70,17 @@ def _get_path(self, context: Union[InputContext, OutputContext]):
return os.path.join(self._base_path, f"{key}.pq")


@io_manager(
config_schema={"base_path": Field(str, is_required=False)},
required_resource_keys={"pyspark"},
)
def local_partitioned_parquet_io_manager(init_context):
return PartitionedParquetIOManager(
base_path=init_context.resource_config.get("base_path", get_system_temp_directory())
)
class LocalPartitionedParquetIOManager(PartitionedParquetIOManager):
base_path: str = get_system_temp_directory()

@property
def _base_path(self):
return self.base_path


class S3PartitionedParquetIOManager(PartitionedParquetIOManager):
s3_bucket: ResourceDependency[str]

@io_manager(required_resource_keys={"pyspark", "s3_bucket"})
def s3_partitioned_parquet_io_manager(init_context):
return PartitionedParquetIOManager(base_path="s3://" + init_context.resources.s3_bucket)
@property
def _base_path(self):
return "s3://" + self.s3_bucket
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from datetime import datetime
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

from dagster import InputContext, IOManager, MetadataValue, OutputContext, TableColumn, TableSchema
from dagster import InputContext, MetadataValue, OutputContext, TableColumn, TableSchema
from dagster._config.structured_config import StructuredConfigIOManager
from pandas import (
DataFrame as PandasDataFrame,
read_sql,
Expand Down Expand Up @@ -36,14 +37,21 @@ def connect_snowflake(config, schema="public"):
conn.close()


class SnowflakeIOManager(IOManager):
class SnowflakeIOManager(StructuredConfigIOManager):
"""
This IOManager can handle outputs that are either Spark or Pandas DataFrames. In either case,
the data will be written to a Snowflake table specified by metadata on the relevant Out.
"""

def __init__(self, config):
self._config = config
account: str
user: str
password: str
database: str
warehouse: str

@property
def _config(self):
return self.dict()

def handle_output(self, context: OutputContext, obj: Union[PandasDataFrame, SparkDataFrame]):
schema, table = context.asset_key.path[-2], context.asset_key.path[-1] # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from project_fully_featured_v2_resources.assets import core
from project_fully_featured_v2_resources.resources.hn_resource import HNSnapshotClient
from project_fully_featured_v2_resources.resources.parquet_io_manager import (
local_partitioned_parquet_io_manager,
LocalPartitionedParquetIOManager,
)


Expand All @@ -23,11 +23,10 @@ def test_download():
"io_manager": fs_io_manager.configured({"base_dir": temp_dir}),
"partition_start": ResourceDefinition.string_resource(),
"partition_end": ResourceDefinition.string_resource(),
"parquet_io_manager": local_partitioned_parquet_io_manager.configured(
{"base_path": temp_dir}
"parquet_io_manager": LocalPartitionedParquetIOManager(
base_path=temp_dir, pyspark=pyspark_resource
),
"warehouse_io_manager": mem_io_manager,
"pyspark": pyspark_resource,
"hn_client": HNSnapshotClient(),
"dbt": ResourceDefinition.none_resource(),
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dagster_pyspark import pyspark_resource
from project_fully_featured_v2_resources.partitions import hourly_partitions
from project_fully_featured_v2_resources.resources.parquet_io_manager import (
local_partitioned_parquet_io_manager,
LocalPartitionedParquetIOManager,
)
from pyspark.sql import DataFrame as SparkDF

Expand All @@ -30,9 +30,8 @@ def spark_input_asset(pandas_df_asset: SparkDF):
res = materialize(
assets=[pandas_df_asset, spark_input_asset],
resources={
"pyspark": pyspark_resource,
"io_manager": local_partitioned_parquet_io_manager.configured(
{"base_path": temp_dir}
"io_manager": LocalPartitionedParquetIOManager(
pyspark=pyspark_resource, base_path=temp_dir
),
},
partition_key="2022-01-01-16:00",
Expand Down
Loading

0 comments on commit cb3d3ca

Please sign in to comment.