Skip to content

Commit

Permalink
fix: replace eval with a safer alternative
Browse files Browse the repository at this point in the history
docs: update documentation with the new format for controller metrics and operations and details of rule evaluation

Signed-off-by: Harikrishnan Balagopal <harikrishmenon@gmail.com>
  • Loading branch information
HarikrishnanBalagopal committed May 10, 2024
1 parent 521a463 commit 012447a
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 17 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ dependencies = [
"trl",
"peft>=0.8.0",
"datasets>=2.15.0",
"fire"
"fire",
"simpleeval",
]

[project.optional-dependencies]
Expand Down
3 changes: 3 additions & 0 deletions tests/data/trainercontroller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
TRAINER_CONFIG_TEST_LOSS_ON_THRESHOLD_WITH_TRAINER_STATE_YAML = os.path.join(
_DATA_DIR, "loss_on_threshold_with_trainer_state.yaml"
)
TRAINER_CONFIG_TEST_INVALID_TYPE_RULE_YAML = os.path.join(
_DATA_DIR, "loss_with_invalid_type_rule.yaml"
)
TRAINER_CONFIG_TEST_MALICIOUS_OS_RULE_YAML = os.path.join(
_DATA_DIR, "loss_with_malicious_os_rule.yaml"
)
Expand Down
10 changes: 10 additions & 0 deletions tests/data/trainercontroller/loss_with_invalid_type_rule.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller-wrong-os-rule
triggers:
- on_log
rule: "2+2"
operations:
- hfcontrols.should_training_stop
29 changes: 25 additions & 4 deletions tests/trainercontroller/test_tuning_trainercontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

# Standard
from dataclasses import dataclass
from typing import Any

# Third Party
from simpleeval import FunctionNotDefined
from transformers import IntervalStrategy, TrainerControl, TrainerState
import pytest

Expand All @@ -32,7 +32,6 @@
import tests.data.trainercontroller as td

# Local
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler
import tuning.config.configs as config
import tuning.trainercontroller as tc

Expand Down Expand Up @@ -162,6 +161,25 @@ def test_custom_operation_invalid_action_handler():
)


def test_invalid_type_rule():
"""Tests the invalid type rule using configuration
`examples/trainer-controller-configs/loss_with_invalid_type_rule.yaml`
"""
test_data = _setup_data()
with pytest.raises(TypeError) as exception_handler:
tc_callback = tc.TrainerControllerCallback(
td.TRAINER_CONFIG_TEST_INVALID_TYPE_RULE_YAML
)
control = TrainerControl(should_training_stop=False)
# Trigger on_init_end to perform registration of handlers to events
tc_callback.on_init_end(
args=test_data.args, state=test_data.state, control=control
)
# Trigger rule and test the condition
tc_callback.on_log(args=test_data.args, state=test_data.state, control=control)
assert str(exception_handler.value) == "Rule failed due to incorrect type usage"


def test_malicious_os_rule():
"""Tests the malicious rule using configuration
`examples/trainer-controller-configs/loss_with_malicious_os_rule.yaml`
Expand Down Expand Up @@ -193,14 +211,17 @@ def test_malicious_input_rule():
td.TRAINER_CONFIG_TEST_MALICIOUS_INPUT_RULE_YAML
)
control = TrainerControl(should_training_stop=False)
with pytest.raises(TypeError) as exception_handler:
with pytest.raises(FunctionNotDefined) as exception_handler:
# Trigger on_init_end to perform registration of handlers to events
tc_callback.on_init_end(
args=test_data.args, state=test_data.state, control=control
)
# Trigger rule and test the condition
tc_callback.on_log(args=test_data.args, state=test_data.state, control=control)
assert str(exception_handler.value) == "Rule failed due to incorrect type usage"
assert (
str(exception_handler.value)
== "Function 'input' not defined, for expression 'input('Please enter your password:')'."
)


def test_invalid_trigger():
Expand Down
164 changes: 164 additions & 0 deletions tests/utils/test_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Standard
from typing import Tuple

# Third Party
import numpy as np
import pytest

# Local
from tuning.utils.evaluator import get_evaluator


def test_mailicious_inputs_to_eval():
"""Tests the malicious rules"""
rules: list[Tuple[str, bool, str]] = [
# Valid rules
("", False, "flags['is_training'] == False"),
("", False, "not flags['is_training']"),
("", True, "-10 < loss"),
("", True, "+1000 > loss"),
("", True, "~1000 < loss"),
("", True, "(10 + 10) < loss"),
("", True, "(20 - 10) < loss"),
("", True, "(20/10) < loss"),
("", True, "(20 % 10) < loss"),
("", False, "loss < 1.0"),
("", False, "(loss < 1.0)"),
("", False, "loss*loss < 1.0"),
("", False, "loss*loss*loss < 1.0"),
("", False, "(loss*loss)*loss < 1.0"),
("", True, "int(''.join(['3', '4'])) < loss"),
("", True, "loss < 9**9"),
("", False, "loss < sqrt(xs[0]*xs[0] + xs[1]*xs[1])"),
("", True, "len(xs) > 2"),
("", True, "loss < abs(-100)"),
("", True, "loss == flags.aaa.bbb[0].ccc"),
("", True, "array3d[0][1][1] == 4"),
("", True, "numpyarray[0][1][1] == 4"),
(
"",
True,
"len(xs) == 4 and xs[0] == 1 and (xs[1] == 0 or xs[2] == 0) and xs[3] == 2",
),
# Invalid rules
(
"'aaa' is not defined for expression 'loss == aaa.bbb[0].ccc'",
False,
"loss == aaa.bbb[0].ccc",
),
("0", False, "loss == flags[0].ccc"), # KeyError
(
"Attribute 'ddd' does not exist in expression 'loss == flags.ddd[0].ccc'",
False,
"loss == flags.ddd[0].ccc",
),
(
"Sorry, access to __attributes or func_ attributes is not available. (__class__)",
False,
"'x'.__class__",
),
(
"Lambda Functions not implemented",
False,
"().__class__.__base__.__subclasses__()[141]('', '')()", # Try to instantiate and call Quitter
),
(
"Lambda Functions not implemented",
False,
"[x for x in ().__class__.__base__.__subclasses__() if x.__name__ == 'Quitter'][0]('', '')()",
),
(
"Function 'getattr' not defined, for expression 'getattr((), '__class__')'.",
False,
"getattr((), '__class__')",
),
(
"Function 'getattr' not defined, for expression 'getattr((), '_' '_class_' '_')'.",
False,
"getattr((), '_' '_class_' '_')",
),
(
"Sorry, I will not evalute something that long.",
False,
'["hello"]*10000000000',
),
(
"Sorry, I will not evalute something that long.",
False,
"'i want to break free'.split() * 9999999999",
),
(
"Lambda Functions not implemented",
False,
"(lambda x='i want to break free'.split(): x * 9999999999)()",
),
(
"Sorry, NamedExpr is not available in this evaluator",
False,
"(x := 'i want to break free'.split()) and (x * 9999999999)",
),
("Sorry! I don't want to evaluate 9 ** 387420489", False, "9**9**9**9"),
(
"Function 'mymetric1' not defined, for expression 'mymetric1() > loss'.",
True,
"mymetric1() > loss",
),
(
"Function 'mymetric2' not defined, for expression 'mymetric2(loss) > loss'.",
True,
"mymetric2(loss) > loss",
),
]
metrics = {
"loss": 42.0,
"flags": {"is_training": True, "aaa": {"bbb": [{"ccc": 42.0}]}},
"xs": [1, 0, 0, 2],
"array3d": [
[
[1, 2],
[3, 4],
],
[
[5, 6],
[7, 8],
],
],
"numpyarray": (np.arange(8).reshape((2, 2, 2)) + 1),
}

evaluator = get_evaluator(metrics=metrics)

for validation_error, expected_rule_is_true, rule in rules:
rule_parsed = evaluator.parse(expr=rule)
if validation_error == "":
actual_rule_is_true = evaluator.eval(
expr=rule,
previously_parsed=rule_parsed,
)
assert (
actual_rule_is_true == expected_rule_is_true
), "failed to execute the rule"
else:
with pytest.raises(Exception) as exception_handler:
evaluator.eval(
expr=rule,
previously_parsed=rule_parsed,
)
assert str(exception_handler.value) == validation_error
37 changes: 26 additions & 11 deletions tuning/trainercontroller/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

# Standard
from importlib import resources as impresources
from typing import List, Union
from typing import Dict, List, Union
import inspect
import os
import re

# Third Party
from simpleeval import EvalWithCompoundTypes, FeatureNotAvailable, NameNotDefined
from transformers import (
TrainerCallback,
TrainerControl,
Expand All @@ -43,6 +44,7 @@
from tuning.trainercontroller.operations import (
operation_handlers as default_operation_handlers,
)
from tuning.utils.evaluator import get_evaluator

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -174,7 +176,7 @@ def __init__(self, trainer_controller_config: Union[dict, str]):
self.register_operation_handlers(default_operation_handlers)

# controls
self.control_actions_on_event: dict[str, list[Control]] = {}
self.control_actions_on_event: Dict[str, list[Control]] = {}

# List of fields produced by the metrics
self.metrics = {}
Expand Down Expand Up @@ -208,23 +210,26 @@ def _compute_metrics(self, event_name: str, **kwargs):
self.metrics[m.get_name()] = m.compute(event_name=event_name, **kwargs)

def _take_control_actions(self, event_name: str, **kwargs):
"""Invokes the act() method for all the operations registered for a given event. \
Note here that the eval() is invoked with `__builtins__` set to None. \
This is a precaution to restric the scope of eval(), to only the \
fields produced by the metrics.
"""Invokes the act() method for all the operations registered for a given event.
Args:
event_name: str. Event name.
kwargs: List of arguments (key, value)-pairs.
"""
if event_name in self.control_actions_on_event:
evaluator = get_evaluator(metrics=self.metrics)
for control_action in self.control_actions_on_event[event_name]:
rule_succeeded = False
try:
# pylint: disable=eval-used
rule_succeeded = eval(
control_action.rule, {"__builtins__": None}, self.metrics
rule_succeeded = evaluator.eval(
expr=control_action.rule_str,
previously_parsed=control_action.rule,
)
if not isinstance(rule_succeeded, bool):
raise TypeError(
"expected the rule to evaluate to a boolean. actual type: %s"
% (type(rule_succeeded))
)
except TypeError as et:
raise TypeError("Rule failed due to incorrect type usage") from et
except ValueError as ev:
Expand All @@ -235,6 +240,14 @@ def _take_control_actions(self, event_name: str, **kwargs):
raise NameError(
"Rule failed due to use of disallowed variables"
) from en
except NameNotDefined as en1:
raise NameError(
"Rule failed because some of the variables are not defined"
) from en1
except FeatureNotAvailable as ef:
raise NotImplementedError(
"Rule failed because it uses some unsupported features"
) from ef
if rule_succeeded:
for operation_action in control_action.operation_actions:
logger.info(
Expand Down Expand Up @@ -374,9 +387,11 @@ def on_init_end(
% (controller_name, event_name)
)
# Generates the byte-code for the rule from the trainer configuration
curr_rule = controller[CONTROLLER_RULE_KEY]
control = Control(
name=controller_name,
rule=compile(controller_rule, "", "eval"),
name=controller[CONTROLLER_NAME_KEY],
rule_str=curr_rule,
rule=EvalWithCompoundTypes.parse(expr=curr_rule),
operation_actions=[],
)
for control_operation_name in controller_ops:
Expand Down
4 changes: 3 additions & 1 deletion tuning/trainercontroller/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Standard
from dataclasses import dataclass
from typing import List, Optional
import ast

# Local
from tuning.trainercontroller.operations import Operation
Expand All @@ -36,5 +37,6 @@ class Control:
"""Stores the name of control, rule byte-code corresponding actions"""

name: str
rule: Optional[object] = None # stores bytecode of the compiled rule
rule_str: str
rule: Optional[ast.AST] = None # stores the abstract syntax tree of the parsed rule
operation_actions: Optional[List[OperationAction]] = None
20 changes: 20 additions & 0 deletions tuning/utils/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Standard
from math import sqrt

# Third Party
from simpleeval import DEFAULT_FUNCTIONS, DEFAULT_NAMES, EvalWithCompoundTypes


def get_evaluator(metrics: dict) -> EvalWithCompoundTypes:
"""Returns an evaluator that can be used to evaluate simple Python expressions."""
all_names = {
**metrics,
**DEFAULT_NAMES.copy(),
}
all_funcs = {
"abs": abs,
"len": len,
"sqrt": sqrt,
**DEFAULT_FUNCTIONS.copy(),
}
return EvalWithCompoundTypes(functions=all_funcs, names=all_names)

0 comments on commit 012447a

Please sign in to comment.