Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/apache/spark into SPARK-3…
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko committed Nov 25, 2020
2 parents b628156 + 19f3b89 commit c75cd57
Show file tree
Hide file tree
Showing 81 changed files with 3,142 additions and 1,696 deletions.
3 changes: 2 additions & 1 deletion python/docs/source/reference/pyspark.mllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ Statistics
ChiSqTestResult
MultivariateGaussian
KernelDensity
ChiSqTestResult
KolmogorovSmirnovTestResult


Tree
Expand Down Expand Up @@ -250,4 +252,3 @@ Utilities
Loader
MLUtils
Saveable

87 changes: 87 additions & 0 deletions python/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,99 @@
;

[mypy]
strict_optional = True
no_implicit_optional = True
disallow_untyped_defs = True

; Allow untyped def in internal modules and tests

[mypy-pyspark.daemon]
disallow_untyped_defs = False

[mypy-pyspark.find_spark_home]
disallow_untyped_defs = False

[mypy-pyspark._globals]
disallow_untyped_defs = False

[mypy-pyspark.install]
disallow_untyped_defs = False

[mypy-pyspark.java_gateway]
disallow_untyped_defs = False

[mypy-pyspark.join]
disallow_untyped_defs = False

[mypy-pyspark.ml.tests.*]
disallow_untyped_defs = False

[mypy-pyspark.mllib.tests.*]
disallow_untyped_defs = False

[mypy-pyspark.rddsampler]
disallow_untyped_defs = False

[mypy-pyspark.resource.tests.*]
disallow_untyped_defs = False

[mypy-pyspark.serializers]
disallow_untyped_defs = False

[mypy-pyspark.shuffle]
disallow_untyped_defs = False

[mypy-pyspark.streaming.tests.*]
disallow_untyped_defs = False

[mypy-pyspark.streaming.util]
disallow_untyped_defs = False

[mypy-pyspark.sql.tests.*]
disallow_untyped_defs = False

[mypy-pyspark.sql.pandas.serializers]
disallow_untyped_defs = False

[mypy-pyspark.sql.pandas.types]
disallow_untyped_defs = False

[mypy-pyspark.sql.pandas.typehints]
disallow_untyped_defs = False

[mypy-pyspark.sql.pandas.utils]
disallow_untyped_defs = False

[mypy-pyspark.sql.pandas._typing.protocols.*]
disallow_untyped_defs = False

[mypy-pyspark.sql.utils]
disallow_untyped_defs = False

[mypy-pyspark.tests.*]
disallow_untyped_defs = False

[mypy-pyspark.testing.*]
disallow_untyped_defs = False

[mypy-pyspark.traceback_utils]
disallow_untyped_defs = False

[mypy-pyspark.util]
disallow_untyped_defs = False

[mypy-pyspark.worker]
disallow_untyped_defs = False

; Ignore errors in embedded third party code

no_implicit_optional = True

[mypy-pyspark.cloudpickle.*]
ignore_errors = True

; Ignore missing imports for external untyped packages

[mypy-py4j.*]
ignore_missing_imports = True

Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/broadcast.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.

import threading
from typing import Any, Dict, Generic, Optional, TypeVar
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar

T = TypeVar("T")

Expand All @@ -32,14 +32,14 @@ class Broadcast(Generic[T]):
path: Optional[Any] = ...,
sock_file: Optional[Any] = ...,
) -> None: ...
def dump(self, value: Any, f: Any) -> None: ...
def load_from_path(self, path: Any): ...
def load(self, file: Any): ...
def dump(self, value: T, f: Any) -> None: ...
def load_from_path(self, path: Any) -> T: ...
def load(self, file: Any) -> T: ...
@property
def value(self) -> T: ...
def unpersist(self, blocking: bool = ...) -> None: ...
def destroy(self, blocking: bool = ...) -> None: ...
def __reduce__(self): ...
def __reduce__(self) -> Tuple[Callable[[int], T], Tuple[int]]: ...

class BroadcastPickleRegistry(threading.local):
def __init__(self) -> None: ...
Expand Down
25 changes: 21 additions & 4 deletions python/pyspark/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,19 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
NoReturn,
Optional,
Tuple,
Type,
TypeVar,
)
from types import TracebackType

from py4j.java_gateway import JavaGateway, JavaObject # type: ignore[import]

Expand Down Expand Up @@ -51,9 +63,14 @@ class SparkContext:
jsc: Optional[JavaObject] = ...,
profiler_cls: type = ...,
) -> None: ...
def __getnewargs__(self): ...
def __enter__(self): ...
def __exit__(self, type, value, trace): ...
def __getnewargs__(self) -> NoReturn: ...
def __enter__(self) -> SparkContext: ...
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
trace: Optional[TracebackType],
) -> None: ...
@classmethod
def getOrCreate(cls, conf: Optional[SparkConf] = ...) -> SparkContext: ...
def setLogLevel(self, logLevel: str) -> None: ...
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/ml/classification.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class _JavaProbabilisticClassifier(
class _JavaProbabilisticClassificationModel(
ProbabilisticClassificationModel, _JavaClassificationModel[T]
):
def predictProbability(self, value: Any): ...
def predictProbability(self, value: Vector) -> Vector: ...

class _ClassificationSummary(JavaWrapper):
@property
Expand Down Expand Up @@ -543,7 +543,7 @@ class RandomForestClassificationModel(
@property
def trees(self) -> List[DecisionTreeClassificationModel]: ...
def summary(self) -> RandomForestClassificationTrainingSummary: ...
def evaluate(self, dataset) -> RandomForestClassificationSummary: ...
def evaluate(self, dataset: DataFrame) -> RandomForestClassificationSummary: ...

class RandomForestClassificationSummary(_ClassificationSummary): ...
class RandomForestClassificationTrainingSummary(
Expand Down Expand Up @@ -891,7 +891,7 @@ class FMClassifier(
solver: str = ...,
thresholds: Optional[Any] = ...,
seed: Optional[Any] = ...,
): ...
) -> FMClassifier: ...
def setFactorSize(self, value: int) -> FMClassifier: ...
def setFitLinear(self, value: bool) -> FMClassifier: ...
def setMiniBatchFraction(self, value: float) -> FMClassifier: ...
Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/ml/common.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,11 @@
# specific language governing permissions and limitations
# under the License.

def callJavaFunc(sc, func, *args): ...
def inherit_doc(cls): ...
from typing import Any, TypeVar

import pyspark.context

C = TypeVar("C", bound=type)

def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ...
def inherit_doc(cls: C) -> C: ...
24 changes: 13 additions & 11 deletions python/pyspark/ml/evaluation.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@ from pyspark.ml.param.shared import (
HasWeightCol,
)
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.sql.dataframe import DataFrame

class Evaluator(Params, metaclass=abc.ABCMeta):
def evaluate(self, dataset, params: Optional[ParamMap] = ...) -> float: ...
def evaluate(
self, dataset: DataFrame, params: Optional[ParamMap] = ...
) -> float: ...
def isLargerBetter(self) -> bool: ...

class JavaEvaluator(JavaParams, Evaluator, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -75,16 +78,15 @@ class BinaryClassificationEvaluator(
def setLabelCol(self, value: str) -> BinaryClassificationEvaluator: ...
def setRawPredictionCol(self, value: str) -> BinaryClassificationEvaluator: ...
def setWeightCol(self, value: str) -> BinaryClassificationEvaluator: ...

def setParams(
self,
*,
rawPredictionCol: str = ...,
labelCol: str = ...,
metricName: BinaryClassificationEvaluatorMetricType = ...,
weightCol: Optional[str] = ...,
numBins: int = ...
) -> BinaryClassificationEvaluator: ...
def setParams(
self,
*,
rawPredictionCol: str = ...,
labelCol: str = ...,
metricName: BinaryClassificationEvaluatorMetricType = ...,
weightCol: Optional[str] = ...,
numBins: int = ...
) -> BinaryClassificationEvaluator: ...

class RegressionEvaluator(
JavaEvaluator,
Expand Down
20 changes: 13 additions & 7 deletions python/pyspark/ml/feature.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ class _LSHParams(HasInputCol, HasOutputCol):
def getNumHashTables(self) -> int: ...

class _LSH(Generic[JM], JavaEstimator[JM], _LSHParams, JavaMLReadable, JavaMLWritable):
def setNumHashTables(self: P, value) -> P: ...
def setInputCol(self: P, value) -> P: ...
def setOutputCol(self: P, value) -> P: ...
def setNumHashTables(self: P, value: int) -> P: ...
def setInputCol(self: P, value: str) -> P: ...
def setOutputCol(self: P, value: str) -> P: ...

class _LSHModel(JavaModel, _LSHParams):
def setInputCol(self: P, value: str) -> P: ...
Expand Down Expand Up @@ -1518,7 +1518,7 @@ class ChiSqSelector(
fpr: float = ...,
fdr: float = ...,
fwe: float = ...
): ...
) -> ChiSqSelector: ...
def setSelectorType(self, value: str) -> ChiSqSelector: ...
def setNumTopFeatures(self, value: int) -> ChiSqSelector: ...
def setPercentile(self, value: float) -> ChiSqSelector: ...
Expand Down Expand Up @@ -1602,7 +1602,10 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol):
def getVarianceThreshold(self) -> float: ...

class VarianceThresholdSelector(
JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable
JavaEstimator[VarianceThresholdSelectorModel],
_VarianceThresholdSelectorParams,
JavaMLReadable[VarianceThresholdSelector],
JavaMLWritable,
):
def __init__(
self,
Expand All @@ -1615,13 +1618,16 @@ class VarianceThresholdSelector(
featuresCol: str = ...,
outputCol: Optional[str] = ...,
varianceThreshold: float = ...,
): ...
) -> VarianceThresholdSelector: ...
def setVarianceThreshold(self, value: float) -> VarianceThresholdSelector: ...
def setFeaturesCol(self, value: str) -> VarianceThresholdSelector: ...
def setOutputCol(self, value: str) -> VarianceThresholdSelector: ...

class VarianceThresholdSelectorModel(
JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable
JavaModel,
_VarianceThresholdSelectorParams,
JavaMLReadable[VarianceThresholdSelectorModel],
JavaMLWritable,
):
def setFeaturesCol(self, value: str) -> VarianceThresholdSelectorModel: ...
def setOutputCol(self, value: str) -> VarianceThresholdSelectorModel: ...
Expand Down
Loading

0 comments on commit c75cd57

Please sign in to comment.