Skip to content

Commit

Permalink
checking valid methods and columns
Browse files Browse the repository at this point in the history
  • Loading branch information
amytangzheng committed Oct 23, 2024
1 parent c4200c5 commit 10f325d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
8 changes: 4 additions & 4 deletions examples/featurize.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@
output_file_path_chat_level = "./jury_TINY_output_chat_level_custom_agg.csv",
output_file_path_user_level = "./jury_TINY_output_user_level_custom_agg.csv",
output_file_path_conv_level = "./jury_TINY_output_conversation_level_custom_agg.csv",
# convo_methods = ['max', 'median'], # This will aggregate ONLY the "positive_bert" at the conversation level, using mean; it will aggregate ONLY "negative_bert" at the speaker/user level, using max.
# convo_columns = ['positive_bert'],
# user_methods = ['max', 'mean', 'min', 'median'],
# user_columns = ['positive_bert', 'negative_bert'],
convo_methods = ['max', 'median'], # This will aggregate ONLY the "positive_bert" at the conversation level, using mean; it will aggregate ONLY "negative_bert" at the speaker/user level, using max.
convo_columns = ['positive_bert'],
user_methods = ['max', 'mean', 'min', 'median', 'help'],
user_columns = ['positive_bert', 'negative_bert'],
turns = False,
)
tiny_juries_feature_builder_custom_aggregation.featurize(col="message")
Expand Down
27 changes: 23 additions & 4 deletions src/team_comm_tools/utils/calculate_conversation_level_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,12 @@ def __init__(self, chat_data: pd.DataFrame,
self.convo_methods = [col.lower() for col in self.convo_methods]
self.columns_to_summarize = [col.lower() for col in self.columns_to_summarize]

# replace interchangable words in columns_to_summarize
# check if columns are numeric
for col in self.columns_to_summarize:
if pd.api.types.is_numeric_dtype(self.columns_to_summarize[col]) is False:
print("WARNING: ", col, " is not numeric. Ignoring them.")

# replace interchangable words in convo_methods and remove invalid methods
for i in range(len(self.convo_methods)):
if self.convo_methods[i] == "average":
self.convo_methods[i] = "mean"
Expand All @@ -129,6 +134,10 @@ def __init__(self, chat_data: pd.DataFrame,
self.convo_methods[i] = "stdev"
if self.convo_methods[i] == "std":
self.convo_methods[i] = "stdev"

current = self.convo_methods[i]
if current != "mean" and current != "max" and current != "min" and current != "stdev" and current != "median":
print("Warning: ", current, "is not a valid user method. Ignoring them.")


# check if user inputted user_columns is None
Expand All @@ -146,7 +155,6 @@ def __init__(self, chat_data: pd.DataFrame,
"Warning: One or more requested user columns are not present in the data. Ignoring them."
)

print(user_columns_in_data, user_columns)

for i in user_columns:
matches = process.extract(i, self.chat_data.columns, limit=3)
Expand All @@ -165,7 +173,12 @@ def __init__(self, chat_data: pd.DataFrame,
self.user_methods = [col.lower() for col in self.user_methods]
self.user_columns = [col.lower() for col in self.user_columns]

# replace interchangable words in columns_to_summarize
# check if columns are numeric
for col in self.user_columns:
if pd.api.types.is_numeric_dtype(self.user_columns[col]) is False:
print("WARNING: ", col, " is not numeric. Ignoring them.")

# replace interchangable words in user_methods and remove invalid methods
for i in range(len(self.user_methods)):
if self.user_methods[i] == "average":
self.user_methods[i] = "mean"
Expand All @@ -179,7 +192,13 @@ def __init__(self, chat_data: pd.DataFrame,
self.user_methods[i] = "stdev"
if self.user_methods[i] == "std":
self.user_methods[i] = "stdev"


current = self.user_methods[i]
if current != "mean" and current != "max" and current != "min" and current != "stdev" and current != "median":
print("Warning: ", current, "is not a valid user method. Ignoring them.")
self.user_methods.remove(current)


self.summable_columns = ["num_words", "num_chars", "num_messages"]


Expand Down
15 changes: 12 additions & 3 deletions src/team_comm_tools/utils/calculate_user_level_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def __init__(self, chat_data: pd.DataFrame,
"Warning: One or more requested user columns are not present in the data. Ignoring them."
)

# print(user_columns_in_data, user_columns)

for i in user_columns:
matches = process.extract(i, self.chat_data.columns, limit=3)
best_match, similarity = matches[0]
Expand All @@ -85,7 +83,12 @@ def __init__(self, chat_data: pd.DataFrame,
self.user_methods = [col.lower() for col in self.user_methods]
self.columns_to_summarize = [col.lower() for col in self.columns_to_summarize]

# replace interchangable words in columns_to_summarize
# check if columns are numeric
for col in self.columns_to_summarize:
if pd.api.types.is_numeric_dtype(self.columns_to_summarize[col]) is False:
print("WARNING: ", col, " is not numeric. Ignoring them.")

# replace interchangable words in user_methods and remove invalid methods
for i in range(len(self.user_methods)):
if self.user_methods[i] == "average":
self.user_methods[i] = "mean"
Expand All @@ -99,6 +102,12 @@ def __init__(self, chat_data: pd.DataFrame,
self.user_methods[i] = "stdev"
if self.user_methods[i] == "std":
self.user_methods[i] = "stdev"

current = self.user_methods[i]
if current != "mean" and current != "max" and current != "min" and current != "stdev" and current != "median":
print("Warning: ", current, "is not a valid user method. Ignoring them.")
self.user_methods.remove(current)


self.summable_columns = ["num_words", "num_chars", "num_messages"]

Expand Down

0 comments on commit 10f325d

Please sign in to comment.