Skip to content

Commit f231fc3

Browse files
committed
add mode logical_and to pruning
1 parent 6cd2de4 commit f231fc3

File tree

1 file changed

+72
-11
lines changed

1 file changed

+72
-11
lines changed

src/pyhf/workspace.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -583,16 +583,40 @@ def _prune_and_rename(
583583
),
584584
)
585585
for modifier in sample['modifiers']
586-
if modifier['name'] not in prune_modifiers
587-
and modifier['type'] not in prune_modifier_types
586+
if (
587+
channel['name'] not in prune_channels
588+
and prune_channels != []
589+
) # want to remove only if channel is in prune_channels or if prune_channels is empty, i.e. we want to prune this modifier for every channel
590+
or (
591+
sample['name'] not in prune_samples
592+
and prune_samples != []
593+
) # want to remove only if sample is in prune_samples or if prune_samples is empty, i.e. we want to prune this modifier for every sample
594+
or (
595+
modifier['name'] not in prune_modifiers
596+
and modifier['type'] not in prune_modifier_types
597+
)
598+
or prune_measurements
599+
!= [] # need to keep the modifier in case it is used in another measurement
588600
],
589601
}
590602
for sample in channel['samples']
591-
if sample['name'] not in prune_samples
603+
if (
604+
channel['name'] not in prune_channels
605+
and prune_channels != []
606+
) # want to remove only if channel is in prune_channels or if prune_channels is empty, i.e. we want to prune this sample for every channel
607+
or sample['name'] not in prune_samples
608+
or prune_modifiers
609+
!= [] # we only want to remove this sample if we did not specify modifiers to prune
610+
or prune_modifier_types != []
592611
],
593612
}
594613
for channel in self['channels']
595614
if channel['name'] not in prune_channels
615+
or ( # we only want to remove this channel if we did not specify any samples or modifiers to prune
616+
prune_samples != []
617+
or prune_modifiers != []
618+
or prune_modifier_types != []
619+
)
596620
],
597621
'measurements': [
598622
{
@@ -607,8 +631,14 @@ def _prune_and_rename(
607631
parameter['name'], parameter['name']
608632
),
609633
)
610-
for parameter in measurement['config']['parameters']
611-
if parameter['name'] not in prune_modifiers
634+
for parameter in measurement['config'][
635+
'parameters'
636+
] # we only want to remove this parameter if measurement is in prune_measurements or if prune_measurements is empty
637+
if (
638+
measurement['name'] not in prune_measurements
639+
and prune_measurements != []
640+
)
641+
or parameter['name'] not in prune_modifiers
612642
],
613643
'poi': rename_modifiers.get(
614644
measurement['config']['poi'], measurement['config']['poi']
@@ -617,6 +647,8 @@ def _prune_and_rename(
617647
}
618648
for measurement in self['measurements']
619649
if measurement['name'] not in prune_measurements
650+
or prune_modifiers
651+
!= [] # we only want to remove this measurement if we did not specify parameters to remove
620652
],
621653
'observations': [
622654
dict(
@@ -625,6 +657,11 @@ def _prune_and_rename(
625657
)
626658
for observation in self['observations']
627659
if observation['name'] not in prune_channels
660+
or ( # we only want to remove this channel if we did not specify any samples or modifiers to prune
661+
prune_samples != []
662+
or prune_modifiers != []
663+
or prune_modifier_types != []
664+
)
628665
],
629666
'version': self['version'],
630667
}
@@ -637,6 +674,7 @@ def prune(
637674
samples=None,
638675
channels=None,
639676
measurements=None,
677+
mode="logical_or",
640678
):
641679
"""
642680
Return a new, pruned workspace specification. This will not modify the original workspace.
@@ -649,6 +687,7 @@ def prune(
649687
samples: A :obj:`list` of samples to prune.
650688
channels: A :obj:`list` of channels to prune.
651689
measurements: A :obj:`list` of measurements to prune.
690+
mode (:obj: string): `logical_or` or `logical_and` to chain pruning with a logical OR or a logical AND, respectively. Default: `logical_or`.
652691
653692
Returns:
654693
~pyhf.workspace.Workspace: A new workspace object with the specified components removed
@@ -657,19 +696,41 @@ def prune(
657696
~pyhf.exceptions.InvalidWorkspaceOperation: An item name to prune does not exist in the workspace.
658697
659698
"""
699+
700+
if mode not in ["logical_and", "logical_or"]:
701+
raise ValueError(
702+
"Pruning mode must be either `logical_and` or `logical_or`."
703+
)
704+
660705
# avoid mutable defaults
661706
modifiers = [] if modifiers is None else modifiers
662707
modifier_types = [] if modifier_types is None else modifier_types
663708
samples = [] if samples is None else samples
664709
channels = [] if channels is None else channels
665710
measurements = [] if measurements is None else measurements
666711

667-
return self._prune_and_rename(
668-
prune_modifiers=modifiers,
669-
prune_modifier_types=modifier_types,
670-
prune_samples=samples,
671-
prune_channels=channels,
672-
prune_measurements=measurements,
712+
if mode == "logical_and":
713+
if samples != [] and measurements != []:
714+
raise ValueError(
715+
"Pruning of measurements and samples cannot be run with mode `logical_and`."
716+
)
717+
if modifier_types != [] and measurements != []:
718+
raise ValueError(
719+
"Pruning of measurements and modifier_types cannot be run with mode `logical_and`."
720+
)
721+
return self._prune_and_rename(
722+
prune_modifiers=modifiers,
723+
prune_modifier_types=modifier_types,
724+
prune_samples=samples,
725+
prune_channels=channels,
726+
prune_measurements=measurements,
727+
)
728+
return (
729+
self._prune_and_rename(prune_modifiers=modifiers)
730+
._prune_and_rename(prune_modifier_types=modifier_types)
731+
._prune_and_rename(prune_samples=samples)
732+
._prune_and_rename(prune_channels=channels)
733+
._prune_and_rename(prune_measurements=measurements)
673734
)
674735

675736
def rename(self, modifiers=None, samples=None, channels=None, measurements=None):

0 commit comments

Comments
 (0)