Skip to content

[SPARK-2871] [PySpark] Add missing API #1791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}
import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.api.java.{JavaDoubleRDD, JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -738,7 +738,7 @@ private[spark] object PythonRDD extends Logging {
}

/**
* Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
*/
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@
(u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')]
"""

import functools


def check_readonly(f):
@functools.wraps(f)
def func(self, *a, **kw):
if self._readonly:
raise Exception("Configuration can not be changed after initialization")
return f(self, *a, **kw)
return func


class SparkConf(object):

Expand Down Expand Up @@ -96,32 +107,41 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None):
_jvm = _jvm or SparkContext._jvm
self._jconf = _jvm.SparkConf(loadDefaults)

# Configuration can not be changed after initialization
self._readonly = False

@check_readonly
def set(self, key, value):
"""Set a configuration property."""
self._jconf.set(key, unicode(value))
return self

@check_readonly
def setIfMissing(self, key, value):
"""Set a configuration property, if not already set."""
if self.get(key) is None:
self.set(key, value)
return self

@check_readonly
def setMaster(self, value):
"""Set master URL to connect to."""
self._jconf.setMaster(value)
return self

@check_readonly
def setAppName(self, value):
"""Set application name."""
self._jconf.setAppName(value)
return self

@check_readonly
def setSparkHome(self, value):
"""Set path where Spark is installed on worker nodes."""
self._jconf.setSparkHome(value)
return self

@check_readonly
def setExecutorEnv(self, key=None, value=None, pairs=None):
"""Set an environment variable to be passed to executors."""
if (key is not None and pairs is not None) or (key is None and pairs is None):
Expand All @@ -133,6 +153,7 @@ def setExecutorEnv(self, key=None, value=None, pairs=None):
self._jconf.setExecutorEnv(k, v)
return self

@check_readonly
def setAll(self, pairs):
"""
Set multiple parameters, passed as a list of key-value pairs.
Expand Down
31 changes: 31 additions & 0 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,

# Create the Java SparkContext through Py4J
self._jsc = self._initialize_context(self._conf._jconf)
self._conf._readonly = True

# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server
Expand Down Expand Up @@ -260,6 +261,22 @@ def defaultMinPartitions(self):
"""
return self._jsc.sc().defaultMinPartitions()

@property
def isLocal(self):
"""
Whether the context run locally
"""
return self._jsc.isLocal()

@property
def conf(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a docstring. Also, the Scala equivalent of this clones the SparkConf because it cannot be changed at runtime. We might want to do the same thing here (to guard against misuse); I'm not sure how clone() interacts with Py4J objects; do we need to implement a custom clone method for objects with Py4J objects as fields that calls those objects' JVM clone methods?

"""
The L{SparkConf} object

Configuration can not be changed after initialization.
"""
return self._conf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Josh here, you need to clone the conf before returning it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will return an read-only copy of it.


def stop(self):
"""
Shut down the SparkContext.
Expand Down Expand Up @@ -733,6 +750,13 @@ def sparkUser(self):
"""
return self._jsc.sc().sparkUser()

@property
def startTime(self):
"""
Return the start time of context in millis seconds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This startTime property isn't documented in the Scala API. Do we want to include it here? What's the use-case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw it in Java API docs,so add it here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The primary use of this, outside of SparkContext, seems to be printing the context's uptime. So, why not add an uptime method instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change it to uptime will not improve anything, or remove it?

"""
return self._jsc.startTime()

def cancelJobGroup(self, groupId):
"""
Cancel active jobs for the specified group. See L{SparkContext.setJobGroup}
Expand Down Expand Up @@ -772,6 +796,13 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))

# TODO
# def runApproximateJob(self, rdd, func, evaluator, timeout):
# """
# :: DeveloperApi ::
# Run a job that can return approximate results.
# """


def _test():
import atexit
Expand Down
47 changes: 37 additions & 10 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,24 @@

from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from collections import namedtuple
from itertools import chain, ifilter, imap
import operator
import os
import sys
import shlex
import traceback
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
import warnings
import heapq
import array
import bisect
import math
from collections import defaultdict, namedtuple
from itertools import chain, ifilter, imap
from random import Random
from math import sqrt, log
from bisect import bisect_right
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
Expand Down Expand Up @@ -1741,6 +1743,13 @@ def batch_as(rdd, batchSize):
other._jrdd_deserializer)
return RDD(pairRDD, self.ctx, deserializer)

# TODO
# def zipPartitions(self, other, f, preservesPartitioning=False):
# """
# Zip this RDD's partitions with one (or more) RDD(s) and return a
# new RDD by applying a function to the zipped partitions.
# """

def zipWithIndex(self):
"""
Zips this RDD with its element indices.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Scala documentation is much more descriptive about what this method does:

  /**
   * Zips this RDD with its element indices. The ordering is first based on the partition index
   * and then the ordering of items within each partition. So the first item in the first
   * partition gets index 0, and the last item in the last partition receives the largest index.
   * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type.
   * This method needs to trigger a spark job when this RDD contains more than one partitions.
   */
  def zipWithIndex(): RDD[(T, Long)] = new ZippedWithIndexRDD(this)

The Python documentation should explain these subtleties, too.

Expand Down Expand Up @@ -1850,10 +1859,28 @@ def _defaultReducePartitions(self):
else:
return self.getNumPartitions()

# TODO: `lookup` is disabled because we can't make direct comparisons based
# on the key; we need to compare the hash of the key to the hash of the
# keys in the pairs. This could be an expensive operation, since those
# hashes aren't retained.
# TODO
# def countApproxDistinctByKey(self, timeout, confidence=0.95):
# """
# :: Experimental ::
# Return approximate number of distinct values for each key in this RDD.
# """

# TODO
# def countByKeyApprox(self, timeout, confidence=0.95):
# """
# :: Experimental ::
# Approximate version of countByKey that can return a partial result
# if it does not finish within a timeout.
# """
#
# def countByValueApprox(self, timeout, confidence=0.95):
# """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you'd like you can implement lookup() the same way as in Scala, it's not too bad

# :: Experimental::
# Approximate version of countByValue().
#
# """
# return self.map(lambda x: (x, None)).countByKeyApprox(timeout, confidence)

def _is_pickled(self):
""" Return this RDD is serialized by Pickle or not. """
Expand Down