Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def trainmodel(
Output: None
"""
if nepochs < 1:
raise ValueError("Minimum 1 epoch, not {nepochs}")
raise ValueError(f"Minimum 1 epoch, not {nepochs}")

if batchsteps is None:
batchsteps_set: set[int] = set()
Expand Down
11 changes: 10 additions & 1 deletion vamb/parsecontigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
_os.path.join(_os.path.dirname(_os.path.abspath(__file__)), "kernel.npz")
)

# This kernel is precomputed, and used to project 4^4 mers into a 103-dimensional
# space. Code below assumes this shape
assert _KERNEL.shape == (256, 103)


class CompositionMetaData:
"""A class containing metadata of sequence composition.
Expand Down Expand Up @@ -140,7 +144,10 @@ def _project(fourmers: _np.ndarray, kernel: _np.ndarray = _KERNEL) -> _np.ndarra
s[s == 0] = 1.0
fourmers *= 1 / s
fourmers += -(1 / 256)
return _np.dot(fourmers, kernel)
projected = _np.dot(fourmers, kernel)
# We know this from the _KERNEL shape
assert projected.shape[1] == 103
return projected

@staticmethod
def _convert(raw: _vambtools.PushArray, projected: _vambtools.PushArray):
Expand Down Expand Up @@ -203,6 +210,8 @@ def from_file(
tnfs_arr = projected.take()
_vambtools.mask_lower_bits(tnfs_arr, 12)

# We have checked this in _project
assert tnfs_arr.shape[0] % 103 == 0
# Don't use reshape since it creates a new array object with shared memory
tnfs_arr.shape = (len(tnfs_arr) // 103, 103)
lengths_arr = lengths.take()
Expand Down
30 changes: 17 additions & 13 deletions vamb/parsemarkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import json
import numpy as np
from loguru import logger
import contextlib

MarkerID = NewType("MarkerID", int)
MarkerName = NewType("MarkerName", str)
Expand Down Expand Up @@ -206,20 +207,23 @@ def split_file(
names = set(contignames)
os.mkdir(tmpdir_to_create)
paths = [tmpdir_to_create.joinpath(str(i)) for i in range(n_splits)]
filehandles = [open(path, "w") for path in paths]
refhasher = RefHasher()
with Reader(input) as infile:
for i, (outfile, record) in enumerate(
zip(
itertools.cycle(filehandles),
filter(lambda x: x.identifier in names, byte_iterfasta(infile, None)),
)
):
refhasher.add_refname(record.identifier)
print(record.format(), file=outfile)

for filehandle in filehandles:
filehandle.close()
# Automatically close all the files on exit
with contextlib.ExitStack() as stack:
filehandles = [stack.enter_context(open(fname, "w")) for fname in paths]
refhasher = RefHasher()
with Reader(input) as infile:
for i, (outfile, record) in enumerate(
zip(
itertools.cycle(filehandles),
filter(
lambda x: x.identifier in names, byte_iterfasta(infile, None)
),
)
):
refhasher.add_refname(record.identifier)
print(record.format(), file=outfile)

refhash = refhasher.digest()
return (refhash, paths)

Expand Down
3 changes: 1 addition & 2 deletions vamb/semisupervised_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def calc_loss(self, labels_in, labels_out, mu, logsigma):
loss = ce_labels * ce_labels_weight + kld * kld_weight

_, labels_out_indices = labels_out.max(dim=1)
_, labels_in_indices = labels_in.max(dim=1)
return loss, ce_labels, kld, _torch.sum(labels_out_indices == labels_in_indices)

def trainepoch(self, data_loader, epoch, optimizer, batchsteps):
Expand Down Expand Up @@ -384,7 +383,7 @@ def trainmodel(
raise ValueError(f"Learning rate must be positive, not {lrate}")

if nepochs < 1:
raise ValueError("Minimum 1 epoch, not {nepochs}")
raise ValueError(f"Minimum 1 epoch, not {nepochs}")

if batchsteps is None:
batchsteps_set: set[int] = set()
Expand Down
7 changes: 7 additions & 0 deletions vamb/vambtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,18 +567,21 @@ def verify_refhash(
for i, (observed_id, target_id) in enumerate(
zip_longest(observed_ids, target_ids)
):
# target_ids is longer
if observed_id is None:
message += (
f"\nIdentifier mismatch: {obs_name} has only "
f"{i} identifier(s), which is fewer than {tgt_name}"
)
log_and_error(ValueError, message)
# observed_ids is longer
elif target_id is None:
message += (
f"\nIdentifier mismatch: {tgt_name} has only "
f"{i} identifier(s), which is fewer than {obs_name}"
)
log_and_error(ValueError, message)
# An element differ
elif observed_id != target_id:
message += (
f"\nIdentifier mismatch: Identifier number {i + 1} does not match "
Expand All @@ -587,6 +590,10 @@ def verify_refhash(
f'{tgt_name}: "{target_id}"'
)
log_and_error(ValueError, message)

# If the refhashes are different, then they must either contain
# different element, or have different lengths.
# Therefore, this line can never be hit.
assert False
else:
log_and_error(ValueError, message)
Expand Down