Expected Behavior
When two distinct PyTorch tensor keys map to the same JAX key via torch_key_to_jax_key during concurrent safetensors loading, the loader should raise a ValueError indicating a duplicate key, preventing silent weight corruption.
Actual Behavior
The duplicate key check (if jax_key_mapped in file_loaded_tensors) and the subsequent dictionary write (file_loaded_tensors[jax_key_mapped] = current_arr) in tunix/models/safetensors_loader.py:219-223 are not protected by any lock, while running inside a ThreadPoolExecutor(max_workers=os.cpu_count()).
This creates a TOCTOU (time-of-check-to-time-of-use) race condition:
Thread A: reads file_loaded_tensors[key] → not found (CHECK passes)
Thread B: reads file_loaded_tensors[key] → not found (CHECK passes)
Thread A: writes file_loaded_tensors[key] = tensor_A
Thread B: writes file_loaded_tensors[key] = tensor_B (silently overwrites A)
The ValueError on line 220 that is intended to catch duplicates never fires. The model loads successfully with incorrect weights and no error is raised, making this extremely difficult to debug.
The existing file_lock only protects sf_file.get_tensor() (line 198-199), leaving the check and write on lines 219-223 completely unguarded:
def process_key(k_name, f, sf_file, file_loaded_tensors):
with file_lock:
v = sf_file.get_tensor(k_name) # protected
jax_key_mapped, transform = torch_utils.torch_key_to_jax_key(key_map, k_name)
# ... transform tensor ...
if jax_key_mapped in file_loaded_tensors: # CHECK — no lock
raise ValueError(...)
file_loaded_tensors[jax_key_mapped] = current_arr # WRITE — no lock
Steps to Reproduce the Problem
- Provide a key mapping where two distinct PyTorch keys resolve to the same JAX key via torch_key_to_jax_key
- Load a safetensors checkpoint that contains both keys using load_params_from_safetensors
- Under concurrent execution (os.cpu_count() threads), both threads pass the duplicate check before either writes the second thread silently overwrites the first tensor
The race is timing dependent and more likely to trigger on machines with higher core counts (larger thread pool) and when the mapped keys are close together in iteration order.
dict_lock = threading.Lock()
def process_key(k_name, f, sf_file, file_loaded_tensors):
with file_lock:
v = sf_file.get_tensor(k_name)
# ... key mapping and transforms ...
with dict_lock:
if jax_key_mapped in file_loaded_tensors:
raise ValueError(f'Duplicate key {jax_key_mapped} found within file {f.name}.')
file_loaded_tensors[jax_key_mapped] = current_arr
Environment
OS: Platform independent (threading behavior)
Project Version: main branch (v0.1.6+, commit beca051)
Python: 3.11+
Checklist
Expected Behavior
When two distinct PyTorch tensor keys map to the same JAX key via
torch_key_to_jax_keyduring concurrent safetensors loading, the loader should raise aValueErrorindicating a duplicate key, preventing silent weight corruption.Actual Behavior
The duplicate key check (
if jax_key_mapped in file_loaded_tensors) and the subsequent dictionary write (file_loaded_tensors[jax_key_mapped] = current_arr) intunix/models/safetensors_loader.py:219-223are not protected by any lock, while running inside aThreadPoolExecutor(max_workers=os.cpu_count()).This creates a TOCTOU (time-of-check-to-time-of-use) race condition:
Thread A: reads file_loaded_tensors[key] → not found (CHECK passes)
Thread B: reads file_loaded_tensors[key] → not found (CHECK passes)
Thread A: writes file_loaded_tensors[key] = tensor_A
Thread B: writes file_loaded_tensors[key] = tensor_B (silently overwrites A)
The
ValueErroron line 220 that is intended to catch duplicates never fires. The model loads successfully with incorrect weights and no error is raised, making this extremely difficult to debug.The existing
file_lockonly protectssf_file.get_tensor()(line 198-199), leaving the check and write on lines 219-223 completely unguarded:Steps to Reproduce the Problem
The race is timing dependent and more likely to trigger on machines with higher core counts (larger thread pool) and when the mapped keys are close together in iteration order.
Environment
OS: Platform independent (threading behavior)
Project Version: main branch (v0.1.6+, commit beca051)
Python: 3.11+
Checklist