Skip to content

Commit 715b9ec

Browse files
committed
feat(pandas): add problem 1321
1 parent f16cbea commit 715b9ec

File tree

3 files changed

+141
-2
lines changed

3 files changed

+141
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ Fiddling around with DataFusion, pandas, and PyArrow.
5858
- [X] [] [] 1978. [Employees Whose Manager Left the Company](https://leetcode.com/problems/employees-whose-manager-left-the-company) - Easy
5959
- [X] [] [] 626. [Exchange Seats](https://leetcode.com/problems/exchange-seats) - Medium
6060
- [X] [] [] 1341. [Movie Rating](https://leetcode.com/problems/movie-rating) - Medium
61-
- [] [X] [] 1321. [Restaurant Growth](https://leetcode.com/problems/restaurant-growth) - Medium
61+
- [] [X] [X] 1321. [Restaurant Growth](https://leetcode.com/problems/restaurant-growth) - Medium
6262
- [X] [] [] 602. [Friend Requests II: Who Has the Most Friends](https://leetcode.com/problems/friend-requests-ii-who-has-the-most-friends) - Medium
6363
- [X] [] [] 585. [Investments in 2016](https://leetcode.com/problems/investments-in-2016) - Medium
6464
- [X] [] [] 185. [Department Top Three Salaries](https://leetcode.com/problems/department-top-three-salaries) - Hard

problems/pandas.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,33 @@ def problem_176(employee: pd.DataFrame) -> pd.DataFrame:
2929
if result.empty:
3030
return pd.DataFrame([None], columns=["SecondHighestSalary"])
3131
return result
32+
33+
34+
def problem_1321(customer: pd.DataFrame) -> pd.DataFrame:
35+
"""Compute the moving average of how much the customer paid in a seven days window.
36+
37+
You are the restaurant owner and you want to analyze a possible expansion (there
38+
will be at least one customer every day). Seven day window refers to current day +
39+
6 days before. `average_amount` should be rounded to two decimal places.
40+
41+
Return the result table ordered by visited_on in ascending order.
42+
43+
Parameters
44+
----------
45+
customer : pa.Table
46+
Table shows the amount paid by a customer on a certain day.
47+
48+
Returns
49+
-------
50+
pd.DataFrame
51+
52+
"""
53+
grouped = customer.groupby(["visited_on"]).aggregate(
54+
amount=pd.NamedAgg("amount", "sum")
55+
)
56+
grouped = (
57+
grouped.assign(amount=grouped["amount"].rolling("7D").sum())
58+
.reset_index()
59+
.loc[6:]
60+
)
61+
return grouped.assign(average_amount=(grouped["amount"] / 7).round(2))

tests/test_pandas.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from datetime import datetime
2+
13
import pandas as pd
24
import pytest
35

4-
from problems.pandas import problem_176
6+
from problems.pandas import problem_176, problem_1321
57

68

79
@pytest.mark.parametrize(
@@ -33,3 +35,110 @@ def test_problem_176(input_data, expected_data):
3335
expected_table = pd.DataFrame(expected_data)
3436
result = problem_176(table)
3537
assert result.equals(expected_table)
38+
39+
40+
@pytest.mark.parametrize(
41+
"input_data, expected_data",
42+
[
43+
pytest.param(
44+
{
45+
"customer_id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 3],
46+
"name": [
47+
"Jhon",
48+
"Daniel",
49+
"Jade",
50+
"Khaled",
51+
"Winston",
52+
"Elvis",
53+
"Anna",
54+
"Maria",
55+
"Jaze",
56+
"Jhon",
57+
"Jade",
58+
],
59+
"visited_on": [
60+
datetime(2019, 1, 1),
61+
datetime(2019, 1, 2),
62+
datetime(2019, 1, 3),
63+
datetime(2019, 1, 4),
64+
datetime(2019, 1, 5),
65+
datetime(2019, 1, 6),
66+
datetime(2019, 1, 7),
67+
datetime(2019, 1, 8),
68+
datetime(2019, 1, 9),
69+
datetime(2019, 1, 10),
70+
datetime(2019, 1, 10),
71+
],
72+
"amount": [100, 110, 120, 130, 110, 140, 150, 80, 110, 130, 150],
73+
},
74+
{
75+
"visited_on": [
76+
datetime(2019, 1, 7),
77+
datetime(2019, 1, 8),
78+
datetime(2019, 1, 9),
79+
datetime(2019, 1, 10),
80+
],
81+
"amount": [860, 840, 840, 1000],
82+
"average_amount": [122.86, 120, 120, 142.86],
83+
},
84+
id="happy_path",
85+
),
86+
pytest.param(
87+
{
88+
"customer_id": [1, 2, 3, 1, 4, 5, 6, 1, 7, 8, 9],
89+
"name": [
90+
"Jhon",
91+
"Daniel",
92+
"Jade",
93+
"Jhon",
94+
"Khaled",
95+
"Winston",
96+
"Elvis",
97+
"Jhon",
98+
"Anna",
99+
"Maria",
100+
"Jaze",
101+
],
102+
"visited_on": [
103+
datetime(2019, 1, 1),
104+
datetime(2019, 1, 2),
105+
datetime(2019, 1, 3),
106+
datetime(2019, 1, 1),
107+
datetime(2019, 1, 4),
108+
datetime(2019, 1, 5),
109+
datetime(2019, 1, 6),
110+
datetime(2019, 1, 1),
111+
datetime(2019, 1, 7),
112+
datetime(2019, 1, 8),
113+
datetime(2019, 1, 9),
114+
],
115+
"amount": [100, 110, 120, 50, 130, 110, 140, 40, 150, 80, 110],
116+
},
117+
{
118+
"visited_on": [
119+
datetime(2019, 1, 7),
120+
datetime(2019, 1, 8),
121+
datetime(2019, 1, 9),
122+
],
123+
"amount": [950, 840, 840],
124+
"average_amount": [135.71, 120, 120],
125+
},
126+
id="duplicated_days",
127+
),
128+
],
129+
)
130+
def test_problem_1321(input_data, expected_data):
131+
table = pd.DataFrame(input_data)
132+
expected_table = pd.DataFrame(expected_data).reset_index(drop=True)
133+
result = (
134+
problem_1321(table)
135+
.reset_index(drop=True)
136+
.astype(expected_table.dtypes.to_dict())
137+
)
138+
assert list(result.index) == list(
139+
expected_table.index
140+
), f"Index mismatch: {result.index} vs {expected_table.index}"
141+
for col in expected_table.columns:
142+
assert result[col].equals(expected_table[col]), f"Mismatch in column '{col}'"
143+
144+
assert result.equals(expected_table)

0 commit comments

Comments
 (0)