Skip to content

Commit

Permalink
i13 python api (aehrc#58)
Browse files Browse the repository at this point in the history
* WIP: Initial commit - Python wrappers for VariantSpark core API classes

* WIP: Test now works

* Missed file

* Fixed Travis CI feedback.

* Fixed Travis CI feedback

* Updates following review: doc fixes, file renames, new analysis methods, helper methods in Scala codebase, returned data structures in proper Python format

* Missed commit of deleted file

* Re-ordered imports as per pylint

* Made tests discoverable

* Removed old test file

* Added configuration to make test environment and outcomes consistent

* Ignore pyspark import issues

* Backed out pyspark install

* Added classpath to variant-spark jar in text context and removed pyspark from dependencies

* Refactored tests to use PROJECT_DIR

* Added inclusion of jar file to python distribution

* Added a simple example and ignore exception when pyspark not present

* Fixed pylint errors

* For now remove running of the test form the build

* Added pyton docs
  • Loading branch information
piotrszul authored Jan 25, 2018
1 parent aae79de commit 5c6bf2b
Show file tree
Hide file tree
Showing 14 changed files with 350 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ variant-spark_2.11.iml
build
dist
_build
spark-warehouse
1 change: 1 addition & 0 deletions dev/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Sphinx>=1.3.5
pylint==1.8.1
decorator>=4.1.2
typedecorator>=0.0.5
3 changes: 3 additions & 0 deletions python/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
global-exclude *.py[cod] __pycache__ .DS_Store
recursive-include target/jars *-all.jar
include README.md
1 change: 1 addition & 0 deletions python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TBP:
25 changes: 25 additions & 0 deletions python/examples/hipster_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
'''
Created on 24 Jan 2018
@author: szu004
'''
import os
from variants import VariantsContext
from pyspark.sql import SparkSession

PROJECT_DIR=os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))

def main():
spark = SparkSession.builder.appName("HipsterIndex") \
.getOrCreate()
vs = VariantsContext(spark)
features = vs.import_vcf(os.path.join(PROJECT_DIR, 'data/chr22_1000.vcf'))
labels = vs.load_label(os.path.join(PROJECT_DIR,'data/chr22-labels.csv'), '22_16050408')
model = features.importance_analysis(labels, mtry_fraction = 0.1, seed = 13, n_trees = 200)
print("Oob = %s" % model.oob_error())
for entry in model.important_variables(10):
print entry

if __name__ == '__main__':
main()

2 changes: 1 addition & 1 deletion python/pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ profile=no

# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=pyspark.heapq3
ignore=pyspark

# Pickle collected data for later comparisons.
persistent=yes
Expand Down
81 changes: 75 additions & 6 deletions python/setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,77 @@
from __future__ import print_function
from setuptools import setup, find_packages
import sys
import os

if sys.version_info < (2, 7):
print("Python versions prior to 2.7 are not supported.", file=sys.stderr)
exit(-1)

setup(
name = 'variants',
version = '0.1.0',
description ='VariantSpark Python API',
packages = find_packages()
)

#VERSION = __version__
# A temporary path so we can access above the Python project root and fetch scripts and jars we need
TEMP_PATH = "target"
ROOT_DIR = os.path.abspath("../")

# Provide guidance about how to use setup.py
incorrect_invocation_message = """
If you are installing variants from variant-spark source, you must first build variant-spark and
run sdist.
To build with maven you can run:
./build/mvn -DskipTests clean package
Building the source dist is done in the Python directory:
cd python
python setup.py sdist
pip install dist/*.tar.gz"""

JARS_PATH = os.path.join(ROOT_DIR, "target")
JARS_TARGET = os.path.join(TEMP_PATH, "jars")

in_src = os.path.isfile("../pom.xml")

if (in_src):
# Construct links for setup
try:
os.mkdir(TEMP_PATH)
except:
print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH),
file=sys.stderr)
exit(-1)


try:

if (in_src):
# Construct the symlink farm - this is necessary since we can't refer to the path above the
# package root and we need to copy the jars and scripts which are up above the python root.
os.symlink(JARS_PATH, JARS_TARGET)

setup(
name='variants',
version='0.1.0',
description='VariantSpark Python API',
packages=find_packages(exclude=["*.test"]) + ['variants.jars'],
install_requires=['typedecorator'],
# test_suite = 'variants.test',
# test_requires = [
# 'pyspark>=2.1.0'
# ],
include_package_data=True,
package_dir={
'variants.jars': 'target/jars',
},
package_data={
'variants.jars': ['*-all.jar'],
},
entry_points='''
[console_scripts]
variants-jar=variants.cli:cli
''',
)
finally:
# We only cleanup the symlink farm if we were in Spark, otherwise we are installing rather than
# packaging.
if (in_src):
# Depending on cleaning up the symlink farm or copied version
os.remove(os.path.join(TEMP_PATH, "jars"))
os.rmdir(TEMP_PATH)
5 changes: 5 additions & 0 deletions python/variants/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
try:
from variants.core import VariantsContext
except:
pass
from variants.setup import find_jar
6 changes: 6 additions & 0 deletions python/variants/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from __future__ import print_function

from variants import find_jar

def cli():
print(find_jar())
136 changes: 136 additions & 0 deletions python/variants/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import sys
from typedecorator import params, Nullable, Union, setup_typecheck
from pyspark import SparkConf
from pyspark.sql import SQLContext
from variants.setup import find_jar

class VariantsContext(object):
"""The main entry point for VariantSpark functionality.
"""

@classmethod
def spark_conf(cls, conf = SparkConf()):
""" Adds the necessary option to the spark configuration.
Note: In client mode these need to be setup up using --jars or --driver-class-path
"""
return conf.set("spark.jars", find_jar())

def __init__(self, ss=None):
"""The main entry point for VariantSpark functionality.
:param ss: SparkSession
:type ss: :class:`.pyspark.SparkSession`
"""
self.sc = ss.sparkContext
self.sql = SQLContext.getOrCreate(self.sc)
self._jsql = self.sql._jsqlContext
self._jvm = self.sc._jvm
self._vs_api = getattr(self._jvm, 'au.csiro.variantspark.api')
jss = ss._jsparkSession
self._jvsc = self._vs_api.VSContext.apply(jss)

setup_typecheck()

sys.stderr.write('Running on Apache Spark version {}\n'.format(self.sc.version))
if self.sc._jsc.sc().uiWebUrl().isDefined():
sys.stderr.write('SparkUI available at {}\n'.format(self.sc._jsc.sc().uiWebUrl().get()))
sys.stderr.write(
'Welcome to\n'
' _ __ _ __ _____ __ \n'
'| | / /___ ______(_)___ _____ / /_/ ___/____ ____ ______/ /__ \n'
'| | / / __ `/ ___/ / __ `/ __ \/ __/\__ \/ __ \/ __ `/ ___/ //_/ \n'
'| |/ / /_/ / / / / /_/ / / / / /_ ___/ / /_/ / /_/ / / / ,< \n'
'|___/\__,_/_/ /_/\__,_/_/ /_/\__//____/ .___/\__,_/_/ /_/|_| \n'
' /_/ \n')

@params(self=object, vcf_file_path=str)
def import_vcf(self, vcf_file_path):
""" Import features from a VCF file.
"""
return FeatureSource(self._jvm, self._vs_api,
self._jsql, self.sql, self._jvsc.importVCF(vcf_file_path))

@params(self=object, label_file_path=str, col_name=str)
def load_label(self, label_file_path, col_name):
""" Loads the label source file
:param label_file_path: The file path for the label source file
:param col_name: the name of the column containing labels
"""
return self._jvsc.loadLabel(label_file_path, col_name)

def stop(self):
""" Shut down the VariantsContext.
"""

self.sc.stop()
self.sc = None


class FeatureSource(object):

def __init__(self, _jvm, _vs_api, _jsql, sql, _jfs):
self._jfs = _jfs
self._jvm = _jvm
self._vs_api = _vs_api
self._jsql = _jsql
self.sql = sql

@params(self=object, label_source=object, n_trees=Nullable(int), mtry_fraction=Nullable(float),
oob=Nullable(bool), seed=Nullable(Union(int, long)), batch_size=Nullable(int),
var_ordinal_levels=Nullable(int))
def importance_analysis(self, label_source, n_trees=1000, mtry_fraction=None,
oob=True, seed=None, batch_size=100, var_ordinal_levels=3):
"""Builds random forest classifier.
:param label_source: The ingested label source
:param int n_trees: The number of trees to build in the forest.
:param float mtry_fraction: The fraction of variables to try at each split.
:param bool oob: Should OOB error be calculated.
:param long seed: Random seed to use.
:param int batch_size: The number of trees to build in one batch.
:param int var_ordinal_levels:
:return: Importance analysis model.
:rtype: :py:class:`ImportanceAnalysis`
"""
jrf_params = self._jvm.au.csiro.variantspark.algo.RandomForestParams(bool(oob),
float(mtry_fraction), True, float('nan'), True, long(seed))
jia = self._vs_api.ImportanceAnalysis(self._jsql, self._jfs, label_source,
jrf_params, n_trees, batch_size, var_ordinal_levels)
return ImportanceAnalysis(jia, self.sql)


class ImportanceAnalysis(object):
""" Model for random forest based importance analysis
"""

def __init__(self, _jia, sql):
self._jia = _jia
self.sql = sql

@params(self=object, limit=Nullable(int))
def important_variables(self, limit=10):
""" Gets the top limit important variables as a list of tuples (name, importance) where:
- name: string - variable name
- importance: double - gini importance
"""
jimpvarmap = self._jia.importantVariablesJavaMap(limit)
return sorted(jimpvarmap.items(), key=lambda x: x[1], reverse=True)

def oob_error(self):
""" OOB (Out of Bag) error estimate for the model
:rtype: float
"""
return self._jia.oobError()

def variable_importance(self):
""" Returns a DataFrame with the gini importance of variables.
The DataFrame has two columns:
- variable: string - variable name
- importance: double - gini importance
"""
jdf = self._jia.variableImportance()
jdf.count()
jdf.createTempView("df")
return self.sql.table("df")
10 changes: 10 additions & 0 deletions python/variants/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import glob
import os
import pkg_resources

def find_jar():
"""Gets the path to the variant spark jar bundled with the
python distribution
"""
jars_dir = pkg_resources.resource_filename(__name__, "jars")
return glob.glob(os.path.join(jars_dir, "*-all.jar"))[0]
10 changes: 10 additions & 0 deletions python/variants/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os
import glob

THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_DIR = os.path.abspath(os.path.join(THIS_DIR, os.pardir, os.pardir, os.pardir))

def find_variants_jar():
jar_candidates = glob.glob(os.path.join(PROJECT_DIR, 'target','variant-spark_*-all.jar'))
assert len(jar_candidates) == 1, "Expecting one jar, but found: %s" % str(jar_candidates)
return jar_candidates[0]
59 changes: 59 additions & 0 deletions python/variants/test/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import unittest

from pyspark import SparkConf
from pyspark.sql import SparkSession

from variants import VariantsContext
from variants.test import find_variants_jar, PROJECT_DIR

THIS_DIR = os.path.dirname(os.path.abspath(__file__))

class VariantSparkPySparkTestCase(unittest.TestCase):

@classmethod
def setUpClass(self):
sconf = SparkConf(loadDefaults=False)\
.set("spark.sql.files.openCostInBytes", 53687091200L)\
.set("spark.sql.files.maxPartitionBytes", 53687091200L)\
.set("spark.driver.extraClassPath", find_variants_jar())
spark = SparkSession.builder.config(conf=sconf)\
.appName("test").master("local").getOrCreate()
self.sc = spark.sparkContext

@classmethod
def tearDownClass(self):
self.sc.stop()


class VariantSparkAPITestCase(VariantSparkPySparkTestCase):

def setUp(self):
self.spark = SparkSession(self.sc)
self.vc = VariantsContext(self.spark)

def test_variants_context_parameter_type(self):
with self.assertRaises(TypeError) as cm:
self.vc.load_label(label_file_path=123, col_name=456)
self.assertEqual('keyword argument label_file_path = 123 doesn\'t match signature str',
str(cm.exception))

def test_importance_analysis_from_vcf(self):
label_data_path = os.path.join(PROJECT_DIR, 'data/chr22-labels.csv')
label = self.vc.load_label(label_file_path=label_data_path, col_name='22_16050408')
feature_data_path = os.path.join(PROJECT_DIR, 'data/chr22_1000.vcf')
features = self.vc.import_vcf(vcf_file_path=feature_data_path)

imp_analysis = features.importance_analysis(label, 100, float('nan'), True, 13, 100, 3)
imp_vars = imp_analysis.important_variables(20)
most_imp_var = imp_vars[0][0]
self.assertEqual('22_16050408', most_imp_var)
df = imp_analysis.variable_importance()
self.assertEqual('22_16050408',
str(df.orderBy('importance', ascending=False).collect()[0][0]))
oob_error = imp_analysis.oob_error()
self.assertAlmostEqual(0.016483516483516484, oob_error, 4)


if __name__ == '__main__':
unittest.main(verbosity=2)
Loading

0 comments on commit 5c6bf2b

Please sign in to comment.