diff --git a/src/main/scala/com/fulcrumgenomics/vcf/AssessPhasing.scala b/src/main/scala/com/fulcrumgenomics/vcf/AssessPhasing.scala index bd4ade169..b573a1d04 100644 --- a/src/main/scala/com/fulcrumgenomics/vcf/AssessPhasing.scala +++ b/src/main/scala/com/fulcrumgenomics/vcf/AssessPhasing.scala @@ -24,9 +24,6 @@ package com.fulcrumgenomics.vcf -import java.nio.file.Paths -import java.util -import java.util.Comparator import com.fulcrumgenomics.FgBioDef._ import com.fulcrumgenomics.cmdline.{ClpGroups, FgBioTool} import com.fulcrumgenomics.commons.io.{Io, PathUtil} @@ -35,15 +32,15 @@ import com.fulcrumgenomics.fasta.SequenceDictionary import com.fulcrumgenomics.sopt._ import com.fulcrumgenomics.util.{GenomicSpan, Metric, ProgressLogger} import com.fulcrumgenomics.vcf.PhaseCigarOp.PhaseCigarOp -import com.fulcrumgenomics.vcf.api.{Variant, VcfCount, VcfFieldType, VcfFormatHeader, VcfHeader, VcfSource, VcfWriter} +import com.fulcrumgenomics.vcf.api._ import htsjdk.samtools.util.{IntervalList, OverlapDetector} -import htsjdk.variant.variantcontext.writer.{Options, VariantContextWriter, VariantContextWriterBuilder} -import htsjdk.variant.variantcontext.{Genotype, GenotypeBuilder, VariantContextBuilder} -import htsjdk.variant.vcf._ +import java.nio.file.Paths +import java.util +import java.util.Comparator import scala.annotation.tailrec -import scala.jdk.CollectionConverters._ import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ @clp( description = @@ -91,7 +88,6 @@ class AssessPhasing // check the sequence dictionaries. val dict = { - import com.fulcrumgenomics.fasta.Converters.FromSAMSequenceDictionary val calledReader = VcfSource(calledVcf) val truthReader = VcfSource(truthVcf) calledReader.header.dict.sameAs(truthReader.header.dict) @@ -99,7 +95,6 @@ class AssessPhasing truthReader.close() calledReader.header.dict } - val knownIntervalList = knownIntervals.map { intv => IntervalList.fromFile(intv.toFile).uniqued() } val calledBlockLengthCounter = new NumericCounter[Long]() val truthBlockLengthCounter = new NumericCounter[Long]() @@ -125,7 +120,6 @@ class AssessPhasing intervalList.getIntervals.map { i => i.getContig }.toSet } - // NB: could parallelize! dict .iterator @@ -279,7 +273,6 @@ class AssessPhasing contig: String, contigLength: Int, intervalList: Option[IntervalList] = None): Iterator[Variant] = { - import com.fulcrumgenomics.fasta.Converters.FromSAMSequenceDictionary val sampleName = reader.header.samples.head val baseIter: Iterator[Variant] = intervalList match { case Some(intv) => ByIntervalListVariantContextIterator(reader.iterator, intv, dict=reader.header.dict) @@ -438,7 +431,6 @@ object PhaseBlock extends LazyLogging { } progress.record(ctx.getContig, ctx.getStart) } - val blocksIn = new util.TreeSet[PhaseBlock](new Comparator[PhaseBlock] { /** Compares the two blocks based on start position, then returns the shorter block. */ def compare(a: PhaseBlock, b: PhaseBlock): Int = { diff --git a/src/main/scala/com/fulcrumgenomics/vcf/ByIntervalListVariantContextIterator.scala b/src/main/scala/com/fulcrumgenomics/vcf/ByIntervalListVariantContextIterator.scala index 6424bc92b..7fcbee6dd 100644 --- a/src/main/scala/com/fulcrumgenomics/vcf/ByIntervalListVariantContextIterator.scala +++ b/src/main/scala/com/fulcrumgenomics/vcf/ByIntervalListVariantContextIterator.scala @@ -26,12 +26,10 @@ package com.fulcrumgenomics.vcf import java.util.NoSuchElementException - import com.fulcrumgenomics.FgBioDef._ import com.fulcrumgenomics.fasta.SequenceDictionary +import com.fulcrumgenomics.vcf.api.{Variant, VcfSource} import htsjdk.samtools.util._ -import htsjdk.variant.variantcontext.VariantContext -import htsjdk.variant.vcf.VCFFileReader import scala.annotation.tailrec @@ -42,9 +40,9 @@ object ByIntervalListVariantContextIterator { * * All variants will be read in and compared to the intervals. */ - def apply(iterator: Iterator[VariantContext], + def apply(iterator: Iterator[Variant], intervalList: IntervalList, - dict: SequenceDictionary): Iterator[VariantContext] = { + dict: SequenceDictionary): Iterator[Variant] = { new OverlapDetectionVariantContextIterator(iterator, intervalList, dict) } @@ -56,29 +54,29 @@ object ByIntervalListVariantContextIterator { * The VCF will be queried when moving to the next variant context, and so may be quite slow * if we jump around the VCF a lot. */ - def apply(reader: VCFFileReader, - intervalList: IntervalList): Iterator[VariantContext] = { + def apply(reader: VcfSource, + intervalList: IntervalList): Iterator[Variant] = { new IndexQueryVariantContextIterator(reader, intervalList) } } -private class OverlapDetectionVariantContextIterator(val iter: Iterator[VariantContext], +private class OverlapDetectionVariantContextIterator(val iter: Iterator[Variant], val intervalList: IntervalList, val dict: SequenceDictionary) - extends Iterator[VariantContext] { + extends Iterator[Variant] { require(dict != null) private val intervals = intervalList.uniqued(false).iterator().buffered - private var nextVariantContext: Option[VariantContext] = None + private var nextVariant: Option[Variant] = None this.advance() - def hasNext: Boolean = this.nextVariantContext.isDefined + def hasNext: Boolean = this.nextVariant.isDefined - def next(): VariantContext = { - this.nextVariantContext match { + def next(): Variant = { + this.nextVariant match { case None => throw new NoSuchElementException("Called next when hasNext is false") - case Some(ctx) => yieldAndThen(ctx) { this.nextVariantContext = None; this.advance() } + case Some(ctx) => yieldAndThen(ctx) { this.nextVariant = None; this.advance() } } } @@ -108,16 +106,16 @@ private class OverlapDetectionVariantContextIterator(val iter: Iterator[VariantC } if (intervals.isEmpty) { } // no more intervals - else if (overlaps(ctx, intervals.head)) { nextVariantContext = Some(ctx) } // overlaps!!! + else if (overlaps(ctx, intervals.head)) { nextVariant = Some(ctx) } // overlaps!!! else if (iter.isEmpty) { } // no more variants else { this.advance() } // move to the next context } } /** NB: if a variant overlaps multiple intervals, only returns it once. */ -private class IndexQueryVariantContextIterator(private val reader: VCFFileReader, intervalList: IntervalList) - extends Iterator[VariantContext] { - private var iter: Option[Iterator[VariantContext]] = None +private class IndexQueryVariantContextIterator(private val reader: VcfSource, intervalList: IntervalList) + extends Iterator[Variant] { + private var iter: Option[Iterator[Variant]] = None private val intervals = intervalList.iterator() private var previousInterval: Option[Interval] = None @@ -127,7 +125,7 @@ private class IndexQueryVariantContextIterator(private val reader: VCFFileReader this.iter.exists(_.hasNext) } - def next(): VariantContext = { + def next(): Variant = { this.iter match { case None => throw new NoSuchElementException("Called next when hasNext is false") case Some(i) => yieldAndThen(i.next()) { advance() } @@ -136,10 +134,10 @@ private class IndexQueryVariantContextIterator(private val reader: VCFFileReader def remove(): Unit = throw new UnsupportedOperationException - private def overlapsInterval(ctx: VariantContext, interval: Interval): Boolean = { - if (!ctx.getContig.equals(interval.getContig)) false // different contig - else if (interval.getStart <= ctx.getStart && ctx.getStart <= interval.getEnd) true // start falls within this interval, count it - else if (ctx.getStart < interval.getStart && interval.getEnd <= ctx.getEnd) true // the variant encloses the interval + private def overlapsInterval(variant: Variant, interval: Interval): Boolean = { + if (!variant.getContig.equals(interval.getContig)) false // different contig + else if (interval.getStart <= variant.getStart && variant.getStart <= interval.getEnd) true // start falls within this interval, count it + else if (variant.getStart < interval.getStart && interval.getEnd <= variant.getEnd) true // the variant encloses the interval else false } diff --git a/src/main/scala/com/fulcrumgenomics/vcf/JointVariantContextIterator.scala b/src/main/scala/com/fulcrumgenomics/vcf/JointVariantContextIterator.scala index 6efbc62e3..d82a2a162 100644 --- a/src/main/scala/com/fulcrumgenomics/vcf/JointVariantContextIterator.scala +++ b/src/main/scala/com/fulcrumgenomics/vcf/JointVariantContextIterator.scala @@ -24,10 +24,8 @@ package com.fulcrumgenomics.vcf -import com.fulcrumgenomics.FgBioDef._ import com.fulcrumgenomics.fasta.SequenceDictionary import com.fulcrumgenomics.vcf.api.Variant -import htsjdk.variant.variantcontext.{VariantContext, VariantContextComparator} object JointVariantContextIterator { def apply(iters: Seq[Iterator[Variant]], @@ -40,7 +38,7 @@ object JointVariantContextIterator { } def apply(iters: Seq[Iterator[Variant]], - comp: VariantContextComparator + comp: VariantComparator ): JointVariantContextIterator = { new JointVariantContextIterator( iters=iters, @@ -54,16 +52,15 @@ object JointVariantContextIterator { * across the iterators. If samples is given, we subset each variant context to just that sample. */ class JointVariantContextIterator private(iters: Seq[Iterator[Variant]], - dictOrComp: Either[SequenceDictionary, VariantContextComparator] + dictOrComp: Either[SequenceDictionary, VariantComparator] ) -extends Iterator[Seq[Option[VariantContext]]] { - import com.fulcrumgenomics.fasta.Converters.ToSAMSequenceDictionary +extends Iterator[Seq[Option[Variant]]] { if (iters.isEmpty) throw new IllegalArgumentException("No iterators given") private val iterators = iters.map(_.buffered) private val comparator = dictOrComp match { - case Left(dict) => new VariantContextComparator(dict.asSam) + case Left(dict) => VariantComparator(dict) case Right(comp) => comp } @@ -72,7 +69,7 @@ extends Iterator[Seq[Option[VariantContext]]] { def next(): Seq[Option[Variant]] = { val minCtx = iterators.filter(_.nonEmpty).map(_.head).sortWith { // this just checks that the variants are the the same position? Shouldn't be difficult to replace. - case (left: Variant, right: Variant) => comparator.compare(VcfConversions.toJavaVariant(left), right) < 0 + case (left: Variant, right: Variant) => comparator.compare(left, right) < 0 }.head // TODO: could use a TreeSet to store the iterators, examine the head of each iterator, then pop the iterator with the min, // and add that iterator back in. @@ -82,3 +79,21 @@ extends Iterator[Seq[Option[VariantContext]]] { } } } + +private object VariantComparator { + def apply(dict: SequenceDictionary): VariantComparator = { + new VariantComparator(dict) + } +} + +/** A class for comparing Variants using a sequence dictionary */ +private class VariantComparator(dict: SequenceDictionary) { + /** Function for comparing two variants. Returns negative if left < right, and positive if right > left + * To mimic the VariantContextComparator, throws an exception if the contig isn't found. */ + def compare(left: Variant, right: Variant): Int = { + val idx1 = this.dict(name=left.chrom).index + val idx2 = this.dict(name=right.chrom).index + if (idx1 - idx2 == 0) left.pos - right.pos + else idx1 - idx2 + } +} diff --git a/src/test/scala/com/fulcrumgenomics/vcf/AssessPhasingTest.scala b/src/test/scala/com/fulcrumgenomics/vcf/AssessPhasingTest.scala index 891460913..d00c53a21 100644 --- a/src/test/scala/com/fulcrumgenomics/vcf/AssessPhasingTest.scala +++ b/src/test/scala/com/fulcrumgenomics/vcf/AssessPhasingTest.scala @@ -31,7 +31,7 @@ import com.fulcrumgenomics.testing.VcfBuilder.Gt import com.fulcrumgenomics.testing.{ErrorLogLevel, UnitSpec, VcfBuilder} import com.fulcrumgenomics.util.Metric import com.fulcrumgenomics.vcf.PhaseCigar.IlluminaSwitchErrors -import com.fulcrumgenomics.vcf.api.{Variant, VcfHeader, VcfSource, VcfWriter} +import com.fulcrumgenomics.vcf.api.{Variant, VcfCount, VcfFieldType, VcfFormatHeader, VcfHeader, VcfSource, VcfWriter} import htsjdk.samtools.SAMFileHeader import htsjdk.samtools.util.{Interval, IntervalList} import htsjdk.variant.vcf.VCFFileReader @@ -86,8 +86,8 @@ object AssessPhasingTest { // NB: call has blocks lengths 4, 5, 1, and 13; truth has block lengths 4, 6, and 13. } - val readBuilderCall: VcfSource = VcfSource(builderCall.toTempFile()) - val readBuilderTruth: VcfSource = VcfSource(builderTruth.toTempFile()) + val readBuilderCall: VcfSource = builderCall.toSource + val readBuilderTruth: VcfSource = builderTruth.toSource private def addPhaseSetId(ctx: Variant): Variant = { if (ctx.getStart <= 10) withPhasingSetId(ctx, 1) @@ -113,7 +113,8 @@ object AssessPhasingTest { }.toSeq } - val Header = readBuilderCall.header + val Header = readBuilderCall.header.copy( + formats = readBuilderCall.header.formats :+ VcfFormatHeader("PS", VcfCount.Fixed(1), VcfFieldType.Integer, "Phase set")) } /** @@ -124,11 +125,9 @@ class AssessPhasingTest extends ErrorLogLevel { private def writeVariants(variants: Seq[Variant]): PathToVcf = { val path = Files.createTempFile("AssessPhasingTest.", ".vcf.gz") - path.toFile.deleteOnExit() - val writer = VcfWriter(path, Header) - - variants.foreach(writer.write) + writer.write(variants) + writer.close() path } @@ -152,9 +151,11 @@ class AssessPhasingTest extends ErrorLogLevel { callVariants: Seq[Variant] = CallVariants, debugVcf: Boolean = false ): Unit = { - // input files + // input files ) val truthVcf = writeVariants(truthVariants) val callVcf = writeVariants(callVariants) + println(truthVcf) + println(callVcf) // output files val output = makeTempFile("AssessPhasingTest.", "prefix") @@ -164,7 +165,6 @@ class AssessPhasingTest extends ErrorLogLevel { // run it new AssessPhasing(truthVcf=truthVcf, calledVcf=callVcf, knownIntervals=intervals, output=output, debugVcf=debugVcf).execute() - // read the metrics val phasingMetrics = Metric.read[AssessPhasingMetric](path=phasingMetricsPath) val blockLengthMetrics = Metric.read[PhaseBlockLengthMetric](path=blockLengthMetricsPath) @@ -178,6 +178,8 @@ class AssessPhasingTest extends ErrorLogLevel { val expectedBlockLengthMetrics = Metric.read[PhaseBlockLengthMetric](path=expectedBlockLengthMetricsPath) // compare the metrics + println(phasingMetrics) + println(expectedPhasingMetrics) phasingMetrics should contain theSameElementsInOrderAs expectedPhasingMetrics blockLengthMetrics should contain theSameElementsInOrderAs expectedBlockLengthMetrics diff --git a/src/test/scala/com/fulcrumgenomics/vcf/ByIntervalListVariantContextIteratorTest.scala b/src/test/scala/com/fulcrumgenomics/vcf/ByIntervalListVariantContextIteratorTest.scala index 8a625229e..092af8e74 100644 --- a/src/test/scala/com/fulcrumgenomics/vcf/ByIntervalListVariantContextIteratorTest.scala +++ b/src/test/scala/com/fulcrumgenomics/vcf/ByIntervalListVariantContextIteratorTest.scala @@ -24,13 +24,11 @@ package com.fulcrumgenomics.vcf -import com.fulcrumgenomics.FgBioDef._ import com.fulcrumgenomics.testing.VcfBuilder.Gt import com.fulcrumgenomics.testing.{UnitSpec, VcfBuilder} +import com.fulcrumgenomics.vcf.api.{Variant, VcfSource} import htsjdk.samtools.SAMFileHeader import htsjdk.samtools.util.{Interval, IntervalList} -import htsjdk.variant.variantcontext.VariantContext -import htsjdk.variant.vcf.VCFFileReader /** * Tests for ByIntervalListVariantContextIterator. @@ -44,49 +42,45 @@ class ByIntervalListVariantContextIteratorTest extends UnitSpec { new IntervalList(header) } - private def toIterator(reader: VCFFileReader, + private def toIterator(reader: VcfSource, intervalList: IntervalList, - useIndex: Boolean): Iterator[VariantContext] = { + useIndex: Boolean): Iterator[Variant] = { if (useIndex) { ByIntervalListVariantContextIterator(reader=reader, intervalList=intervalList) } else { - import com.fulcrumgenomics.fasta.Converters.FromSAMSequenceDictionary - val dict = reader.getFileHeader.getSequenceDictionary.fromSam - ByIntervalListVariantContextIterator(iterator=reader.iterator(), intervalList=intervalList, dict=dict) + val dict = reader.header.dict + ByIntervalListVariantContextIterator(iterator=reader.iterator, intervalList=intervalList, dict=dict) } } "ByIntervalListVariantContextIterator" should "return no variants if the interval list is empty" in { Iterator(true, false).foreach { useIndex => - val vcfBuilder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=1, alleles=Seq("A"), gts=Seq(Gt(sample="s1", gt="0/0"))) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) - val iterator = toIterator(reader=builder, intervalList=emtpyIntervalList(), useIndex=useIndex) + val builder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=1, alleles=Seq("A"), gts=Seq(Gt(sample="s1", gt="0/0"))) + val iterator = toIterator(reader=builder.toSource, intervalList=emtpyIntervalList(), useIndex=useIndex) iterator.isEmpty shouldBe true } } it should "return no variants if the variants are empty" in { Iterator(true, false).foreach { useIndex => - val vcfBuilder = VcfBuilder(samples=Seq("s1")) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) + val builder = VcfBuilder(samples=Seq("s1")) val intervalList = emtpyIntervalList() intervalList.add(new Interval(dict.getSequence(0).getSequenceName, 1, 1000, false, "foo")) - val iterator = toIterator(reader=builder, intervalList=emtpyIntervalList(), useIndex=useIndex) + val iterator = toIterator(reader=builder.toSource, intervalList=emtpyIntervalList(), useIndex=useIndex) iterator.isEmpty shouldBe true } } it should "return a variant context if it overlaps an interval" in { Iterator(true, false).foreach { useIndex => - val vcfBuilder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=500, alleles=Seq("A"), gts=Seq(Gt(sample="s1", gt="0/0"))) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) + val builder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=500, alleles=Seq("A"), gts=Seq(Gt(sample="s1", gt="0/0"))) val intervalList = emtpyIntervalList() intervalList.add(new Interval(dict.getSequence(0).getSequenceName, 1, 1000, false, "foo")) - val iterator = toIterator(reader=builder, intervalList=intervalList, useIndex=useIndex) + val iterator = toIterator(reader=builder.toSource, intervalList=intervalList, useIndex=useIndex) iterator.isEmpty shouldBe false val actual = iterator.next() - val expected = builder.iterator().next() + val expected = builder.iterator.next() actual.getContig shouldBe expected.getContig actual.getStart shouldBe expected.getStart actual.getEnd shouldBe expected.getEnd @@ -96,31 +90,28 @@ class ByIntervalListVariantContextIteratorTest extends UnitSpec { it should "not return a variant context if it doesn't overlap an interval (same chromosome)" in { Iterator(true, false).foreach { useIndex => - val vcfBuilder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=500, alleles=Seq("A"), gts=Seq(Gt(sample="s1", gt="0/0"))) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) + val builder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=500, alleles=Seq("A"), gts=Seq(Gt(sample="s1", gt="0/0"))) val intervalList = emtpyIntervalList() intervalList.add(new Interval(dict.getSequence(0).getSequenceName, 750, 1000, false, "foo")) - val iterator = toIterator(reader=builder, intervalList=intervalList, useIndex=useIndex) + val iterator = toIterator(reader=builder.toSource, intervalList=intervalList, useIndex=useIndex) iterator.isEmpty shouldBe true } } it should "not return a variant context if it doesn't overlap an interval (different chromosome)" in { Iterator(true, false).foreach { useIndex => - val vcfBuilder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=500, alleles=Seq("A"), gts=Seq(Gt(sample="s1", gt="0/0"))) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) + val builder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=500, alleles=Seq("A"), gts=Seq(Gt(sample="s1", gt="0/0"))) val intervalList = emtpyIntervalList() intervalList.add(new Interval(dict.getSequence(1).getSequenceName, 1, 1000, false, "foo")) - val iterator = toIterator(reader=builder, intervalList=intervalList, useIndex=useIndex) + val iterator = toIterator(reader=builder.toSource, intervalList=intervalList, useIndex=useIndex) iterator.isEmpty shouldBe true } } it should "throw an exception when next() is call but hasNext() is false" in { Iterator(true, false).foreach { useIndex => - val vcfBuilder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=1, alleles=Seq("A"), gts=Seq(Gt(sample="s1", gt="0/0"))) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) - val iterator = toIterator(reader=builder, intervalList=emtpyIntervalList(), useIndex=useIndex) + val builder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=1, alleles=Seq("A"), gts=Seq(Gt(sample="s1", gt="0/0"))) + val iterator = toIterator(reader=builder.toSource, intervalList=emtpyIntervalList(), useIndex=useIndex) iterator.hasNext shouldBe false an[Exception] should be thrownBy iterator.next() } @@ -128,14 +119,13 @@ class ByIntervalListVariantContextIteratorTest extends UnitSpec { it should "return a variant context if it encloses an interval" in { Iterator(true, false).foreach { useIndex => - val vcfBuilder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=495, alleles=Seq("AAAAA", "A"), gts=Seq(Gt(sample="s1", gt="1/1"))) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) + val builder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=495, alleles=Seq("AAAAA", "A"), gts=Seq(Gt(sample="s1", gt="1/1"))) val intervalList = emtpyIntervalList() intervalList.add(new Interval(dict.getSequence(0).getSequenceName, 496, 496, false, "foo")) - val iterator = toIterator(reader=builder, intervalList=intervalList, useIndex=useIndex) + val iterator = toIterator(reader=builder.toSource, intervalList=intervalList, useIndex=useIndex) iterator.isEmpty shouldBe false val actual = iterator.next() - val expected = builder.iterator().next() + val expected = builder.iterator.next() actual.getContig shouldBe expected.getContig actual.getStart shouldBe expected.getStart actual.getEnd shouldBe expected.getEnd @@ -145,15 +135,14 @@ class ByIntervalListVariantContextIteratorTest extends UnitSpec { it should "return a variant context only once if it overlaps multiple intervals" in { Iterator(true, false).foreach { useIndex => - val vcfBuilder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=495, alleles=Seq("AAAAA", "A"), gts=Seq(Gt(sample="s1", gt="1/1"))) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) + val builder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=495, alleles=Seq("AAAAA", "A"), gts=Seq(Gt(sample="s1", gt="1/1"))) val intervalList = emtpyIntervalList() intervalList.add(new Interval(dict.getSequence(0).getSequenceName, 496, 496, false, "foo")) intervalList.add(new Interval(dict.getSequence(0).getSequenceName, 500, 500, false, "foo")) - val iterator = toIterator(reader=builder, intervalList=intervalList, useIndex=useIndex) + val iterator = toIterator(reader=builder.toSource, intervalList=intervalList, useIndex=useIndex) iterator.isEmpty shouldBe false val actual = iterator.next() - val expected = builder.iterator().next() + val expected = builder.iterator.next() actual.getContig shouldBe expected.getContig actual.getStart shouldBe expected.getStart actual.getEnd shouldBe expected.getEnd @@ -162,14 +151,13 @@ class ByIntervalListVariantContextIteratorTest extends UnitSpec { } it should "throw an exception when intervals are given out of order when using the VCF index" in { - val vcfBuilder = VcfBuilder(samples=Seq("s1")) + val builder = VcfBuilder(samples=Seq("s1")) .add(chrom="chr1", pos=495, alleles=Seq("AAAAA", "A"), gts=Seq(Gt(sample="s1", gt="1/1"))) .add(chrom="chr1", pos=595, alleles=Seq("AAAAA", "A"), gts=Seq(Gt(sample="s1", gt="1/1"))) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) val intervalList = emtpyIntervalList() intervalList.add(new Interval(dict.getSequence(0).getSequenceName, 494, 500, false, "foo")) intervalList.add(new Interval(dict.getSequence(0).getSequenceName, 500, 500, false, "foo")) - val iterator = toIterator(reader=builder, intervalList=intervalList, useIndex=true) + val iterator = toIterator(reader=builder.toSource, intervalList=intervalList, useIndex=true) // OK, since we are overlapping the first interval iterator.isEmpty shouldBe false // NOK, since the intervals were overlapping when we pre-fetch the second variant context @@ -178,11 +166,10 @@ class ByIntervalListVariantContextIteratorTest extends UnitSpec { it should "ignore a variant context if does not overlap an interval" in { Iterator(true, false).foreach { useIndex => - val vcfBuilder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=495, alleles=Seq("A", "C"), gts=Seq(Gt(sample="s1", gt="1/1"))) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) + val builder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=495, alleles=Seq("A", "C"), gts=Seq(Gt(sample="s1", gt="1/1"))) val intervalList = emtpyIntervalList() intervalList.add(new Interval(dict.getSequence(0).getSequenceName, 500, 500, false, "foo")) - val iterator = toIterator(reader=builder, intervalList=intervalList, useIndex=useIndex) + val iterator = toIterator(reader=builder.toSource, intervalList=intervalList, useIndex=useIndex) iterator.isEmpty shouldBe true } } diff --git a/src/test/scala/com/fulcrumgenomics/vcf/JointVariantContextIteratorTest.scala b/src/test/scala/com/fulcrumgenomics/vcf/JointVariantContextIteratorTest.scala index 9f53e7193..ffad4aa59 100644 --- a/src/test/scala/com/fulcrumgenomics/vcf/JointVariantContextIteratorTest.scala +++ b/src/test/scala/com/fulcrumgenomics/vcf/JointVariantContextIteratorTest.scala @@ -25,10 +25,7 @@ package com.fulcrumgenomics.vcf import com.fulcrumgenomics.testing.{UnitSpec, VcfBuilder} -import htsjdk.variant.variantcontext.VariantContext -import htsjdk.variant.vcf.VCFFileReader - -import scala.jdk.CollectionConverters.IteratorHasAsScala +import com.fulcrumgenomics.vcf.api.Variant /** * Tests for JointVariantContextIterator. @@ -36,51 +33,46 @@ import scala.jdk.CollectionConverters.IteratorHasAsScala class JointVariantContextIteratorTest extends UnitSpec { private val dict = VcfBuilder(samples=Seq("s1")).header.dict - private def compareVariantContexts(actual: VariantContext, expected: VariantContext): Unit = { + private def compareVariantContexts(actual: Variant, expected: Variant): Unit = { actual.getContig shouldBe expected.getContig actual.getStart shouldBe expected.getStart actual.getEnd shouldBe expected.getEnd } "JointVariantContextIterator" should "iterate variant contexts given a single iterator" in { - val vcfBuilder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=1, alleles=Seq("A")) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) + val builder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=1, alleles=Seq("A")) - val iterator = JointVariantContextIterator(iters=Seq(builder.iterator().asScala), dict=dict) - compareVariantContexts(actual=iterator.next().head.get, expected=builder.iterator().next()) + val iterator = JointVariantContextIterator(iters=Seq(builder.iterator), dict=dict) + compareVariantContexts(actual=iterator.next().head.get, expected=builder.iterator.next()) } it should "not return a variant context if all the iterators are empty" in { - val vcfBuilder = VcfBuilder(samples=Seq("s1")) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) + val builder = VcfBuilder(samples=Seq("s1")) - val iterator = JointVariantContextIterator(iters=Seq(builder.iterator().asScala, builder.iterator().asScala), dict=dict) + val iterator = JointVariantContextIterator(iters=Seq(builder.iterator, builder.iterator), dict=dict) iterator.hasNext shouldBe false an[NoSuchElementException] should be thrownBy iterator.next() } it should "return a pair of variant contexts at the same position" in { - val vcfBuilder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=1, alleles=Seq("A")) - val builder = new VCFFileReader(vcfBuilder.toTempFile()) + val builder = VcfBuilder(samples=Seq("s1")).add(chrom="chr1", pos=1, alleles=Seq("A")) - val iterator = JointVariantContextIterator(iters=Seq(builder.iterator().asScala, builder.iterator().asScala), dict=dict) + val iterator = JointVariantContextIterator(iters=Seq(builder.iterator, builder.iterator), dict=dict) iterator.hasNext shouldBe true val Seq(left, right) = iterator.next().flatten compareVariantContexts(left, right) } it should "return a None for an iterator that doesn't have a variant context for a given covered site" in { - val vcfBuilderLeft = VcfBuilder(samples=Seq("s1")) + val builderLeft = VcfBuilder(samples=Seq("s1")) .add(chrom="chr1", pos=10, alleles=Seq("A")) .add(chrom="chr1", pos=30, alleles=Seq("A")) - val builderLeft = new VCFFileReader(vcfBuilderLeft.toTempFile()) - val vcfBuilderRight = VcfBuilder(samples=Seq("s1")) + val builderRight = VcfBuilder(samples=Seq("s1")) .add(chrom="chr1", pos=10, alleles=Seq("A")) .add(chrom="chr1", pos=20, alleles=Seq("A")) - val builderRight = new VCFFileReader(vcfBuilderRight.toTempFile()) - val iterator = JointVariantContextIterator(iters=Seq(builderLeft.iterator().asScala, builderRight.iterator().asScala), dict=dict) + val iterator = JointVariantContextIterator(iters=Seq(builderLeft.iterator, builderRight.iterator), dict=dict) // pos: 10 status: both iterator.hasNext shouldBe true iterator.next().flatten match { @@ -89,12 +81,46 @@ class JointVariantContextIteratorTest extends UnitSpec { // pos: 20 status: right iterator.hasNext shouldBe true iterator.next() match { - case Seq(None, Some(right)) => compareVariantContexts(right, builderRight.iterator().asScala.toSeq.last) + case Seq(None, Some(right)) => compareVariantContexts(right, builderRight.iterator.toSeq.last) } // pos: 30 status: left iterator.hasNext shouldBe true iterator.next() match { - case Seq(Some(left), None) => compareVariantContexts(left, builderLeft.iterator().asScala.toSeq.last) + case Seq(Some(left), None) => compareVariantContexts(left, builderLeft.iterator.toSeq.last) } } } + +/** + * Tests for VariantComparator. + */ +class VariantComparatorTest extends UnitSpec { + private val dict = VcfBuilder(samples = Seq("s1")).header.dict + + "VariantComparator" should "correctly compare variant positions" in { + val builder1 = VcfBuilder(samples = Seq("s1")) + .add(chrom = "chr1", pos = 1, alleles = Seq("A")) // same variant + .add(chrom = "chr1", pos = 2, alleles = Seq("A")) // same chrom, lower position + .add(chrom = "chr1", pos = 5, alleles = Seq("A")) // same chrom, higher position + .add(chrom = "chr1", pos = 6, alleles = Seq("A")) // lower chrom + .add(chrom = "chr4", pos = 1, alleles = Seq("A")) // higher chrom + + val builder2 = VcfBuilder(samples = Seq("s2")) + .add(chrom = "chr1", pos = 1, alleles = Seq("A")) + .add(chrom = "chr1", pos = 3, alleles = Seq("A")) + .add(chrom = "chr1", pos = 4, alleles = Seq("A")) + .add(chrom = "chr2", pos = 6, alleles = Seq("A")) + .add(chrom = "chr3", pos = 1, alleles = Seq("A")) + + val answers = Seq(0, -1, 1, -1, 1) + + val comparator = VariantComparator(dict) + + var i = 0 + builder1.zip(builder2).foreach { case (v1, v2) => + val comp = comparator.compare(v1, v2) + comp shouldBe answers(i) + i += 1 + } + } +} \ No newline at end of file diff --git a/src/test/scala/com/fulcrumgenomics/vcf/VariantMaskTest.scala b/src/test/scala/com/fulcrumgenomics/vcf/VariantMaskTest.scala index d42ccb039..03a25a62a 100644 --- a/src/test/scala/com/fulcrumgenomics/vcf/VariantMaskTest.scala +++ b/src/test/scala/com/fulcrumgenomics/vcf/VariantMaskTest.scala @@ -65,7 +65,7 @@ class VariantMaskTest extends UnitSpec { it should "mask all deleted bases for deletions, plus the upstream base" in { val builder = VcfBuilder(samples = Seq("S1")) - builder.add(chrom="chr1.", pos=100, alleles=Seq("AA", "A")) + builder.add(chrom="chr1", pos=100, alleles=Seq("AA", "A")) val mask = VariantMask(builder.iterator, dict=dict) mask.isVariant(1, 99) shouldBe false