From 951460900c57d37e9ef2c7a0a6a4973b01d057e1 Mon Sep 17 00:00:00 2001 From: Daniel Rizk Date: Fri, 4 Oct 2024 14:30:03 -0400 Subject: [PATCH 01/10] add mdy, ymd, dmy and more tests --- NEWS.md | 4 ++ README.md | 2 +- docs/examples/UserGuide/from_queryex.jl | 4 +- src/joins_sq.jl | 86 ++++++++++++++++++++----- src/parsing_athena.jl | 6 ++ src/parsing_clickhouse.jl | 6 ++ src/parsing_duckdb.jl | 6 ++ src/parsing_gbq.jl | 6 ++ src/parsing_mysql.jl | 6 ++ src/parsing_postgres.jl | 6 ++ src/parsing_snowflake.jl | 6 ++ test/Project.toml | 3 +- test/comp_tests.jl | 39 +++++++++-- test/runtests.jl | 1 + 14 files changed, 156 insertions(+), 25 deletions(-) diff --git a/NEWS.md b/NEWS.md index a3abd6c..07bef6a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,4 +1,8 @@ # TidierDB.jl updates +## v0.4.2 - 2024-10-06 +- add `dmy`, `mdy`, `ymd` support for most backends +- add Date parsing and filtering tests + ## v0.4.1 - 2024-10-02 - Adds 50 tests comparing TidierDB to TidierData to assure accuracy across a complex chains of operations, including combinations of `@mutate`, `@summarize`, `@filter`, `@select`, `@group_by` and `@join` operations. diff --git a/README.md b/README.md index 355f103..507b080 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ TidierDB.jl currently supports the following top-level macros: | **Utility** | `@show_query`, `@collect`, `@head`, `@count`, `show_tables`, `@create_view` , `drop_view` | | **Helper Functions** | `across`, `desc`, `if_else`, `case_when`, `n`, `starts_with`, `ends_with`, `contains`, `as_float`, `as_integer`, `as_string`, `is_missing`, `missing_if`, `replace_missing` | | **TidierStrings.jl Functions** | `str_detect`, `str_replace`, `str_replace_all`, `str_remove_all`, `str_remove` | -| **TidierDates.jl Functions** | `year`, `month`, `day`, `hour`, `min`, `second`, `floor_date`, `difftime` | +| **TidierDates.jl Functions** | `year`, `month`, `day`, `hour`, `min`, `second`, `floor_date`, `difftime`, `mdy`, `ymd`, `dmy` | | **Aggregate Functions** | `mean`, `minimum`, `maximum`, `std`, `sum`, `cumsum`, `cor`, `cov`, `var`, all SQL aggregate `@summarize` supports any SQL aggregate function in addition to the list above. Simply write the function as written in SQL syntax and it will work. diff --git a/docs/examples/UserGuide/from_queryex.jl b/docs/examples/UserGuide/from_queryex.jl index ffafac3..ec9ab41 100644 --- a/docs/examples/UserGuide/from_queryex.jl +++ b/docs/examples/UserGuide/from_queryex.jl @@ -12,7 +12,7 @@ # Start a query to analyze fuel efficiency by number of cylinders. However, to further build on this query later, end the chain without using `@show_query` or `@collect` # ```julia -# query = DB.@chain DB.t(mtcars) begin +# query = DB.@chain DB.t(query) begin # DB.@group_by cyl # DB.@summarize begin # across(mpg, (mean, minimum, maximum)) @@ -30,7 +30,7 @@ # ## `from_query()` or `t(query)` # Now, `from_query`, or `t()` a convienece wrapper, will allow you to reuse the query to calculate the average horsepower for each efficiency category # ```julia -# DB.@chain DB.t(mtcars) begin +# DB.@chain DB.t(query) begin # DB.@left_join("mtcars2", cyl, cyl) # DB.@group_by(efficiency) # DB.@summarize(avg_hp = mean(hp)) diff --git a/src/joins_sq.jl b/src/joins_sq.jl index 08f2485..0721b4e 100644 --- a/src/joins_sq.jl +++ b/src/joins_sq.jl @@ -45,7 +45,33 @@ function finalize_query_jq(sqlquery::SQLQuery, from_clause) complete_query = join(filter(!isempty, query_parts), " ") return complete_query end - +function create_and_add_cte(sq, cte_name) + select_expressions = !isempty(sq.select) ? [sq.select] : ["*"] + cte_sql = " " * join(select_expressions, ", ") * " FROM " * sq.from + if sq.is_aggregated && !isempty(sq.groupBy) + cte_sql *= " " * sq.groupBy + sq.groupBy = "" + end + if !isempty(sq.where) + cte_sql *= " WHERE " * sq.where + sq.where = " " + end + if !isempty(sq.having) + cte_sql *= " " * sq.having + sq.having = " " + end + if !isempty(sq.select) + cte_sql *= " " + sq.select = " * " + end + # Create and add the new CTE + new_cte = CTE(name=string(cte_name), select=cte_sql) + push!(sq.ctes, new_cte) + sq.cte_count += 1 + cte_name = "cte_" * string(sq.cte_count) + most_recent_source = !isempty(sq.ctes) ? "cte_" * string(sq.cte_count - 1) : sq.from + return most_recent_source, cte_name +end """ $docstring_left_join @@ -78,7 +104,7 @@ macro left_join(sqlquery, join_table, lhs_column, rhs_column) cte_name_jq = "jcte_" * string(jq.cte_count) most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from -select_sql_jq = "SELECT * FROM " * most_recent_source_jq + select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) jq.from = cte_name_jq @@ -97,7 +123,10 @@ select_sql_jq = "SELECT * FROM " * most_recent_source_jq end sq.metadata = vcat(sq.metadata, new_metadata) end - + if sq.groupBy != "" + most_recent_source, cte_name = create_and_add_cte(sq, cte_name) + end + join_sql = " " * most_recent_source * ".*, " * get_join_columns(sq.db, join_table_name, $lhs_col_str) * gbq_join_parse(most_recent_source) * " LEFT JOIN " * join_table_name * " ON " * @@ -121,7 +150,7 @@ select_sql_jq = "SELECT * FROM " * most_recent_source_jq jq.cte_count += 1 cte_name_jq = "jcte_" * string(jq.cte_count) # most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from -select_sql_jq = "SELECT * FROM " * most_recent_source_jq + select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) jq.from = cte_name_jq @@ -140,6 +169,9 @@ select_sql_jq = "SELECT * FROM " * most_recent_source_jq end sq.metadata = vcat(sq.metadata, new_metadata) end + if sq.groupBy != "" + most_recent_source, cte_name = create_and_add_cte(sq, cte_name) + end join_clause = " LEFT JOIN " * join_table_name * " ON " * gbq_join_parse(join_table_name) * "." * $lhs_col_str * " = " * gbq_join_parse(sq.from) * "." * $rhs_col_str @@ -204,7 +236,9 @@ macro right_join(sqlquery, join_table, lhs_column, rhs_column) end sq.metadata = vcat(sq.metadata, new_metadata) end - + if sq.groupBy != "" + most_recent_source, cte_name = create_and_add_cte(sq, cte_name) + end join_sql = " " * most_recent_source * ".*, " * get_join_columns(sq.db, join_table_name, $lhs_col_str) * gbq_join_parse(most_recent_source) * " RIGHT JOIN " * join_table_name * " ON " * @@ -247,7 +281,9 @@ macro right_join(sqlquery, join_table, lhs_column, rhs_column) end sq.metadata = vcat(sq.metadata, new_metadata) end - + if sq.groupBy != "" + most_recent_source, cte_name = create_and_add_cte(sq, cte_name) + end join_clause = " RIGHT JOIN " * join_table_name * " ON " * gbq_join_parse(join_table_name) * "." * $lhs_col_str * " = " * gbq_join_parse(sq.from) * "." * $rhs_col_str @@ -311,7 +347,9 @@ macro inner_join(sqlquery, join_table, lhs_column, rhs_column) end sq.metadata = vcat(sq.metadata, new_metadata) end - + if sq.groupBy != "" + most_recent_source, cte_name = create_and_add_cte(sq, cte_name) + end join_sql = " " * most_recent_source * ".*, " * get_join_columns(sq.db, join_table_name, $lhs_col_str) * gbq_join_parse(most_recent_source) * " INNER JOIN " * join_table_name * " ON " * @@ -354,7 +392,9 @@ macro inner_join(sqlquery, join_table, lhs_column, rhs_column) end sq.metadata = vcat(sq.metadata, new_metadata) end - + if sq.groupBy != "" + most_recent_source, cte_name = create_and_add_cte(sq, cte_name) + end join_clause = " INNER JOIN " * join_table_name * " ON " * gbq_join_parse(join_table_name) * "." * $lhs_col_str * " = " * gbq_join_parse(sq.from) * "." * $rhs_col_str @@ -395,9 +435,7 @@ macro full_join(sqlquery, join_table, lhs_column, rhs_column) if needs_new_cte_jq for cte in jq.ctes cte.name = "j" * cte.name - end - jq.from_join = true - + end cte_name_jq = "jcte_" * string(jq.cte_count) most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) @@ -418,8 +456,11 @@ macro full_join(sqlquery, join_table, lhs_column, rhs_column) new_metadata = get_table_metadata_athena(sq.db, join_table_name, sq.athena_params) end sq.metadata = vcat(sq.metadata, new_metadata) + end - + if sq.groupBy != "" + most_recent_source, cte_name = create_and_add_cte(sq, cte_name) + end join_sql = " " * most_recent_source * ".*, " * get_join_columns(sq.db, join_table_name, $lhs_col_str) * gbq_join_parse(most_recent_source) * " FULL JOIN " * join_table_name * " ON " * @@ -462,12 +503,15 @@ macro full_join(sqlquery, join_table, lhs_column, rhs_column) end sq.metadata = vcat(sq.metadata, new_metadata) end - + if sq.groupBy != "" + most_recent_source, cte_name = create_and_add_cte(sq, cte_name) + end join_clause = " FULL JOIN " * join_table_name * " ON " * gbq_join_parse(join_table_name) * "." * $lhs_col_str * " = " * gbq_join_parse(sq.from) * "." * $rhs_col_str sq.from *= join_clause end + else error("Expected sqlquery to be an instance of SQLQuery") end @@ -527,7 +571,9 @@ macro semi_join(sqlquery, join_table, lhs_column, rhs_column) end sq.metadata = vcat(sq.metadata, new_metadata) end - + if sq.groupBy != "" + most_recent_source, cte_name = create_and_add_cte(sq, cte_name) + end join_sql = " " * most_recent_source * ".*, " * get_join_columns(sq.db, join_table_name, $lhs_col_str) * gbq_join_parse(most_recent_source) * " SEMI JOIN " * join_table_name * " ON " * @@ -570,7 +616,9 @@ macro semi_join(sqlquery, join_table, lhs_column, rhs_column) end sq.metadata = vcat(sq.metadata, new_metadata) end - + if sq.groupBy != "" + most_recent_source, cte_name = create_and_add_cte(sq, cte_name) + end join_clause = " SEMI JOIN " * join_table_name * " ON " * gbq_join_parse(join_table_name) * "." * $lhs_col_str * " = " * gbq_join_parse(sq.from) * "." * $rhs_col_str @@ -634,7 +682,9 @@ macro anti_join(sqlquery, join_table, lhs_column, rhs_column) end sq.metadata = vcat(sq.metadata, new_metadata) end - + if sq.groupBy != "" + most_recent_source, cte_name = create_and_add_cte(sq, cte_name) + end join_sql = " " * most_recent_source * ".*, " * get_join_columns(sq.db, join_table_name, $lhs_col_str) * gbq_join_parse(most_recent_source) * " ANTI JOIN " * join_table_name * " ON " * @@ -677,7 +727,9 @@ macro anti_join(sqlquery, join_table, lhs_column, rhs_column) end sq.metadata = vcat(sq.metadata, new_metadata) end - + if sq.groupBy != "" + most_recent_source, cte_name = create_and_add_cte(sq, cte_name) + end join_clause = " ANTI JOIN " * join_table_name * " ON " * gbq_join_parse(join_table_name) * "." * $lhs_col_str * " = " * gbq_join_parse(sq.from) * "." * $rhs_col_str diff --git a/src/parsing_athena.jl b/src/parsing_athena.jl index f62f1e1..4ee2469 100644 --- a/src/parsing_athena.jl +++ b/src/parsing_athena.jl @@ -128,6 +128,12 @@ function expr_to_sql_trino(expr, sq; from_summarize::Bool) return "EXTRACT(SECOND FROM " * string(a) * ")" elseif @capture(x, floordate(time_column_, unit_)) return :(DATE_TRUNC($unit, $time_column)) + elseif @capture(x, ymd(time_)) + return :(date_parse($time, "%Y-%m-%d")) + elseif @capture(x, mdy(time_)) + return :(date_parse($time, "%m-%d-%Y")) + elseif @capture(x, dmy(time_)) + return :(date_parse($time, "%d-%m-%Y")) elseif @capture(x, difftime(endtime_, starttime_, unit_)) return :(date_diff($unit, $starttime, $endtime)) elseif @capture(x, replacemissing(column_, replacement_value_)) diff --git a/src/parsing_clickhouse.jl b/src/parsing_clickhouse.jl index ecd492b..ca93065 100644 --- a/src/parsing_clickhouse.jl +++ b/src/parsing_clickhouse.jl @@ -128,6 +128,12 @@ function expr_to_sql_clickhouse(expr, sq; from_summarize::Bool) return "toSecond(" * string(a) * ")" elseif @capture(x, floordate(time_column_, unit_)) return :(DATE_TRUNC($unit, $time_column)) + elseif @capture(x, ymd(time_)) + return :(formatDateTime($time, "yyyy-MM-dd")) + elseif @capture(x, mdy(time_)) + return :(formatDateTime($time, "MM-dd-yyyy")) + elseif @capture(x, dmy(time_)) + return :(formatDateTime($time, "dd-MM-yyyy")) elseif @capture(x, difftime(endtime_, starttime_, unit_)) return :(date_diff($unit, $starttime, $endtime)) elseif @capture(x, replacemissing(column_, replacement_value_)) diff --git a/src/parsing_duckdb.jl b/src/parsing_duckdb.jl index 2b4c4c5..c262fe4 100644 --- a/src/parsing_duckdb.jl +++ b/src/parsing_duckdb.jl @@ -129,6 +129,12 @@ function expr_to_sql_duckdb(expr, sq; from_summarize::Bool) return :(DATE_TRUNC($unit, $time_column)) elseif @capture(x, difftime(endtime_, starttime_, unit_)) return :(date_diff($unit, $starttime, $endtime)) + elseif @capture(x, ymd(time_)) + return :(STRPTIME($time, "%Y-%m-%d")) + elseif @capture(x, mdy(time_)) + return :(STRPTIME($time, "%m-%d-%Y")) + elseif @capture(x, dmy(time_)) + return :(STRPTIME($time, "%d-%m-%Y")) elseif @capture(x, replacemissing(column_, replacement_value_)) return :(COALESCE($column, $replacement_value)) elseif @capture(x, missingif(column_, value_to_replace_)) diff --git a/src/parsing_gbq.jl b/src/parsing_gbq.jl index e990596..8b5668f 100644 --- a/src/parsing_gbq.jl +++ b/src/parsing_gbq.jl @@ -155,6 +155,12 @@ function expr_to_sql_gbq(expr, sq; from_summarize::Bool) return "EXTRACT(MINUTE FROM " * string(a) * ")" elseif @capture(x, second(a_)) return "EXTRACT(SECOND FROM " * string(a) * ")" + elseif @capture(x, ymd(time_)) + return :(PARSE_DATE($time, "%Y-%m-%d")) + elseif @capture(x, mdy(time_)) + return :(PARSE_DATE($time, "%m-%d-%Y")) + elseif @capture(x, dmy(time_)) + return :(PARSE_DATE($time, "%d-%m-%Y")) elseif @capture(x, floordate(time_column_, unit_)) return :(DATE_TRUNC($unit, $time_column)) elseif @capture(x, difftime(endtime_, starttime_, unit_)) diff --git a/src/parsing_mysql.jl b/src/parsing_mysql.jl index 3ed9989..1814141 100644 --- a/src/parsing_mysql.jl +++ b/src/parsing_mysql.jl @@ -129,6 +129,12 @@ function expr_to_sql_mysql(expr, sq; from_summarize::Bool) elseif @capture(x, floordate(time_column_, unit_)) # Call floordate_to_sql with the captured variables return floordate_to_sql(unit, time_column) + elseif @capture(x, ymd(time_)) + return :(STR_TO_DATE($time, "%Y-%m-%d")) + elseif @capture(x, mdy(time_)) + return :(STR_TO_DATE($time, "%m-%d-%Y")) + elseif @capture(x, dmy(time_)) + return :(STR_TO_DATE($time, "%d-%m-%Y")) elseif @capture(x, replacemissing(column_, replacement_value_)) return :(COALESCE($column, $replacement_value)) elseif @capture(x, missingif(column_, value_to_replace_)) diff --git a/src/parsing_postgres.jl b/src/parsing_postgres.jl index b7b94a6..0793df5 100644 --- a/src/parsing_postgres.jl +++ b/src/parsing_postgres.jl @@ -126,6 +126,12 @@ function expr_to_sql_postgres(expr, sq; from_summarize::Bool) return "EXTRACT(MINUTE FROM " * string(a) * ")" elseif @capture(x, second(a_)) return "EXTRACT(SECOND FROM " * string(a) * ")" + elseif @capture(x, ymd(time_)) + return :(TO_DATE($time, "YYYYMMDD")) + elseif @capture(x, mdy(time_)) + return :(TO_DATE($time, "MMDDYYYY")) + elseif @capture(x, dmy(time_)) + return :(TO_DATE($time, "DDMMYYYY")) elseif @capture(x, floordate(time_column_, unit_)) return :(DATE_TRUNC($unit, $time_column)) elseif @capture(x, replacemissing(column_, replacement_value_)) diff --git a/src/parsing_snowflake.jl b/src/parsing_snowflake.jl index a24f150..ad47ec0 100644 --- a/src/parsing_snowflake.jl +++ b/src/parsing_snowflake.jl @@ -129,6 +129,12 @@ function expr_to_sql_snowflake(expr, sq; from_summarize::Bool) return "EXTRACT(SECOND FROM " * string(a) * ")" elseif @capture(x, floordate(time_column_, unit_)) return :(DATE_TRUNC($unit, $time_column)) + elseif @capture(x, ymd(time_)) + return :(TO_DATE($time, "YYYY-MM-DD")) + elseif @capture(x, mdy(time_)) + return :(TO_DATE($time, "MM-DD-YYYY")) + elseif @capture(x, dmy(time_)) + return :(TO_DATE($time, "DD-MM-YYYY")) elseif @capture(x, replacemissing(column_, replacement_value_)) return :(COALESCE($column, $replacement_value)) elseif @capture(x, missingif(column_, value_to_replace_)) diff --git a/test/Project.toml b/test/Project.toml index 7630d67..a08e3c4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,4 +3,5 @@ TidierData = "fe2206b3-d496-4ee9-a338-6a095c4ece80" TidierDB = "86993f9b-bbba-4084-97c5-ee15961ad48b" TidierStrings = "248e6834-d0f8-40ef-8fbb-8e711d883e9c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" \ No newline at end of file +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +TIdierDates = "20186a3f-b5d3-468e-823e-77aae96fe2d8" \ No newline at end of file diff --git a/test/comp_tests.jl b/test/comp_tests.jl index 0a55a2c..2520798 100644 --- a/test/comp_tests.jl +++ b/test/comp_tests.jl @@ -11,10 +11,10 @@ @test all(Array(TDF_3 .== TDB_3)) end @testset "Group By Summarize" begin - TDF_1 = @chain test_df @group_by(groups) @summarize(value = sum(value)) - TDB_1 = @chain DB.t(test_db) DB.@group_by(groups) DB.@summarize(value = sum(value)) DB.@collect - TDF_2 = @chain test_df @group_by(groups) @summarize(across(value,(mean, minimum, maximum))) - TDB_2 = @chain DB.t(test_db) DB.@group_by(groups) DB.@summarize(across(value, (mean, minimum, maximum))) DB.@collect + TDF_1 = @chain test_df @group_by(groups) @summarize(value = sum(value), n = n()) + TDB_1 = @chain DB.t(test_db) DB.@group_by(groups) DB.@summarize(value = sum(value), n = n()) DB.@collect + TDF_2 = @chain test_df @group_by(groups) @summarize(across(starts_with("v"), (mean, minimum, maximum, std))) + TDB_2 = @chain DB.t(test_db) DB.@group_by(groups) DB.@summarize(across(starts_with("v"), (mean, minimum, maximum, std))) DB.@collect TDF_3 = @chain test_df @group_by(groups) @summarize(across(value,(mean, minimum, maximum))) @mutate(value_mean = value_mean + 4 * 4) TDB_3 = @chain DB.t(test_db) DB.@group_by(groups) DB.@summarize(across(value, (mean, minimum, maximum))) DB.@mutate(value_mean = value_mean + 4 * 4) DB.@collect @test all(Array(TDF_1 .== TDB_1)) @@ -87,6 +87,18 @@ TDB_7 = @chain DB.t(test_db) DB.@anti_join(DB.t(query), id2, id) DB.@collect TDF_8 = @right_join(test_df, @filter(df2, score > 85 && str_detect(id2, "C")), id = id2) TDB_8 = @chain DB.t(test_db) DB.@right_join(DB.t(query), id2, id) DB.@select(!id2) DB.@collect + # mutate in a new category, group by summarize, and then join on two summary tables based on key + TDF_9 = @chain test_df @mutate(category = if_else(value/2 >1, "X", "Y")) @group_by(category) @summarize(percent_mean= mean(percent)) @left_join((@chain df2 @group_by(category) @summarize(score_mean= mean(score))), category = category) + x = @chain DB.t(join_db) DB.@group_by(category) DB.@summarize(score_mean= mean(score)) + TDB_9 = @chain DB.t(test_db) DB.@mutate(category = if_else(value/2 >1, "X", "Y")) DB.@group_by(category) DB.@summarize(percent_mean= mean(percent)) DB.@left_join(DB.t(x), category, category) DB.@select(cte_4.category, percent_mean, score_mean) DB.@collect + TDF_10 = @chain test_df @mutate(category = if_else(value/2 >1, "X", "Y")) @group_by(category) @summarize(percent_mean= mean(percent)) @right_join((@chain df2 @group_by(category) @summarize(score_mean= mean(score))), category = category) + TDB_10 = @chain DB.t(test_db) DB.@mutate(category = if_else(value/2 >1, "X", "Y")) DB.@group_by(category) DB.@summarize(percent_mean= mean(percent)) DB.@right_join(DB.t(x), category, category) DB.@select(cte_4.category, percent_mean, score_mean) DB.@collect + # mutate in a new category, group by summarize, and then join on two summary tables based on key, then mutate and filter on new column + TDF_11 = @chain test_df @mutate(category = if_else(value/2 >1, "X", "Y")) @group_by(category) @summarize(percent_mean= mean(percent)) @right_join((@chain df2 @group_by(category) @summarize(score_mean= mean(score))), category = category) @mutate(test = score_mean^percent_mean) @filter(test > 10) + TDB_11 = @chain DB.t(test_db) DB.@mutate(category = if_else(value/2 >1, "X", "Y")) DB.@group_by(category) DB.@summarize(percent_mean= mean(percent)) DB.@right_join(DB.t(x), category, category) DB.@mutate(test = score_mean^percent_mean) DB.@filter(test > 10) DB.@select(cte_7.category, percent_mean, score_mean, test) DB.@collect + TDF_11 = @chain test_df @mutate(category = if_else(value/2 >1, "X", "Y")) @group_by(category) @summarize(percent_mean= mean(percent)) @full_join((@chain df2 @group_by(category) @summarize(score_mean= mean(score))), category = category) @mutate(test = score_mean^percent_mean) @filter(test > 10) + TDB_11 = @chain DB.t(test_db) DB.@mutate(category = if_else(value/2 >1, "X", "Y")) DB.@group_by(category) DB.@summarize(percent_mean= mean(percent)) DB.@full_join(DB.t(x), category, category) DB.@mutate(test = score_mean^percent_mean) DB.@filter(test > 10) DB.@select(cte_7.category, percent_mean, score_mean, test) DB.@collect + @test all(isequal.(Array(TDF_1), Array(TDB_1))) @test all(isequal.(Array(TDF_2), Array(TDB_2))) @test all(isequal.(Array(TDF_3), Array(TDB_3))) @@ -95,6 +107,9 @@ @test all(isequal.(Array(TDF_6), Array(TDB_6))) @test all(isequal.(Array(TDF_7), Array(TDB_7))) @test all(isequal.(Array(TDF_8), Array(TDB_8))) + @test all(isequal.(Array(TDF_9), Array(TDB_9))) + @test all(isequal.(Array(TDF_10), Array(TDB_10))) + @test all(isequal.(Array(TDF_11), Array(TDB_11))) end @testset "Mutate" begin # simple arithmetic mutates @@ -185,4 +200,20 @@ @test all(isequal.(Array(TDF_1), Array(TBD_1))) @test all(isequal.(Array(TDF_2), Array(TBD_2))) end + @testset "Date Parsing" begin + TDF_1 = @chain test_df @mutate(test = ymd_hms("2023-06-15 00:00:00")) + TDB_1 = @chain DB.t(test_db) DB.@mutate(test = ymd("2023-06-15")) DB.@collect + # Filter by date + TDF_2 = @chain test_df @mutate(test = ymd_hms("2023-06-15 00:00:00")) @filter(test < ymd("2023-04-14")) + TDB_2 = @chain DB.t(test_db) DB.@mutate(test = ymd("2023-06-15")) DB.@filter(test < ymd("2023-04-14")) DB.@collect + TDF_3 = @chain test_df @mutate(test = if_else(groups == "aa", ymd_hms("2023-06-15 00:00:00"), ymd_hms("2024-06-15 00:00:00"))) @filter(test == ymd("2023-06-15")) + TDB_3 = @chain DB.t(test_db) DB.@mutate(test = if_else(groups == "aa", ymd("2023-06-15"), ymd("2024-06-15"))) DB.@filter(test == ymd("2023-06-15")) DB.@collect + TDF_4 = @chain test_df @mutate(test = if_else(groups == "aa", ymd("2023-06-15"), ymd("2024-06-15")), test2= ymd("2020-06-15")) @mutate(tryt = if_else((test - test2) > Day(1095), "old", "young")) @select(tryt) + TDB_4 = @chain DB.t(test_db) DB.@mutate(test = if_else(groups == "aa", ymd("2023-06-15"), ymd("2024-06-15")), test2= ymd("2020-06-15")) DB.@mutate(tryt = if_else(Day((test - test2)) > 1095, "old", "young")) DB.@select(tryt) DB.@collect + @test all(isequal.(Array(TDF_1), Array(TDB_1))) + @test all(isequal.(Array(TDF_2), Array(TDB_2))) + @test all(isequal.(Array(TDF_3), Array(TDB_3))) + @test all(isequal.(Array(TDF_4), Array(TDB_4))) + + end end diff --git a/test/runtests.jl b/test/runtests.jl index 15afd54..52af548 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,7 @@ using TidierData using TidierStrings import TidierDB as DB using Test +using TidierDates test_df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], groups = [i % 2 == 0 ? "aa" : "bb" for i in 1:10], From 96f4bcded901c0ce83dee6fabb6c5b6fbb95b5e4 Mon Sep 17 00:00:00 2001 From: Daniel Rizk Date: Fri, 4 Oct 2024 15:46:54 -0400 Subject: [PATCH 02/10] enable adding date interval.. blended syntax --- src/TidierDB.jl | 2 +- test/comp_tests.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/TidierDB.jl b/src/TidierDB.jl index 97e2531..516b2c3 100644 --- a/src/TidierDB.jl +++ b/src/TidierDB.jl @@ -162,7 +162,7 @@ function finalize_query(sqlquery::SQLQuery) "FROM )" => ")" , "SELECT SELECT " => "SELECT ", "SELECT SELECT " => "SELECT ", "DISTINCT SELECT " => "DISTINCT ", "SELECT SELECT SELECT " => "SELECT ", "PARTITION BY GROUP BY" => "PARTITION BY", "GROUP BY GROUP BY" => "GROUP BY", "HAVING HAVING" => "HAVING", r"var\"(.*?)\"" => s"\1", r"\"\\\$" => "\"\$", "WHERE \"" => "WHERE ", "WHERE \"NOT" => "WHERE NOT", "%')\"" =>"%\")", "NULL)\"" => "NULL)", - "NULL))\"" => "NULL))" + "NULL))\"" => "NULL))", r"(?i)INTERVAL(\d)" => s"INTERVAL \1" ) complete_query = replace(complete_query, ", AS " => " AS ", "OR \"" => "OR ") if current_sql_mode[] == postgres() || current_sql_mode[] == duckdb() || current_sql_mode[] == mysql() || current_sql_mode[] == mssql() || current_sql_mode[] == clickhouse() || current_sql_mode[] == athena() || current_sql_mode[] == gbq() || current_sql_mode[] == oracle() || current_sql_mode[] == snowflake() || current_sql_mode[] == databricks() diff --git a/test/comp_tests.jl b/test/comp_tests.jl index 2520798..24dc30f 100644 --- a/test/comp_tests.jl +++ b/test/comp_tests.jl @@ -208,6 +208,7 @@ TDB_2 = @chain DB.t(test_db) DB.@mutate(test = ymd("2023-06-15")) DB.@filter(test < ymd("2023-04-14")) DB.@collect TDF_3 = @chain test_df @mutate(test = if_else(groups == "aa", ymd_hms("2023-06-15 00:00:00"), ymd_hms("2024-06-15 00:00:00"))) @filter(test == ymd("2023-06-15")) TDB_3 = @chain DB.t(test_db) DB.@mutate(test = if_else(groups == "aa", ymd("2023-06-15"), ymd("2024-06-15"))) DB.@filter(test == ymd("2023-06-15")) DB.@collect + # if_else based on value of date difference TDF_4 = @chain test_df @mutate(test = if_else(groups == "aa", ymd("2023-06-15"), ymd("2024-06-15")), test2= ymd("2020-06-15")) @mutate(tryt = if_else((test - test2) > Day(1095), "old", "young")) @select(tryt) TDB_4 = @chain DB.t(test_db) DB.@mutate(test = if_else(groups == "aa", ymd("2023-06-15"), ymd("2024-06-15")), test2= ymd("2020-06-15")) DB.@mutate(tryt = if_else(Day((test - test2)) > 1095, "old", "young")) DB.@select(tryt) DB.@collect @test all(isequal.(Array(TDF_1), Array(TDB_1))) From 325bbc0b97403cde4d0d0dd03b8dcde8dfdfb7ce Mon Sep 17 00:00:00 2001 From: Daniel Rizk Date: Fri, 4 Oct 2024 16:32:47 -0400 Subject: [PATCH 03/10] adds distinct tests, some readme updates --- NEWS.md | 4 ++-- Project.toml | 2 +- README.md | 2 +- docs/src/index.md | 6 +++--- src/joins_sq.jl | 4 ++-- test/comp_tests.jl | 12 +++++++++++- 6 files changed, 20 insertions(+), 10 deletions(-) diff --git a/NEWS.md b/NEWS.md index 07bef6a..e472beb 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,7 @@ # TidierDB.jl updates -## v0.4.2 - 2024-10-06 +## v0.4.2 - 2024-10-04 - add `dmy`, `mdy`, `ymd` support for most backends -- add Date parsing and filtering tests +- add date related tests ## v0.4.1 - 2024-10-02 - Adds 50 tests comparing TidierDB to TidierData to assure accuracy across a complex chains of operations, including combinations of `@mutate`, `@summarize`, `@filter`, `@select`, `@group_by` and `@join` operations. diff --git a/Project.toml b/Project.toml index 323cf53..4939257 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TidierDB" uuid = "86993f9b-bbba-4084-97c5-ee15961ad48b" authors = ["Daniel Rizk and contributors"] -version = "0.4.1" +version = "0.4.2" [deps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" diff --git a/README.md b/README.md index 507b080..74e894d 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ TidierDB.jl currently supports the following top-level macros: | **Helper Functions** | `across`, `desc`, `if_else`, `case_when`, `n`, `starts_with`, `ends_with`, `contains`, `as_float`, `as_integer`, `as_string`, `is_missing`, `missing_if`, `replace_missing` | | **TidierStrings.jl Functions** | `str_detect`, `str_replace`, `str_replace_all`, `str_remove_all`, `str_remove` | | **TidierDates.jl Functions** | `year`, `month`, `day`, `hour`, `min`, `second`, `floor_date`, `difftime`, `mdy`, `ymd`, `dmy` | -| **Aggregate Functions** | `mean`, `minimum`, `maximum`, `std`, `sum`, `cumsum`, `cor`, `cov`, `var`, all SQL aggregate +| **Aggregate Functions** | `mean`, `minimum`, `maximum`, `std`, `sum`, `cumsum`, `cor`, `cov`, `var`, all aggregate sql fxns `@summarize` supports any SQL aggregate function in addition to the list above. Simply write the function as written in SQL syntax and it will work. `@mutate` supports all builtin SQL functions as well. diff --git a/docs/src/index.md b/docs/src/index.md index bae5a8e..c0eacc2 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -35,11 +35,11 @@ TidierDB.jl currently supports: | **Data Manipulation** | `@arrange`, `@group_by`, `@filter`, `@select`, `@mutate` (supports `across`), `@summarize`/`@summarise` (supports `across`), `@distinct` | | **Joining** | `@left_join`, `@right_join`, `@inner_join`, `@anti_join`, `@full_join`, `@semi_join`, `@union` | | **Slice and Order** | `@slice_min`, `@slice_max`, `@slice_sample`, `@order`, `@window_order`, `@window_frame` | -| **Utility** | `@show_query`, `@collect`, `@head`, `@count`, `show_tables`, `@create_view`, `drop_view` | +| **Utility** | `@show_query`, `@collect`, `@head`, `@count`, `show_tables`, `@create_view` , `drop_view` | | **Helper Functions** | `across`, `desc`, `if_else`, `case_when`, `n`, `starts_with`, `ends_with`, `contains`, `as_float`, `as_integer`, `as_string`, `is_missing`, `missing_if`, `replace_missing` | | **TidierStrings.jl Functions** | `str_detect`, `str_replace`, `str_replace_all`, `str_remove_all`, `str_remove` | -| **TidierDates.jl Functions** | `year`, `month`, `day`, `hour`, `min`, `second`, `floor_date`, `difftime` | -| **Aggregate Functions** | `mean`, `minimum`, `maximum`, `std`, `sum`, `cumsum`, `cor`, `cov`, `var` +| **TidierDates.jl Functions** | `year`, `month`, `day`, `hour`, `min`, `second`, `floor_date`, `difftime`, `mdy`, `ymd`, `dmy` | +| **Aggregate Functions** | `mean`, `minimum`, `maximum`, `std`, `sum`, `cumsum`, `cor`, `cov`, `var`, all aggregate sql fxns `@summarize` supports any SQL aggregate function in addition to the list above. Simply write the function as written in SQL syntax and it will work. `@mutate` supports all builtin SQL functions as well. diff --git a/src/joins_sq.jl b/src/joins_sq.jl index 0721b4e..7545042 100644 --- a/src/joins_sq.jl +++ b/src/joins_sq.jl @@ -786,7 +786,7 @@ macro union(sqlquery, union_query) # Merge CTEs and metadata sq.ctes = vcat(sq.ctes, uq.ctes) - sq.metadata = vcat(sq.metadata, uq.metadata) + # sq.metadata = vcat(sq.metadata, uq.metadata) else # Treat uq as a table name union_sql = "SELECT * FROM " * sq.from * " UNION SELECT * FROM " * string(uq) @@ -796,7 +796,7 @@ macro union(sqlquery, union_query) else new_metadata = get_table_metadata_athena(sq.db, string(uq), sq.athena_params) end - sq.metadata = vcat(sq.metadata, new_metadata) + # sq.metadata = vcat(sq.metadata, new_metadata) end # Create a new CTE for the union diff --git a/test/comp_tests.jl b/test/comp_tests.jl index 24dc30f..06ea6df 100644 --- a/test/comp_tests.jl +++ b/test/comp_tests.jl @@ -211,10 +211,20 @@ # if_else based on value of date difference TDF_4 = @chain test_df @mutate(test = if_else(groups == "aa", ymd("2023-06-15"), ymd("2024-06-15")), test2= ymd("2020-06-15")) @mutate(tryt = if_else((test - test2) > Day(1095), "old", "young")) @select(tryt) TDB_4 = @chain DB.t(test_db) DB.@mutate(test = if_else(groups == "aa", ymd("2023-06-15"), ymd("2024-06-15")), test2= ymd("2020-06-15")) DB.@mutate(tryt = if_else(Day((test - test2)) > 1095, "old", "young")) DB.@select(tryt) DB.@collect + # filter by time with interval change + TDF_5 = @chain test_df @mutate(test = if_else(groups == "aa", ymd_hms("2023-06-15 00:00:00"), ymd_hms("2024-06-15 00:00:00"))) @filter(test > ymd("2023-06-15") - Year(1)) + TDB_5 = @chain DB.t(test_db) DB.@mutate(test = if_else(groups == "aa", ymd("2023-06-15"), ymd("2024-06-15"))) DB.@filter(test > ymd("2023-06-15") - interval1year) DB.@collect @test all(isequal.(Array(TDF_1), Array(TDB_1))) @test all(isequal.(Array(TDF_2), Array(TDB_2))) @test all(isequal.(Array(TDF_3), Array(TDB_3))) @test all(isequal.(Array(TDF_4), Array(TDB_4))) - + @test all(isequal.(Array(TDF_5), Array(TDB_5))) end + @testset "Distinct" begin + query = DB.@chain DB.t(test_db) DB.@mutate(value = value *2) DB.@filter(value > 5) + # using mutate to make the some rows distinct instead of creating a new df + TDF_1 = @chain test_df @bind_rows((@chain test_df @mutate(value = value *2) @filter(value > 5))) @mutate(value = if_else(value > 5, value/2, value)) @distinct() + TDB_1 = @chain DB.t(test_db) DB.@union(DB.t(query)) DB.@mutate(value = if_else(value > 5, value/2, value)) DB.@distinct() DB.@collect + @test all(isequal.(Array(TDF_1), Array(TDB_1))) + end end From ab120266dd018d065e6ce4ccda33f09fc06332d2 Mon Sep 17 00:00:00 2001 From: Daniel Rizk Date: Fri, 4 Oct 2024 16:47:21 -0400 Subject: [PATCH 04/10] one last test to get 60 --- test/comp_tests.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/comp_tests.jl b/test/comp_tests.jl index 06ea6df..2a849c7 100644 --- a/test/comp_tests.jl +++ b/test/comp_tests.jl @@ -96,8 +96,8 @@ # mutate in a new category, group by summarize, and then join on two summary tables based on key, then mutate and filter on new column TDF_11 = @chain test_df @mutate(category = if_else(value/2 >1, "X", "Y")) @group_by(category) @summarize(percent_mean= mean(percent)) @right_join((@chain df2 @group_by(category) @summarize(score_mean= mean(score))), category = category) @mutate(test = score_mean^percent_mean) @filter(test > 10) TDB_11 = @chain DB.t(test_db) DB.@mutate(category = if_else(value/2 >1, "X", "Y")) DB.@group_by(category) DB.@summarize(percent_mean= mean(percent)) DB.@right_join(DB.t(x), category, category) DB.@mutate(test = score_mean^percent_mean) DB.@filter(test > 10) DB.@select(cte_7.category, percent_mean, score_mean, test) DB.@collect - TDF_11 = @chain test_df @mutate(category = if_else(value/2 >1, "X", "Y")) @group_by(category) @summarize(percent_mean= mean(percent)) @full_join((@chain df2 @group_by(category) @summarize(score_mean= mean(score))), category = category) @mutate(test = score_mean^percent_mean) @filter(test > 10) - TDB_11 = @chain DB.t(test_db) DB.@mutate(category = if_else(value/2 >1, "X", "Y")) DB.@group_by(category) DB.@summarize(percent_mean= mean(percent)) DB.@full_join(DB.t(x), category, category) DB.@mutate(test = score_mean^percent_mean) DB.@filter(test > 10) DB.@select(cte_7.category, percent_mean, score_mean, test) DB.@collect + TDF_12 = @chain test_df @mutate(category = if_else(value/2 >1, "X", "Y")) @group_by(category) @summarize(percent_mean= mean(percent)) @full_join((@chain df2 @group_by(category) @summarize(score_mean= mean(score))), category = category) @mutate(test = score_mean^percent_mean) @filter(test > 10) + TDB_12 = @chain DB.t(test_db) DB.@mutate(category = if_else(value/2 >1, "X", "Y")) DB.@group_by(category) DB.@summarize(percent_mean= mean(percent)) DB.@full_join(DB.t(x), category, category) DB.@mutate(test = score_mean^percent_mean) DB.@filter(test > 10) DB.@select(cte_7.category, percent_mean, score_mean, test) DB.@collect @test all(isequal.(Array(TDF_1), Array(TDB_1))) @test all(isequal.(Array(TDF_2), Array(TDB_2))) @@ -110,6 +110,8 @@ @test all(isequal.(Array(TDF_9), Array(TDB_9))) @test all(isequal.(Array(TDF_10), Array(TDB_10))) @test all(isequal.(Array(TDF_11), Array(TDB_11))) + @test all(isequal.(Array(TDF_12), Array(TDB_12))) + end @testset "Mutate" begin # simple arithmetic mutates From 99bf22bef8106c273fae2a81a2e5088390ad924b Mon Sep 17 00:00:00 2001 From: Daniel Rizk Date: Fri, 4 Oct 2024 16:53:27 -0400 Subject: [PATCH 05/10] fix toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index a08e3c4..87375f0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,4 +4,4 @@ TidierDB = "86993f9b-bbba-4084-97c5-ee15961ad48b" TidierStrings = "248e6834-d0f8-40ef-8fbb-8e711d883e9c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -TIdierDates = "20186a3f-b5d3-468e-823e-77aae96fe2d8" \ No newline at end of file +TidierDates = "20186a3f-b5d3-468e-823e-77aae96fe2d8" \ No newline at end of file From cd45f7994ae4788e7bac0489e1a5aa0193c2aed8 Mon Sep 17 00:00:00 2001 From: Daniel Rizk Date: Mon, 7 Oct 2024 21:34:06 -0400 Subject: [PATCH 06/10] fix gbq intnl issue,improve gbq type mapping --- NEWS.md | 6 ++-- Project.toml | 1 + ext/GBQExt.jl | 83 +++++++++++++++++++++++++++++++++++++++----- src/TidierDB.jl | 6 ++-- src/docstrings.jl | 2 +- src/joins_sq.jl | 2 +- src/parsing_gbq.jl | 55 ++--------------------------- src/parsing_mssql.jl | 7 ++++ src/structs.jl | 7 +++- 9 files changed, 102 insertions(+), 67 deletions(-) diff --git a/NEWS.md b/NEWS.md index e472beb..ff88489 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,9 @@ # TidierDB.jl updates -## v0.4.2 - 2024-10-04 -- add `dmy`, `mdy`, `ymd` support for most backends +## v0.4.2 - 2024-10-08 +- add `dmy`, `mdy`, `ymd` support DuckDB, Postgres, GBQ, Clickhouse, MySQL, MsSQL, Athena - add date related tests +- improve Google Big Query type mapping when collecting df +- change `gbq()`'s `connect()` to accept `location` as second argument ## v0.4.1 - 2024-10-02 - Adds 50 tests comparing TidierDB to TidierData to assure accuracy across a complex chains of operations, including combinations of `@mutate`, `@summarize`, `@filter`, `@select`, `@group_by` and `@join` operations. diff --git a/Project.toml b/Project.toml index 4939257..d2f19d9 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ AWS = "fbe9abb3-538b-5e4e-ba9e-bc94f4f92ebc" MySQL = "39abe10b-433b-5dbd-92d4-e302a9df00cd" ClickHouse = "82f2e89e-b495-11e9-1d9d-fb40d7cf2130" ODBC = "be6f12e9-ca4f-5eb2-a339-a4f995cc0291" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" [extensions] diff --git a/ext/GBQExt.jl b/ext/GBQExt.jl index 6714043..5848cfa 100644 --- a/ext/GBQExt.jl +++ b/ext/GBQExt.jl @@ -2,7 +2,7 @@ module GBQExt using TidierDB using DataFrames -using GoogleCloud, HTTP, JSON3 +using GoogleCloud, HTTP, JSON3, Dates __init__() = println("Extension was loaded!") mutable struct GBQ @@ -10,14 +10,75 @@ mutable struct GBQ session::GoogleSession bigquery_resource bigquery_method + location::String end -function TidierDB.connect(::gbq, json_key_path::String, project_id::String) + +function apply_type_conversion_gbq(df, col_index, col_type) + +end + +function parse_gbq_df(df, column_types) + for (i, col_type) in enumerate(column_types) + # Check if column index is within bounds of DataFrame columns + if i <= size(df, 2) + try + apply_type_conversion_gbq(df, i, col_type) + catch e + # @warn "Failed to convert column $(i) to $(col_type): $e" + end + else + # @warn "Column index $(i) is out of bounds for the current DataFrame." + end + end; + return df +end + +type_map = Dict( + "STRING" => String, + "FLOAT" => Float64, + "INTEGER" => Int64, + "DATE" => Date, + "DATETIME" => DateTime, + "ARRAY" => Array, + "STRUCT" => Struct +) + +function convert_df_types!(df::DataFrame, new_names::Vector{String}, new_types::Vector{String}) + for (name, type_str) in zip(new_names, new_types) + if haskey(type_map, type_str) + # Get the corresponding Julia type + target_type = type_map[type_str] + + # Check if the DataFrame has the column + if hasproperty(df, name) + # Convert the column to the target type + if target_type == Float64 + df[!, name] = [x === nothing || ismissing(x) ? missing : parse(Float64, x) for x in df[!, name]] + elseif target_type == Int64 + df[!, name] = [x === nothing || ismissing(x) ? missing : parse(Int64, x) for x in df[!, name]] + elseif target_type == Date + df[!, name] = [x === nothing || ismissing(x) ? missing : Date(x) for x in df[!, name]] + else + df[!, name] = convert.(target_type, df[!, name]) + end + else + println("Warning: Column $name not found in DataFrame.") + end + else + println("Warning: Type $type_str is not recognized.") + end + end + return df +end + +function TidierDB.connect(::gbq, json_key_path::String, location::String) # Expand the user's path to the JSON key creds_path = expanduser(json_key_path) set_sql_mode(gbq()) # Create credentials and session for Google Cloud creds = JSONCredentials(creds_path) + project_id = JSONCredentials(creds_path).project_id session = GoogleSession(creds, ["https://www.googleapis.com/auth/bigquery"]) # Define the API method for BigQuery @@ -36,7 +97,7 @@ function TidierDB.connect(::gbq, json_key_path::String, project_id::String) ) # Store all data in a global GBQ instance - global gbq_instance = GBQ(project_id, session, bigquery_resource, bigquery_method) + global gbq_instance = GBQ(project_id, session, bigquery_resource, bigquery_method, location) # Return only the session return session @@ -47,7 +108,7 @@ function collect_gbq(conn, query) query_data = Dict( "query" => query, "useLegacySql" => false, - "location" => "US") + "location" => gbq_instance.location) response = GoogleCloud.api.execute( conn, @@ -62,23 +123,29 @@ function collect_gbq(conn, query) # Convert rows to DataFrame # First, extract column names from the schema column_names = [field["name"] for field in response_data["schema"]["fields"]] + # println(column_names) column_types = [field["type"] for field in response_data["schema"]["fields"]] + # println(column_types) # Then, convert each row's data (currently nested inside dicts with key "v") into arrays of dicts if !isempty(rows) # Return an empty DataFrame with the correct columns but 0 rows data = [get(row["f"][i], "v", missing) for row in rows, i in 1:length(column_names)] df = DataFrame(data, Symbol.(column_names)) - df = TidierDB.parse_gbq_df(df, column_types) + # df = TidierDB.parse_gbq_df(df, column_types) + convert_df_types!(df, column_names, column_types) + return df else # Convert each row's data (nested inside dicts with key "v") into arrays of dicts - df =DataFrame([Vector{Union{Missing, Any}}(undef, 0) for _ in column_names], Symbol.(column_names)) - df = TidierDB.parse_gbq_df(df, column_types) + df = DataFrame([Vector{Union{Missing, Any}}(undef, 0) for _ in column_names], Symbol.(column_names)) + # df = TidierDB.parse_gbq_df(df, column_types) + convert_df_types!(df, column_names, column_types) return df end return df end + function TidierDB.get_table_metadata(conn::GoogleSession{JSONCredentials}, table_name::String) query = " SELECT * FROM $table_name LIMIT 0 @@ -86,7 +153,7 @@ function TidierDB.get_table_metadata(conn::GoogleSession{JSONCredentials}, table query_data = Dict( "query" => query, "useLegacySql" => false, - "location" => "US") + "location" => gbq_instance.location) # Define the API resource response = GoogleCloud.api.execute( diff --git a/src/TidierDB.jl b/src/TidierDB.jl index 516b2c3..b00999e 100644 --- a/src/TidierDB.jl +++ b/src/TidierDB.jl @@ -162,13 +162,15 @@ function finalize_query(sqlquery::SQLQuery) "FROM )" => ")" , "SELECT SELECT " => "SELECT ", "SELECT SELECT " => "SELECT ", "DISTINCT SELECT " => "DISTINCT ", "SELECT SELECT SELECT " => "SELECT ", "PARTITION BY GROUP BY" => "PARTITION BY", "GROUP BY GROUP BY" => "GROUP BY", "HAVING HAVING" => "HAVING", r"var\"(.*?)\"" => s"\1", r"\"\\\$" => "\"\$", "WHERE \"" => "WHERE ", "WHERE \"NOT" => "WHERE NOT", "%')\"" =>"%\")", "NULL)\"" => "NULL)", - "NULL))\"" => "NULL))", r"(?i)INTERVAL(\d)" => s"INTERVAL \1" + "NULL))\"" => "NULL))", r"(?i)INTERVAL(\d+)([a-zA-Z]+)" => s"INTERVAL \1 \2" ) complete_query = replace(complete_query, ", AS " => " AS ", "OR \"" => "OR ") if current_sql_mode[] == postgres() || current_sql_mode[] == duckdb() || current_sql_mode[] == mysql() || current_sql_mode[] == mssql() || current_sql_mode[] == clickhouse() || current_sql_mode[] == athena() || current_sql_mode[] == gbq() || current_sql_mode[] == oracle() || current_sql_mode[] == snowflake() || current_sql_mode[] == databricks() complete_query = replace(complete_query, "\"" => "'", "==" => "=") end - + if current_sql_mode[] == postgres() + complete_query = replace(complete_query, r"INTERVAL (\d+) ([a-zA-Z]+)" => s"INTERVAL '\1 \2'") + end return complete_query end diff --git a/src/docstrings.jl b/src/docstrings.jl index 6d27956..2f23ab9 100644 --- a/src/docstrings.jl +++ b/src/docstrings.jl @@ -1068,7 +1068,7 @@ This function establishes a database connection based on the specified backend a # Connect to SQLite # conn = connect(sqlite()) # Connect to Google Big Query -# conn = connect(gbq(), "json_user_key_path", "project_id") +# conn = connect(gbq(), "json_user_key_path", "location") # Connect to Snowflake # conn = connect(snowflake(), "ac_id", "token", "Database_name", "Schema_name", "warehouse_name") # Connect to Microsoft SQL Server diff --git a/src/joins_sq.jl b/src/joins_sq.jl index 7545042..c9b7e08 100644 --- a/src/joins_sq.jl +++ b/src/joins_sq.jl @@ -4,7 +4,7 @@ function gbq_join_parse(input) input = string(input) parts = split(input, ".") if current_sql_mode[] == gbq() && length(parts) >=2 - return join(parts[2:end], ".") + return parts[end] elseif occursin(".", input) return split(input, '.')[end] else diff --git a/src/parsing_gbq.jl b/src/parsing_gbq.jl index 8b5668f..5ae9c7e 100644 --- a/src/parsing_gbq.jl +++ b/src/parsing_gbq.jl @@ -1,32 +1,3 @@ - - -function apply_type_conversion_gbq(df, col_index, col_type) - if col_type == "FLOAT" - df[!, col_index] = [ismissing(x) ? missing : parse(Float64, x) for x in df[!, col_index]] - elseif col_type == "INTEGER" - df[!, col_index] = [ismissing(x) ? missing : parse(Int, x) for x in df[!, col_index]] - elseif col_type == "STRING" - # Assuming varchar needs to stay as String, no conversion needed - end -end - -function parse_gbq_df(df, column_types) - for (i, col_type) in enumerate(column_types) - # Check if column index is within bounds of DataFrame columns - if i <= size(df, 2) - try - apply_type_conversion_gbq(df, i, col_type) - catch e - # @warn "Failed to convert column $(i) to $(col_type): $e" - end - else - # @warn "Column index $(i) is out of bounds for the current DataFrame." - end - end; - return df -end - - function expr_to_sql_gbq(expr, sq; from_summarize::Bool) # expr = parse_char_matching(expr) expr = exc_capture_bug(expr, names_to_modify) @@ -156,11 +127,11 @@ function expr_to_sql_gbq(expr, sq; from_summarize::Bool) elseif @capture(x, second(a_)) return "EXTRACT(SECOND FROM " * string(a) * ")" elseif @capture(x, ymd(time_)) - return :(PARSE_DATE($time, "%Y-%m-%d")) + return :(PARSE_DATE("%Y-%m-%d", $time)) elseif @capture(x, mdy(time_)) - return :(PARSE_DATE($time, "%m-%d-%Y")) + return :(PARSE_DATE("%m-%d-%Y", $time)) elseif @capture(x, dmy(time_)) - return :(PARSE_DATE($time, "%d-%m-%Y")) + return :(PARSE_DATE("%d-%m-%Y", $time)) elseif @capture(x, floordate(time_column_, unit_)) return :(DATE_TRUNC($unit, $time_column)) elseif @capture(x, difftime(endtime_, starttime_, unit_)) @@ -195,24 +166,4 @@ function expr_to_sql_gbq(expr, sq; from_summarize::Bool) end return x end -end - - - - -function process_column(input::String) - if current_sql_mode == :gbq - return join_gbq_parse(input, full=false) - else - return input - end -end - -function join_gbq_parse(input_str::String; full::Bool = true) - parts = split(input_str, ".") - if full - return input_str - else - return parts[end] - end end \ No newline at end of file diff --git a/src/parsing_mssql.jl b/src/parsing_mssql.jl index 4f01641..3225fc9 100644 --- a/src/parsing_mssql.jl +++ b/src/parsing_mssql.jl @@ -126,6 +126,13 @@ function expr_to_sql_mssql(expr, sq; from_summarize::Bool) return "DATEPART(MINUTE FROM " * string(a) * ")" elseif @capture(x, second(a_)) return "DATEPART(SECOND FROM " * string(a) * ")" + # https://www.mssqltips.com/sqlservertip/1145/date-and-time-conversions-using-sql-server/ + elseif @capture(x, ymd(time_column_)) + return :(convert(varchar, time_column, 23)) + elseif @capture(x, dmy(time_column_)) + return :(convert(varchar, time_column, 105)) + elseif @capture(x, mdy(time_column_)) + return :(convert(varchar, time_column, 10)) elseif @capture(x, floordate(time_column_, unit_)) return floordate_to_mssql(unit, time_column) elseif @capture(x, difftime(endtime_, starttime_, unit_)) diff --git a/src/structs.jl b/src/structs.jl index 9ba4df3..ab6b8a5 100644 --- a/src/structs.jl +++ b/src/structs.jl @@ -57,6 +57,7 @@ function add_interp_parameter!(name::Symbol, value::Any) add_interp_parameter2!(name, value) end +@reexport TidierDB """ $docstring_interpolate """ @@ -68,7 +69,11 @@ macro interpolate( args...) end name, value = arg.args quoted_name = QuoteNode(name) - push!(exprs, :(esc(add_interp_parameter!(Symbol($quoted_name), $((value)))))) + # try + push!(exprs, :(esc(add_interp_parameter!(Symbol($quoted_name), $((value)))))) + # catch e + # push!(exprs, :(esc(DB.add_interp_parameter!(Symbol($quoted_name), $((value)))))) + # end end return esc(Expr(:block, exprs...)) end From c7eaa3e52768eaf5170d2c4294f1a28454152941 Mon Sep 17 00:00:00 2001 From: Daniel Rizk Date: Mon, 7 Oct 2024 21:37:20 -0400 Subject: [PATCH 07/10] whoops --- src/structs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/structs.jl b/src/structs.jl index ab6b8a5..75d7d39 100644 --- a/src/structs.jl +++ b/src/structs.jl @@ -57,7 +57,7 @@ function add_interp_parameter!(name::Symbol, value::Any) add_interp_parameter2!(name, value) end -@reexport TidierDB + """ $docstring_interpolate """ From 49c086c1ac7aa45c94e714d07657b36882ef4554 Mon Sep 17 00:00:00 2001 From: Daniel Rizk Date: Mon, 7 Oct 2024 21:49:30 -0400 Subject: [PATCH 08/10] try adding dates to docs toml --- test/Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 87375f0..7eed131 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,4 +4,5 @@ TidierDB = "86993f9b-bbba-4084-97c5-ee15961ad48b" TidierStrings = "248e6834-d0f8-40ef-8fbb-8e711d883e9c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -TidierDates = "20186a3f-b5d3-468e-823e-77aae96fe2d8" \ No newline at end of file +TidierDates = "20186a3f-b5d3-468e-823e-77aae96fe2d8" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" \ No newline at end of file From 274cac832e7bc86a43fd2e00ad95bd8863e4680e Mon Sep 17 00:00:00 2001 From: Daniel Rizk Date: Thu, 10 Oct 2024 08:26:11 -0400 Subject: [PATCH 09/10] further improve gbq parsing to support arrays and json cols --- ext/GBQExt.jl | 70 ++++---------------------------------- ext/GBQ_to_DF.jl | 87 +++++++++++++++++++++++++++++++++++++++++++++++ src/TBD_macros.jl | 2 ++ src/TidierDB.jl | 2 +- src/docstrings.jl | 3 +- 5 files changed, 97 insertions(+), 67 deletions(-) create mode 100644 ext/GBQ_to_DF.jl diff --git a/ext/GBQExt.jl b/ext/GBQExt.jl index 5848cfa..2ff9392 100644 --- a/ext/GBQExt.jl +++ b/ext/GBQExt.jl @@ -5,6 +5,8 @@ using DataFrames using GoogleCloud, HTTP, JSON3, Dates __init__() = println("Extension was loaded!") +include("GBQ_to_DF.jl") + mutable struct GBQ projectname::String session::GoogleSession @@ -13,65 +15,6 @@ mutable struct GBQ location::String end - -function apply_type_conversion_gbq(df, col_index, col_type) - -end - -function parse_gbq_df(df, column_types) - for (i, col_type) in enumerate(column_types) - # Check if column index is within bounds of DataFrame columns - if i <= size(df, 2) - try - apply_type_conversion_gbq(df, i, col_type) - catch e - # @warn "Failed to convert column $(i) to $(col_type): $e" - end - else - # @warn "Column index $(i) is out of bounds for the current DataFrame." - end - end; - return df -end - -type_map = Dict( - "STRING" => String, - "FLOAT" => Float64, - "INTEGER" => Int64, - "DATE" => Date, - "DATETIME" => DateTime, - "ARRAY" => Array, - "STRUCT" => Struct -) - -function convert_df_types!(df::DataFrame, new_names::Vector{String}, new_types::Vector{String}) - for (name, type_str) in zip(new_names, new_types) - if haskey(type_map, type_str) - # Get the corresponding Julia type - target_type = type_map[type_str] - - # Check if the DataFrame has the column - if hasproperty(df, name) - # Convert the column to the target type - if target_type == Float64 - df[!, name] = [x === nothing || ismissing(x) ? missing : parse(Float64, x) for x in df[!, name]] - elseif target_type == Int64 - df[!, name] = [x === nothing || ismissing(x) ? missing : parse(Int64, x) for x in df[!, name]] - elseif target_type == Date - df[!, name] = [x === nothing || ismissing(x) ? missing : Date(x) for x in df[!, name]] - else - df[!, name] = convert.(target_type, df[!, name]) - end - else - println("Warning: Column $name not found in DataFrame.") - end - else - println("Warning: Type $type_str is not recognized.") - end - end - return df -end - function TidierDB.connect(::gbq, json_key_path::String, location::String) # Expand the user's path to the JSON key creds_path = expanduser(json_key_path) @@ -123,15 +66,12 @@ function collect_gbq(conn, query) # Convert rows to DataFrame # First, extract column names from the schema column_names = [field["name"] for field in response_data["schema"]["fields"]] - # println(column_names) column_types = [field["type"] for field in response_data["schema"]["fields"]] - # println(column_types) - # Then, convert each row's data (currently nested inside dicts with key "v") into arrays of dicts + if !isempty(rows) # Return an empty DataFrame with the correct columns but 0 rows data = [get(row["f"][i], "v", missing) for row in rows, i in 1:length(column_names)] df = DataFrame(data, Symbol.(column_names)) - # df = TidierDB.parse_gbq_df(df, column_types) convert_df_types!(df, column_names, column_types) return df @@ -147,6 +87,7 @@ function collect_gbq(conn, query) end function TidierDB.get_table_metadata(conn::GoogleSession{JSONCredentials}, table_name::String) + set_sql_mode(gbq()); query = " SELECT * FROM $table_name LIMIT 0 ;" @@ -179,7 +120,8 @@ function TidierDB.final_collect(sqlquery::SQLQuery, ::Type{<:gbq}) return collect_gbq(sqlquery.db, final_query) end -function TidierDB.show_tables(con::GoogleSession{JSONCredentials}, project_id, datasetname) +function TidierDB.show_tables(con::GoogleSession{JSONCredentials}, datasetname) + project_id = gbq_instance.projectname query = """ SELECT table_name FROM `$project_id.$datasetname.INFORMATION_SCHEMA.TABLES` diff --git a/ext/GBQ_to_DF.jl b/ext/GBQ_to_DF.jl new file mode 100644 index 0000000..a219e67 --- /dev/null +++ b/ext/GBQ_to_DF.jl @@ -0,0 +1,87 @@ + +type_map = Dict( + "STRING" => String, + "FLOAT" => Float64, + "INTEGER" => Int64, + "DATE" => Date, + "DATETIME" => DateTime, + "BOOLEAN" => Bool, + "JSON" => Any # Map JSON to Any +) + +# Function to get Julia type from BigQuery type string +function get_julia_type(type_str::String) + if startswith(type_str, "ARRAY<") && endswith(type_str, ">") + element_type_str = type_str[7:end-1] + element_type = get(type_map, element_type_str, Any) + return Vector{element_type} + else + return get(type_map, type_str, Any) + end +end + +# Helper function to parse scalar values +function parse_scalar_value(x, target_type; type_str="") + if target_type == Date + return Date(x) + elseif target_type == DateTime + return DateTime(x) + elseif target_type == String + return String(x) + elseif target_type <: Number + return parse(target_type, x) + elseif target_type == Bool + return x in ("true", "1", 1, true) + elseif type_str == "JSON" + try + # Ensure x is a String or Vector{UInt8} + if isa(x, AbstractString) || isa(x, Vector{UInt8}) + return JSON3.read(x) + else + # Convert x to String if possible + x_str = String(x) + return JSON3.read(x_str) + end + catch e + println("Failed to parse JSON value '$x' of type $(typeof(x)): ", e) + return missing + end + else + return convert(target_type, x) + end +end + + +# Helper function to parse array elements +function parse_array_elements(x::JSON3.Array, target_type) + element_type = eltype(target_type) + return [parse_scalar_value(v["v"], element_type) for v in x] +end + +function convert_df_types!(df::DataFrame, new_names::Vector{String}, new_types::Vector{String}) + for (name, type_str) in zip(new_names, new_types) + # Get the corresponding Julia type + target_type = get_julia_type(type_str) + + # Check if the DataFrame has the column + if hasproperty(df, name) + # Get the column data + column_data = df[!, name] + + # Replace `nothing` with `missing` + column_data = replace(column_data, nothing => missing) + + # Check if the data is an array of values + if !isempty(column_data) && isa(column_data[1], JSON3.Array) + # Handle arrays + df[!, name] = [ismissing(x) ? missing : parse_array_elements(x, target_type) for x in column_data] + else + # Handle scalar values + df[!, name] = [ismissing(x) ? missing : parse_scalar_value(x, target_type; type_str=type_str) for x in column_data] + end + else + println("Warning: Column $name not found in DataFrame.") + end + end + return df +end \ No newline at end of file diff --git a/src/TBD_macros.jl b/src/TBD_macros.jl index a485649..fc75cad 100644 --- a/src/TBD_macros.jl +++ b/src/TBD_macros.jl @@ -424,6 +424,8 @@ macro summarize(sqlquery, expressions...) # Check if there's already a SELECT clause and append, otherwise create new if startswith(existing_select, "SELECT") sq.select = existing_select * ", " * summary_clause + elseif isempty(summary_clause) + sq.select = "SUMMARIZE " else sq.select = "SELECT " * summary_clause end diff --git a/src/TidierDB.jl b/src/TidierDB.jl index b00999e..e27fcc5 100644 --- a/src/TidierDB.jl +++ b/src/TidierDB.jl @@ -162,7 +162,7 @@ function finalize_query(sqlquery::SQLQuery) "FROM )" => ")" , "SELECT SELECT " => "SELECT ", "SELECT SELECT " => "SELECT ", "DISTINCT SELECT " => "DISTINCT ", "SELECT SELECT SELECT " => "SELECT ", "PARTITION BY GROUP BY" => "PARTITION BY", "GROUP BY GROUP BY" => "GROUP BY", "HAVING HAVING" => "HAVING", r"var\"(.*?)\"" => s"\1", r"\"\\\$" => "\"\$", "WHERE \"" => "WHERE ", "WHERE \"NOT" => "WHERE NOT", "%')\"" =>"%\")", "NULL)\"" => "NULL)", - "NULL))\"" => "NULL))", r"(?i)INTERVAL(\d+)([a-zA-Z]+)" => s"INTERVAL \1 \2" + "NULL))\"" => "NULL))", r"(?i)INTERVAL(\d+)([a-zA-Z]+)" => s"INTERVAL \1 \2", "SELECT SUMMARIZE " => "SUMMARIZE " ) complete_query = replace(complete_query, ", AS " => " AS ", "OR \"" => "OR ") if current_sql_mode[] == postgres() || current_sql_mode[] == duckdb() || current_sql_mode[] == mysql() || current_sql_mode[] == mssql() || current_sql_mode[] == clickhouse() || current_sql_mode[] == athena() || current_sql_mode[] == gbq() || current_sql_mode[] == oracle() || current_sql_mode[] == snowflake() || current_sql_mode[] == databricks() diff --git a/src/docstrings.jl b/src/docstrings.jl index 2f23ab9..20943dd 100644 --- a/src/docstrings.jl +++ b/src/docstrings.jl @@ -1255,13 +1255,12 @@ julia> @chain db_table(db, :df_mem) begin const docstring_show_tables = """ - show_tables(con; GBQ_project_id, GBQ_datasetname) + show_tables(con; GBQ_datasetname) Shows tables available in database. currently supports DuckDB, databricks, Snowflake, GBQ, SQLite, LibPQ # Arguments - `con` : connection to backend -- `GBQ_project_id` : string of project id - `GBQ_datasetname` : string of dataset name # Examples ```jldoctest From f76bb860dbcfb19167d0268f35c8f83aa5b05079 Mon Sep 17 00:00:00 2001 From: Daniel Rizk Date: Thu, 10 Oct 2024 17:25:55 -0400 Subject: [PATCH 10/10] add support for mult tdb queries in chains --- NEWS.md | 4 +- src/joins_sq.jl | 113 +++++++++++++++++++++++++++------------------ src/structs.jl | 5 +- test/comp_tests.jl | 13 +++++- test/runtests.jl | 6 +++ 5 files changed, 90 insertions(+), 51 deletions(-) diff --git a/NEWS.md b/NEWS.md index ff88489..4efb544 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,9 +1,11 @@ # TidierDB.jl updates ## v0.4.2 - 2024-10-08 -- add `dmy`, `mdy`, `ymd` support DuckDB, Postgres, GBQ, Clickhouse, MySQL, MsSQL, Athena +- add support for performing greater than 2 joins using TidierDB queries in a single chain and additional tests +- add `dmy`, `mdy`, `ymd` support DuckDB, Postgres, GBQ, Clickhouse, MySQL, MsSQL, Athena, MsSQL - add date related tests - improve Google Big Query type mapping when collecting df - change `gbq()`'s `connect()` to accept `location` as second argument +- adds `copy_to` for MsSQL to write dataframe to database ## v0.4.1 - 2024-10-02 - Adds 50 tests comparing TidierDB to TidierData to assure accuracy across a complex chains of operations, including combinations of `@mutate`, `@summarize`, `@filter`, `@select`, `@group_by` and `@join` operations. diff --git a/src/joins_sq.jl b/src/joins_sq.jl index c9b7e08..e452c29 100644 --- a/src/joins_sq.jl +++ b/src/joins_sq.jl @@ -96,14 +96,16 @@ macro left_join(sqlquery, join_table, lhs_column, rhs_column) if isa(jq, SQLQuery) jq.cte_count += 1 # Handle when join_table is an SQLQuery + sq.join_count += 1 needs_new_cte_jq = !isempty(jq.select) || !isempty(jq.where) || jq.is_aggregated || !isempty(jq.ctes) if needs_new_cte_jq + joinc = repeat("j", sq.join_count) for cte in jq.ctes - cte.name = "j" * cte.name + cte.name = joinc * cte.name end - cte_name_jq = "jcte_" * string(jq.cte_count) - most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from + cte_name_jq = joinc* "cte_" * string(jq.cte_count) + most_recent_source_jq = !isempty(jq.ctes) ? joinc * "cte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) @@ -142,14 +144,15 @@ macro left_join(sqlquery, join_table, lhs_column, rhs_column) if isa(jq, SQLQuery) # Handle when join_table is an SQLQuery needs_new_cte_jq = !isempty(jq.select) || !isempty(jq.where) || jq.is_aggregated || !isempty(jq.ctes) - + sq.join_count += 1 if needs_new_cte_jq + joinc = repeat("j", sq.join_count) for cte in jq.ctes - cte.name = "j" * cte.name + cte.name = joinc * cte.name end jq.cte_count += 1 - cte_name_jq = "jcte_" * string(jq.cte_count) # - most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from + cte_name_jq = joinc * "cte_" * string(jq.cte_count) # + most_recent_source_jq = !isempty(jq.ctes) ? joinc * "cte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) @@ -209,14 +212,16 @@ macro right_join(sqlquery, join_table, lhs_column, rhs_column) if isa(jq, SQLQuery) jq.cte_count += 1 # Handle when join_table is an SQLQuery + sq.join_count += 1 needs_new_cte_jq = !isempty(jq.select) || !isempty(jq.where) || jq.is_aggregated || !isempty(jq.ctes) if needs_new_cte_jq + joinc = repeat("j", sq.join_count) for cte in jq.ctes - cte.name = "j" * cte.name + cte.name = joinc * cte.name end - cte_name_jq = "jcte_" * string(jq.cte_count) - most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from + cte_name_jq = joinc* "cte_" * string(jq.cte_count) + most_recent_source_jq = !isempty(jq.ctes) ? joinc * "cte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) @@ -254,14 +259,15 @@ macro right_join(sqlquery, join_table, lhs_column, rhs_column) if isa(jq, SQLQuery) # Handle when join_table is an SQLQuery needs_new_cte_jq = !isempty(jq.select) || !isempty(jq.where) || jq.is_aggregated || !isempty(jq.ctes) - + sq.join_count += 1 if needs_new_cte_jq + joinc = repeat("j", sq.join_count) for cte in jq.ctes - cte.name = "j" * cte.name + cte.name = joinc * cte.name end jq.cte_count += 1 - cte_name_jq = "jcte_" * string(jq.cte_count) # - most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from + cte_name_jq = joinc * "cte_" * string(jq.cte_count) # + most_recent_source_jq = !isempty(jq.ctes) ? joinc * "cte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) @@ -320,14 +326,16 @@ macro inner_join(sqlquery, join_table, lhs_column, rhs_column) if isa(jq, SQLQuery) jq.cte_count += 1 # Handle when join_table is an SQLQuery + sq.join_count += 1 needs_new_cte_jq = !isempty(jq.select) || !isempty(jq.where) || jq.is_aggregated || !isempty(jq.ctes) if needs_new_cte_jq + joinc = repeat("j", sq.join_count) for cte in jq.ctes - cte.name = "j" * cte.name + cte.name = joinc * cte.name end - cte_name_jq = "jcte_" * string(jq.cte_count) - most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from + cte_name_jq = joinc* "cte_" * string(jq.cte_count) + most_recent_source_jq = !isempty(jq.ctes) ? joinc * "cte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) @@ -365,14 +373,15 @@ macro inner_join(sqlquery, join_table, lhs_column, rhs_column) if isa(jq, SQLQuery) # Handle when join_table is an SQLQuery needs_new_cte_jq = !isempty(jq.select) || !isempty(jq.where) || jq.is_aggregated || !isempty(jq.ctes) - + sq.join_count += 1 if needs_new_cte_jq + joinc = repeat("j", sq.join_count) for cte in jq.ctes - cte.name = "j" * cte.name + cte.name = joinc * cte.name end jq.cte_count += 1 - cte_name_jq = "jcte_" * string(jq.cte_count) # - most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from + cte_name_jq = joinc * "cte_" * string(jq.cte_count) # + most_recent_source_jq = !isempty(jq.ctes) ? joinc * "cte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) @@ -431,13 +440,16 @@ macro full_join(sqlquery, join_table, lhs_column, rhs_column) if isa(jq, SQLQuery) jq.cte_count += 1 # Handle when join_table is an SQLQuery + sq.join_count += 1 needs_new_cte_jq = !isempty(jq.select) || !isempty(jq.where) || jq.is_aggregated || !isempty(jq.ctes) if needs_new_cte_jq + joinc = repeat("j", sq.join_count) for cte in jq.ctes - cte.name = "j" * cte.name - end - cte_name_jq = "jcte_" * string(jq.cte_count) - most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from + cte.name = joinc * cte.name + end + + cte_name_jq = joinc* "cte_" * string(jq.cte_count) + most_recent_source_jq = !isempty(jq.ctes) ? joinc * "cte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) @@ -476,14 +488,15 @@ macro full_join(sqlquery, join_table, lhs_column, rhs_column) if isa(jq, SQLQuery) # Handle when join_table is an SQLQuery needs_new_cte_jq = !isempty(jq.select) || !isempty(jq.where) || jq.is_aggregated || !isempty(jq.ctes) - + sq.join_count += 1 if needs_new_cte_jq + joinc = repeat("j", sq.join_count) for cte in jq.ctes - cte.name = "j" * cte.name + cte.name = joinc * cte.name end jq.cte_count += 1 - cte_name_jq = "jcte_" * string(jq.cte_count) # - most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from + cte_name_jq = joinc * "cte_" * string(jq.cte_count) # + most_recent_source_jq = !isempty(jq.ctes) ? joinc * "cte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) @@ -544,14 +557,16 @@ macro semi_join(sqlquery, join_table, lhs_column, rhs_column) if isa(jq, SQLQuery) jq.cte_count += 1 # Handle when join_table is an SQLQuery + sq.join_count += 1 needs_new_cte_jq = !isempty(jq.select) || !isempty(jq.where) || jq.is_aggregated || !isempty(jq.ctes) if needs_new_cte_jq + joinc = repeat("j", sq.join_count) for cte in jq.ctes - cte.name = "j" * cte.name + cte.name = joinc * cte.name end - cte_name_jq = "jcte_" * string(jq.cte_count) - most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from + cte_name_jq = joinc* "cte_" * string(jq.cte_count) + most_recent_source_jq = !isempty(jq.ctes) ? joinc * "cte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) @@ -589,14 +604,15 @@ macro semi_join(sqlquery, join_table, lhs_column, rhs_column) if isa(jq, SQLQuery) # Handle when join_table is an SQLQuery needs_new_cte_jq = !isempty(jq.select) || !isempty(jq.where) || jq.is_aggregated || !isempty(jq.ctes) - + sq.join_count += 1 if needs_new_cte_jq + joinc = repeat("j", sq.join_count) for cte in jq.ctes - cte.name = "j" * cte.name + cte.name = joinc * cte.name end jq.cte_count += 1 - cte_name_jq = "jcte_" * string(jq.cte_count) # - most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from + cte_name_jq = joinc * "cte_" * string(jq.cte_count) # + most_recent_source_jq = !isempty(jq.ctes) ? joinc * "cte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) @@ -655,14 +671,16 @@ macro anti_join(sqlquery, join_table, lhs_column, rhs_column) if isa(jq, SQLQuery) jq.cte_count += 1 # Handle when join_table is an SQLQuery + sq.join_count += 1 needs_new_cte_jq = !isempty(jq.select) || !isempty(jq.where) || jq.is_aggregated || !isempty(jq.ctes) if needs_new_cte_jq + joinc = repeat("j", sq.join_count) for cte in jq.ctes - cte.name = "j" * cte.name + cte.name = joinc * cte.name end - cte_name_jq = "jcte_" * string(jq.cte_count) - most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from + cte_name_jq = joinc* "cte_" * string(jq.cte_count) + most_recent_source_jq = !isempty(jq.ctes) ? joinc * "cte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) @@ -700,14 +718,15 @@ macro anti_join(sqlquery, join_table, lhs_column, rhs_column) if isa(jq, SQLQuery) # Handle when join_table is an SQLQuery needs_new_cte_jq = !isempty(jq.select) || !isempty(jq.where) || jq.is_aggregated || !isempty(jq.ctes) - + sq.join_count += 1 if needs_new_cte_jq + joinc = repeat("j", sq.join_count) for cte in jq.ctes - cte.name = "j" * cte.name + cte.name = joinc * cte.name end jq.cte_count += 1 - cte_name_jq = "jcte_" * string(jq.cte_count) # - most_recent_source_jq = !isempty(jq.ctes) ? "jcte_" * string(jq.cte_count - 1) : jq.from + cte_name_jq = joinc * "cte_" * string(jq.cte_count) # + most_recent_source_jq = !isempty(jq.ctes) ? joinc * "cte_" * string(jq.cte_count - 1) : jq.from select_sql_jq = finalize_query_jq(jq, most_recent_source_jq) new_cte_jq = CTE(name=cte_name_jq, select=select_sql_jq) push!(jq.ctes, new_cte_jq) @@ -769,12 +788,14 @@ macro union(sqlquery, union_query) # Determine if uq needs a new CTE needs_new_cte_uq = !isempty(uq.select) || !isempty(uq.where) || uq.is_aggregated || !isempty(uq.ctes) if needs_new_cte_uq + sq.join_count +=1 + joinc = repeat("j", sq.join_count) for cte in uq.ctes - cte.name = "j" * cte.name + cte.name = joinc * cte.name end uq.cte_count += 1 - cte_name_uq = "jcte_" * string(uq.cte_count) - most_recent_source_uq = !isempty(uq.ctes) ? "jcte_" * string(uq.cte_count - 1) : uq.from + cte_name_uq = joinc * "cte_" * string(uq.cte_count) + most_recent_source_uq = !isempty(uq.ctes) ? joinc * "cte_" * string(uq.cte_count - 1) : uq.from select_sql_uq = finalize_query_jq(uq, most_recent_source_uq) new_cte_uq = CTE(name=cte_name_uq, select=select_sql_uq) push!(uq.ctes, new_cte_uq) diff --git a/src/structs.jl b/src/structs.jl index 75d7d39..0419a02 100644 --- a/src/structs.jl +++ b/src/structs.jl @@ -29,13 +29,14 @@ mutable struct SQLQuery athena_params::Any limit::String ch_settings::String + join_count::Int function SQLQuery(;select::String="", from::String="", where::String="", groupBy::String="", orderBy::String="", having::String="", window_order::String="", windowFrame::String="", is_aggregated::Bool=false, post_aggregation::Bool=false, metadata::DataFrame=DataFrame(), distinct::Bool=false, db::Any=nothing, ctes::Vector{CTE}=Vector{CTE}(), cte_count::Int=0, athena_params::Any=nothing, limit::String="", - ch_settings::String="") + ch_settings::String="", join_count::Int = 0) new(select, from, where, groupBy, orderBy, having, window_order, windowFrame, is_aggregated, - post_aggregation, metadata, distinct, db, ctes, cte_count, athena_params, limit, ch_settings) + post_aggregation, metadata, distinct, db, ctes, cte_count, athena_params, limit, ch_settings, join_count) end end diff --git a/test/comp_tests.jl b/test/comp_tests.jl index 2a849c7..fdc3760 100644 --- a/test/comp_tests.jl +++ b/test/comp_tests.jl @@ -150,9 +150,16 @@ TDB_11 = @chain DB.t(test_db) DB.@full_join(DB.t(query), id2, id) DB.@select(!id2) DB.@mutate(score = replace_missing(score, 0)) DB.@collect TDF_12 = @chain test_df @mutate(value = value * 2, new_col = (value + percent)/2) TDB_12 = @chain DB.t(test_db) DB.@mutate(value = value * 2, new_col = (value + percent)/2) DB.@collect + # test a joining multiple TidierDB queries in one chain + TDF_13 = @chain test_df @left_join(@filter(df2, score > 85 && str_detect(id2, "C")), id = id2) @mutate(score = replace_missing(score, 0)) @left_join((@chain df3 @filter(value2 != 20)), id = id3) + TDB_13 = @chain DB.t(test_db) DB.@left_join(DB.t(query), id2, id) DB.@mutate(score = replace_missing(score, 0)) DB.@left_join((@chain DB.t(join_db2) DB.@filter(value2 != 20)), id3, id) DB.@select(!id2, !id3) DB.@collect # testing as_string, as_float, as_integer - TDF_13 = @chain test_df @mutate(value = as_string(value)) @mutate(value2 = as_float(value), value3 = as_integer(value)) @filter(value2 > 4 && value3 < 10) - TDB_13 = @chain DB.t(test_db) DB.@mutate(value = as_string(value)) DB.@mutate(value2 = as_float(value), value3 = as_integer(value)) DB.@filter(value2 > 4 && value3 < 10) DB.@collect + TDF_14 = @chain test_df @mutate(value = as_string(value)) @mutate(value2 = as_float(value), value3 = as_integer(value)) @filter(value2 > 4 && value3 < 10) + TDB_14 = @chain DB.t(test_db) DB.@mutate(value = as_string(value)) DB.@mutate(value2 = as_float(value), value3 = as_integer(value)) DB.@filter(value2 > 4 && value3 < 10) DB.@collect + # test a joining multiple TidierDB queries in one chain + TDF_15 = @chain test_df @full_join(@filter(df2, score > 85 && str_detect(id2, "C")), id = id2) @mutate(score = replace_missing(score, 0)) @right_join((@chain df3 @filter(value2 != 20)), id = id3) @relocate(id, after = score) + TDB_15 = @chain DB.t(test_db) DB.@full_join(DB.t(query), id2, id) DB.@mutate(score = replace_missing(score, 0)) DB.@right_join((@chain DB.t(join_db2) DB.@filter(value2 != 20)), id3, id) DB.@select(!id2, !id) DB.@collect + @test all(isequal.(Array(TDF_1), Array(TDB_1))) @test all(isequal.(Array(TDF_2), Array(TDB_2))) @test all(isequal.(Array(TDF_3), Array(TDB_3))) @@ -166,6 +173,8 @@ @test all(isequal.(Array(TDF_11), Array(TDB_11))) @test all(isequal.(Array(TDF_12), Array(TDB_12))) @test all(isequal.(Array(TDF_13), Array(TDB_13))) + @test all(isequal.(Array(TDF_14), Array(TDB_14))) + @test all(isequal.(Array(TDF_15), Array(TDB_15))) end @testset "Mutate with Conditionals, Strings and then Filter" begin diff --git a/test/runtests.jl b/test/runtests.jl index 52af548..c30f3a0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,11 +23,17 @@ test_df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], df2 = DataFrame(id2 = ["AA", "AC", "AE", "AG", "AI", "AK", "AM"], category = ["X", "Y", "X", "Y", "X", "Y", "X"], score = [88, 92, 77, 83, 95, 68, 74]); +df3 = DataFrame(id3 = ["AA", "AG", "AI", "AM", "AN"], + description = ["Desc1", "Desc2", "Desc3", "Desc4", "Desc5"], + value2 = [10, 20, 30, 40, 50]) + db = DB.connect(DB.duckdb()); DB.copy_to(db, test_df, "test_df"); DB.copy_to(db, df2, "df_join"); +DB.copy_to(db, df3, "df_join2"); test_db = DB.db_table(db, "test_df"); join_db = DB.db_table(db, "df_join"); +join_db2 = DB.db_table(db, "df_join2"); @testset "TidierDB to TidierData comparisons" verbose = true begin include("comp_tests.jl")