From 47e03b90c70e73f592019dbdf9726212cb2eda3f Mon Sep 17 00:00:00 2001 From: james hadfield Date: Tue, 2 Jul 2024 11:04:50 +1200 Subject: [PATCH] [curate rename] update behaviour to match expected behaviour in tests. The main changes functional changes are around the order of fields, where we now rename "in-place" rather than adding the renamed column at the end (which for TSV output is the last column). More sanity checks are performed on arguments and they are cross-referenced with the provided records. Note that this relies on each record having the same fields, and this is not asserted here. See --- augur/curate/rename.py | 55 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/augur/curate/rename.py b/augur/curate/rename.py index 57172e944..64e6d2841 100644 --- a/augur/curate/rename.py +++ b/augur/curate/rename.py @@ -3,8 +3,9 @@ """ from typing import Iterable -from augur.io.print import print_err import argparse +from augur.io.print import print_err +from augur.errors import AugurError def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("rename", @@ -25,29 +26,63 @@ def register_parser(parent_subparsers): return parser -def run(args: argparse.Namespace, records: Iterable[dict]) -> Iterable[dict]: +def parse_field_map(field_map_arg: 'list[str]') -> 'dict[str,str]': + seen_old, seen_new = set(), set() + field_map = {} - for field in args.field_map: + for field in field_map_arg: old_name, new_name = field.split('=') + # Sanity check the requests to catch typos etc + if old_name in seen_old: + raise AugurError(f"Asked to rename field {old_name!r} multiple times.") + if new_name in seen_new: + raise AugurError(f"Asked to rename multiple fields to {new_name!r}.") + seen_old.add(old_name) + seen_new.add(new_name) + if old_name == new_name: continue field_map[old_name] = new_name + return field_map - for record in records: - record = record.copy() - for old_field, new_field in field_map.items(): +def transform_columns(existing_fields: 'list[str]', field_map: 'dict[str,str]', force: bool) -> 'list[tuple[str,str]]': + """ + Calculate the mapping of old column names to new column names + """ + # check that all columns to be renamed exist + for name in field_map: + if name not in existing_fields: + raise AugurError(f"Asked to rename field {name!r} (to {field_map[name]!r}) but it doesn't exist in the input data.") - if record.get(new_field) and not args.force: + # iterate through field_map and remove rename requests if they would drop an existing column + # doing this ahead-of-time allows us to preserve the order of fields using a simpler implementation + if not force: + for old_field, new_field in list(field_map.items()): + if new_field in existing_fields: print_err( f"WARNING: skipping rename of {old_field} because record", f"already has a field named {new_field}." ) - continue + del field_map[old_field] + + m = [] + for field in existing_fields: + if field in field_map: + m.append((field, field_map[field])) + elif field in field_map.values(): + pass # another column is renamed to this name, so we drop it + else: + m.append((field, field)) # no change to field name + return m - record[new_field] = record.pop(old_field, '') - yield(record) +def run(args: argparse.Namespace, records: Iterable[dict]) -> Iterable[dict]: + col_map: False | 'list[tuple[str,str]]' = False + for record in records: + if not col_map: # initialise using first record + col_map = transform_columns(list(record.keys()), parse_field_map(args.field_map), args.force) + yield({new_field:record[old_field] for old_field, new_field in col_map})