Skip to content

Commit

Permalink
Added a comparator to the iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
Kari Stromhaug committed Apr 29, 2022
1 parent 3da0aed commit 88f5148
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 118 deletions.
18 changes: 5 additions & 13 deletions src/main/scala/com/fulcrumgenomics/vcf/AssessPhasing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@

package com.fulcrumgenomics.vcf

import java.nio.file.Paths
import java.util
import java.util.Comparator
import com.fulcrumgenomics.FgBioDef._
import com.fulcrumgenomics.cmdline.{ClpGroups, FgBioTool}
import com.fulcrumgenomics.commons.io.{Io, PathUtil}
Expand All @@ -35,15 +32,15 @@ import com.fulcrumgenomics.fasta.SequenceDictionary
import com.fulcrumgenomics.sopt._
import com.fulcrumgenomics.util.{GenomicSpan, Metric, ProgressLogger}
import com.fulcrumgenomics.vcf.PhaseCigarOp.PhaseCigarOp
import com.fulcrumgenomics.vcf.api.{Variant, VcfCount, VcfFieldType, VcfFormatHeader, VcfHeader, VcfSource, VcfWriter}
import com.fulcrumgenomics.vcf.api._
import htsjdk.samtools.util.{IntervalList, OverlapDetector}
import htsjdk.variant.variantcontext.writer.{Options, VariantContextWriter, VariantContextWriterBuilder}
import htsjdk.variant.variantcontext.{Genotype, GenotypeBuilder, VariantContextBuilder}
import htsjdk.variant.vcf._

import java.nio.file.Paths
import java.util
import java.util.Comparator
import scala.annotation.tailrec
import scala.jdk.CollectionConverters._
import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._

@clp(
description =
Expand Down Expand Up @@ -91,15 +88,13 @@ class AssessPhasing

// check the sequence dictionaries.
val dict = {
import com.fulcrumgenomics.fasta.Converters.FromSAMSequenceDictionary
val calledReader = VcfSource(calledVcf)
val truthReader = VcfSource(truthVcf)
calledReader.header.dict.sameAs(truthReader.header.dict)
calledReader.close()
truthReader.close()
calledReader.header.dict
}

val knownIntervalList = knownIntervals.map { intv => IntervalList.fromFile(intv.toFile).uniqued() }
val calledBlockLengthCounter = new NumericCounter[Long]()
val truthBlockLengthCounter = new NumericCounter[Long]()
Expand All @@ -125,7 +120,6 @@ class AssessPhasing

intervalList.getIntervals.map { i => i.getContig }.toSet
}

// NB: could parallelize!
dict
.iterator
Expand Down Expand Up @@ -279,7 +273,6 @@ class AssessPhasing
contig: String,
contigLength: Int,
intervalList: Option[IntervalList] = None): Iterator[Variant] = {
import com.fulcrumgenomics.fasta.Converters.FromSAMSequenceDictionary
val sampleName = reader.header.samples.head
val baseIter: Iterator[Variant] = intervalList match {
case Some(intv) => ByIntervalListVariantContextIterator(reader.iterator, intv, dict=reader.header.dict)
Expand Down Expand Up @@ -438,7 +431,6 @@ object PhaseBlock extends LazyLogging {
}
progress.record(ctx.getContig, ctx.getStart)
}

val blocksIn = new util.TreeSet[PhaseBlock](new Comparator[PhaseBlock] {
/** Compares the two blocks based on start position, then returns the shorter block. */
def compare(a: PhaseBlock, b: PhaseBlock): Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
package com.fulcrumgenomics.vcf

import java.util.NoSuchElementException

import com.fulcrumgenomics.FgBioDef._
import com.fulcrumgenomics.fasta.SequenceDictionary
import com.fulcrumgenomics.vcf.api.{Variant, VcfSource}
import htsjdk.samtools.util._
import htsjdk.variant.variantcontext.VariantContext
import htsjdk.variant.vcf.VCFFileReader

import scala.annotation.tailrec

Expand All @@ -42,9 +40,9 @@ object ByIntervalListVariantContextIterator {
*
* All variants will be read in and compared to the intervals.
*/
def apply(iterator: Iterator[VariantContext],
def apply(iterator: Iterator[Variant],
intervalList: IntervalList,
dict: SequenceDictionary): Iterator[VariantContext] = {
dict: SequenceDictionary): Iterator[Variant] = {
new OverlapDetectionVariantContextIterator(iterator, intervalList, dict)
}

Expand All @@ -56,29 +54,29 @@ object ByIntervalListVariantContextIterator {
* The VCF will be queried when moving to the next variant context, and so may be quite slow
* if we jump around the VCF a lot.
*/
def apply(reader: VCFFileReader,
intervalList: IntervalList): Iterator[VariantContext] = {
def apply(reader: VcfSource,
intervalList: IntervalList): Iterator[Variant] = {
new IndexQueryVariantContextIterator(reader, intervalList)
}
}

private class OverlapDetectionVariantContextIterator(val iter: Iterator[VariantContext],
private class OverlapDetectionVariantContextIterator(val iter: Iterator[Variant],
val intervalList: IntervalList,
val dict: SequenceDictionary)
extends Iterator[VariantContext] {
extends Iterator[Variant] {

require(dict != null)
private val intervals = intervalList.uniqued(false).iterator().buffered
private var nextVariantContext: Option[VariantContext] = None
private var nextVariant: Option[Variant] = None

this.advance()

def hasNext: Boolean = this.nextVariantContext.isDefined
def hasNext: Boolean = this.nextVariant.isDefined

def next(): VariantContext = {
this.nextVariantContext match {
def next(): Variant = {
this.nextVariant match {
case None => throw new NoSuchElementException("Called next when hasNext is false")
case Some(ctx) => yieldAndThen(ctx) { this.nextVariantContext = None; this.advance() }
case Some(ctx) => yieldAndThen(ctx) { this.nextVariant = None; this.advance() }
}
}

Expand Down Expand Up @@ -108,16 +106,16 @@ private class OverlapDetectionVariantContextIterator(val iter: Iterator[VariantC
}

if (intervals.isEmpty) { } // no more intervals
else if (overlaps(ctx, intervals.head)) { nextVariantContext = Some(ctx) } // overlaps!!!
else if (overlaps(ctx, intervals.head)) { nextVariant = Some(ctx) } // overlaps!!!
else if (iter.isEmpty) { } // no more variants
else { this.advance() } // move to the next context
}
}

/** NB: if a variant overlaps multiple intervals, only returns it once. */
private class IndexQueryVariantContextIterator(private val reader: VCFFileReader, intervalList: IntervalList)
extends Iterator[VariantContext] {
private var iter: Option[Iterator[VariantContext]] = None
private class IndexQueryVariantContextIterator(private val reader: VcfSource, intervalList: IntervalList)
extends Iterator[Variant] {
private var iter: Option[Iterator[Variant]] = None
private val intervals = intervalList.iterator()
private var previousInterval: Option[Interval] = None

Expand All @@ -127,7 +125,7 @@ private class IndexQueryVariantContextIterator(private val reader: VCFFileReader
this.iter.exists(_.hasNext)
}

def next(): VariantContext = {
def next(): Variant = {
this.iter match {
case None => throw new NoSuchElementException("Called next when hasNext is false")
case Some(i) => yieldAndThen(i.next()) { advance() }
Expand All @@ -136,10 +134,10 @@ private class IndexQueryVariantContextIterator(private val reader: VCFFileReader

def remove(): Unit = throw new UnsupportedOperationException

private def overlapsInterval(ctx: VariantContext, interval: Interval): Boolean = {
if (!ctx.getContig.equals(interval.getContig)) false // different contig
else if (interval.getStart <= ctx.getStart && ctx.getStart <= interval.getEnd) true // start falls within this interval, count it
else if (ctx.getStart < interval.getStart && interval.getEnd <= ctx.getEnd) true // the variant encloses the interval
private def overlapsInterval(variant: Variant, interval: Interval): Boolean = {
if (!variant.getContig.equals(interval.getContig)) false // different contig
else if (interval.getStart <= variant.getStart && variant.getStart <= interval.getEnd) true // start falls within this interval, count it
else if (variant.getStart < interval.getStart && interval.getEnd <= variant.getEnd) true // the variant encloses the interval
else false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@

package com.fulcrumgenomics.vcf

import com.fulcrumgenomics.FgBioDef._
import com.fulcrumgenomics.fasta.SequenceDictionary
import com.fulcrumgenomics.vcf.api.Variant
import htsjdk.variant.variantcontext.{VariantContext, VariantContextComparator}

object JointVariantContextIterator {
def apply(iters: Seq[Iterator[Variant]],
Expand All @@ -40,7 +38,7 @@ object JointVariantContextIterator {
}

def apply(iters: Seq[Iterator[Variant]],
comp: VariantContextComparator
comp: VariantComparator
): JointVariantContextIterator = {
new JointVariantContextIterator(
iters=iters,
Expand All @@ -54,16 +52,15 @@ object JointVariantContextIterator {
* across the iterators. If samples is given, we subset each variant context to just that sample.
*/
class JointVariantContextIterator private(iters: Seq[Iterator[Variant]],
dictOrComp: Either[SequenceDictionary, VariantContextComparator]
dictOrComp: Either[SequenceDictionary, VariantComparator]
)
extends Iterator[Seq[Option[VariantContext]]] {
import com.fulcrumgenomics.fasta.Converters.ToSAMSequenceDictionary
extends Iterator[Seq[Option[Variant]]] {

if (iters.isEmpty) throw new IllegalArgumentException("No iterators given")

private val iterators = iters.map(_.buffered)
private val comparator = dictOrComp match {
case Left(dict) => new VariantContextComparator(dict.asSam)
case Left(dict) => VariantComparator(dict)
case Right(comp) => comp
}

Expand All @@ -72,7 +69,7 @@ extends Iterator[Seq[Option[VariantContext]]] {
def next(): Seq[Option[Variant]] = {
val minCtx = iterators.filter(_.nonEmpty).map(_.head).sortWith {
// this just checks that the variants are the the same position? Shouldn't be difficult to replace.
case (left: Variant, right: Variant) => comparator.compare(VcfConversions.toJavaVariant(left), right) < 0
case (left: Variant, right: Variant) => comparator.compare(left, right) < 0
}.head
// TODO: could use a TreeSet to store the iterators, examine the head of each iterator, then pop the iterator with the min,
// and add that iterator back in.
Expand All @@ -82,3 +79,21 @@ extends Iterator[Seq[Option[VariantContext]]] {
}
}
}

private object VariantComparator {
def apply(dict: SequenceDictionary): VariantComparator = {
new VariantComparator(dict)
}
}

/** A class for comparing Variants using a sequence dictionary */
private class VariantComparator(dict: SequenceDictionary) {
/** Function for comparing two variants. Returns negative if left < right, and positive if right > left
* To mimic the VariantContextComparator, throws an exception if the contig isn't found. */
def compare(left: Variant, right: Variant): Int = {
val idx1 = this.dict(name=left.chrom).index
val idx2 = this.dict(name=right.chrom).index
if (idx1 - idx2 == 0) left.pos - right.pos
else idx1 - idx2
}
}
22 changes: 12 additions & 10 deletions src/test/scala/com/fulcrumgenomics/vcf/AssessPhasingTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import com.fulcrumgenomics.testing.VcfBuilder.Gt
import com.fulcrumgenomics.testing.{ErrorLogLevel, UnitSpec, VcfBuilder}
import com.fulcrumgenomics.util.Metric
import com.fulcrumgenomics.vcf.PhaseCigar.IlluminaSwitchErrors
import com.fulcrumgenomics.vcf.api.{Variant, VcfHeader, VcfSource, VcfWriter}
import com.fulcrumgenomics.vcf.api.{Variant, VcfCount, VcfFieldType, VcfFormatHeader, VcfHeader, VcfSource, VcfWriter}
import htsjdk.samtools.SAMFileHeader
import htsjdk.samtools.util.{Interval, IntervalList}
import htsjdk.variant.vcf.VCFFileReader
Expand Down Expand Up @@ -86,8 +86,8 @@ object AssessPhasingTest {
// NB: call has blocks lengths 4, 5, 1, and 13; truth has block lengths 4, 6, and 13.
}

val readBuilderCall: VcfSource = VcfSource(builderCall.toTempFile())
val readBuilderTruth: VcfSource = VcfSource(builderTruth.toTempFile())
val readBuilderCall: VcfSource = builderCall.toSource
val readBuilderTruth: VcfSource = builderTruth.toSource

private def addPhaseSetId(ctx: Variant): Variant = {
if (ctx.getStart <= 10) withPhasingSetId(ctx, 1)
Expand All @@ -113,7 +113,8 @@ object AssessPhasingTest {
}.toSeq
}

val Header = readBuilderCall.header
val Header = readBuilderCall.header.copy(
formats = readBuilderCall.header.formats :+ VcfFormatHeader("PS", VcfCount.Fixed(1), VcfFieldType.Integer, "Phase set"))
}

/**
Expand All @@ -124,11 +125,9 @@ class AssessPhasingTest extends ErrorLogLevel {

private def writeVariants(variants: Seq[Variant]): PathToVcf = {
val path = Files.createTempFile("AssessPhasingTest.", ".vcf.gz")
path.toFile.deleteOnExit()

val writer = VcfWriter(path, Header)

variants.foreach(writer.write)
writer.write(variants)
writer.close()
path
}

Expand All @@ -152,9 +151,11 @@ class AssessPhasingTest extends ErrorLogLevel {
callVariants: Seq[Variant] = CallVariants,
debugVcf: Boolean = false
): Unit = {
// input files
// input files )
val truthVcf = writeVariants(truthVariants)
val callVcf = writeVariants(callVariants)
println(truthVcf)
println(callVcf)

// output files
val output = makeTempFile("AssessPhasingTest.", "prefix")
Expand All @@ -164,7 +165,6 @@ class AssessPhasingTest extends ErrorLogLevel {

// run it
new AssessPhasing(truthVcf=truthVcf, calledVcf=callVcf, knownIntervals=intervals, output=output, debugVcf=debugVcf).execute()

// read the metrics
val phasingMetrics = Metric.read[AssessPhasingMetric](path=phasingMetricsPath)
val blockLengthMetrics = Metric.read[PhaseBlockLengthMetric](path=blockLengthMetricsPath)
Expand All @@ -178,6 +178,8 @@ class AssessPhasingTest extends ErrorLogLevel {
val expectedBlockLengthMetrics = Metric.read[PhaseBlockLengthMetric](path=expectedBlockLengthMetricsPath)

// compare the metrics
println(phasingMetrics)
println(expectedPhasingMetrics)
phasingMetrics should contain theSameElementsInOrderAs expectedPhasingMetrics
blockLengthMetrics should contain theSameElementsInOrderAs expectedBlockLengthMetrics

Expand Down
Loading

0 comments on commit 88f5148

Please sign in to comment.