Skip to content

Add serialization test and fix NumUniqueSeparators #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Future Release
==============
* Enhancements
* Fixes
* Fix ``NumUniqueSeparators`` to allow for serialization and deserialization (:pr:`122`)
* Changes
* Speed up LSA primitive initialization (:pr:`118`)
* Documentation Changes
Expand Down
16 changes: 8 additions & 8 deletions nlp_primitives/num_unique_separators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
from woodwork.column_schema import ColumnSchema
from woodwork.logical_types import IntegerNullable, NaturalLanguage

NATURAL_LANGUAGE_SEPARATORS = " .,!?;\n"
NATURAL_LANGUAGE_SEPARATORS = [" ", ".", ",", "!", "?", ";", "\n"]


class NumUniqueSeparators(TransformPrimitive):
"""Calculates the number of unique separators.

Description:
Given a string and an iterable of separators, determine
Given a string and a list of separators, determine
the number of unique separators in each string. If a string
is null determined by pd.isnull return pd.NA.

Args:
separators (str, optional): an iterable of characters to count.
" .,!?;\n" is used by default.
separators (list, optional): a list of separator characters to count.
`[`" ", ".", ",", "!", "?", ";", "\n"]` is used by default.

Examples:
>>> x = ['First. Line.', 'This. is the second, line!', 'notinlist@#$%^%&']
>>> num_unique_separators = NumUniqueSeparators(".,!")
>>> x = ["First. Line.", "This. is the second, line!", "notinlist@#$%^%&"]
>>> num_unique_separators = NumUniqueSeparators([".", ",", "!"])
>>> num_unique_separators(x).tolist()
[1, 3, 0]
"""
Expand All @@ -31,13 +31,13 @@ class NumUniqueSeparators(TransformPrimitive):

def __init__(self, separators=NATURAL_LANGUAGE_SEPARATORS):
assert separators is not None, "separators needs to be defined"
self.separators = set(separators)
self.separators = separators

def get_function(self):
def count_unique_separator(s):
if pd.isnull(s):
return pd.NA
return len(self.separators.intersection(set(s)))
return len(set(self.separators).intersection(set(s)))

def get_separator_count(column):
return column.apply(count_unique_separator)
Expand Down
33 changes: 0 additions & 33 deletions nlp_primitives/tests/test_lsa.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import numpy as np
import pandas as pd
from featuretools import (
calculate_feature_matrix,
dfs,
load_features,
save_features
)

from ..lsa import LSA
from .test_utils import PrimitiveT, find_applicable_primitives, valid_dfs
Expand Down Expand Up @@ -56,30 +50,3 @@ def test_with_featuretools(self, es):
primitive_instance = self.primitive()
transform.append(primitive_instance)
valid_dfs(es, aggregation, transform, self.primitive.name.upper(), multi_output=True)

def test_serialize(self, es):
features = dfs(entityset=es,
target_dataframe_name="log",
trans_primitives=[self.primitive],
max_features=-1,
max_depth=3,
features_only=True)

feat_to_serialize = None
for feature in features:
if feature.primitive.__class__ == self.primitive:
feat_to_serialize = feature
break
for base_feature in feature.get_dependencies(deep=True):
if base_feature.primitive.__class__ == self.primitive:
feat_to_serialize = base_feature
break
assert feat_to_serialize is not None

df1 = calculate_feature_matrix([feat_to_serialize], entityset=es)

new_feat = load_features(save_features([feat_to_serialize]))[0]

df2 = calculate_feature_matrix([new_feat], entityset=es)

assert df1.equals(df2)
40 changes: 39 additions & 1 deletion nlp_primitives/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@

import featuretools as ft
import pytest
from featuretools import dfs, list_primitives
from featuretools import (
calculate_feature_matrix,
dfs,
list_primitives,
load_features,
save_features
)
from featuretools.tests.testing_utils import make_ecommerce_entityset

ft.primitives._load_primitives()
Expand Down Expand Up @@ -50,6 +56,38 @@ def test_arg_init(self):
if parameter.default is not parameter.empty:
assert hasattr(primitive_, name)

def test_serialize(self, es):
features = dfs(entityset=es,
target_dataframe_name="log",
trans_primitives=[self.primitive],
max_features=-1,
max_depth=3,
features_only=True)

feat_to_serialize = None
for feature in features:
if feature.primitive.__class__ == self.primitive:
feat_to_serialize = feature
break
for base_feature in feature.get_dependencies(deep=True):
if base_feature.primitive.__class__ == self.primitive:
feat_to_serialize = base_feature
break
assert feat_to_serialize is not None

# Skip calculating feature matrix for long running primitives
skip_primitives = ["elmo"]

if self.primitive.name not in skip_primitives:
df1 = calculate_feature_matrix([feat_to_serialize], entityset=es)

new_feat = load_features(save_features([feat_to_serialize]))[0]
assert isinstance(new_feat, ft.FeatureBase)

if self.primitive.name not in skip_primitives:
df2 = calculate_feature_matrix([new_feat], entityset=es)
assert df1.equals(df2)


def find_applicable_primitives(primitive):
from featuretools.primitives.utils import (
Expand Down