Skip to content

Commit 68524ba

Browse files
committed
truncate groups by global_max_flow_len
1 parent c46797d commit 68524ba

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

netshare/pre_post_processors/netshare/preprocess_helper.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,27 @@ def split_per_chunk(
203203
split_name = config["split_name"]
204204
metadata_cols = [m for m in config["metadata"]]
205205

206+
# Truncate groups with length greater than global_max_flow_len
207+
def process_group(group):
208+
if len(group) > global_max_flow_len:
209+
processed_group = group.head(global_max_flow_len)
210+
else:
211+
processed_group = group
212+
return processed_group
213+
214+
def truncate_group(raw_df, metadata_cols):
215+
grouped = raw_df.groupby([m.column for m in metadata_cols])
216+
processed = grouped.apply(process_group)
217+
218+
# reset the index of the resulting DataFrame
219+
processed = processed.reset_index(drop=True)
220+
221+
return processed
222+
223+
print("Before truncation, df_per_chunk:", df_per_chunk.shape)
224+
df_per_chunk = truncate_group(df_per_chunk, metadata_cols)
225+
print("After truncation, df_per_chunk:", df_per_chunk.shape)
226+
206227
df_per_chunk, new_metadata_list = apply_per_field(
207228
original_df=df_per_chunk,
208229
config_fields=config["metadata"],

0 commit comments

Comments
 (0)