diff --git a/README.md b/README.md index e19b85c..5b425a1 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,14 @@ Apply PRS to multiple files by using file patterns: ./pgs-calc --ref PGS000018.txt.gz test.chr*.vcf.gz --out scores.txt ``` +#### Multiple scores + +Apply multiple score files: + +``` +./pgs-calc --ref PGS000018.txt.gz,PGS000027.txt.gz test.chr*.vcf.gz --out scores.txt +``` + #### Filter by Imputation Qualitity diff --git a/pom.xml b/pom.xml index b798877..d8aad80 100644 --- a/pom.xml +++ b/pom.xml @@ -3,9 +3,9 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 - genepi + lukfor pgs-calc - 0.8.2 + 0.8.3 jar riskscore diff --git a/src/main/java/genepi/riskscore/App.java b/src/main/java/genepi/riskscore/App.java index 59a0516..379e252 100644 --- a/src/main/java/genepi/riskscore/App.java +++ b/src/main/java/genepi/riskscore/App.java @@ -8,7 +8,7 @@ public class App { public static final String APP = "pgs-calc"; - public static final String VERSION = "0.8.2"; + public static final String VERSION = "0.8.3"; public static final String URL = "https://github.com/lukfor/pgs-calc"; diff --git a/src/main/java/genepi/riskscore/commands/ApplyScoreCommand.java b/src/main/java/genepi/riskscore/commands/ApplyScoreCommand.java index 114cf84..b9b20f8 100644 --- a/src/main/java/genepi/riskscore/commands/ApplyScoreCommand.java +++ b/src/main/java/genepi/riskscore/commands/ApplyScoreCommand.java @@ -9,6 +9,7 @@ import genepi.riskscore.io.Chunk; import genepi.riskscore.io.OutputFile; import genepi.riskscore.model.RiskScoreFormat; +import genepi.riskscore.model.RiskScoreSummary; import genepi.riskscore.tasks.ApplyScoreTask; import htsjdk.samtools.util.StopWatch; import picocli.CommandLine.ArgGroup; @@ -80,23 +81,30 @@ public Integer call() throws Exception { System.out.println(); ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFilename(ref); + + String[] refs = parseRef(ref); + task.setRiskScoreFilenames(refs); + if (format != null) { + RiskScoreFormat riskScoreFormat = RiskScoreFormat.load(format); + for (String file : refs) { + task.setRiskScoreFormat(file, riskScoreFormat); + } + } else { + for (String file : refs) { + String autoFormat = file + ".format"; + if (new File(autoFormat).exists()) { + RiskScoreFormat riskScoreFormat = RiskScoreFormat.load(autoFormat); + task.setRiskScoreFormat(file, riskScoreFormat); + } + } + } + if (chunk != null) { task.setChunk(chunk); } task.setVcfFilenames(vcfs); task.setMinR2(minR2); task.setGenotypeFormat(genotypeFormat); - if (format != null) { - RiskScoreFormat riskScoreFormat = RiskScoreFormat.load(format); - task.setRiskScoreFormat(riskScoreFormat); - } else { - String autoFormat = ref + ".format"; - if (new File(autoFormat).exists()) { - RiskScoreFormat riskScoreFormat = RiskScoreFormat.load(autoFormat); - task.setRiskScoreFormat(riskScoreFormat); - } - } task.setOutputVariantFilename(outputVariantFilename); task.setIncludeVariantFilename(includeVariantFilename); @@ -104,7 +112,7 @@ public Integer call() throws Exception { watch.start(); task.run(); - OutputFile output = new OutputFile(task.getRiskScores()); + OutputFile output = new OutputFile(task.getRiskScores(), task.getSummaries()); output.save(out); watch.stop(); @@ -119,19 +127,11 @@ public Integer call() throws Exception { System.out.println(" - Samples: " + number(task.getCountSamples())); System.out.println(" - Variants: " + number(task.getCountVariants())); System.out.println(); - System.out.println(" Risk Score:"); - System.out.println(" - Variants: " + number(task.getCountVariantsRiskScore())); - System.out.println(" - Variants used: " + number(task.getCountVariantsUsed()) + " (" - + percentage(task.getCountVariantsUsed(), task.getCountVariantsRiskScore()) + ")"); - System.out.println(" - Found in target and filtered by: "); - System.out.println(" - allele mismatch: " + number(task.getCountVariantsAlleleMissmatch())); - System.out.println(" - multi allelic or indels: " + number(task.getCountVariantsMultiAllelic())); - System.out.println(" - low R2 value: " + number(task.getCountVariantsFilteredR2())); - System.out.println(" - variants file: " + number(task.getCountFiltered())); - - int notFound = task.getCountVariantsRiskScore() - - (task.getCountVariantsUsed() + task.getCountFiltered() + task.getCountVariantsAlleleMissmatch() - + task.getCountVariantsMultiAllelic() + task.getCountVariantsFilteredR2()); + + for (RiskScoreSummary summary : task.getSummaries()) { + System.out.println(summary); + System.out.println(); + } // System.out.println(" - Not found in target: " + number(notFound)); System.out.println(); @@ -141,6 +141,14 @@ public Integer call() throws Exception { } + private String[] parseRef(String ref) { + String[] refs = ref.split(","); + for (int i = 0; i < refs.length; i++) { + refs[i] = refs[i].trim(); + } + return refs; + } + public static String number(long number) { DecimalFormat formatter = new DecimalFormat("###,###,###"); return formatter.format(number); diff --git a/src/main/java/genepi/riskscore/io/OutputFile.java b/src/main/java/genepi/riskscore/io/OutputFile.java index 6f2e2a6..3545503 100644 --- a/src/main/java/genepi/riskscore/io/OutputFile.java +++ b/src/main/java/genepi/riskscore/io/OutputFile.java @@ -9,6 +9,7 @@ import genepi.io.table.writer.CsvTableWriter; import genepi.io.table.writer.ITableWriter; import genepi.riskscore.model.RiskScore; +import genepi.riskscore.model.RiskScoreSummary; public class OutputFile { @@ -28,15 +29,24 @@ public OutputFile() { } - public OutputFile(RiskScore[] finalScores) { - samples = new Vector(); - data = new Vector[1]; - data[0] = new Vector(); + public OutputFile(RiskScore[] finalScores, RiskScoreSummary[] summaries) { + scores = new Vector(); - scores.add(COLUMN_SCORE); + for (RiskScoreSummary summary : summaries) { + scores.add(summary.getName()); + } + + samples = new Vector(); + data = new Vector[scores.size()]; + for (int i = 0; i < scores.size(); i++) { + data[i] = new Vector(); + } + for (RiskScore riskScore : finalScores) { samples.add(riskScore.getSample()); - data[0].add(riskScore.getScore()); + for (int i = 0; i < scores.size(); i++) { + data[i].add(riskScore.getScore(i)); + } } } diff --git a/src/main/java/genepi/riskscore/io/PGSCatalog.java b/src/main/java/genepi/riskscore/io/PGSCatalog.java new file mode 100644 index 0000000..e8812de --- /dev/null +++ b/src/main/java/genepi/riskscore/io/PGSCatalog.java @@ -0,0 +1,45 @@ +package genepi.riskscore.io; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.text.MessageFormat; + +import genepi.io.FileUtil; + +public class PGSCatalog { + + public static String USER_HOME = System.getProperty("user.home"); + + public static String CACHE_DIR = FileUtil.path(USER_HOME, ".pgs-calc", "pgs-catalog"); + + public static String FILE_URL = "http://ftp.ebi.ac.uk/pub/databases/spot/pgs/scores/{0}/ScoringFiles/{0}.txt.gz"; + + public static String getFilenameById(String id) throws IOException { + + String filename = FileUtil.path(CACHE_DIR, id + ".txt.gz"); + + if ((new File(filename)).exists()) { + System.out.println("Score '" + id + "' found in local cache " + filename); + return filename; + } + + FileUtil.createDirectory(CACHE_DIR); + + MessageFormat format = new MessageFormat(FILE_URL); + String url = format.format(new Object[] { id }); + + System.out.println("Downloading score '" + id + "' from " + url + "..."); + + InputStream in = new URL(url).openStream(); + Files.copy(in, Paths.get(filename), StandardCopyOption.REPLACE_EXISTING); + + return filename; + + } + +} diff --git a/src/main/java/genepi/riskscore/io/RiskScoreFile.java b/src/main/java/genepi/riskscore/io/RiskScoreFile.java index aaae848..eba46d9 100644 --- a/src/main/java/genepi/riskscore/io/RiskScoreFile.java +++ b/src/main/java/genepi/riskscore/io/RiskScoreFile.java @@ -36,13 +36,22 @@ public RiskScoreFile(String filename, RiskScoreFormat format) throws Exception { variants = new HashMap(); if (!new File(filename).exists()) { - throw new Exception("File '" + filename + "' not found."); + + // check if its a PGS id + if (filename.startsWith("PGS") && filename.length() == 9 && !filename.endsWith(".txt.gz")) { + String id = filename; + this.filename = PGSCatalog.getFilenameById(id); + } else { + + throw new Exception("File '" + filename + "' not found."); + + } } - DataInputStream in = openTxtOrGzipStream(filename); + DataInputStream in = openTxtOrGzipStream(this.filename); ITableReader reader = new CsvTableReader(in, RiskScoreFormat.SEPARATOR); - checkFileFormat(reader, filename); + checkFileFormat(reader, this.filename); reader.close(); } @@ -70,7 +79,7 @@ private void checkFileFormat(ITableReader reader, String filename) throws Except } public void buildIndex(String chromosome) throws IOException { - buildIndex(chromosome, new Chunk()); + buildIndex(chromosome, new Chunk()); } public void buildIndex(String chromosome, Chunk chunk) throws IOException { diff --git a/src/main/java/genepi/riskscore/io/vcf/MinimalVariantContext.java b/src/main/java/genepi/riskscore/io/vcf/MinimalVariantContext.java index 081d02a..5921083 100644 --- a/src/main/java/genepi/riskscore/io/vcf/MinimalVariantContext.java +++ b/src/main/java/genepi/riskscore/io/vcf/MinimalVariantContext.java @@ -32,19 +32,21 @@ public class MinimalVariantContext { private boolean[] genotypes; - private String[] genotypesParsed; - + private float[] genotypesParsed; + private String id = null; private String genotype = null; private String info = null; - + private Map infos; + + private boolean dirtyGenotypes = true; public MinimalVariantContext(int samples) { genotypes = new boolean[samples]; - genotypesParsed = new String[samples]; + genotypesParsed = new float[samples]; } public int getHetCount() { @@ -133,6 +135,7 @@ public int getNSamples() { public void setRawLine(String rawLine) { this.rawLine = rawLine; this.id = null; + this.dirtyGenotypes = true; } public String getRawLine() { @@ -199,20 +202,20 @@ public String toString() { } return id; } - + public void setInfo(String info) { this.info = info; infos = null; } - + public String getInfo(String key) { - //lazy loadding + // lazy loadding if (infos == null) { infos = new HashMap(); String[] tiles = this.info.split(";"); for (int i = 0; i < tiles.length; i++) { String[] tiles2 = tiles[i].split("="); - if(tiles2.length == 2) { + if (tiles2.length == 2) { infos.put(tiles2[0], tiles2[1]); } } @@ -228,27 +231,49 @@ public double getInfoAsDouble(String key, double defaultValue) { return defaultValue; } } - - public String[] getGenotypes(String field) throws IOException { - String tiles[] = rawLine.split("\t", 10); - String[] formats = tiles[8].split(":"); - int index = -1; - for (int i = 0; i < formats.length; i++) { - if (formats[i].equals(field)) { - index = i; + + public float[] getGenotypeDosages(String field) throws IOException { + + if (dirtyGenotypes) { + + String tiles[] = rawLine.split("\t", 10); + String[] formats = tiles[8].split(":"); + int index = -1; + for (int i = 0; i < formats.length; i++) { + if (formats[i].equals(field)) { + index = i; + } } - } - - if (index == -1) { - throw new IOException("field '" + field + "' not found in FORMAT"); - } - String[] values = tiles[9].split("\t"); - for (int i = 0; i < genotypesParsed.length; i++) { - String value = values[i]; - String[] tiles2 = value.split(":"); - genotypesParsed[i] = tiles2[index]; + + if (index == -1) { + throw new IOException("field '" + field + "' not found in FORMAT"); + } + String[] values = tiles[9].split("\t"); + for (int i = 0; i < genotypesParsed.length; i++) { + String value = values[i]; + String[] tiles2 = value.split(":"); + String genotype = tiles2[index]; + + float dosage = 0; + // genotypes + if (genotype.equals("0|0")) { + dosage = 0; + } else if (genotype.equals("0|1") || genotype.equals("1|0")) { + dosage = 1; + } else if (genotype.equals("1|1")) { + dosage = 2; + } else { + // dosage + dosage = Float.parseFloat(genotype); + } + genotypesParsed[i] = dosage; + + } + + dirtyGenotypes = false; + } return genotypesParsed; } - + } diff --git a/src/main/java/genepi/riskscore/model/RiskScore.java b/src/main/java/genepi/riskscore/model/RiskScore.java index 25893b4..5add809 100644 --- a/src/main/java/genepi/riskscore/model/RiskScore.java +++ b/src/main/java/genepi/riskscore/model/RiskScore.java @@ -6,19 +6,20 @@ public class RiskScore { private String chromosome; - private float score = 0; + private float[] scores; - public RiskScore(String chromosome, String sample) { + public RiskScore(String chromosome, String sample, int numberOfScores) { this.chromosome = chromosome; this.sample = sample; + this.scores = new float[numberOfScores]; } - public void setScore(float score) { - this.score = score; + public void setScore(int index, float score) { + this.scores[index] = score; } - public float getScore() { - return score; + public float getScore(int index) { + return scores[index]; } public String getSample() { @@ -31,6 +32,6 @@ public String getChromosome() { @Override public String toString() { - return sample + ": " + score; + return sample + ": " + scores; } } diff --git a/src/main/java/genepi/riskscore/model/RiskScoreSummary.java b/src/main/java/genepi/riskscore/model/RiskScoreSummary.java new file mode 100644 index 0000000..c1f6fe3 --- /dev/null +++ b/src/main/java/genepi/riskscore/model/RiskScoreSummary.java @@ -0,0 +1,134 @@ +package genepi.riskscore.model; + +import java.text.DecimalFormat; + +public class RiskScoreSummary { + + private String name; + + private int variants = 0; + + private int variantsUsed = 0; + + private int variantsSwitched = 0; + + private int variantsMultiAllelic = 0; + + private int variantsAlleleMissmatch = 0; + + private int R2Filtered = 0; + + private int NotFound = 0; + + private int Filtered = 0; + + public RiskScoreSummary(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public int getVariantsUsed() { + return variantsUsed; + } + + public void incVariantsUsed() { + this.variantsUsed++; + } + + public int getSwitched() { + return variantsSwitched; + } + + public void incSwitched() { + this.variantsSwitched++; + } + + public int getMultiAllelic() { + return variantsMultiAllelic; + } + + public void incMultiAllelic() { + this.variantsMultiAllelic++; + } + + public int getAlleleMissmatch() { + return variantsAlleleMissmatch; + } + + public void incAlleleMissmatch() { + this.variantsAlleleMissmatch++; + } + + public int getR2Filtered() { + return R2Filtered; + } + + public void incR2Filtered() { + this.R2Filtered++; + } + + public int getVariants() { + return variants; + } + + public void setVariants(int count) { + this.variants = count; + } + + public int getNotFound() { + return NotFound; + } + + public void incNotFound() { + this.NotFound++; + } + + public int getFiltered() { + return Filtered; + } + + public void incFiltered() { + this.Filtered++; + } + + public int getVariantsNotUsed() { + return (variants - variantsUsed); + } + + @Override + public String toString() { + + StringBuffer buffer = new StringBuffer(); + + buffer.append(" " + name + ":\n"); + buffer.append(" - Variants: " + number(getVariants()) + "\n"); + buffer.append(" - Variants used: " + number(getVariantsUsed()) + " (" + + percentage(getVariantsUsed(), getVariants()) + ")\n"); + buffer.append(" - Found in target and filtered by:\n"); + buffer.append(" - allele mismatch: " + number(getAlleleMissmatch()) + "\n"); + buffer.append(" - multi allelic or indels: " + number(getMultiAllelic()) + "\n"); + buffer.append(" - low R2 value: " + number(getR2Filtered()) + "\n"); + buffer.append(" - variants file: " + number(getFiltered()) + "\n"); + + int notFound = getVariants() + - (getVariantsUsed() + getFiltered() + getAlleleMissmatch() + getMultiAllelic() + getR2Filtered()); + + return buffer.toString(); + + } + + public static String number(long number) { + DecimalFormat formatter = new DecimalFormat("###,###,###"); + return formatter.format(number); + } + + public static String percentage(double obtained, double total) { + double percentage = (obtained / total) * 100; + DecimalFormat df = new DecimalFormat("###.##'%'"); + return df.format(percentage); + } + +} diff --git a/src/main/java/genepi/riskscore/tasks/ApplyScoreTask.java b/src/main/java/genepi/riskscore/tasks/ApplyScoreTask.java index 5ef77dd..fabfb03 100644 --- a/src/main/java/genepi/riskscore/tasks/ApplyScoreTask.java +++ b/src/main/java/genepi/riskscore/tasks/ApplyScoreTask.java @@ -1,7 +1,9 @@ package genepi.riskscore.tasks; import java.io.IOException; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Vector; import genepi.io.table.writer.CsvTableWriter; @@ -14,6 +16,7 @@ import genepi.riskscore.model.ReferenceVariant; import genepi.riskscore.model.RiskScore; import genepi.riskscore.model.RiskScoreFormat; +import genepi.riskscore.model.RiskScoreSummary; public class ApplyScoreTask { @@ -21,30 +24,12 @@ public class ApplyScoreTask { private List vcfs = null; - private String riskScoreFilename = null; + private String riskScoreFilenames[] = null; private int countSamples = 0; private int countVariants = 0; - private int countVariantsUsed = 0; - - private int countVariantsSwitched = 0; - - private int countVariantsMultiAllelic = 0; - - private int countVariantsNotUsed = 0; - - private int countVariantsAlleleMissmatch = 0; - - private int countR2Filtered = 0; - - private int countVariantsRiskScore = 0; - - private int countNotFound = 0; - - private int countFiltered = 0; - private Chunk chunk = null; private float minR2 = 0; @@ -55,16 +40,25 @@ public class ApplyScoreTask { private CsvTableWriter variantFile; - private RiskScoreFormat format = new PGSCatalogFormat(); + private RiskScoreFormat defaultFormat = new PGSCatalogFormat(); + + private Map formats = new HashMap(); private String genotypeFormat = DOSAGE_FORMAT; + private int numberRiskScores = 0; + + private RiskScoreSummary[] summaries; + public static final String INFO_R2 = "R2"; public static final String DOSAGE_FORMAT = "DS"; - public void setRiskScoreFilename(String filename) { - this.riskScoreFilename = filename; + public void setRiskScoreFilenames(String... filenames) { + this.riskScoreFilenames = filenames; + for (String filename : filenames) { + formats.put(filename, defaultFormat); + } } public void setChunk(Chunk chunk) { @@ -100,8 +94,8 @@ public void run() throws Exception { throw new Exception("Please specify at leat one vcf file."); } - if (riskScoreFilename == null) { - throw new Exception("Reference can not be null."); + if (riskScoreFilenames == null || riskScoreFilenames.length == 0) { + throw new Exception("Reference can not be null or empty."); } long start = System.currentTimeMillis(); @@ -111,23 +105,31 @@ public void run() throws Exception { variantFile.setColumns(new String[] { VariantFile.CHROMOSOME, VariantFile.POSITION }); } + numberRiskScores = riskScoreFilenames.length; + summaries = new RiskScoreSummary[numberRiskScores]; + for (int i = 0; i < numberRiskScores; i++) { + if (i == 0) { + summaries[i] = new RiskScoreSummary("score"); + } else { + summaries[i] = new RiskScoreSummary("score_" + i); + } + } + for (String vcfFilename : vcfs) { - processVCF(vcfFilename, riskScoreFilename); + processVCF(vcfFilename, riskScoreFilenames); } if (variantFile != null) { variantFile.close(); } - countVariantsNotUsed = (countVariantsRiskScore - countVariantsUsed); - long end = System.currentTimeMillis(); System.out.println("Execution Time: " + ((end - start) / 1000.0 / 60.0) + " min"); } - private void processVCF(String vcfFilename, String riskScoreFilename) throws Exception { + private void processVCF(String vcfFilename, String... riskScoreFilenames) throws Exception { // read chromosome from first variant String chromosome = null; @@ -147,30 +149,37 @@ private void processVCF(String vcfFilename, String riskScoreFilename) throws Exc includeVariants.buildIndex(chromosome); System.out.println("Loaded " + includeVariants.getCacheSize() + " variants for chromosome " + chromosome); } + RiskScoreFile[] riskscores = new RiskScoreFile[numberRiskScores]; + for (int i = 0; i < numberRiskScores; i++) { + + System.out.println("Loading file " + riskScoreFilenames[i] + "..."); + + RiskScoreFormat format = formats.get(riskScoreFilenames[i]); + RiskScoreFile riskscore = new RiskScoreFile(riskScoreFilenames[i], format); + + if (chunk != null) { + riskscore.buildIndex(chromosome, chunk); + } else { + riskscore.buildIndex(chromosome); + } - RiskScoreFile riskscore = new RiskScoreFile(riskScoreFilename, format); - System.out.println("Loading file " + riskScoreFilename + "..."); + summaries[i].setVariants(riskscore.getTotalVariants()); - if (chunk != null) { - riskscore.buildIndex(chromosome, chunk); - } else { - riskscore.buildIndex(chromosome); + System.out.println("Loaded " + riskscore.getCacheSize() + " weights for chromosome " + chromosome); + riskscores[i] = riskscore; } - if (countVariantsRiskScore == 0) { - countVariantsRiskScore = riskscore.getTotalVariants(); - } - System.out.println("Loaded " + riskscore.getCacheSize() + " weights for chromosome " + chromosome); - System.out.println("Loading file " + vcfFilename + "..."); - + vcfReader = new FastVCFFileReader(vcfFilename); countSamples = vcfReader.getGenotypedSamples().size(); if (riskScores == null) { riskScores = new RiskScore[countSamples]; for (int i = 0; i < countSamples; i++) { - riskScores[i] = new RiskScore(chromosome, vcfReader.getGenotypedSamples().get(i)); + riskScores[i] = new RiskScore(chromosome, vcfReader.getGenotypedSamples().get(i), + riskScoreFilenames.length); + } } else { if (riskScores.length != countSamples) { @@ -193,88 +202,83 @@ private void processVCF(String vcfFilename, String riskScoreFilename) throws Exc int position = variant.getStart(); - boolean isPartOfRiskScore = riskscore.contains(position); + for (int j = 0; j < riskScoreFilenames.length; j++) { - if (!isPartOfRiskScore) { - countNotFound++; - continue; - } + RiskScoreSummary summary = summaries[j]; - if (includeVariants != null) { - if (!includeVariants.contains(position)) { - countFiltered++; + RiskScoreFile riskscore = riskscores[j]; + boolean isPartOfRiskScore = riskscore.contains(position); + + if (!isPartOfRiskScore) { + summary.incNotFound(); continue; } - } - // Imputation Quality Filter - double r2 = variant.getInfoAsDouble(INFO_R2, 0); - if (r2 < minR2) { - countR2Filtered++; - continue; - } + if (includeVariants != null) { + if (!includeVariants.contains(position)) { + summary.incFiltered(); + continue; + } + } - ReferenceVariant referenceVariant = riskscore.getVariant(position); + // Imputation Quality Filter + double r2 = variant.getInfoAsDouble(INFO_R2, 0); + if (r2 < minR2) { + summary.incR2Filtered(); + continue; + } - if (variant.isComplexIndel()) { - countVariantsMultiAllelic++; - continue; - } + ReferenceVariant referenceVariant = riskscore.getVariant(position); + + if (variant.isComplexIndel()) { + summary.incMultiAllelic(); + continue; + } - float effectWeight = -referenceVariant.getEffectWeight(); + float effectWeight = -referenceVariant.getEffectWeight(); - char referenceAllele = variant.getReferenceAllele().charAt(0); + char referenceAllele = variant.getReferenceAllele().charAt(0); - // ignore deletions - if (variant.getAlternateAllele().length() == 0) { - countVariantsMultiAllelic++; - continue; - } + // ignore deletions + if (variant.getAlternateAllele().length() == 0) { + summary.incMultiAllelic(); + continue; + } - char alternateAllele = variant.getAlternateAllele().charAt(0); + char alternateAllele = variant.getAlternateAllele().charAt(0); - if (!referenceVariant.hasAllele(referenceAllele) || !referenceVariant.hasAllele(alternateAllele)) { - countVariantsAlleleMissmatch++; - continue; - } + if (!referenceVariant.hasAllele(referenceAllele) || !referenceVariant.hasAllele(alternateAllele)) { + summary.incAlleleMissmatch(); + continue; + } - if (!referenceVariant.isEffectAllele(referenceAllele)) { - effectWeight = -effectWeight; - countVariantsSwitched++; - } + if (!referenceVariant.isEffectAllele(referenceAllele)) { + effectWeight = -effectWeight; + summary.incSwitched(); + } - if (variantFile != null) { - variantFile.setString(VariantFile.CHROMOSOME, variant.getContig()); - variantFile.setInteger(VariantFile.POSITION, variant.getStart()); - variantFile.next(); - } + if (variantFile != null) { + variantFile.setString(VariantFile.CHROMOSOME, variant.getContig()); + variantFile.setInteger(VariantFile.POSITION, variant.getStart()); + variantFile.next(); + } - String[] values = variant.getGenotypes(genotypeFormat); + float[] dosages = variant.getGenotypeDosages(genotypeFormat); - for (int i = 0; i < countSamples; i++) { - float dosage = 0; - // genotypes - if (values[i].equals("0|0")) { - dosage = 0; - } else if (values[i].equals("0|1") || values[i].equals("1|0")) { - dosage = 1; - } else if (values[i].equals("1|1")) { - dosage = 2; - } else { - // dosage - dosage = Float.parseFloat(values[i]); + for (int i = 0; i < countSamples; i++) { + float dosage = dosages[i]; + float score = riskScores[i].getScore(j) + (dosage * effectWeight); + riskScores[i].setScore(j, score); } - float score = riskScores[i].getScore() + (dosage * effectWeight); - riskScores[i].setScore(score); - } - countVariantsUsed++; + summary.incVariantsUsed(); + } } vcfReader.close(); - System.out.println("Loaded " + getRiskScores().length + " samples and " + getCountVariants() + " variants."); + System.out.println("Loaded " + getRiskScores().length + " samples and " + countVariants + " variants."); } @@ -282,8 +286,17 @@ public void setMinR2(float minR2) { this.minR2 = minR2; } - public void setRiskScoreFormat(RiskScoreFormat format) { - this.format = format; + public void setDefaultRiskScoreFormat(RiskScoreFormat defaultFormat) { + this.defaultFormat = defaultFormat; + if (riskScoreFilenames != null) { + for (String file : riskScoreFilenames) { + setRiskScoreFormat(file, defaultFormat); + } + } + } + + public void setRiskScoreFormat(String file, RiskScoreFormat format) { + this.formats.put(file, format); } public int getCountSamples() { @@ -294,44 +307,11 @@ public RiskScore[] getRiskScores() { return riskScores; } - public int getCountVariants() { - return countVariants; - } - - public int getCountVariantsUsed() { - return countVariantsUsed; + public RiskScoreSummary[] getSummaries() { + return summaries; } - public int getCountVariantsSwitched() { - return countVariantsSwitched; - } - - public int getCountVariantsMultiAllelic() { - return countVariantsMultiAllelic; - } - - public int getCountVariantsNotUsed() { - return countVariantsNotUsed; - } - - public int getCountVariantsAlleleMissmatch() { - return countVariantsAlleleMissmatch; - } - - public int getCountVariantsFilteredR2() { - return countR2Filtered; - } - - public int getCountVariantsRiskScore() { - return countVariantsRiskScore; - } - - public int getCountVariantsNotFound() { - return countNotFound; - } - - public int getCountFiltered() { - return countFiltered; + public int getCountVariants() { + return countVariants; } - } diff --git a/src/test/java/genepi/riskscore/commands/ApplyScoreCommandTest.java b/src/test/java/genepi/riskscore/commands/ApplyScoreCommandTest.java index f2fb925..2541b56 100644 --- a/src/test/java/genepi/riskscore/commands/ApplyScoreCommandTest.java +++ b/src/test/java/genepi/riskscore/commands/ApplyScoreCommandTest.java @@ -28,7 +28,65 @@ public void testCall() { } assertEquals(EXPECTED_SAMPLES, samples); + reader.close(); + } + + + @Test + public void testCallWithPGSID() { + + String[] args = { "test-data/chr20.dose.vcf.gz", "--ref", "PGS000028", "--out", "output.csv" }; + int result = new CommandLine(new ApplyScoreCommand()).execute(args); + assertEquals(0, result); + + int samples = 0; + ITableReader reader = new CsvTableReader("output.csv", ','); + while (reader.next()) { + samples++; + + } + assertEquals(EXPECTED_SAMPLES, samples); + reader.close(); + } + + + @Test + public void testCallWithMultipleScores() throws IOException { + + String[] args = { "test-data/test.chr1.vcf", "test-data/test.chr2.vcf", "--ref", + "test-data/test.scores.csv,test-data/test.scores.csv,test-data/test.scores.csv", "--out", + "output.csv" }; + int result = new CommandLine(new ApplyScoreCommand()).execute(args); + assertEquals(0, result); + + int samples = 0; + ITableReader reader = new CsvTableReader("output.csv", ','); + while (reader.next()) { + samples++; + + } + assertEquals(4, reader.getColumns().length); + assertEquals(2, samples); + reader.close(); + + reader = new CsvTableReader("output.csv", ','); + + assertEquals(true, reader.next()); + + double score = reader.getDouble("score"); + String sample = reader.getString("sample"); + assertEquals("LF001", sample); + assertEquals(-(1 + 3), score, 0.0000001); + + assertEquals(true, reader.next()); + score = reader.getDouble("score"); + sample = reader.getString("sample"); + assertEquals("LF002", sample); + assertEquals(-(3 + 7), score, 0.0000001); + + assertEquals(false, reader.next()); + reader.close(); } @Test @@ -122,11 +180,12 @@ public void testCallWithMissingVcf() { assertEquals(1, result); } - + @Test public void testCallWithChunk() { - String[] args = { "test-data/chr20.dose.vcf.gz", "--ref", "test-data/chr20.scores.csv", "--out", "output.csv", "--start", "61795", "--end", "63231" }; + String[] args = { "test-data/chr20.dose.vcf.gz", "--ref", "test-data/chr20.scores.csv", "--out", "output.csv", + "--start", "61795", "--end", "63231" }; int result = new CommandLine(new ApplyScoreCommand()).execute(args); assertEquals(0, result); @@ -137,16 +196,17 @@ public void testCallWithChunk() { } assertEquals(EXPECTED_SAMPLES, samples); - + reader.close(); } - + @Test public void testCallWithStartOnly() { - String[] args = { "test-data/chr20.dose.vcf.gz", "--ref", "test-data/chr20.scores.csv", "--out", "output.csv", "--start", "61795"}; + String[] args = { "test-data/chr20.dose.vcf.gz", "--ref", "test-data/chr20.scores.csv", "--out", "output.csv", + "--start", "61795" }; int result = new CommandLine(new ApplyScoreCommand()).execute(args); assertEquals(2, result); - + } } diff --git a/src/test/java/genepi/riskscore/tasks/ApplyScoreTaskTest.java b/src/test/java/genepi/riskscore/tasks/ApplyScoreTaskTest.java index 29ffead..f3c2c74 100644 --- a/src/test/java/genepi/riskscore/tasks/ApplyScoreTaskTest.java +++ b/src/test/java/genepi/riskscore/tasks/ApplyScoreTaskTest.java @@ -8,6 +8,7 @@ import genepi.riskscore.io.VariantFile; import genepi.riskscore.model.RiskScore; import genepi.riskscore.model.RiskScoreFormat; +import genepi.riskscore.model.RiskScoreSummary; public class ApplyScoreTaskTest { @@ -17,17 +18,19 @@ public class ApplyScoreTaskTest { public void testPerformance() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/chr20.dose.vcf.gz"); - task.setRiskScoreFilename("test-data/chr20.scores.csv"); + task.setRiskScoreFilenames("test-data/chr20.scores.csv"); task.run(); assertEquals(63480, task.getCountVariants()); - assertEquals(3, task.getCountVariantsUsed()); - assertEquals(1, task.getCountVariantsSwitched()); - assertEquals(1, task.getCountVariantsNotUsed()); - assertEquals(0, task.getCountVariantsMultiAllelic()); - assertEquals(0, task.getCountVariantsAlleleMissmatch()); + + RiskScoreSummary summary = task.getSummaries()[0]; + assertEquals(3, summary.getVariantsUsed()); + assertEquals(1, summary.getSwitched()); + assertEquals(1, summary.getVariantsNotUsed()); + assertEquals(0, summary.getMultiAllelic()); + assertEquals(0, summary.getAlleleMissmatch()); assertEquals(EXPECTED_SAMPLES, task.getCountSamples()); } @@ -36,40 +39,68 @@ public void testPerformance() throws Exception { public void testMultiPostion() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/small.vcf"); - task.setRiskScoreFilename("test-data/chr20.scores.csv"); + task.setRiskScoreFilenames("test-data/chr20.scores.csv"); task.run(); assertEquals(4, task.getCountVariants()); - assertEquals(3, task.getCountVariantsUsed()); - assertEquals(1, task.getCountVariantsSwitched()); - assertEquals(1, task.getCountVariantsNotUsed()); - assertEquals(0, task.getCountVariantsMultiAllelic()); - assertEquals(1, task.getCountVariantsAlleleMissmatch()); + + RiskScoreSummary summary = task.getSummaries()[0]; + assertEquals(3, summary.getVariantsUsed()); + assertEquals(1, summary.getSwitched()); + assertEquals(1, summary.getVariantsNotUsed()); + assertEquals(0, summary.getMultiAllelic()); + assertEquals(1, summary.getAlleleMissmatch()); assertEquals(EXPECTED_SAMPLES, task.getCountSamples()); assertEquals(EXPECTED_SAMPLES, task.getRiskScores().length); } + @Test + public void testMultipleScores() throws Exception { + + ApplyScoreTask task = new ApplyScoreTask(); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); + task.setVcfFilenames("test-data/chr20.dose.vcf.gz"); + task.setRiskScoreFilenames("test-data/chr20.scores.csv", "test-data/chr20.scores.csv", + "test-data/chr20.scores.csv"); + task.run(); + + assertEquals(63480, task.getCountVariants()); + + assertEquals(3, task.getSummaries().length); + + RiskScoreSummary summary = task.getSummaries()[0]; + assertEquals(3, summary.getVariantsUsed()); + assertEquals(1, summary.getSwitched()); + assertEquals(1, summary.getVariantsNotUsed()); + assertEquals(0, summary.getMultiAllelic()); + assertEquals(0, summary.getAlleleMissmatch()); + assertEquals(EXPECTED_SAMPLES, task.getCountSamples()); + + } + @Test public void testScore() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/single.vcf"); - task.setRiskScoreFilename("test-data/chr20.scores.csv"); + task.setRiskScoreFilenames("test-data/chr20.scores.csv"); task.run(); assertEquals(5, task.getCountVariants()); - assertEquals(3, task.getCountVariantsUsed()); - assertEquals(1, task.getCountVariantsSwitched()); - assertEquals(1, task.getCountVariantsNotUsed()); - assertEquals(0, task.getCountVariantsMultiAllelic()); - assertEquals(1, task.getCountVariantsAlleleMissmatch()); + + RiskScoreSummary summary = task.getSummaries()[0]; + assertEquals(3, summary.getVariantsUsed()); + assertEquals(1, summary.getSwitched()); + assertEquals(1, summary.getVariantsNotUsed()); + assertEquals(0, summary.getMultiAllelic()); + assertEquals(1, summary.getAlleleMissmatch()); assertEquals(1, task.getCountSamples()); assertEquals(1, task.getRiskScores().length); assertEquals("LF001", task.getRiskScores()[0].getSample()); - assertEquals(-0.4, task.getRiskScores()[0].getScore(), 0.00001); + assertEquals(-0.4, task.getRiskScores()[0].getScore(0), 0.00001); } @@ -77,44 +108,48 @@ public void testScore() throws Exception { public void testScoreSwitchEffectAllele() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/single.vcf"); - task.setRiskScoreFilename("test-data/chr20.scores.2.csv"); + task.setRiskScoreFilenames("test-data/chr20.scores.2.csv"); task.run(); assertEquals(5, task.getCountVariants()); - assertEquals(3, task.getCountVariantsUsed()); - assertEquals(0, task.getCountVariantsSwitched()); - assertEquals(1, task.getCountVariantsNotUsed()); - assertEquals(0, task.getCountVariantsMultiAllelic()); - assertEquals(1, task.getCountVariantsAlleleMissmatch()); + + RiskScoreSummary summary = task.getSummaries()[0]; + assertEquals(3, summary.getVariantsUsed()); + assertEquals(0, summary.getSwitched()); + assertEquals(1, summary.getVariantsNotUsed()); + assertEquals(0, summary.getMultiAllelic()); + assertEquals(1, summary.getAlleleMissmatch()); assertEquals(1, task.getCountSamples()); assertEquals(1, task.getRiskScores().length); assertEquals("LF001", task.getRiskScores()[0].getSample()); - assertEquals(-0.6, task.getRiskScores()[0].getScore(), 0.00001); + assertEquals(-0.6, task.getRiskScores()[0].getScore(0), 0.00001); } @Test public void testMinR2_06() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/two.vcf"); - task.setRiskScoreFilename("test-data/chr20.scores.csv"); + task.setRiskScoreFilenames("test-data/chr20.scores.csv"); task.setMinR2(0.6f); task.run(); assertEquals(5, task.getCountVariants()); - assertEquals(1, task.getCountVariantsUsed()); - assertEquals(0, task.getCountVariantsSwitched()); - assertEquals(3, task.getCountVariantsNotUsed()); - assertEquals(3, task.getCountVariantsFilteredR2()); - assertEquals(0, task.getCountVariantsMultiAllelic()); - assertEquals(0, task.getCountVariantsAlleleMissmatch()); + + RiskScoreSummary summary = task.getSummaries()[0]; + assertEquals(1, summary.getVariantsUsed()); + assertEquals(0, summary.getSwitched()); + assertEquals(3, summary.getVariantsNotUsed()); + assertEquals(3, summary.getR2Filtered()); + assertEquals(0, summary.getMultiAllelic()); + assertEquals(0, summary.getAlleleMissmatch()); assertEquals(2, task.getCountSamples()); assertEquals(2, task.getRiskScores().length); assertEquals("LF001", task.getRiskScores()[0].getSample()); - assertEquals(-0.2, task.getRiskScores()[0].getScore(), 0.00001); + assertEquals(-0.2, task.getRiskScores()[0].getScore(0), 0.00001); } @@ -122,23 +157,25 @@ public void testMinR2_06() throws Exception { public void testMinR2_05() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/two.vcf"); - task.setRiskScoreFilename("test-data/chr20.scores.2.csv"); + task.setRiskScoreFilenames("test-data/chr20.scores.2.csv"); task.setMinR2(0.5f); task.run(); assertEquals(5, task.getCountVariants()); - assertEquals(2, task.getCountVariantsUsed()); - assertEquals(0, task.getCountVariantsSwitched()); - assertEquals(2, task.getCountVariantsNotUsed()); - assertEquals(0, task.getCountVariantsMultiAllelic()); - assertEquals(0, task.getCountVariantsAlleleMissmatch()); + + RiskScoreSummary summary = task.getSummaries()[0]; + assertEquals(2, summary.getVariantsUsed()); + assertEquals(0, summary.getSwitched()); + assertEquals(2, summary.getVariantsNotUsed()); + assertEquals(0, summary.getMultiAllelic()); + assertEquals(0, summary.getAlleleMissmatch()); assertEquals(2, task.getCountSamples()); - assertEquals(2, task.getCountVariantsFilteredR2()); + assertEquals(2, summary.getR2Filtered()); assertEquals(2, task.getRiskScores().length); assertEquals("LF001", task.getRiskScores()[0].getSample()); - assertEquals(-0.3, task.getRiskScores()[0].getScore(), 0.00001); + assertEquals(-0.3, task.getRiskScores()[0].getScore(0), 0.00001); } @@ -146,23 +183,25 @@ public void testMinR2_05() throws Exception { public void testMinR2_1() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/two.vcf"); - task.setRiskScoreFilename("test-data/chr20.scores.2.csv"); + task.setRiskScoreFilenames("test-data/chr20.scores.2.csv"); task.setMinR2(1f); task.run(); assertEquals(5, task.getCountVariants()); - assertEquals(0, task.getCountVariantsUsed()); - assertEquals(0, task.getCountVariantsSwitched()); - assertEquals(4, task.getCountVariantsNotUsed()); - assertEquals(0, task.getCountVariantsMultiAllelic()); - assertEquals(0, task.getCountVariantsAlleleMissmatch()); + + RiskScoreSummary summary = task.getSummaries()[0]; + assertEquals(0, summary.getVariantsUsed()); + assertEquals(0, summary.getSwitched()); + assertEquals(4, summary.getVariantsNotUsed()); + assertEquals(0, summary.getMultiAllelic()); + assertEquals(0, summary.getAlleleMissmatch()); assertEquals(2, task.getCountSamples()); - assertEquals(4, task.getCountVariantsFilteredR2()); + assertEquals(4, summary.getR2Filtered()); assertEquals(2, task.getRiskScores().length); assertEquals("LF001", task.getRiskScores()[0].getSample()); - assertEquals(0.0, task.getRiskScores()[0].getScore(), 0.00001); + assertEquals(0.0, task.getRiskScores()[0].getScore(0), 0.00001); } @@ -170,30 +209,32 @@ public void testMinR2_1() throws Exception { public void testMultipleFiles() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/test.chr1.vcf", "test-data/test.chr2.vcf"); - task.setRiskScoreFilename("test-data/test.scores.csv"); + task.setRiskScoreFilenames("test-data/test.scores.csv"); task.run(); assertEquals(10, task.getCountVariants()); - assertEquals(11, task.getCountVariantsRiskScore()); - assertEquals(7, task.getCountVariantsUsed()); - assertEquals(4, task.getCountVariantsNotUsed()); - assertEquals(0, task.getCountVariantsSwitched()); - assertEquals(0, task.getCountVariantsFilteredR2()); - assertEquals(0, task.getCountVariantsMultiAllelic()); - assertEquals(0, task.getCountVariantsAlleleMissmatch()); + + RiskScoreSummary summary = task.getSummaries()[0]; + assertEquals(11, summary.getVariants()); + assertEquals(7, summary.getVariantsUsed()); + assertEquals(4, summary.getVariantsNotUsed()); + assertEquals(0, summary.getSwitched()); + assertEquals(0, summary.getR2Filtered()); + assertEquals(0, summary.getMultiAllelic()); + assertEquals(0, summary.getAlleleMissmatch()); assertEquals(2, task.getCountSamples()); assertEquals(2, task.getRiskScores().length); RiskScore first = task.getRiskScores()[0]; assertEquals("LF001", first.getSample()); - assertEquals(-(1 + 3), first.getScore(), 0.0000001); + assertEquals(-(1 + 3), first.getScore(0), 0.0000001); RiskScore second = task.getRiskScores()[1]; assertEquals("LF002", second.getSample()); - assertEquals(-(3 + 7), second.getScore(), 0.0000001); + assertEquals(-(3 + 7), second.getScore(0), 0.0000001); } @@ -201,15 +242,17 @@ public void testMultipleFiles() throws Exception { public void testWriteVariantFile() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/test.chr1.vcf", "test-data/test.chr2.vcf"); - task.setRiskScoreFilename("test-data/test.scores.csv"); + task.setRiskScoreFilenames("test-data/test.scores.csv"); task.setOutputVariantFilename("variants.txt"); task.run(); assertEquals(10, task.getCountVariants()); - assertEquals(11, task.getCountVariantsRiskScore()); - assertEquals(7, task.getCountVariantsUsed()); + + RiskScoreSummary summary = task.getSummaries()[0]; + assertEquals(11, summary.getVariants()); + assertEquals(7, summary.getVariantsUsed()); VariantFile variants = new VariantFile("variants.txt"); variants.buildIndex("1"); @@ -221,24 +264,26 @@ public void testWriteVariantFile() throws Exception { assertEquals(3, variants.getCacheSize()); } - + @Test public void testReadVariantsFile() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/test.chr1.vcf", "test-data/test.chr2.vcf"); - task.setRiskScoreFilename("test-data/test.scores.csv"); + task.setRiskScoreFilenames("test-data/test.scores.csv"); task.setIncludeVariantFilename("test-data/variants.txt"); task.run(); assertEquals(10, task.getCountVariants()); - assertEquals(11, task.getCountVariantsRiskScore()); - assertEquals(5, task.getCountVariantsUsed()); - assertEquals(0, task.getCountVariantsSwitched()); - assertEquals(0, task.getCountVariantsFilteredR2()); - assertEquals(0, task.getCountVariantsMultiAllelic()); - assertEquals(0, task.getCountVariantsAlleleMissmatch()); + + RiskScoreSummary summary = task.getSummaries()[0]; + assertEquals(11, summary.getVariants()); + assertEquals(5, summary.getVariantsUsed()); + assertEquals(0, summary.getSwitched()); + assertEquals(0, summary.getR2Filtered()); + assertEquals(0, summary.getMultiAllelic()); + assertEquals(0, summary.getAlleleMissmatch()); assertEquals(2, task.getCountSamples()); assertEquals(2, task.getRiskScores().length); @@ -248,9 +293,9 @@ public void testReadVariantsFile() throws Exception { public void testWrongChromosome() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/single.wrong_chr.vcf"); - task.setRiskScoreFilename("test-data/chr20.scores.2.csv"); + task.setRiskScoreFilenames("test-data/chr20.scores.2.csv"); task.setMinR2(1f); task.run(); @@ -260,9 +305,9 @@ public void testWrongChromosome() throws Exception { public void testDifferentSamples() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/test.chr1.vcf", "test-data/test.chr2.wrong.vcf"); - task.setRiskScoreFilename("test-data/test.scores.csv"); + task.setRiskScoreFilenames("test-data/test.scores.csv"); task.setMinR2(1f); task.run(); @@ -272,9 +317,9 @@ public void testDifferentSamples() throws Exception { public void testWithChunk() throws Exception { ApplyScoreTask task = new ApplyScoreTask(); - task.setRiskScoreFormat(new RiskScoreFormat()); + task.setDefaultRiskScoreFormat(new RiskScoreFormat()); task.setVcfFilenames("test-data/chr20.dose.vcf.gz"); - task.setRiskScoreFilename("test-data/chr20.scores.csv"); + task.setRiskScoreFilenames("test-data/chr20.scores.csv"); Chunk chunk = new Chunk(); chunk.setStart(61795); chunk.setEnd(63231); @@ -282,14 +327,15 @@ public void testWithChunk() throws Exception { task.run(); assertEquals(63480, task.getCountVariants()); - assertEquals(2, task.getCountVariantsUsed()); - assertEquals(1, task.getCountVariantsSwitched()); - assertEquals(2, task.getCountVariantsNotUsed()); - assertEquals(0, task.getCountVariantsMultiAllelic()); - assertEquals(0, task.getCountVariantsAlleleMissmatch()); + + RiskScoreSummary summary = task.getSummaries()[0]; + assertEquals(2, summary.getVariantsUsed()); + assertEquals(1, summary.getSwitched()); + assertEquals(2, summary.getVariantsNotUsed()); + assertEquals(0, summary.getMultiAllelic()); + assertEquals(0, summary.getAlleleMissmatch()); assertEquals(EXPECTED_SAMPLES, task.getCountSamples()); } - }