From b684bdba4ccf8579cb03e3240aa5bd6943eca812 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 23 Jul 2019 17:32:01 +0200 Subject: [PATCH 1/5] Fix #81 --- root_pandas/readwrite.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/root_pandas/readwrite.py b/root_pandas/readwrite.py index ddb7ba4..07751e3 100644 --- a/root_pandas/readwrite.py +++ b/root_pandas/readwrite.py @@ -95,11 +95,11 @@ def get_nonscalar_columns(array): def get_matching_variables(branches, patterns, fail=True): # Convert branches to a set to make x "in branches" O(1) on average branches = set(branches) - patterns = set(patterns) # Find any trivial matches - selected = list(branches.intersection(patterns)) + selected = sorted(branches.intersection(patterns), + key=lambda s: patterns.index(s)) # Any matches that weren't trivial need to be looped over... - for pattern in patterns.difference(selected): + for pattern in set(patterns).difference(selected): found = False # Avoid using fnmatch if the pattern if possible if re.findall(r'(\*)|(\?)|(\[.*\])|(\[\!.*\])', pattern): From 9e611a3c9b9c0e3758e8890712b07883523362be Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 23 Jul 2019 17:32:19 +0200 Subject: [PATCH 2/5] Fix #82 --- root_pandas/readwrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/root_pandas/readwrite.py b/root_pandas/readwrite.py index 07751e3..3edc04a 100644 --- a/root_pandas/readwrite.py +++ b/root_pandas/readwrite.py @@ -317,7 +317,7 @@ def convert_to_dataframe(array, start_index=None): # Filter to remove __index__ columns columns = [c for c in array.dtype.names if c in df.columns] assert len(columns) == len(df.columns), (columns, df.columns) - df = df.reindex_axis(columns, axis=1, copy=False) + df = df.reindex(columns, axis=1, copy=False) # Convert categorical columns back to categories for c in df.columns: From f8316e0023b9340ae392b13ae21fb6aeee6f2e1e Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 23 Jul 2019 17:32:32 +0200 Subject: [PATCH 3/5] Bump version to 0.7.0 --- root_pandas/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/root_pandas/version.py b/root_pandas/version.py index 4049b2d..cabbbce 100644 --- a/root_pandas/version.py +++ b/root_pandas/version.py @@ -4,6 +4,6 @@ 'version_info', ] -__version__ = '0.6.1' +__version__ = '0.7.0' version = __version__ version_info = tuple(__version__.split('.')) From cea43d80c4c38c8d8cdd48834de30705d25bc25c Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 23 Jul 2019 17:45:39 +0200 Subject: [PATCH 4/5] Fix #80 --- root_pandas/readwrite.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/root_pandas/readwrite.py b/root_pandas/readwrite.py index 3edc04a..63b33cc 100644 --- a/root_pandas/readwrite.py +++ b/root_pandas/readwrite.py @@ -2,6 +2,7 @@ """ A module that extends pandas to support the ROOT data format. """ +from collections import Counter import numpy as np from numpy.lib.recfunctions import append_fields @@ -366,6 +367,11 @@ def to_root(df, path, key='my_ttree', mode='w', store_index=True, *args, **kwarg else: raise ValueError('Unknown mode: {}. Must be "a" or "w".'.format(mode)) + column_name_counts = Counter(df.columns) + if max(column_name_counts.values()) > 1: + raise ValueError('DataFrame contains duplicated column names: ' + + ' '.join({k for k, v in column_name_counts.items() if v > 1})) + from root_numpy import array2tree # We don't want to modify the user's DataFrame here, so we make a shallow copy df_ = df.copy(deep=False) From ea2ec6b2c2dd2125fb2a968ba884db721184fc8e Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 23 Jul 2019 17:45:47 +0200 Subject: [PATCH 5/5] Add tests --- tests/test_issues.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_issues.py b/tests/test_issues.py index caaa170..f031d04 100644 --- a/tests/test_issues.py +++ b/tests/test_issues.py @@ -24,3 +24,20 @@ def test_issue_63(): assert all(len(df) == 1 for df in result) os.remove('tmp_1.root') os.remove('tmp_2.root') + + +def test_issue_80(): + df = pd.DataFrame({'a': [1, 2], 'b': [4, 5]}) + df.columns = ['a', 'a'] + try: + root_pandas.to_root(df, '/tmp/example.root') + except ValueError as e: + assert 'DataFrame contains duplicated column names' in e.args[0] + else: + raise Exception('ValueError is expected') + + +def test_issue_82(): + variables = ['MET_px', 'MET_py', 'EventWeight'] + df = root_pandas.read_root('http://scikit-hep.org/uproot/examples/HZZ.root', 'events', columns=variables) + assert list(df.columns) == variables