Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/JoelNiklaus/LEXTREME into main
Browse files Browse the repository at this point in the history
  • Loading branch information
kapllan committed Jun 1, 2023
2 parents 56688d1 + a479846 commit a8f6c30
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions utils/create_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,26 @@ def read_report_specs(self, report_spec_name):
return report_spec

def prepare_revision_lookup_table(self):

revision_lookup_table = dict()
for name_or_path_and_revision in self.report_specs['_name_or_path']:
_name_or_path, revision = name_or_path_and_revision.split('@')
revision_lookup_table[_name_or_path] = revision
for finetuning_task in self.report_specs['finetuning_task']:
entry = dict()
for name_or_path_and_revision in self.report_specs['_name_or_path']:
_name_or_path, revision = name_or_path_and_revision.split('@')
entry[_name_or_path] = revision
revision_lookup_table[finetuning_task] = entry

return revision_lookup_table

def insert_revision(self, _name_or_path):
def insert_revision(self, finetuning_task, _name_or_path):
if _name_or_path in self.revision_lookup_table.keys():
return self.revision_lookup_table[_name_or_path]
return self.revision_lookup_table[finetuning_task][_name_or_path]
else:
return 'main'

def revision_does_match(self, _name_or_path, revision_from_wandb):
if _name_or_path in self.revision_lookup_table.keys():
if revision_from_wandb == self.revision_lookup_table[_name_or_path]:
def revision_does_match(self, finetuning_task, _name_or_path, revision_from_wandb):
if finetuning_task in self.revision_lookup_table.keys():
if revision_from_wandb == self.revision_lookup_table[finetuning_task][_name_or_path]:
return True
else:
return False
Expand Down Expand Up @@ -436,19 +440,21 @@ def edit_result_dataframe(self, results, name_editing=True):
elif 'state' in results.columns:
results = results[results.state == "finished"]

results = results.drop_duplicates(
["seed", "finetuning_task", "_name_or_path", "language"])

# TODO: Remove this constraint in the future
# results = results[results._name_or_path.str.contains('joelito') == False]
results = results[
results.finetuning_task.str.contains('turkish_constitutional_court_decisions_judgment') == False]

# Check if revisions match
results['revisions_match'] = results.apply(
lambda row: self.revision_does_match(row['_name_or_path'], row['revision']), axis=1)
lambda row: self.revision_does_match(finetuning_task=row['finetuning_task'],
_name_or_path=row['_name_or_path'],
revision_from_wandb=row['revision']), axis=1)
results = results[results.revisions_match == True]

results = results.drop_duplicates(
["seed", "finetuning_task", "_name_or_path", "language"], keep='last')

return results

def remove_languages(self, languages):
Expand Down Expand Up @@ -542,7 +548,8 @@ def check_seed_per_task(self, task_constraint: list = [], model_constraint: list
if which_language is not None:
report_df = self.filter_by_language(report_df, which_language)

report_df['revision'] = report_df._name_or_path.apply(self.insert_revision)
report_df['revision'] = report_df.apply(
lambda x: self.insert_revision(x['finetuning_task'], x['_name_or_path']), axis=1)

report_df = insert_responsibilities(report_df)

Expand Down Expand Up @@ -826,7 +833,7 @@ def insert_abbreviations(self, dataframe):

return dataframe

def round_value(self, value, places=2):
def round_value(self, value, places=1):
if isinstance(value, float):
return round(value * 100, places)
else:
Expand Down Expand Up @@ -1230,8 +1237,14 @@ def get_language_aggregated_score(self, write_to_csv=True, task_constraint: list
self.language_aggregated_score = self.language_aggregated_score[
self.language_aggregated_score.index.isin(model_constraint)]

# Order of columns
column_order = sorted(self.language_aggregated_score.columns.tolist())
column_order = [c for c in column_order if c != 'Agg.']
self.language_aggregated_score = self.language_aggregated_score[column_order + ['Agg.']]

language_aggregated_score = deepcopy(self.language_aggregated_score)
language_aggregated_score = self.postprocess_columns(language_aggregated_score)

if write_to_csv:
language_aggregated_score.to_csv(
f'{self.output_dir}/language_aggregated_scores.csv')
Expand Down

0 comments on commit a8f6c30

Please sign in to comment.