Skip to content

Commit 85379e0

Browse files
authored
Harden against memory leak (#53)
The extensions implementation of packstream `Structure` could leak memory when being part of a reference cycle. In reality this doesn't matter because the driver never constructs cyclic `Structure`s. Every packstream value is a tree in terms of references (both directions: packing and unpacking). This change is meant to harden the extensions against introducing effective memory leaks in the driver should the driver's usage of `Structure` change in the future. See also https://pyo3.rs/v0.22.0/class/protocols#garbage-collector-integration
1 parent 3e4f68e commit 85379e0

File tree

4 files changed

+74
-5
lines changed

4 files changed

+74
-5
lines changed

changelog.d/53.improve.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Harden `Structure` class against memory leak<ISSUES_LIST>.
2+
The extensions' implementation of packstream `Structure` could leak memory when being part of a reference cycle.
3+
In reality this doesn't matter because the driver never constructs cyclic `Structure`s.
4+
Every packstream value is a tree in terms of references (both directions: packing and unpacking).
5+
This change is meant to harden the extensions against introducing effective memory leaks in the driver should the driver's usage of `Structure` change in the future.

src/lib.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use pyo3::basic::CompareOp;
1919
use pyo3::exceptions::PyValueError;
2020
use pyo3::prelude::*;
2121
use pyo3::types::{PyBytes, PyTuple};
22-
use pyo3::IntoPyObjectExt;
22+
use pyo3::{IntoPyObjectExt, PyTraverseError, PyVisit};
2323

2424
#[pymodule(gil_used = false)]
2525
#[pyo3(name = "_rust")]
@@ -114,4 +114,15 @@ impl Structure {
114114
}
115115
Ok(fields_hash.wrapping_add(self.tag.into()))
116116
}
117+
118+
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
119+
for field in &self.fields {
120+
visit.call(field)?;
121+
}
122+
Ok(())
123+
}
124+
125+
fn __clear__(&mut self) {
126+
self.fields.clear();
127+
}
117128
}

src/v1/unpack.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,7 @@ impl<'a> PackStreamDecoder<'a> {
135135
_ => {
136136
// raise ValueError("Unknown PackStream marker %02X" % marker)
137137
return Err(PyErr::new::<PyValueError, _>(format!(
138-
"Unknown PackStream marker {:02X}",
139-
marker
138+
"Unknown PackStream marker {marker:02X}",
140139
)));
141140
}
142141
})
@@ -243,8 +242,7 @@ impl<'a> PackStreamDecoder<'a> {
243242
STRING_16 => self.read_u16(),
244243
STRING_32 => self.read_u32(),
245244
_ => Err(PyErr::new::<PyValueError, _>(format!(
246-
"Invalid string length marker: {}",
247-
marker
245+
"Invalid string length marker: {marker}",
248246
))),
249247
}
250248
}

tests/test_structure.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
from __future__ import annotations
18+
19+
import gc
20+
from contextlib import contextmanager
21+
22+
from neo4j._codec.packstream import Structure
23+
24+
25+
@contextmanager
26+
def gc_disabled():
27+
try:
28+
gc.disable()
29+
yield
30+
finally:
31+
gc.enable()
32+
gc.collect()
33+
34+
35+
class StructureHolder:
36+
s: Structure | None = None
37+
38+
39+
def test_memory_leak() -> None:
40+
iterations = 10_000
41+
42+
gc.collect()
43+
with gc_disabled():
44+
for _ in range(iterations):
45+
# create a reference cycle
46+
holder1 = StructureHolder()
47+
structure1 = Structure(b"\x00", [holder1])
48+
holder2 = StructureHolder()
49+
structure2 = Structure(b"\x01", [holder2])
50+
holder1.s = structure2
51+
holder2.s = structure1
52+
del structure1, structure2, holder1, holder2
53+
54+
cleaned = gc.collect()
55+
assert cleaned >= 4 * iterations

0 commit comments

Comments
 (0)