Skip to content

Commit 04ba40d

Browse files
Rocketknight1duanjunwen
authored andcommitted
More ReDOS fixes! (huggingface#36964)
* More ReDOS fixes! * Slight regex cleanup * Cleanup regex replacement * Drop that regex entirely too * The regex didn't match config.json, let's make sure we don't either * Cleanup allowed_value_chars a little * Cleanup the import search * Catch multi-condition blocks too * Trigger tests * Trigger tests
1 parent eb3e7a7 commit 04ba40d

File tree

3 files changed

+71
-26
lines changed

3 files changed

+71
-26
lines changed

src/transformers/commands/chat.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import json
1818
import os
1919
import platform
20-
import re
20+
import string
2121
import time
2222
from argparse import ArgumentParser, Namespace
2323
from dataclasses import dataclass, field
@@ -44,6 +44,10 @@
4444

4545
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
4646

47+
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
48+
ALLOWED_VALUE_CHARS = set(
49+
string.ascii_letters + string.digits + string.whitespace + r".!\"#$%&'()*+,\-/:<=>?@[]^_`{|}~"
50+
)
4751

4852
HELP_STRING = """\
4953
@@ -71,8 +75,6 @@
7175
"repetition_penalty",
7276
]
7377

74-
SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$"
75-
7678
DEFAULT_EXAMPLES = {
7779
"llama": {"text": "There is a Llama in my lawn, how can I get rid of it?"},
7880
"code": {
@@ -438,6 +440,36 @@ def register_subcommand(parser: ArgumentParser):
438440
def __init__(self, args):
439441
self.args = args
440442

443+
@staticmethod
444+
def is_valid_setting_command(s: str) -> bool:
445+
# First check the basic structure
446+
if not s.startswith("set ") or "=" not in s:
447+
return False
448+
449+
# Split into individual assignments
450+
assignments = [a.strip() for a in s[4:].split(";") if a.strip()]
451+
452+
for assignment in assignments:
453+
# Each assignment should have exactly one '='
454+
if assignment.count("=") != 1:
455+
return False
456+
457+
key, value = assignment.split("=", 1)
458+
key = key.strip()
459+
value = value.strip()
460+
if not key or not value:
461+
return False
462+
463+
# Keys can only have alphabetic characters, spaces and underscores
464+
if not set(key).issubset(ALLOWED_KEY_CHARS):
465+
return False
466+
467+
# Values can have just about anything that isn't a semicolon
468+
if not set(value).issubset(ALLOWED_VALUE_CHARS):
469+
return False
470+
471+
return True
472+
441473
def run(self):
442474
if not is_rich_available():
443475
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
@@ -499,7 +531,7 @@ def run(self):
499531
interface.print_green(f"Chat saved in {filename}!")
500532
continue
501533

502-
if re.match(SETTING_RE, user_input):
534+
if self.is_valid_setting_command(user_input):
503535
current_args, success = parse_settings(user_input, current_args, interface)
504536
if success:
505537
chat = []

src/transformers/configuration_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import copy
1818
import json
1919
import os
20-
import re
2120
import warnings
2221
from typing import Any, Optional, Union
2322

@@ -44,8 +43,6 @@
4443

4544
logger = logging.get_logger(__name__)
4645

47-
_re_configuration_file = re.compile(r"config\.(.*)\.json")
48-
4946

5047
class PretrainedConfig(PushToHubMixin):
5148
# no-format
@@ -1160,9 +1157,8 @@ def get_configuration_file(configuration_files: list[str]) -> str:
11601157
"""
11611158
configuration_files_map = {}
11621159
for file_name in configuration_files:
1163-
search = _re_configuration_file.search(file_name)
1164-
if search is not None:
1165-
v = search.groups()[0]
1160+
if file_name.startswith("config.") and file_name.endswith(".json") and file_name != "config.json":
1161+
v = file_name.removeprefix("config.").removesuffix(".json")
11661162
configuration_files_map[v] = file_name
11671163
available_versions = sorted(configuration_files_map.keys())
11681164

src/transformers/dynamic_module_utils.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Utilities to dynamically load objects from the Hub."""
1515

16+
import ast
1617
import filecmp
1718
import hashlib
1819
import importlib
@@ -148,22 +149,38 @@ def get_imports(filename: Union[str, os.PathLike]) -> list[str]:
148149
"""
149150
with open(filename, encoding="utf-8") as f:
150151
content = f.read()
151-
152-
# filter out try/except block so in custom code we can have try/except imports
153-
content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL)
154-
155-
# filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
156-
content = re.sub(
157-
r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+", "", content, flags=re.MULTILINE
158-
)
159-
160-
# Imports of the form `import xxx`
161-
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
162-
# Imports of the form `from xxx import yyy`
163-
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
164-
# Only keep the top-level module
165-
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
166-
return list(set(imports))
152+
imported_modules = set()
153+
154+
def recursive_look_for_imports(node):
155+
if isinstance(node, ast.Try):
156+
return # Don't recurse into Try blocks and ignore imports in them
157+
elif isinstance(node, ast.If):
158+
test = node.test
159+
for condition_node in ast.walk(test):
160+
if isinstance(condition_node, ast.Call) and condition_node.func.id.startswith("is_flash_attn"):
161+
# Don't recurse into "if flash_attn_available()" blocks and ignore imports in them
162+
return
163+
elif isinstance(node, ast.Import):
164+
# Handle 'import x' statements
165+
for alias in node.names:
166+
top_module = alias.name.split(".")[0]
167+
if top_module:
168+
imported_modules.add(top_module)
169+
elif isinstance(node, ast.ImportFrom):
170+
# Handle 'from x import y' statements, ignoring relative imports
171+
if node.level == 0 and node.module:
172+
top_module = node.module.split(".")[0]
173+
if top_module:
174+
imported_modules.add(top_module)
175+
176+
# Recursively visit all children
177+
for child in ast.iter_child_nodes(node):
178+
recursive_look_for_imports(child)
179+
180+
tree = ast.parse(content)
181+
recursive_look_for_imports(tree)
182+
183+
return sorted(imported_modules)
167184

168185

169186
def check_imports(filename: Union[str, os.PathLike]) -> list[str]:

0 commit comments

Comments
 (0)