diff --git a/.gitignore b/.gitignore index 8147a833..d44ba574 100644 --- a/.gitignore +++ b/.gitignore @@ -59,3 +59,4 @@ variant-spark_2.11.iml build dist _build +spark-warehouse diff --git a/dev/dev-requirements.txt b/dev/dev-requirements.txt index a7ec07a9..d12de827 100644 --- a/dev/dev-requirements.txt +++ b/dev/dev-requirements.txt @@ -1,3 +1,4 @@ Sphinx>=1.3.5 pylint==1.8.1 decorator>=4.1.2 +typedecorator>=0.0.5 diff --git a/python/MANIFEST.in b/python/MANIFEST.in new file mode 100644 index 00000000..9bc65829 --- /dev/null +++ b/python/MANIFEST.in @@ -0,0 +1,3 @@ +global-exclude *.py[cod] __pycache__ .DS_Store +recursive-include target/jars *-all.jar +include README.md diff --git a/python/README.md b/python/README.md new file mode 100644 index 00000000..735bf8b8 --- /dev/null +++ b/python/README.md @@ -0,0 +1 @@ +TBP: \ No newline at end of file diff --git a/python/examples/hipster_index.py b/python/examples/hipster_index.py new file mode 100644 index 00000000..710ff95c --- /dev/null +++ b/python/examples/hipster_index.py @@ -0,0 +1,25 @@ +''' +Created on 24 Jan 2018 + +@author: szu004 +''' +import os +from variants import VariantsContext +from pyspark.sql import SparkSession + +PROJECT_DIR=os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) + +def main(): + spark = SparkSession.builder.appName("HipsterIndex") \ + .getOrCreate() + vs = VariantsContext(spark) + features = vs.import_vcf(os.path.join(PROJECT_DIR, 'data/chr22_1000.vcf')) + labels = vs.load_label(os.path.join(PROJECT_DIR,'data/chr22-labels.csv'), '22_16050408') + model = features.importance_analysis(labels, mtry_fraction = 0.1, seed = 13, n_trees = 200) + print("Oob = %s" % model.oob_error()) + for entry in model.important_variables(10): + print entry + +if __name__ == '__main__': + main() + diff --git a/python/pylintrc b/python/pylintrc index b3053569..76f65ef2 100644 --- a/python/pylintrc +++ b/python/pylintrc @@ -29,7 +29,7 @@ profile=no # Add files or directories to the blacklist. They should be base names, not # paths. -ignore=pyspark.heapq3 +ignore=pyspark # Pickle collected data for later comparisons. persistent=yes diff --git a/python/setup.py b/python/setup.py index fa757f76..dc9a27e1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -1,8 +1,77 @@ +from __future__ import print_function from setuptools import setup, find_packages +import sys +import os + +if sys.version_info < (2, 7): + print("Python versions prior to 2.7 are not supported.", file=sys.stderr) + exit(-1) -setup( - name = 'variants', - version = '0.1.0', - description ='VariantSpark Python API', - packages = find_packages() -) \ No newline at end of file + +#VERSION = __version__ +# A temporary path so we can access above the Python project root and fetch scripts and jars we need +TEMP_PATH = "target" +ROOT_DIR = os.path.abspath("../") + +# Provide guidance about how to use setup.py +incorrect_invocation_message = """ +If you are installing variants from variant-spark source, you must first build variant-spark and +run sdist. + To build with maven you can run: + ./build/mvn -DskipTests clean package + Building the source dist is done in the Python directory: + cd python + python setup.py sdist + pip install dist/*.tar.gz""" + +JARS_PATH = os.path.join(ROOT_DIR, "target") +JARS_TARGET = os.path.join(TEMP_PATH, "jars") + +in_src = os.path.isfile("../pom.xml") + +if (in_src): + # Construct links for setup + try: + os.mkdir(TEMP_PATH) + except: + print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH), + file=sys.stderr) + exit(-1) + + +try: + + if (in_src): + # Construct the symlink farm - this is necessary since we can't refer to the path above the + # package root and we need to copy the jars and scripts which are up above the python root. + os.symlink(JARS_PATH, JARS_TARGET) + + setup( + name='variants', + version='0.1.0', + description='VariantSpark Python API', + packages=find_packages(exclude=["*.test"]) + ['variants.jars'], + install_requires=['typedecorator'], +# test_suite = 'variants.test', +# test_requires = [ +# 'pyspark>=2.1.0' +# ], + include_package_data=True, + package_dir={ + 'variants.jars': 'target/jars', + }, + package_data={ + 'variants.jars': ['*-all.jar'], + }, + entry_points=''' + [console_scripts] + variants-jar=variants.cli:cli + ''', + ) +finally: + # We only cleanup the symlink farm if we were in Spark, otherwise we are installing rather than + # packaging. + if (in_src): + # Depending on cleaning up the symlink farm or copied version + os.remove(os.path.join(TEMP_PATH, "jars")) + os.rmdir(TEMP_PATH) diff --git a/python/variants/__init__.py b/python/variants/__init__.py index e69de29b..c0238475 100644 --- a/python/variants/__init__.py +++ b/python/variants/__init__.py @@ -0,0 +1,5 @@ +try: + from variants.core import VariantsContext +except: + pass +from variants.setup import find_jar diff --git a/python/variants/cli.py b/python/variants/cli.py new file mode 100644 index 00000000..2a349a3b --- /dev/null +++ b/python/variants/cli.py @@ -0,0 +1,6 @@ +from __future__ import print_function + +from variants import find_jar + +def cli(): + print(find_jar()) diff --git a/python/variants/core.py b/python/variants/core.py new file mode 100644 index 00000000..ccefaf7c --- /dev/null +++ b/python/variants/core.py @@ -0,0 +1,136 @@ +import sys +from typedecorator import params, Nullable, Union, setup_typecheck +from pyspark import SparkConf +from pyspark.sql import SQLContext +from variants.setup import find_jar + +class VariantsContext(object): + """The main entry point for VariantSpark functionality. + """ + + @classmethod + def spark_conf(cls, conf = SparkConf()): + """ Adds the necessary option to the spark configuration. + Note: In client mode these need to be setup up using --jars or --driver-class-path + """ + return conf.set("spark.jars", find_jar()) + + def __init__(self, ss=None): + """The main entry point for VariantSpark functionality. + :param ss: SparkSession + :type ss: :class:`.pyspark.SparkSession` + """ + self.sc = ss.sparkContext + self.sql = SQLContext.getOrCreate(self.sc) + self._jsql = self.sql._jsqlContext + self._jvm = self.sc._jvm + self._vs_api = getattr(self._jvm, 'au.csiro.variantspark.api') + jss = ss._jsparkSession + self._jvsc = self._vs_api.VSContext.apply(jss) + + setup_typecheck() + + sys.stderr.write('Running on Apache Spark version {}\n'.format(self.sc.version)) + if self.sc._jsc.sc().uiWebUrl().isDefined(): + sys.stderr.write('SparkUI available at {}\n'.format(self.sc._jsc.sc().uiWebUrl().get())) + sys.stderr.write( + 'Welcome to\n' + ' _ __ _ __ _____ __ \n' + '| | / /___ ______(_)___ _____ / /_/ ___/____ ____ ______/ /__ \n' + '| | / / __ `/ ___/ / __ `/ __ \/ __/\__ \/ __ \/ __ `/ ___/ //_/ \n' + '| |/ / /_/ / / / / /_/ / / / / /_ ___/ / /_/ / /_/ / / / ,< \n' + '|___/\__,_/_/ /_/\__,_/_/ /_/\__//____/ .___/\__,_/_/ /_/|_| \n' + ' /_/ \n') + + @params(self=object, vcf_file_path=str) + def import_vcf(self, vcf_file_path): + """ Import features from a VCF file. + """ + return FeatureSource(self._jvm, self._vs_api, + self._jsql, self.sql, self._jvsc.importVCF(vcf_file_path)) + + @params(self=object, label_file_path=str, col_name=str) + def load_label(self, label_file_path, col_name): + """ Loads the label source file + + :param label_file_path: The file path for the label source file + :param col_name: the name of the column containing labels + """ + return self._jvsc.loadLabel(label_file_path, col_name) + + def stop(self): + """ Shut down the VariantsContext. + """ + + self.sc.stop() + self.sc = None + + +class FeatureSource(object): + + def __init__(self, _jvm, _vs_api, _jsql, sql, _jfs): + self._jfs = _jfs + self._jvm = _jvm + self._vs_api = _vs_api + self._jsql = _jsql + self.sql = sql + + @params(self=object, label_source=object, n_trees=Nullable(int), mtry_fraction=Nullable(float), + oob=Nullable(bool), seed=Nullable(Union(int, long)), batch_size=Nullable(int), + var_ordinal_levels=Nullable(int)) + def importance_analysis(self, label_source, n_trees=1000, mtry_fraction=None, + oob=True, seed=None, batch_size=100, var_ordinal_levels=3): + """Builds random forest classifier. + + :param label_source: The ingested label source + :param int n_trees: The number of trees to build in the forest. + :param float mtry_fraction: The fraction of variables to try at each split. + :param bool oob: Should OOB error be calculated. + :param long seed: Random seed to use. + :param int batch_size: The number of trees to build in one batch. + :param int var_ordinal_levels: + + :return: Importance analysis model. + :rtype: :py:class:`ImportanceAnalysis` + """ + jrf_params = self._jvm.au.csiro.variantspark.algo.RandomForestParams(bool(oob), + float(mtry_fraction), True, float('nan'), True, long(seed)) + jia = self._vs_api.ImportanceAnalysis(self._jsql, self._jfs, label_source, + jrf_params, n_trees, batch_size, var_ordinal_levels) + return ImportanceAnalysis(jia, self.sql) + + +class ImportanceAnalysis(object): + """ Model for random forest based importance analysis + """ + + def __init__(self, _jia, sql): + self._jia = _jia + self.sql = sql + + @params(self=object, limit=Nullable(int)) + def important_variables(self, limit=10): + """ Gets the top limit important variables as a list of tuples (name, importance) where: + - name: string - variable name + - importance: double - gini importance + """ + jimpvarmap = self._jia.importantVariablesJavaMap(limit) + return sorted(jimpvarmap.items(), key=lambda x: x[1], reverse=True) + + def oob_error(self): + """ OOB (Out of Bag) error estimate for the model + + :rtype: float + """ + return self._jia.oobError() + + def variable_importance(self): + """ Returns a DataFrame with the gini importance of variables. + The DataFrame has two columns: + - variable: string - variable name + - importance: double - gini importance + """ + jdf = self._jia.variableImportance() + jdf.count() + jdf.createTempView("df") + return self.sql.table("df") diff --git a/python/variants/setup.py b/python/variants/setup.py new file mode 100644 index 00000000..016803fe --- /dev/null +++ b/python/variants/setup.py @@ -0,0 +1,10 @@ +import glob +import os +import pkg_resources + +def find_jar(): + """Gets the path to the variant spark jar bundled with the + python distribution + """ + jars_dir = pkg_resources.resource_filename(__name__, "jars") + return glob.glob(os.path.join(jars_dir, "*-all.jar"))[0] diff --git a/python/variants/test/__init__.py b/python/variants/test/__init__.py new file mode 100644 index 00000000..5c560c9e --- /dev/null +++ b/python/variants/test/__init__.py @@ -0,0 +1,10 @@ +import os +import glob + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_DIR = os.path.abspath(os.path.join(THIS_DIR, os.pardir, os.pardir, os.pardir)) + +def find_variants_jar(): + jar_candidates = glob.glob(os.path.join(PROJECT_DIR, 'target','variant-spark_*-all.jar')) + assert len(jar_candidates) == 1, "Expecting one jar, but found: %s" % str(jar_candidates) + return jar_candidates[0] diff --git a/python/variants/test/test_core.py b/python/variants/test/test_core.py new file mode 100644 index 00000000..bfb39a9b --- /dev/null +++ b/python/variants/test/test_core.py @@ -0,0 +1,59 @@ +import os +import unittest + +from pyspark import SparkConf +from pyspark.sql import SparkSession + +from variants import VariantsContext +from variants.test import find_variants_jar, PROJECT_DIR + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + +class VariantSparkPySparkTestCase(unittest.TestCase): + + @classmethod + def setUpClass(self): + sconf = SparkConf(loadDefaults=False)\ + .set("spark.sql.files.openCostInBytes", 53687091200L)\ + .set("spark.sql.files.maxPartitionBytes", 53687091200L)\ + .set("spark.driver.extraClassPath", find_variants_jar()) + spark = SparkSession.builder.config(conf=sconf)\ + .appName("test").master("local").getOrCreate() + self.sc = spark.sparkContext + + @classmethod + def tearDownClass(self): + self.sc.stop() + + +class VariantSparkAPITestCase(VariantSparkPySparkTestCase): + + def setUp(self): + self.spark = SparkSession(self.sc) + self.vc = VariantsContext(self.spark) + + def test_variants_context_parameter_type(self): + with self.assertRaises(TypeError) as cm: + self.vc.load_label(label_file_path=123, col_name=456) + self.assertEqual('keyword argument label_file_path = 123 doesn\'t match signature str', + str(cm.exception)) + + def test_importance_analysis_from_vcf(self): + label_data_path = os.path.join(PROJECT_DIR, 'data/chr22-labels.csv') + label = self.vc.load_label(label_file_path=label_data_path, col_name='22_16050408') + feature_data_path = os.path.join(PROJECT_DIR, 'data/chr22_1000.vcf') + features = self.vc.import_vcf(vcf_file_path=feature_data_path) + + imp_analysis = features.importance_analysis(label, 100, float('nan'), True, 13, 100, 3) + imp_vars = imp_analysis.important_variables(20) + most_imp_var = imp_vars[0][0] + self.assertEqual('22_16050408', most_imp_var) + df = imp_analysis.variable_importance() + self.assertEqual('22_16050408', + str(df.orderBy('importance', ascending=False).collect()[0][0])) + oob_error = imp_analysis.oob_error() + self.assertAlmostEqual(0.016483516483516484, oob_error, 4) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/main/scala/au/csiro/variantspark/api/ImportanceAnalysis.scala b/src/main/scala/au/csiro/variantspark/api/ImportanceAnalysis.scala index 3e80c218..4454d1aa 100644 --- a/src/main/scala/au/csiro/variantspark/api/ImportanceAnalysis.scala +++ b/src/main/scala/au/csiro/variantspark/api/ImportanceAnalysis.scala @@ -1,19 +1,15 @@ package au.csiro.variantspark.api -import au.csiro.variantspark.input.FeatureSource -import au.csiro.variantspark.input.LabelSource -import au.csiro.variantspark.algo.RandomForestParams -import au.csiro.variantspark.algo.ByteRandomForest -import au.csiro.variantspark.data.BoundedOrdinal import au.csiro.pbdava.ssparkle.spark.SparkUtils -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.types.StructField -import org.apache.spark.sql.types.StringType -import org.apache.spark.sql.types.DoubleType -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.Row +import au.csiro.variantspark.algo.{ByteRandomForest, RandomForestParams} +import au.csiro.variantspark.data.BoundedOrdinal +import au.csiro.variantspark.input.{FeatureSource, LabelSource} import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} + import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ /** * A class to represent an instance of the Importance Analysis @@ -47,7 +43,9 @@ class ImportanceAnalysis(val sqlContext:SQLContext, val featureSource:FeatureSou val trainingData = inputData.map{ case (f, i) => (f.values, i)} rf.batchTrain(trainingData, dataType, labels, nTrees, rfBatchSize) } - + + val oobError: Double = rfModel.oobError + private lazy val br_normalizedVariableImportance = { val indexImportance = rfModel.normalizedVariableImportance() sc.broadcast(new Long2DoubleOpenHashMap(indexImportance.asInstanceOf[Map[java.lang.Long, java.lang.Double]])) @@ -69,8 +67,13 @@ class ImportanceAnalysis(val sqlContext:SQLContext, val featureSource:FeatureSou } topImportantVariables.map({ case (i, importance) => (index(i), importance)}) - } - + } + + def importantVariablesJavaMap(nTopLimit:Int = 100) = { + val impVarMap = collection.mutable.Map(importantVariables(nTopLimit).toMap.toSeq: _*) + impVarMap.map{ case (k, v) => k -> double2Double(v) } + impVarMap.asJava + } } object ImportanceAnalysis {