Skip to content

Commit a04ba92

Browse files
committed
start datafusion problems
1 parent 043b3ab commit a04ba92

File tree

2 files changed

+285
-0
lines changed

2 files changed

+285
-0
lines changed

problems/datafusion.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,39 @@ def problem_180(logs: pa.Table) -> datafusion.dataframe.DataFrame:
9696
return result
9797

9898

99+
def problem_550(activity: pa.Table) -> pa.Table:
100+
"""Report the fraction of players who logged in the day after their first login.
101+
102+
Write a solution to report the fraction of players that logged in again on the
103+
day after the day they first logged in, rounded to 2 decimal places. In other
104+
words, you need to count the number of players that logged in for at least two
105+
consecutive days starting from their first login date, then divide that number by
106+
the total number of players.
107+
108+
Parameters
109+
----------
110+
activity : pa.Table
111+
This table shows the activity of players of some games.
112+
113+
Returns
114+
-------
115+
pa.Table
116+
117+
"""
118+
ctx = datafusion.SessionContext()
119+
activity = ctx.from_arrow(activity)
120+
# TODO: Why do I need to cache for this to work without an Internal error?
121+
distinct_players = activity.select("player_id").distinct().cache().count()
122+
expr = F.lag(
123+
F.col("event_date"),
124+
partition_by=[F.col("player_id")],
125+
order_by=[F.col("event_date")],
126+
).alias("event_date_lag")
127+
ctx.from_arrow(
128+
activity.select("event_date", expr).to_arrow_table(), name="event_date_lags"
129+
)
130+
131+
99132
def problem_584(customer: pa.Table) -> datafusion.dataframe.DataFrame:
100133
"""Find names of customers not referred by the customer with ID = 2.
101134
@@ -255,6 +288,33 @@ def problem_1068(sales: pa.Table, product: pa.Table) -> datafusion.dataframe.Dat
255288
).select("product_name", "year", "price")
256289

257290

291+
def problem_1075(project: pa.Table, employee: pa.Table) -> pa.Table:
292+
"""Report each project's average employee experience, rounded to 2 digits.
293+
294+
Return the result table in any order.
295+
296+
Parameters
297+
----------
298+
project : pa.Table
299+
Table shows employees (employee_id) working on projects (project_id).
300+
employee : pa.Table
301+
This table contains information about one employee.
302+
303+
Returns
304+
-------
305+
pa.Table
306+
307+
"""
308+
joined = (
309+
project.join(employee, keys="employee_id", join_type="inner")
310+
.group_by("project_id")
311+
.aggregate([("experience_years", "mean")])
312+
)
313+
return joined.set_column(
314+
1, "average_years", pc.round(joined["experience_years_mean"], 2)
315+
)
316+
317+
258318
def problem_1148(views: pa.Table) -> datafusion.dataframe.DataFrame:
259319
"""Find all the authors that viewed at least one of their own articles.
260320
@@ -297,6 +357,147 @@ def problem_1148(views: pa.Table) -> datafusion.dataframe.DataFrame:
297357
)
298358

299359

360+
def problem_1174(delivery: pa.Table) -> pa.Table:
361+
"""Find the percentage of immediate orders in the first orders of all customers.
362+
363+
If the customer's preferred delivery date is the same as the order date, then the
364+
order is called immediate; otherwise, it is called scheduled. The first order of a
365+
customer is the order with the earliest order date that the customer made. It is
366+
guaranteed that a customer has precisely one first order.
367+
368+
Round the result to 2 decimal places.
369+
370+
Parameters
371+
----------
372+
delivery : pa.Table
373+
Table shows the order date, customer name, and preferred delivery date.
374+
375+
Returns
376+
-------
377+
pa.Table
378+
379+
"""
380+
delivery = delivery.append_column(
381+
"is_immediate",
382+
(pc.equal(delivery["order_date"], delivery["customer_pref_delivery_date"])),
383+
)
384+
first_orders = delivery.group_by("customer_id").aggregate([("order_date", "min")])
385+
joined = delivery.join(
386+
first_orders,
387+
keys=["customer_id", "order_date"],
388+
right_keys=["customer_id", "order_date_min"],
389+
join_type="inner",
390+
)
391+
return pa.Table.from_arrays(
392+
[
393+
pa.array(
394+
[
395+
pc.round(
396+
pc.multiply(
397+
pc.mean(pc.cast(joined["is_immediate"], pa.int16())),
398+
pa.scalar(100.0),
399+
),
400+
2,
401+
)
402+
]
403+
)
404+
],
405+
names=["immediate_percentage"],
406+
)
407+
408+
409+
def problem_1211(queries: pa.Table) -> pa.Table:
410+
"""Find each query_name, the quality and poor_query_percentage.
411+
412+
We define query quality as:
413+
The average of the ratio between query rating and its position.
414+
We also define poor query percentage as:
415+
The percentage of all queries with rating less than 3.
416+
417+
Both quality and poor_query_percentage should be rounded to 2 decimal places.
418+
419+
Return the result table in any order.
420+
421+
Parameters
422+
----------
423+
queries : pa.Table
424+
This table contains information collected from some queries on a database.
425+
426+
Returns
427+
-------
428+
pa.Table
429+
430+
"""
431+
queries = queries.append_column(
432+
"quality", pc.divide(queries["rating"], queries["position"])
433+
).append_column(
434+
"poor_query_percentage",
435+
pc.if_else(pc.less(queries["rating"], pa.scalar(3)), 100, 0),
436+
)
437+
438+
queries_agg = queries.group_by("query_name").aggregate(
439+
[("quality", "mean"), ("poor_query_percentage", "mean")]
440+
)
441+
442+
return queries_agg.set_column(
443+
1, "quality", pc.round(queries_agg["quality_mean"], 2)
444+
).set_column(
445+
2,
446+
"poor_query_percentage",
447+
pc.round(queries_agg["poor_query_percentage_mean"], 2),
448+
)
449+
450+
451+
def problem_1251(prices: pa.Table, units_sold: pa.Table) -> pa.Table:
452+
"""Find the average selling price for each product.
453+
454+
average_price should be rounded to 2 decimal places. If a product does not have any
455+
sold units, its average selling price is assumed to be 0.
456+
457+
Return the result table in any order.
458+
459+
Parameters
460+
----------
461+
prices : pa.Table
462+
Table shows product prices by product_id for a date range.
463+
units_sold : pa.Table
464+
Table indicates the date, units, and product_id of each product sold.
465+
466+
Returns
467+
-------
468+
pa.Table
469+
470+
"""
471+
joined = prices.join(units_sold, keys="product_id")
472+
joined = joined.filter(
473+
pc.or_kleene(
474+
pc.and_(
475+
pc.greater_equal(joined["purchase_date"], joined["start_date"]),
476+
pc.less_equal(joined["purchase_date"], joined["end_date"]),
477+
),
478+
pc.is_null(joined["purchase_date"]),
479+
)
480+
)
481+
joined = joined.append_column(
482+
"total", pc.multiply(joined["price"], joined["units"])
483+
)
484+
grouped = joined.group_by("product_id").aggregate(
485+
[("units", "sum"), ("total", "sum")]
486+
)
487+
return grouped.append_column(
488+
"average_price",
489+
pc.round(
490+
pc.fill_null(
491+
pc.divide(
492+
pc.cast(grouped["total_sum"], pa.float64()), grouped["units_sum"]
493+
),
494+
0,
495+
),
496+
2,
497+
),
498+
).select(["product_id", "average_price"])
499+
500+
300501
def problem_1321(customer: pa.Table) -> datafusion.dataframe.DataFrame:
301502
"""Compute the moving average of how much the customer paid in a seven days window.
302503
@@ -509,6 +710,37 @@ def problem_1517(users: pa.Table) -> datafusion.dataframe.DataFrame:
509710
""")
510711

511712

713+
def problem_1633(users: pa.Table, register: pa.Table) -> pa.Table:
714+
"""Find the percentage of the users registered in each contest.
715+
716+
Return the result table ordered by percentage in descending order. In case of a
717+
tie, order it by contest_id in ascending order. The result should be rounded to two
718+
decimals.
719+
720+
Parameters
721+
----------
722+
users : pa.Table
723+
This table contains the name and the id of a user.
724+
register : pa.Table
725+
This table contains the id of a user and the contest they registered into.
726+
727+
Returns
728+
-------
729+
pa.Table
730+
731+
"""
732+
register_agg = register.group_by("contest_id").aggregate([("user_id", "count")])
733+
total_users = pa.scalar(float(users.num_rows))
734+
return register_agg.set_column(
735+
1,
736+
"percentage",
737+
pc.round(
738+
pc.divide(register_agg["user_id_count"], total_users),
739+
2,
740+
),
741+
).sort_by([("percentage", "descending")])
742+
743+
512744
def problem_1683(tweets: pa.Table) -> datafusion.dataframe.DataFrame:
513745
"""Find the IDs of the invalid tweets.
514746

tests/test_datafusion.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from problems.datafusion import (
77
problem_176,
88
problem_180,
9+
problem_550,
910
problem_584,
1011
problem_595,
1112
problem_620,
@@ -108,6 +109,58 @@ def test_problem_180(input_data, expected_data):
108109
assert result.to_arrow_table().equals(expected_table)
109110

110111

112+
@pytest.mark.parametrize(
113+
"input_data, expected_data",
114+
[
115+
pytest.param(
116+
{
117+
"player_id": [1, 1, 2, 3, 3],
118+
"device_id": [2, 2, 3, 1, 4],
119+
"event_date": [
120+
datetime(2016, 3, 1),
121+
datetime(2016, 3, 2),
122+
datetime(2017, 6, 25),
123+
datetime(2016, 3, 2),
124+
datetime(2018, 7, 3),
125+
],
126+
"games_played": [5, 6, 1, 0, 5],
127+
},
128+
{"fraction": [0.33]},
129+
id="happy_path_basic",
130+
),
131+
pytest.param(
132+
{
133+
"player_id": [1, 1, 1, 2, 2],
134+
"event_date": [
135+
datetime(2023, 1, 1),
136+
datetime(2023, 1, 2),
137+
datetime(2023, 1, 3),
138+
datetime(2023, 1, 1),
139+
datetime(2023, 1, 2),
140+
],
141+
"games_played": [1, 2, 3, 4, 5],
142+
},
143+
{"fraction": [1.0]},
144+
id="happy_path_multiple_dates",
145+
),
146+
pytest.param(
147+
{
148+
"player_id": [1],
149+
"event_date": [datetime(2023, 1, 1)],
150+
"games_played": [1],
151+
},
152+
{"fraction": [0.0]},
153+
id="edge_case_single_entry",
154+
),
155+
],
156+
)
157+
def test_problem_550(input_data, expected_data):
158+
input_table = pa.Table.from_pydict(input_data)
159+
expected_table = pa.Table.from_pydict(expected_data)
160+
result = problem_550(input_table)
161+
assert result.to_arrow_table().equals(expected_table)
162+
163+
111164
@pytest.mark.parametrize(
112165
"input_data, expected_data",
113166
[

0 commit comments

Comments
 (0)