Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
--------------------
[0.2.7] - 2026-XX-XX
--------------------

- Add support for very large columns and add the ``chunk_size`` parameter.
(jeromekelleher, #119).

--------------------
[0.2.6] - 2025-09-18
--------------------
Expand Down
26 changes: 25 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2019 Tskit Developers
# Copyright (c) 2019-2026 Tskit Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -34,6 +34,7 @@

import tszip
import tszip.cli as cli
from tszip import compat


def get_stdout_for_pytest():
Expand Down Expand Up @@ -98,6 +99,7 @@ def test_default_values(self):
self.assertEqual(args.decompress, False)
self.assertEqual(args.list, False)
self.assertEqual(args.stdout, False)
self.assertEqual(args.chunk_size, tszip.DEFAULT_CHUNK_SIZE)
self.assertEqual(args.variants_only, False)
self.assertEqual(args.suffix, ".tsz")

Expand All @@ -123,6 +125,14 @@ def test_decompress(self):
args = parser.parse_args([infile, "--decompress"])
self.assertTrue(args.decompress)

def test_chunk_size(self):
parser = cli.tszip_cli_parser()
infile = "tmp.trees.tsz"
args = parser.parse_args([infile, "-C", "1234"])
self.assertEqual(args.chunk_size, 1234)
args = parser.parse_args([infile, "--chunk-size=1234"])
self.assertTrue(args.chunk_size, 1234)


class TestCli(unittest.TestCase):
"""
Expand Down Expand Up @@ -248,6 +258,20 @@ def test_variants_only(self):
G2 = self.ts.genotype_matrix()
self.assertTrue(np.array_equal(G1, G2))

def test_chunk_size(self):
self.assertTrue(self.trees_path.exists())
self.run_tszip([str(self.trees_path), "--chunk-size=20"])
self.assertFalse(self.trees_path.exists())
outpath = pathlib.Path(str(self.trees_path) + ".tsz")
self.assertTrue(outpath.exists())
ts = tszip.decompress(outpath)
self.assertEqual(ts.tables, self.ts.tables)
store = compat.create_zip_store(str(outpath), mode="r")
root = compat.create_zarr_group(store=store)
for _, g in root.groups():
for _, a in g.arrays():
assert a.chunks == (20,)

def test_keep(self):
self.assertTrue(self.trees_path.exists())
self.run_tszip([str(self.trees_path), "--keep"])
Expand Down
47 changes: 44 additions & 3 deletions tests/test_compression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2021 Tskit Developers
# Copyright (c) 2021-2026 Tskit Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_small_msprime_complex_mutations(self):

def test_ref_seq(self):
ts = msprime.simulate(10, recombination_rate=1, mutation_rate=2, random_seed=2)
tables = ts.tables
tables = ts.dump_tables()
tables.reference_sequence.metadata_schema = (
tskit.MetadataSchema.permissive_json()
)
Expand Down Expand Up @@ -307,7 +307,12 @@ def test_provenance(self):
root = compat.create_zarr_group(store=store)
self.assertEqual(
root.attrs["provenance"],
provenance.get_provenance_dict({"variants_only": variants_only}),
provenance.get_provenance_dict(
{
"variants_only": variants_only,
"chunk_size": compression.DEFAULT_CHUNK_SIZE,
}
),
)

def write_file(self, attrs, path):
Expand Down Expand Up @@ -526,3 +531,39 @@ def test_issue95_metadata_dtype_regression(self):
assert len(ts_decompressed.metadata["reverse_node_map"]) == len(
ts_original.metadata["reverse_node_map"]
)


class TestChunkSize:
@pytest.mark.parametrize(
"chunk_size", [1, 2, 1000, 2**21, np.array([100], dtype=int)[0]]
)
def test_good_chunks(self, tmpdir, chunk_size):
files = pathlib.Path(__file__).parent / "files"
ts1 = tskit.load(files / "1.0.0.trees")
path = tmpdir / "out.trees.tsz"
tszip.compress(ts1, path, chunk_size=chunk_size)
ts2 = tszip.decompress(path)
assert ts1 == ts2

store = compat.create_zip_store(str(path), mode="r")
root = compat.create_zarr_group(store=store)
for _, g in root.groups():
for _, a in g.arrays():
assert a.chunks == (chunk_size,)

@pytest.mark.parametrize(
["chunk_size", "exception"],
[
(0, ValueError),
(-1, ValueError),
(1.1, TypeError),
("x", TypeError),
("10", TypeError),
],
)
def test_bad_chunks(self, tmpdir, chunk_size, exception):
files = pathlib.Path(__file__).parent / "files"
ts = tskit.load(files / "1.0.0.trees")
path = tmpdir / "out.trees.tsz"
with pytest.raises(exception):
tszip.compress(ts, path, chunk_size=chunk_size)
1 change: 1 addition & 0 deletions tszip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# SOFTWARE.
from .compression import compress # NOQA
from .compression import decompress # NOQA
from .compression import DEFAULT_CHUNK_SIZE # NOQA
from .compression import load # NOQA
from .compression import print_summary # NOQA
from .provenance import __version__ # NOQA
12 changes: 11 additions & 1 deletion tszip/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ def tszip_cli_parser():
"-v", "--verbosity", action="count", default=0, help="Increase the verbosity"
)
parser.add_argument("files", nargs="+", help="The files to compress/decompress.")
parser.add_argument(
"-C",
"--chunk-size",
type=int,
default=tszip.DEFAULT_CHUNK_SIZE,
help="Sets the size of array chunks to be compressed to the specified "
f"number of elements. Default={tszip.DEFAULT_CHUNK_SIZE}",
)
parser.add_argument(
"--variants-only",
action="store_true",
Expand Down Expand Up @@ -125,7 +133,9 @@ def run_compress(args):
check_output(outfile, args)
if args.stdout:
outfile = get_stdout()
tszip.compress(ts, outfile, variants_only=args.variants_only)
tszip.compress(
ts, outfile, variants_only=args.variants_only, chunk_size=args.chunk_size
)
remove_input(infile, args)


Expand Down
50 changes: 35 additions & 15 deletions tszip/compression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2019 Tskit Developers
# Copyright (c) 2019-2026 Tskit Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,6 +26,7 @@
import functools
import json
import logging
import numbers
import os
import pathlib
import tempfile
Expand All @@ -47,6 +48,10 @@
FORMAT_NAME = "tszip"
FORMAT_VERSION = [1, 0]

# ~8 million elements, giving 32MiB chunks to be compressed
# for most columns
DEFAULT_CHUNK_SIZE = 8 * 2**20


def minimal_dtype(array):
"""
Expand All @@ -73,7 +78,7 @@ def minimal_dtype(array):
return dtype


def compress(ts, destination, variants_only=False):
def compress(ts, destination, variants_only=False, *, chunk_size=None):
"""
Compresses the specified tree sequence and writes it to the specified path
or file-like object. By default, fully lossless compression is used so that
Expand All @@ -87,6 +92,9 @@ def compress(ts, destination, variants_only=False):
we should write the compressed file to.
:param bool variants_only: If True, discard all information not necessary
to represent the variants in the input file.
:param int chunk_size: The number of array elements per chunk in the
Zarr encoding. Defaults to 8_388_608, resulting in
each encoded chunk of 4-byte integer data being 32MiB.
"""
try:
destination = pathlib.Path(destination).resolve()
Expand All @@ -100,15 +108,15 @@ def compress(ts, destination, variants_only=False):
logging.debug(f"Writing to temporary file {filename}")
with compat.create_zip_store(filename, mode="w") as store:
root = compat.create_zarr_group(store=store)
compress_zarr(ts, root, variants_only=variants_only)
compress_zarr(ts, root, variants_only=variants_only, chunk_size=chunk_size)
if is_path:
os.replace(filename, destination)
logging.info(f"Wrote {destination}")
else:
# Assume that destination is a file-like object open in "wb" mode.
with open(filename, "rb") as source:
chunk_size = 2**10 # 1MiB
for chunk in iter(functools.partial(source.read, chunk_size), b""):
read_chunk_size = 2**10 # 1MiB
for chunk in iter(functools.partial(source.read, read_chunk_size), b""):
destination.write(chunk)


Expand All @@ -131,16 +139,14 @@ class Column:
A single column that is stored in the compressed output.
"""

def __init__(self, name, array, delta_filter=False):
def __init__(self, name, array, chunk_size, delta_filter=False):
self.name = name
self.array = array
self.delta_filter = delta_filter
self.chunks = (chunk_size,)

def compress(self, root, compressor):
shape = self.array.shape
chunks = shape
if shape[0] == 0:
chunks = (1,)
dtype = minimal_dtype(self.array)
filters = None
if self.delta_filter:
Expand All @@ -150,7 +156,7 @@ def compress(self, root, compressor):
self.name,
shape=shape,
dtype=dtype,
chunks=chunks,
chunks=self.chunks,
filters=filters,
compressor=compressor,
)
Expand All @@ -170,8 +176,18 @@ def compress(self, root, compressor):
)


def compress_zarr(ts, root, variants_only=False):
provenance_dict = provenance.get_provenance_dict({"variants_only": variants_only})
def compress_zarr(ts, root, variants_only=False, chunk_size=None):
if chunk_size is None:
chunk_size = DEFAULT_CHUNK_SIZE
if not isinstance(chunk_size, numbers.Integral):
raise TypeError("Chunk size must be an integer")
if chunk_size < 1:
raise ValueError("Storage chunk size must be >= 1")
chunk_size = int(chunk_size) # Avoid issues with JSON serialisation

provenance_dict = provenance.get_provenance_dict(
{"variants_only": variants_only, "chunk_size": chunk_size}
)

if variants_only:
logging.info("Using lossy variants-only compression")
Expand Down Expand Up @@ -254,9 +270,13 @@ def compress_zarr(ts, root, variants_only=False):
cname="zstd", clevel=9, shuffle=numcodecs.Blosc.SHUFFLE
)
for name, data in columns.items():
Column(
name, data, delta_filter="_offset" in name or name in delta_filter_cols
).compress(root, compressor)
col = Column(
name,
data,
chunk_size,
delta_filter="_offset" in name or name in delta_filter_cols,
)
col.compress(root, compressor)


def check_format(root):
Expand Down