Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/decima/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,8 @@ def __len__(self):
def validate_allele_seq(self, gene, variant):
seq = self.result.gene_sequence(gene, genome=self.genome)
pos = variant.rel_pos
if variant.strand == "-":
pos = pos - len(variant.ref) + 1
ref_match = seq[pos : pos + len(variant.ref)] == variant.ref_tx
alt_match = seq[pos : pos + len(variant.alt)] == variant.alt_tx
return ref_match, alt_match
Expand All @@ -889,6 +891,9 @@ def __getitem__(self, idx):
variant = self.variants.iloc[seq_idx]
rel_pos = variant.rel_pos + self.max_seq_shift

if variant.strand == "-":
rel_pos = rel_pos - len(variant.ref) + 1

# by default cache values are nan if matched with reference genome
# then it will be replaced with the predicted expression from cache.
pred_expr = {model_name: torch.full((self.result.shape[0],), torch.nan) for model_name in self.model_names}
Expand Down
2 changes: 2 additions & 0 deletions src/decima/utils/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
# breakpoint()
if self.writer is not None:
if self.metadata is not None:
self.writer.add_key_value_metadata({str(k): str(v) for k, v in self.metadata.items()})
self.writer.close()
else:
warnings.warn("NoDataFrameWrittenError: No dataframe was written to the parquet file.")
pd.DataFrame({}).to_parquet(self.output_path)
self.first_chunk = True

def write(self, chunk: pd.DataFrame) -> None:
Expand Down
10 changes: 5 additions & 5 deletions src/decima/vep/attributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from decima.utils.io import read_vcf_chunks, VariantAttributionWriter
from decima.core.result import DecimaResult
from decima.data.dataset import VariantDataset
from decima.hub import load_decima_model
from decima.interpret.attributer import DecimaAttributer
from decima.model.metrics import WarningCounter
from decima.vep.vep import _log_vep_warnings, _write_vep_warnings
Expand Down Expand Up @@ -158,16 +159,15 @@ def variant_effect_attribution(
f"Unsupported input type: {type(variants)}. Must be pd.DataFrame or str (path to .tsv or .vcf)."
)

result = DecimaResult.load(metadata_anndata)

model = load_decima_model(model, device=device)
result = DecimaResult.load(metadata_anndata or model.name)
tasks, off_tasks = _get_on_off_tasks(result, tasks, off_tasks)
attributer = DecimaAttributer.load_decima_attributer(
model_name=model,
attributer = DecimaAttributer(
model=model,
tasks=tasks,
off_tasks=off_tasks,
method=method,
transform=transform,
device=device,
)

warning_counter = WarningCounter()
Expand Down
11 changes: 11 additions & 0 deletions tests/test_vep.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ def test_VariantDataset_overlap_genes(df_variant):
})
df = VariantDataset.overlap_genes(df_variant, df_genes)

def test_VariantDataset_validate_allele_seq():
df_variant = pd.DataFrame({
"chrom": ["chr15"],
"pos": [44715509],
"ref": ["CC"],
"alt": ["TT"]
})
dataset = VariantDataset(df_variant)
ref_match, _ = dataset.validate_allele_seq("SPG11", dataset.variants.iloc[1])
assert ref_match

def test_VariantDataset(df_variant):

dataset = VariantDataset(df_variant, model_name="v1_rep0")
Expand Down