Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New operations updates #504

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 98 additions & 43 deletions ddpui/api/transform_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,8 @@ def sync_sources(request):
@transformapi.post("/dbt_project/model/", auth=auth.CanManagePipelines())
def post_construct_dbt_model_operation(request, payload: CreateDbtModelPayload):
"""
Construct a model, operation and the edge in django db
same api will be used for creating the under_construction model and chaining operations
input_uuid(s) is required for the first model in the chain
Construct a model or chain operations on a under construction target model
"""
OP_CONFIG = payload.config

orguser: OrgUser = request.orguser
org = orguser.org
Expand All @@ -174,13 +171,13 @@ def post_construct_dbt_model_operation(request, payload: CreateDbtModelPayload):
if payload.op_type not in dbtautomation_service.OPERATIONS_DICT.keys():
raise HttpError(422, "Operation not supported")

target_model = None
if payload.model_uuid:
target_model = OrgDbtModel.objects.filter(uuid=payload.model_uuid).first()
is_multi_input_op = payload.op_type == "join"

# only under construction models can be modified
if target_model and not target_model.under_construction:
raise HttpError(422, "model is locked")
target_model = None
if payload.target_model_uuid:
target_model = OrgDbtModel.objects.filter(
uuid=payload.target_model_uuid
).first()

if not target_model:
target_model = OrgDbtModel.objects.create(
Expand All @@ -189,23 +186,48 @@ def post_construct_dbt_model_operation(request, payload: CreateDbtModelPayload):
under_construction=True,
)

# only under construction models can be modified
if not target_model.under_construction:
raise HttpError(422, "model is locked")

current_operations_chained = OrgDbtOperation.objects.filter(
dbtmodel=target_model
).count()

# input things for the first operation to be chained
if current_operations_chained == 0:
logger.info("Chaining the first operation")
logger.info("Making sure atleast one input orgdbtmodel is present")
input_models: list[OrgDbtModel] = []
seq: list[int] = []
other_input_columns: list[list[str]] = []

if not payload.input_uuids or len(payload.input_uuids) == 0:
raise HttpError(422, "input is required for the first model in the chain")
logger.info(
f"Operations chained for the target model {target_model.uuid} : {current_operations_chained}"
)

input_models = OrgDbtModel.objects.filter(uuid__in=payload.input_uuids).all()
if len(input_models) != len(payload.input_uuids):
if current_operations_chained == 0:
if not payload.input_uuid:
raise HttpError(422, "input is required")

model = OrgDbtModel.objects.filter(uuid=payload.input_uuid).first()
if not model:
raise HttpError(404, "input not found")

# create edge if it doesn't exist
input_models.append(model)

if is_multi_input_op: # multi input operation
if len(payload.other_inputs) == 0:
raise HttpError(422, "atleast 2 inputs are required for this operation")

payload.other_inputs.sort(key=lambda x: x.seq)

for other_input in payload.other_inputs:
model = OrgDbtModel.objects.filter(uuid=other_input.uuid).first()
if not model:
raise HttpError(404, "input not found")
seq.append(other_input.seq)
other_input_columns.append(other_input.columns)
input_models.append(model)

# we create edges only with tables/models at the start of the chain & not operation nodes
if current_operations_chained == 0:
for source in input_models:
edge = DbtEdge.objects.filter(
from_node=source, to_node=target_model
Expand All @@ -219,16 +241,43 @@ def post_construct_dbt_model_operation(request, payload: CreateDbtModelPayload):
logger.info("passed all validation; moving to create operation")

# source columns or selected columns
OP_CONFIG["source_columns"] = payload.select_columns
# there will be atleast one input
OP_CONFIG = payload.config
OP_CONFIG["source_columns"] = payload.source_columns
OP_CONFIG["other_inputs"] = []

# in case of mutli input; send the rest of the inputs in the config; dbt_automation will handle the rest
for dbtmodel, seq, columns in zip(input_models, seq, other_input_columns):
OP_CONFIG["other_inputs"].append(
{
"input": {
"input_type": dbtmodel.type,
"input_name": dbtmodel.name,
"source_name": dbtmodel.source_name,
},
"source_columns": columns,
"seq": seq,
}
)

input_config = {
"config": OP_CONFIG,
"type": payload.op_type,
"input_uuids": payload.input_uuids if current_operations_chained == 0 else [],
"input_models": [
{
"uuid": str(model.uuid),
"name": model.name,
"display_name": model.display_name,
"source_name": model.source_name,
"schema": model.schema,
"type": model.type,
}
for model in input_models
],
}
output_cols = dbtautomation_service.get_output_cols_for_operation(
org_warehouse, payload.op_type, OP_CONFIG.copy()
)

logger.info("creating operation")

dbt_op = OrgDbtOperation.objects.create(
Expand All @@ -248,11 +297,11 @@ def post_construct_dbt_model_operation(request, payload: CreateDbtModelPayload):
logger.info("updated output cols for the model")

return {
"id": target_model.uuid,
"input_type": target_model.type,
"source_name": target_model.source_name,
"input_name": target_model.name,
"schema": target_model.schema,
"id": dbt_op.uuid,
"output_cols": dbt_op.output_cols,
"config": dbt_op.config,
"type": "operation_node",
"target_model_id": dbt_op.dbtmodel.uuid,
}


Expand Down Expand Up @@ -374,21 +423,24 @@ def get_dbt_project_DAG(request):
):
operation_nodes.append(operation)

if operation.seq == 1:
input_uuids = operation.config["input_uuids"]
if input_uuids and len(input_uuids) > 0:
# edge(s) between the node(s) and their first operation
for op_src_node in OrgDbtModel.objects.filter(
uuid__in=input_uuids
).all():
res_edges.append(
{
"id": str(op_src_node.uuid) + "_" + str(operation.uuid),
"source": op_src_node.uuid,
"target": operation.uuid,
}
)
else:
if (
"input_models" in operation.config
and len(operation.config["input_models"]) > 0
):
input_models = operation.config["input_models"]
# edge(s) between the node(s) and other sources involved that are tables (OrgDbtModel)
for op_src_node in OrgDbtModel.objects.filter(
uuid__in=[model["uuid"] for model in input_models]
).all():
model_nodes.append(op_src_node)
res_edges.append(
{
"id": str(op_src_node.uuid) + "_" + str(operation.uuid),
"source": op_src_node.uuid,
"target": operation.uuid,
}
)
if operation.seq >= 2:
# for chained operations for seq >= 2
res_edges.append(
{
Expand Down Expand Up @@ -442,7 +494,10 @@ def get_dbt_project_DAG(request):
res["nodes"] = [
nn for nn in res_nodes if not (nn["id"] in seen or seen.add(nn["id"]))
]
res["edges"] = res_edges
seen = set()
res["edges"] = [
edg for edg in res_edges if not (edg["id"] in seen or seen.add(edg["id"]))
]

return res

Expand Down
21 changes: 15 additions & 6 deletions ddpui/core/dbtautomation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from dbt_automation.operations.flattenairbyte import flatten_operation

# operations
from dbt_automation.operations.flattenjson import flattenjson, flattenjson_dbt_sql

# from dbt_automation.operations.mergetables import union_tables, union_tables_sql
Expand All @@ -35,6 +34,10 @@
sync_sources,
generate_source_definitions_yaml,
)
from dbt_automation.operations.joins import join, joins_sql
from dbt_automation.operations.groupby import groupby, groupby_dbt_sql
from dbt_automation.operations.wherefilter import where_filter, where_filter_sql
from dbt_automation.operations.mergetables import union_tables, union_tables_sql
from dbt_automation.utils.warehouseclient import get_client
from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.utils.dbtsources import read_sources
Expand All @@ -58,18 +61,24 @@
"dropcolumns": drop_columns,
"renamecolumns": rename_columns,
"regexextraction": regex_extraction,
"join": join,
"groupby": groupby,
"where": where_filter,
}

OPERATIONS_DICT_SQL = {
"flattenjson": flattenjson_dbt_sql,
# "unionall": union_tables_sql,
"castdatatypes": cast_datatypes_sql,
# "unionall": union_tables_sql,
"coalescecolumns": coalesce_columns_dbt_sql,
"arithmetic": arithmetic_dbt_sql,
"concat": concat_columns_dbt_sql,
"dropcolumns": drop_columns_dbt_sql,
"renamecolumns": rename_columns_dbt_sql,
"regexextraction": regex_extraction_sql,
"join": joins_sql,
"groupby": groupby_dbt_sql,
"where": where_filter_sql,
}


Expand Down Expand Up @@ -118,19 +127,19 @@ def create_dbt_model_in_project(
wclient = _get_wclient(org_warehouse)

operations = []
input_uuids = []
input_models = []
for operation in (
OrgDbtOperation.objects.filter(dbtmodel=orgdbt_model).order_by("seq").all()
):
if operation.seq == 1:
input_uuids = operation.config["input_uuids"]
input_models = operation.config["input_models"]
operations.append(
{"type": operation.config["type"], "config": operation.config["config"]}
)

merge_input = []
for uuid in input_uuids:
source_model = OrgDbtModel.objects.filter(uuid=uuid).first()
for model in input_models:
source_model = OrgDbtModel.objects.filter(uuid=model["uuid"]).first()
if source_model:
merge_input.append(
{
Expand Down
18 changes: 14 additions & 4 deletions ddpui/schemas/dbt_workflow_schema.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from ninja import Field, Schema
import uuid


class InputModelPayload(Schema):
"""
Schema to be expected when we are creating models in a chain
"""

uuid: str
columns: list[str] = []
seq: int = 1


class CreateDbtModelPayload(Schema):
"""
schema to define the payload required to create a custom org task
"""

model_uuid: str
select_columns: list[str]
config: dict
op_type: str
input_uuids: list[str] = []
target_model_uuid: str = ""
input_uuid: str = ""
source_columns: list[str] = []
other_inputs: list[InputModelPayload] = []


class CompleteDbtModelPayload(Schema):
Expand Down
Loading