Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ venvs
.DS_Store
build
uv.lock
.venv*
12 changes: 8 additions & 4 deletions docs/generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,10 +813,14 @@ def safe_default_value(p):
if value is inspect.Parameter.empty:
return p

replacement = next(
(i for i in ("os.environ", "sys.stdin", "sys.stdout", "sys.stderr") if value is eval(i)),
None,
)
# Resolve a small whitelist of global objects without using eval
_safe_globals = {
"os.environ": os.environ,
"sys.stdin": sys.stdin,
"sys.stdout": sys.stdout,
"sys.stderr": sys.stderr,
}
replacement = next((name for name, obj in _safe_globals.items() if value is obj), None)
if not replacement:
if isinstance(value, CPUDispatcher):
replacement = value.py_func.__name__
Expand Down
202 changes: 201 additions & 1 deletion vectorbt/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@

from copy import copy
from string import Template
import ast

from vectorbt import _typing as tp
from vectorbt.utils import checks
from vectorbt.utils.config import set_dict_item, get_func_arg_names, merge_dicts
from vectorbt.utils.docs import SafeToStr, prepare_for_doc

# Allowlist of attributes on mapped names that may be accessed/called from templates.
# Keys are names as they appear in the `mapping` (e.g. 'np'), values are sets of attribute names.
# Keep this intentionally small and conservative; expand only when necessary.
TEMPLATE_ALLOWED_ATTRS = {
'np': {'prod'},
}


class Sub(SafeToStr):
"""Template to substitute parts of the string with the respective values from `mapping`.
Expand Down Expand Up @@ -105,7 +113,199 @@ def eval(self, mapping: tp.Optional[tp.Mapping] = None) -> tp.Any:

Merges `mapping` and `RepEval.mapping`."""
mapping = merge_dicts(self.mapping, mapping)
return eval(self.expression, {}, mapping)
# Use a restricted AST evaluator to avoid arbitrary code execution

def _handle_constant(node):
return node.value

def _handle_name(node):
if node.id in mapping:
return mapping[node.id]
raise NameError(f"name '{node.id}' is not defined")

def _handle_binop(node):
left = _eval_node(node.left)
right = _eval_node(node.right)
if isinstance(node.op, ast.Add):
return left + right
if isinstance(node.op, ast.Sub):
return left - right
if isinstance(node.op, ast.Mult):
return left * right
if isinstance(node.op, ast.Div):
return left / right
if isinstance(node.op, ast.FloorDiv):
return left // right
if isinstance(node.op, ast.Mod):
return left % right
if isinstance(node.op, ast.Pow):
return left ** right
raise ValueError(f"unsupported binary operator: {node.op}")

def _handle_unaryop(node):
operand = _eval_node(node.operand)
if isinstance(node.op, ast.USub):
return -operand
if isinstance(node.op, ast.UAdd):
return +operand
if isinstance(node.op, ast.Not):
return not operand
raise ValueError(f"unsupported unary operator: {node.op}")

def _handle_boolop(node):
values = [_eval_node(v) for v in node.values]
if isinstance(node.op, ast.And):
return all(values)
if isinstance(node.op, ast.Or):
return any(values)
raise ValueError(f"unsupported boolean operator: {node.op}")

def _handle_compare(node):
left = _eval_node(node.left)
ops_map = {
ast.Eq: lambda a, b: a == b,
ast.NotEq: lambda a, b: a != b,
ast.Is: lambda a, b: a is b,
ast.IsNot: lambda a, b: a is not b,
ast.In: lambda a, b: a in b,
ast.NotIn: lambda a, b: a not in b,
ast.Lt: lambda a, b: a < b,
ast.LtE: lambda a, b: a <= b,
ast.Gt: lambda a, b: a > b,
ast.GtE: lambda a, b: a >= b,
}
for op, comparator in zip(node.ops, node.comparators):
right = _eval_node(comparator)
func = ops_map.get(type(op))
if func is None:
raise ValueError(f"unsupported comparison operator: {op}")
if not func(left, right):
return False
left = right
return True

def _handle_attribute(node):
# Only allow attribute access on top-level names from mapping and only allowed attrs
if isinstance(node.value, ast.Name):
base_name = node.value.id
if base_name not in mapping:
raise NameError(f"name '{base_name}' is not defined")
allowed = TEMPLATE_ALLOWED_ATTRS.get(base_name, set())
if node.attr not in allowed:
raise ValueError(f"access to attribute '{node.attr}' of '{base_name}' is not allowed")
base_obj = mapping[base_name]
return getattr(base_obj, node.attr)
raise ValueError("attribute access is only allowed on top-level mapped names")

def _handle_call(node):
# Allow calls only when calling an attribute of a mapped name, e.g. np.prod(...)
func_node = node.func
if isinstance(func_node, ast.Attribute) and isinstance(func_node.value, ast.Name):
base_name = func_node.value.id
if base_name not in mapping:
raise NameError(f"name '{base_name}' is not defined")
allowed = TEMPLATE_ALLOWED_ATTRS.get(base_name, set())
if func_node.attr not in allowed:
raise ValueError(f"call to '{func_node.attr}' of '{base_name}' is not allowed")
base_obj = mapping[base_name]
func = getattr(base_obj, func_node.attr)
if not callable(func):
raise ValueError(f"object '{func_node.attr}' of '{base_name}' is not callable")
args = [_eval_node(a) for a in node.args]
kwargs = {kw.arg: _eval_node(kw.value) for kw in node.keywords}
return func(*args, **kwargs)
raise ValueError("only calls to mapped attributes are allowed")

def _handle_subscript(node):
val = _eval_node(node.value)
# Handle slice objects properly
s = node.slice
if isinstance(s, ast.Slice):
lower = _eval_node(s.lower) if s.lower is not None else None
upper = _eval_node(s.upper) if s.upper is not None else None
step = _eval_node(s.step) if s.step is not None else None
return val[slice(lower, upper, step)]
# Tuple of indices (multi-dimensional)
if isinstance(s, ast.Tuple):
idx = tuple(_eval_node(elt) for elt in s.elts)
return val[idx]
# Other single index types
idx = _eval_node(s)
return val[idx]

def _handle_list(node):
result = []
for elt in node.elts:
if isinstance(elt, ast.Starred):
val = _eval_node(elt.value)
try:
result.extend(list(val))
except Exception:
raise ValueError("can't unpack starred expression")
else:
result.append(_eval_node(elt))
return result

def _handle_tuple(node):
result = []
for elt in node.elts:
if isinstance(elt, ast.Starred):
val = _eval_node(elt.value)
try:
result.extend(list(val))
except Exception:
raise ValueError("can't unpack starred expression")
else:
result.append(_eval_node(elt))
return tuple(result)

def _handle_joinedstr(node):
parts = []
for v in node.values:
if isinstance(v, ast.Constant):
parts.append(str(v.value))
elif isinstance(v, ast.FormattedValue):
val = _eval_node(v.value)
parts.append('' if val is None else str(val))
else:
parts.append(str(_eval_node(v)))
return ''.join(parts)

def _handle_dict(node):
return {_eval_node(k): _eval_node(v) for k, v in zip(node.keys, node.values)}

def _handle_ifexp(node):
# Ternary conditional expression: body if test else orelse
test_val = _eval_node(node.test)
if test_val:
return _eval_node(node.body)
return _eval_node(node.orelse)

handlers = {
ast.Constant: _handle_constant,
ast.Name: _handle_name,
ast.BinOp: _handle_binop,
ast.UnaryOp: _handle_unaryop,
ast.BoolOp: _handle_boolop,
ast.Compare: _handle_compare,
ast.Attribute: _handle_attribute,
ast.Call: _handle_call,
ast.Subscript: _handle_subscript,
ast.List: _handle_list,
ast.Tuple: _handle_tuple,
ast.Dict: _handle_dict,
ast.IfExp: _handle_ifexp,
ast.JoinedStr: _handle_joinedstr,
}

def _eval_node(node):
handler = handlers.get(type(node))
if handler is not None:
return handler(node)
raise ValueError(f"unsupported expression: {type(node).__name__}")

parsed = ast.parse(self.expression, mode="eval")
return _eval_node(parsed.body)

def __str__(self) -> str:
return f"{self.__class__.__name__}(" \
Expand Down