Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

i13 python api #58

Merged
merged 21 commits into from
Jan 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ variant-spark_2.11.iml
build
dist
_build
spark-warehouse
1 change: 1 addition & 0 deletions dev/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Sphinx>=1.3.5
pylint==1.8.1
decorator>=4.1.2
typedecorator>=0.0.5
3 changes: 3 additions & 0 deletions python/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
global-exclude *.py[cod] __pycache__ .DS_Store
recursive-include target/jars *-all.jar
include README.md
1 change: 1 addition & 0 deletions python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TBP:
25 changes: 25 additions & 0 deletions python/examples/hipster_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
'''
Created on 24 Jan 2018

@author: szu004
'''
import os
from variants import VariantsContext
from pyspark.sql import SparkSession

PROJECT_DIR=os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))

def main():
spark = SparkSession.builder.appName("HipsterIndex") \
.getOrCreate()
vs = VariantsContext(spark)
features = vs.import_vcf(os.path.join(PROJECT_DIR, 'data/chr22_1000.vcf'))
labels = vs.load_label(os.path.join(PROJECT_DIR,'data/chr22-labels.csv'), '22_16050408')
model = features.importance_analysis(labels, mtry_fraction = 0.1, seed = 13, n_trees = 200)
print("Oob = %s" % model.oob_error())
for entry in model.important_variables(10):
print entry

if __name__ == '__main__':
main()

2 changes: 1 addition & 1 deletion python/pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ profile=no

# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=pyspark.heapq3
ignore=pyspark

# Pickle collected data for later comparisons.
persistent=yes
Expand Down
81 changes: 75 additions & 6 deletions python/setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,77 @@
from __future__ import print_function
from setuptools import setup, find_packages
import sys
import os

if sys.version_info < (2, 7):
print("Python versions prior to 2.7 are not supported.", file=sys.stderr)
exit(-1)

setup(
name = 'variants',
version = '0.1.0',
description ='VariantSpark Python API',
packages = find_packages()
)

#VERSION = __version__
# A temporary path so we can access above the Python project root and fetch scripts and jars we need
TEMP_PATH = "target"
ROOT_DIR = os.path.abspath("../")

# Provide guidance about how to use setup.py
incorrect_invocation_message = """
If you are installing variants from variant-spark source, you must first build variant-spark and
run sdist.
To build with maven you can run:
./build/mvn -DskipTests clean package
Building the source dist is done in the Python directory:
cd python
python setup.py sdist
pip install dist/*.tar.gz"""

JARS_PATH = os.path.join(ROOT_DIR, "target")
JARS_TARGET = os.path.join(TEMP_PATH, "jars")

in_src = os.path.isfile("../pom.xml")

if (in_src):
# Construct links for setup
try:
os.mkdir(TEMP_PATH)
except:
print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH),
file=sys.stderr)
exit(-1)


try:

if (in_src):
# Construct the symlink farm - this is necessary since we can't refer to the path above the
# package root and we need to copy the jars and scripts which are up above the python root.
os.symlink(JARS_PATH, JARS_TARGET)

setup(
name='variants',
version='0.1.0',
description='VariantSpark Python API',
packages=find_packages(exclude=["*.test"]) + ['variants.jars'],
install_requires=['typedecorator'],
# test_suite = 'variants.test',
# test_requires = [
# 'pyspark>=2.1.0'
# ],
include_package_data=True,
package_dir={
'variants.jars': 'target/jars',
},
package_data={
'variants.jars': ['*-all.jar'],
},
entry_points='''
[console_scripts]
variants-jar=variants.cli:cli
''',
)
finally:
# We only cleanup the symlink farm if we were in Spark, otherwise we are installing rather than
# packaging.
if (in_src):
# Depending on cleaning up the symlink farm or copied version
os.remove(os.path.join(TEMP_PATH, "jars"))
os.rmdir(TEMP_PATH)
5 changes: 5 additions & 0 deletions python/variants/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
try:
from variants.core import VariantsContext
except:
pass
from variants.setup import find_jar
6 changes: 6 additions & 0 deletions python/variants/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from __future__ import print_function

from variants import find_jar

def cli():
print(find_jar())
136 changes: 136 additions & 0 deletions python/variants/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import sys
from typedecorator import params, Nullable, Union, setup_typecheck
from pyspark import SparkConf
from pyspark.sql import SQLContext
from variants.setup import find_jar

class VariantsContext(object):
"""The main entry point for VariantSpark functionality.
"""

@classmethod
def spark_conf(cls, conf = SparkConf()):
""" Adds the necessary option to the spark configuration.
Note: In client mode these need to be setup up using --jars or --driver-class-path
"""
return conf.set("spark.jars", find_jar())

def __init__(self, ss=None):
"""The main entry point for VariantSpark functionality.
:param ss: SparkSession
:type ss: :class:`.pyspark.SparkSession`
"""
self.sc = ss.sparkContext
self.sql = SQLContext.getOrCreate(self.sc)
self._jsql = self.sql._jsqlContext
self._jvm = self.sc._jvm
self._vs_api = getattr(self._jvm, 'au.csiro.variantspark.api')
jss = ss._jsparkSession
self._jvsc = self._vs_api.VSContext.apply(jss)

setup_typecheck()

sys.stderr.write('Running on Apache Spark version {}\n'.format(self.sc.version))
if self.sc._jsc.sc().uiWebUrl().isDefined():
sys.stderr.write('SparkUI available at {}\n'.format(self.sc._jsc.sc().uiWebUrl().get()))
sys.stderr.write(
'Welcome to\n'
' _ __ _ __ _____ __ \n'
'| | / /___ ______(_)___ _____ / /_/ ___/____ ____ ______/ /__ \n'
'| | / / __ `/ ___/ / __ `/ __ \/ __/\__ \/ __ \/ __ `/ ___/ //_/ \n'
'| |/ / /_/ / / / / /_/ / / / / /_ ___/ / /_/ / /_/ / / / ,< \n'
'|___/\__,_/_/ /_/\__,_/_/ /_/\__//____/ .___/\__,_/_/ /_/|_| \n'
' /_/ \n')

@params(self=object, vcf_file_path=str)
def import_vcf(self, vcf_file_path):
""" Import features from a VCF file.
"""
return FeatureSource(self._jvm, self._vs_api,
self._jsql, self.sql, self._jvsc.importVCF(vcf_file_path))

@params(self=object, label_file_path=str, col_name=str)
def load_label(self, label_file_path, col_name):
""" Loads the label source file

:param label_file_path: The file path for the label source file
:param col_name: the name of the column containing labels
"""
return self._jvsc.loadLabel(label_file_path, col_name)

def stop(self):
""" Shut down the VariantsContext.
"""

self.sc.stop()
self.sc = None


class FeatureSource(object):

def __init__(self, _jvm, _vs_api, _jsql, sql, _jfs):
self._jfs = _jfs
self._jvm = _jvm
self._vs_api = _vs_api
self._jsql = _jsql
self.sql = sql

@params(self=object, label_source=object, n_trees=Nullable(int), mtry_fraction=Nullable(float),
oob=Nullable(bool), seed=Nullable(Union(int, long)), batch_size=Nullable(int),
var_ordinal_levels=Nullable(int))
def importance_analysis(self, label_source, n_trees=1000, mtry_fraction=None,
oob=True, seed=None, batch_size=100, var_ordinal_levels=3):
"""Builds random forest classifier.

:param label_source: The ingested label source
:param int n_trees: The number of trees to build in the forest.
:param float mtry_fraction: The fraction of variables to try at each split.
:param bool oob: Should OOB error be calculated.
:param long seed: Random seed to use.
:param int batch_size: The number of trees to build in one batch.
:param int var_ordinal_levels:

:return: Importance analysis model.
:rtype: :py:class:`ImportanceAnalysis`
"""
jrf_params = self._jvm.au.csiro.variantspark.algo.RandomForestParams(bool(oob),
float(mtry_fraction), True, float('nan'), True, long(seed))
jia = self._vs_api.ImportanceAnalysis(self._jsql, self._jfs, label_source,
jrf_params, n_trees, batch_size, var_ordinal_levels)
return ImportanceAnalysis(jia, self.sql)


class ImportanceAnalysis(object):
""" Model for random forest based importance analysis
"""

def __init__(self, _jia, sql):
self._jia = _jia
self.sql = sql

@params(self=object, limit=Nullable(int))
def important_variables(self, limit=10):
""" Gets the top limit important variables as a list of tuples (name, importance) where:
- name: string - variable name
- importance: double - gini importance
"""
jimpvarmap = self._jia.importantVariablesJavaMap(limit)
return sorted(jimpvarmap.items(), key=lambda x: x[1], reverse=True)

def oob_error(self):
""" OOB (Out of Bag) error estimate for the model

:rtype: float
"""
return self._jia.oobError()

def variable_importance(self):
""" Returns a DataFrame with the gini importance of variables.
The DataFrame has two columns:
- variable: string - variable name
- importance: double - gini importance
"""
jdf = self._jia.variableImportance()
jdf.count()
jdf.createTempView("df")
return self.sql.table("df")
10 changes: 10 additions & 0 deletions python/variants/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import glob
import os
import pkg_resources

def find_jar():
"""Gets the path to the variant spark jar bundled with the
python distribution
"""
jars_dir = pkg_resources.resource_filename(__name__, "jars")
return glob.glob(os.path.join(jars_dir, "*-all.jar"))[0]
10 changes: 10 additions & 0 deletions python/variants/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os
import glob

THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_DIR = os.path.abspath(os.path.join(THIS_DIR, os.pardir, os.pardir, os.pardir))

def find_variants_jar():
jar_candidates = glob.glob(os.path.join(PROJECT_DIR, 'target','variant-spark_*-all.jar'))
assert len(jar_candidates) == 1, "Expecting one jar, but found: %s" % str(jar_candidates)
return jar_candidates[0]
59 changes: 59 additions & 0 deletions python/variants/test/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import unittest

from pyspark import SparkConf
from pyspark.sql import SparkSession

from variants import VariantsContext
from variants.test import find_variants_jar, PROJECT_DIR

THIS_DIR = os.path.dirname(os.path.abspath(__file__))

class VariantSparkPySparkTestCase(unittest.TestCase):

@classmethod
def setUpClass(self):
sconf = SparkConf(loadDefaults=False)\
.set("spark.sql.files.openCostInBytes", 53687091200L)\
.set("spark.sql.files.maxPartitionBytes", 53687091200L)\
.set("spark.driver.extraClassPath", find_variants_jar())
spark = SparkSession.builder.config(conf=sconf)\
.appName("test").master("local").getOrCreate()
self.sc = spark.sparkContext

@classmethod
def tearDownClass(self):
self.sc.stop()


class VariantSparkAPITestCase(VariantSparkPySparkTestCase):

def setUp(self):
self.spark = SparkSession(self.sc)
self.vc = VariantsContext(self.spark)

def test_variants_context_parameter_type(self):
with self.assertRaises(TypeError) as cm:
self.vc.load_label(label_file_path=123, col_name=456)
self.assertEqual('keyword argument label_file_path = 123 doesn\'t match signature str',
str(cm.exception))

def test_importance_analysis_from_vcf(self):
label_data_path = os.path.join(PROJECT_DIR, 'data/chr22-labels.csv')
label = self.vc.load_label(label_file_path=label_data_path, col_name='22_16050408')
feature_data_path = os.path.join(PROJECT_DIR, 'data/chr22_1000.vcf')
features = self.vc.import_vcf(vcf_file_path=feature_data_path)

imp_analysis = features.importance_analysis(label, 100, float('nan'), True, 13, 100, 3)
imp_vars = imp_analysis.important_variables(20)
most_imp_var = imp_vars[0][0]
self.assertEqual('22_16050408', most_imp_var)
df = imp_analysis.variable_importance()
self.assertEqual('22_16050408',
str(df.orderBy('importance', ascending=False).collect()[0][0]))
oob_error = imp_analysis.oob_error()
self.assertAlmostEqual(0.016483516483516484, oob_error, 4)


if __name__ == '__main__':
unittest.main(verbosity=2)
Loading