forked from aehrc/VariantSpark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* WIP: Initial commit - Python wrappers for VariantSpark core API classes * WIP: Test now works * Missed file * Fixed Travis CI feedback. * Fixed Travis CI feedback * Updates following review: doc fixes, file renames, new analysis methods, helper methods in Scala codebase, returned data structures in proper Python format * Missed commit of deleted file * Re-ordered imports as per pylint * Made tests discoverable * Removed old test file * Added configuration to make test environment and outcomes consistent * Ignore pyspark import issues * Backed out pyspark install * Added classpath to variant-spark jar in text context and removed pyspark from dependencies * Refactored tests to use PROJECT_DIR * Added inclusion of jar file to python distribution * Added a simple example and ignore exception when pyspark not present * Fixed pylint errors * For now remove running of the test form the build * Added pyton docs
- Loading branch information
Showing
14 changed files
with
350 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,3 +59,4 @@ variant-spark_2.11.iml | |
build | ||
dist | ||
_build | ||
spark-warehouse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
Sphinx>=1.3.5 | ||
pylint==1.8.1 | ||
decorator>=4.1.2 | ||
typedecorator>=0.0.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
global-exclude *.py[cod] __pycache__ .DS_Store | ||
recursive-include target/jars *-all.jar | ||
include README.md |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
TBP: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
''' | ||
Created on 24 Jan 2018 | ||
@author: szu004 | ||
''' | ||
import os | ||
from variants import VariantsContext | ||
from pyspark.sql import SparkSession | ||
|
||
PROJECT_DIR=os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) | ||
|
||
def main(): | ||
spark = SparkSession.builder.appName("HipsterIndex") \ | ||
.getOrCreate() | ||
vs = VariantsContext(spark) | ||
features = vs.import_vcf(os.path.join(PROJECT_DIR, 'data/chr22_1000.vcf')) | ||
labels = vs.load_label(os.path.join(PROJECT_DIR,'data/chr22-labels.csv'), '22_16050408') | ||
model = features.importance_analysis(labels, mtry_fraction = 0.1, seed = 13, n_trees = 200) | ||
print("Oob = %s" % model.oob_error()) | ||
for entry in model.important_variables(10): | ||
print entry | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,77 @@ | ||
from __future__ import print_function | ||
from setuptools import setup, find_packages | ||
import sys | ||
import os | ||
|
||
if sys.version_info < (2, 7): | ||
print("Python versions prior to 2.7 are not supported.", file=sys.stderr) | ||
exit(-1) | ||
|
||
setup( | ||
name = 'variants', | ||
version = '0.1.0', | ||
description ='VariantSpark Python API', | ||
packages = find_packages() | ||
) | ||
|
||
#VERSION = __version__ | ||
# A temporary path so we can access above the Python project root and fetch scripts and jars we need | ||
TEMP_PATH = "target" | ||
ROOT_DIR = os.path.abspath("../") | ||
|
||
# Provide guidance about how to use setup.py | ||
incorrect_invocation_message = """ | ||
If you are installing variants from variant-spark source, you must first build variant-spark and | ||
run sdist. | ||
To build with maven you can run: | ||
./build/mvn -DskipTests clean package | ||
Building the source dist is done in the Python directory: | ||
cd python | ||
python setup.py sdist | ||
pip install dist/*.tar.gz""" | ||
|
||
JARS_PATH = os.path.join(ROOT_DIR, "target") | ||
JARS_TARGET = os.path.join(TEMP_PATH, "jars") | ||
|
||
in_src = os.path.isfile("../pom.xml") | ||
|
||
if (in_src): | ||
# Construct links for setup | ||
try: | ||
os.mkdir(TEMP_PATH) | ||
except: | ||
print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH), | ||
file=sys.stderr) | ||
exit(-1) | ||
|
||
|
||
try: | ||
|
||
if (in_src): | ||
# Construct the symlink farm - this is necessary since we can't refer to the path above the | ||
# package root and we need to copy the jars and scripts which are up above the python root. | ||
os.symlink(JARS_PATH, JARS_TARGET) | ||
|
||
setup( | ||
name='variants', | ||
version='0.1.0', | ||
description='VariantSpark Python API', | ||
packages=find_packages(exclude=["*.test"]) + ['variants.jars'], | ||
install_requires=['typedecorator'], | ||
# test_suite = 'variants.test', | ||
# test_requires = [ | ||
# 'pyspark>=2.1.0' | ||
# ], | ||
include_package_data=True, | ||
package_dir={ | ||
'variants.jars': 'target/jars', | ||
}, | ||
package_data={ | ||
'variants.jars': ['*-all.jar'], | ||
}, | ||
entry_points=''' | ||
[console_scripts] | ||
variants-jar=variants.cli:cli | ||
''', | ||
) | ||
finally: | ||
# We only cleanup the symlink farm if we were in Spark, otherwise we are installing rather than | ||
# packaging. | ||
if (in_src): | ||
# Depending on cleaning up the symlink farm or copied version | ||
os.remove(os.path.join(TEMP_PATH, "jars")) | ||
os.rmdir(TEMP_PATH) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
try: | ||
from variants.core import VariantsContext | ||
except: | ||
pass | ||
from variants.setup import find_jar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from __future__ import print_function | ||
|
||
from variants import find_jar | ||
|
||
def cli(): | ||
print(find_jar()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import sys | ||
from typedecorator import params, Nullable, Union, setup_typecheck | ||
from pyspark import SparkConf | ||
from pyspark.sql import SQLContext | ||
from variants.setup import find_jar | ||
|
||
class VariantsContext(object): | ||
"""The main entry point for VariantSpark functionality. | ||
""" | ||
|
||
@classmethod | ||
def spark_conf(cls, conf = SparkConf()): | ||
""" Adds the necessary option to the spark configuration. | ||
Note: In client mode these need to be setup up using --jars or --driver-class-path | ||
""" | ||
return conf.set("spark.jars", find_jar()) | ||
|
||
def __init__(self, ss=None): | ||
"""The main entry point for VariantSpark functionality. | ||
:param ss: SparkSession | ||
:type ss: :class:`.pyspark.SparkSession` | ||
""" | ||
self.sc = ss.sparkContext | ||
self.sql = SQLContext.getOrCreate(self.sc) | ||
self._jsql = self.sql._jsqlContext | ||
self._jvm = self.sc._jvm | ||
self._vs_api = getattr(self._jvm, 'au.csiro.variantspark.api') | ||
jss = ss._jsparkSession | ||
self._jvsc = self._vs_api.VSContext.apply(jss) | ||
|
||
setup_typecheck() | ||
|
||
sys.stderr.write('Running on Apache Spark version {}\n'.format(self.sc.version)) | ||
if self.sc._jsc.sc().uiWebUrl().isDefined(): | ||
sys.stderr.write('SparkUI available at {}\n'.format(self.sc._jsc.sc().uiWebUrl().get())) | ||
sys.stderr.write( | ||
'Welcome to\n' | ||
' _ __ _ __ _____ __ \n' | ||
'| | / /___ ______(_)___ _____ / /_/ ___/____ ____ ______/ /__ \n' | ||
'| | / / __ `/ ___/ / __ `/ __ \/ __/\__ \/ __ \/ __ `/ ___/ //_/ \n' | ||
'| |/ / /_/ / / / / /_/ / / / / /_ ___/ / /_/ / /_/ / / / ,< \n' | ||
'|___/\__,_/_/ /_/\__,_/_/ /_/\__//____/ .___/\__,_/_/ /_/|_| \n' | ||
' /_/ \n') | ||
|
||
@params(self=object, vcf_file_path=str) | ||
def import_vcf(self, vcf_file_path): | ||
""" Import features from a VCF file. | ||
""" | ||
return FeatureSource(self._jvm, self._vs_api, | ||
self._jsql, self.sql, self._jvsc.importVCF(vcf_file_path)) | ||
|
||
@params(self=object, label_file_path=str, col_name=str) | ||
def load_label(self, label_file_path, col_name): | ||
""" Loads the label source file | ||
:param label_file_path: The file path for the label source file | ||
:param col_name: the name of the column containing labels | ||
""" | ||
return self._jvsc.loadLabel(label_file_path, col_name) | ||
|
||
def stop(self): | ||
""" Shut down the VariantsContext. | ||
""" | ||
|
||
self.sc.stop() | ||
self.sc = None | ||
|
||
|
||
class FeatureSource(object): | ||
|
||
def __init__(self, _jvm, _vs_api, _jsql, sql, _jfs): | ||
self._jfs = _jfs | ||
self._jvm = _jvm | ||
self._vs_api = _vs_api | ||
self._jsql = _jsql | ||
self.sql = sql | ||
|
||
@params(self=object, label_source=object, n_trees=Nullable(int), mtry_fraction=Nullable(float), | ||
oob=Nullable(bool), seed=Nullable(Union(int, long)), batch_size=Nullable(int), | ||
var_ordinal_levels=Nullable(int)) | ||
def importance_analysis(self, label_source, n_trees=1000, mtry_fraction=None, | ||
oob=True, seed=None, batch_size=100, var_ordinal_levels=3): | ||
"""Builds random forest classifier. | ||
:param label_source: The ingested label source | ||
:param int n_trees: The number of trees to build in the forest. | ||
:param float mtry_fraction: The fraction of variables to try at each split. | ||
:param bool oob: Should OOB error be calculated. | ||
:param long seed: Random seed to use. | ||
:param int batch_size: The number of trees to build in one batch. | ||
:param int var_ordinal_levels: | ||
:return: Importance analysis model. | ||
:rtype: :py:class:`ImportanceAnalysis` | ||
""" | ||
jrf_params = self._jvm.au.csiro.variantspark.algo.RandomForestParams(bool(oob), | ||
float(mtry_fraction), True, float('nan'), True, long(seed)) | ||
jia = self._vs_api.ImportanceAnalysis(self._jsql, self._jfs, label_source, | ||
jrf_params, n_trees, batch_size, var_ordinal_levels) | ||
return ImportanceAnalysis(jia, self.sql) | ||
|
||
|
||
class ImportanceAnalysis(object): | ||
""" Model for random forest based importance analysis | ||
""" | ||
|
||
def __init__(self, _jia, sql): | ||
self._jia = _jia | ||
self.sql = sql | ||
|
||
@params(self=object, limit=Nullable(int)) | ||
def important_variables(self, limit=10): | ||
""" Gets the top limit important variables as a list of tuples (name, importance) where: | ||
- name: string - variable name | ||
- importance: double - gini importance | ||
""" | ||
jimpvarmap = self._jia.importantVariablesJavaMap(limit) | ||
return sorted(jimpvarmap.items(), key=lambda x: x[1], reverse=True) | ||
|
||
def oob_error(self): | ||
""" OOB (Out of Bag) error estimate for the model | ||
:rtype: float | ||
""" | ||
return self._jia.oobError() | ||
|
||
def variable_importance(self): | ||
""" Returns a DataFrame with the gini importance of variables. | ||
The DataFrame has two columns: | ||
- variable: string - variable name | ||
- importance: double - gini importance | ||
""" | ||
jdf = self._jia.variableImportance() | ||
jdf.count() | ||
jdf.createTempView("df") | ||
return self.sql.table("df") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import glob | ||
import os | ||
import pkg_resources | ||
|
||
def find_jar(): | ||
"""Gets the path to the variant spark jar bundled with the | ||
python distribution | ||
""" | ||
jars_dir = pkg_resources.resource_filename(__name__, "jars") | ||
return glob.glob(os.path.join(jars_dir, "*-all.jar"))[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import os | ||
import glob | ||
|
||
THIS_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
PROJECT_DIR = os.path.abspath(os.path.join(THIS_DIR, os.pardir, os.pardir, os.pardir)) | ||
|
||
def find_variants_jar(): | ||
jar_candidates = glob.glob(os.path.join(PROJECT_DIR, 'target','variant-spark_*-all.jar')) | ||
assert len(jar_candidates) == 1, "Expecting one jar, but found: %s" % str(jar_candidates) | ||
return jar_candidates[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
import unittest | ||
|
||
from pyspark import SparkConf | ||
from pyspark.sql import SparkSession | ||
|
||
from variants import VariantsContext | ||
from variants.test import find_variants_jar, PROJECT_DIR | ||
|
||
THIS_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
class VariantSparkPySparkTestCase(unittest.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(self): | ||
sconf = SparkConf(loadDefaults=False)\ | ||
.set("spark.sql.files.openCostInBytes", 53687091200L)\ | ||
.set("spark.sql.files.maxPartitionBytes", 53687091200L)\ | ||
.set("spark.driver.extraClassPath", find_variants_jar()) | ||
spark = SparkSession.builder.config(conf=sconf)\ | ||
.appName("test").master("local").getOrCreate() | ||
self.sc = spark.sparkContext | ||
|
||
@classmethod | ||
def tearDownClass(self): | ||
self.sc.stop() | ||
|
||
|
||
class VariantSparkAPITestCase(VariantSparkPySparkTestCase): | ||
|
||
def setUp(self): | ||
self.spark = SparkSession(self.sc) | ||
self.vc = VariantsContext(self.spark) | ||
|
||
def test_variants_context_parameter_type(self): | ||
with self.assertRaises(TypeError) as cm: | ||
self.vc.load_label(label_file_path=123, col_name=456) | ||
self.assertEqual('keyword argument label_file_path = 123 doesn\'t match signature str', | ||
str(cm.exception)) | ||
|
||
def test_importance_analysis_from_vcf(self): | ||
label_data_path = os.path.join(PROJECT_DIR, 'data/chr22-labels.csv') | ||
label = self.vc.load_label(label_file_path=label_data_path, col_name='22_16050408') | ||
feature_data_path = os.path.join(PROJECT_DIR, 'data/chr22_1000.vcf') | ||
features = self.vc.import_vcf(vcf_file_path=feature_data_path) | ||
|
||
imp_analysis = features.importance_analysis(label, 100, float('nan'), True, 13, 100, 3) | ||
imp_vars = imp_analysis.important_variables(20) | ||
most_imp_var = imp_vars[0][0] | ||
self.assertEqual('22_16050408', most_imp_var) | ||
df = imp_analysis.variable_importance() | ||
self.assertEqual('22_16050408', | ||
str(df.orderBy('importance', ascending=False).collect()[0][0])) | ||
oob_error = imp_analysis.oob_error() | ||
self.assertAlmostEqual(0.016483516483516484, oob_error, 4) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main(verbosity=2) |
Oops, something went wrong.