Skip to content

TOCTOU race condition in concurrent safetensors loading bypasses duplicate key detection #1259

@kbhujbal

Description

@kbhujbal

Expected Behavior

When two distinct PyTorch tensor keys map to the same JAX key via torch_key_to_jax_key during concurrent safetensors loading, the loader should raise a ValueError indicating a duplicate key, preventing silent weight corruption.

Actual Behavior

The duplicate key check (if jax_key_mapped in file_loaded_tensors) and the subsequent dictionary write (file_loaded_tensors[jax_key_mapped] = current_arr) in tunix/models/safetensors_loader.py:219-223 are not protected by any lock, while running inside a ThreadPoolExecutor(max_workers=os.cpu_count()).

This creates a TOCTOU (time-of-check-to-time-of-use) race condition:

Thread A: reads file_loaded_tensors[key] → not found (CHECK passes)
Thread B: reads file_loaded_tensors[key] → not found (CHECK passes)
Thread A: writes file_loaded_tensors[key] = tensor_A
Thread B: writes file_loaded_tensors[key] = tensor_B (silently overwrites A)

The ValueError on line 220 that is intended to catch duplicates never fires. The model loads successfully with incorrect weights and no error is raised, making this extremely difficult to debug.

The existing file_lock only protects sf_file.get_tensor() (line 198-199), leaving the check and write on lines 219-223 completely unguarded:

def process_key(k_name, f, sf_file, file_loaded_tensors):
    with file_lock:
        v = sf_file.get_tensor(k_name)           # protected

    jax_key_mapped, transform = torch_utils.torch_key_to_jax_key(key_map, k_name)
    # ... transform tensor ...

    if jax_key_mapped in file_loaded_tensors:     # CHECK — no lock
        raise ValueError(...)
    file_loaded_tensors[jax_key_mapped] = current_arr  # WRITE — no lock

Steps to Reproduce the Problem

  1. Provide a key mapping where two distinct PyTorch keys resolve to the same JAX key via torch_key_to_jax_key
  2. Load a safetensors checkpoint that contains both keys using load_params_from_safetensors
  3. Under concurrent execution (os.cpu_count() threads), both threads pass the duplicate check before either writes the second thread silently overwrites the first tensor

The race is timing dependent and more likely to trigger on machines with higher core counts (larger thread pool) and when the mapped keys are close together in iteration order.

dict_lock = threading.Lock()

def process_key(k_name, f, sf_file, file_loaded_tensors):
    with file_lock:
        v = sf_file.get_tensor(k_name)

    # ... key mapping and transforms ...

    with dict_lock:
        if jax_key_mapped in file_loaded_tensors:
            raise ValueError(f'Duplicate key {jax_key_mapped} found within file {f.name}.')
        file_loaded_tensors[jax_key_mapped] = current_arr

Environment

OS: Platform independent (threading behavior)
Project Version: main branch (v0.1.6+, commit beca051)
Python: 3.11+

Checklist

  • I have searched the existing issues for a similar bug report.
  • I have provided all the required information in the "Environment" section.
  • I have provided a minimal, reproducible example.

Metadata

Metadata

Assignees

Labels

type:bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions