Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Survival Analysis for Churn Prediction #275

Merged
merged 20 commits into from
Jul 23, 2019
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The examples folder contains example solutions across a variety of Google Cloud
* [Cloud SQL Custom Metric](examples/cloud-sql-custom-metric) - An example of creating a Stackdriver custom metric monitoring Cloud SQL Private Services IP consumption.
* [CloudML Bank Marketing](examples/cloudml-bank-marketing) - Notebook for creating a classification model for marketing using CloudML.
* [CloudML Bee Health Detection](examples/cloudml-bee-health-detection) - Detect if a bee is unhealthy based on an image of it and its subspecies.
* [CloudML Churn Prediction](examples/cloudml-churn-prediction) - Predict users' propensity to churn using Survival Analysis.
* [CloudML Energy Price Forecasting](examples/cloudml-energy-price-forecasting) - Predicting the future energy price based on historical price and weather.
* [CloudML Fraud Detection](examples/cloudml-fraud-detection) - Fraud detection model for credit-cards transactions.
* [CloudML Sentiment Analysis](examples/cloudml-sentiment-analysis) - Sentiment analysis for movie reviews using TensorFlow `RNNEstimator`.
Expand Down
175 changes: 175 additions & 0 deletions examples/cloudml-churn-prediction/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Churn Prediction with Survival Analysis
This model uses Survival Analysis to classify customers into time-to-churn buckets. The model output can be used to calculate each user's churn score for different durations.

The same methodology can be used used to predict customers' total lifetime from their "birth" (intital signup, or t = 0) and from the current state (t > 0).

## Why is Survival Analysis Helpful for Churn Prediction?
Survival Analysis is used to predict the time-to-event, when the event in question has not necessarily occurred yet. In this case, the event is a customer churning.

If a customer is still active, or is "censored" using Survival Analysis terminology, we do not know their final lifetime or when they will churn. If we assume that the customer's lifetime ended at the time of prediction (or training), the results will be biased (underestimating lifetime). Throwing out active users will also bias results through information loss.

By using a Survival Analysis approach to churn prediction, the entire population (regardless of current tenure or status) can be included.

## Dataset
This example uses the public [Google Analytics Sample Dataset](https://support.google.com/analytics/answer/7586738?hl=en) on BigQuery and artificially generated subscription start and end dates as input.

To create a churn model with real data, omit the 'Generate Data' step in the Beam pipeline in preprocessor/preprocessor/preprocess.py. Instead of randomly generating values, the BigQuery results should include the following fields: start_date, end_date, and active. These values correspond to the user's subscription lifetime and their censorship status.

## Setup
### Set up GCP credentials
```shell
gcloud auth login
gcloud auth application-default login
```

### Set up Python environment
```shell
virtualenv venv
source ./venv/bin/activate
pip install -r requirements.txt
```


## Preprocessing
Using Dataflow, the data preprocessing script reads user data from BigQuery, generates random (fake) time-to-churn labels, creates TFRecords, and adds them to Google Cloud Storage.

Each record should have three labels before preprocessing:
1. **active**: indicator for censorship. It is 0 if user is inactive (uncensored) and 1 if the user is active (censored).
2. **start_date**: Date when user began their lifetime.
3. **end_date**: Date when user ends their lifetime. It is None if the user is still active.
`_generateFakeData` randomly generates these three fields in order to create fake sample data. In practice, these fields should be available in some form in the historical data.

During preprocessing, the aforementioned fields are combined into a single `2*n-dimensional indicator array`, where n is the number of bounded lifetime buckets (i.e. n = 2 for 0-2 months, 2-3 months, 3+ months):
+ indicator array = [survival array | failure array]
+ survival array = 1 if individual has survived interval, 0 otherwise (for each of the n intervals)
+ failure array = 1 if individual failed during interval, 0 otherwise
+ If an individual is censored (still active), their failure array contains only 0s

### Set Constants
```shell
BUCKET="gs://[GCS Bucket]"
NOW="$(date +%Y%m%d%H%M%S)"
OUTPUT_DIR="${BUCKET}/output_data/${NOW}"
PROJECT="[PROJECT ID]"
```

### Run locally with Dataflow
```shell
python -m preprocessor.run_preprocessing \
--output_dir "${OUTPUT_DIR}" \
--project_id "${PROJECT}"
```

### Run on the Cloud with Dataflow
The top-level preprocessor directory should be the working directory for running the preprocessing script. The setup.py file should be located in the working directory.

```shell
cd preprocessor

python -m run_preprocessing \
--cloud \
--output_dir "${OUTPUT_DIR}" \
--project_id "${PROJECT}"

cd ..
```


## Model Training
Model training minimizes the negative of the log likelihood function for a statistical Survival Analysis model with discrete-time intervals. The loss function is based off the paper [A scalable discrete-time survival model for neural networks](https://peerj.com/articles/6257.pdf).

For each record, the conditional hazard probability is the probability of failure in an interval, given that individual has survived at least to the beginning of the interval. Therefore, the probability that a user survives the given interval, or the likelihood, is the product of (1 - hazard) for all of the earlier (and current) intervals.

So, the log likelihood is: ln(current hazard) + sum(ln(1 - earlier hazards)) summed over all time intervals. Equivalently, each individual's log likelihood is: `ln(1 - (1 if survived 0 if not)*(Prob of failure)) + ln(1 - (1 if failed 0 if not)*(Prob of survival))` summed over all time intervals.

### Set Constants
The TFRecord output of the preprocessing job should be used as input to the training job.

Make sure to navigate back to the top-level directory.

```shell
INPUT_DIR="${OUTPUT_DIR}"
MODEL_DIR="${BUCKET}/model/$(date +%Y%m%d%H%M%S)"
```

### Train locally with AI Platform
```shell
gcloud ai-platform local train \
--module-name trainer.task \
--package-path trainer/trainer \
--job-dir ${MODEL_DIR} \
-- \
--input-dir "${INPUT_DIR}"
```

### Train on the Cloud with AI Platform
```shell
JOB_NAME="train_$(date +%Y%m%d%H%M%S)"

gcloud ai-platform jobs submit training ${JOB_NAME} \
--job-dir ${MODEL_DIR} \
--config trainer/config.yaml \
--module-name trainer.task \
--package-path trainer/trainer \
--region us-east1 \
--python-version 3.5 \
--runtime-version 1.13 \
-- \
--input-dir ${INPUT_DIR}
```

### Hyperparameter Tuning with AI Platform
```shell
JOB_NAME="hptuning_$(date +%Y%m%d%H%M%S)"

gcloud ai-platform jobs submit training ${JOB_NAME} \
--job-dir ${MODEL_DIR} \
--module-name trainer.task \
--package-path trainer/trainer \
--config trainer/hptuning_config.yaml \
--python-version 3.5 \
--runtime-version 1.13 \
-- \
--input-dir ${INPUT_DIR}
```

### Launch Tensorboard
```shell
tensorboard --log-dir ${MODEL_DIR}
```

## Predictions
The model predicts the conditional likelihood that a user survived an interval given that the user reached the interval. It outputs an n-dimensional vector, where each element corresponds to predicted conditional probability of surviving through end of time interval (1 - hazard).

In order to determine the predicted class, the cumulative product of the conditional probabilities must be compared to some threshold.

### Deploy model on AI Platform
The SavedModel was saved in a timestamped subdirectory of model_dir.
```shell
MODEL_NAME="survival_model"
VERSION_NAME="demo_version"
SAVED_MODEL_DIR=$(gsutil ls $MODEL_DIR/export/export | tail -1)

gcloud ai-platform models create $MODEL_NAME \
--regions us-east1

gcloud ai-platform versions create $VERSION_NAME \
--model $MODEL_NAME \
--origin $SAVED_MODEL_DIR \
--runtime-version=1.13 \
--framework TENSORFLOW \
--python-version=3.5
```
### Running batch predictions
```shell
INPUT_PATHS=$INPUT_DIR/data/test/*
OUTPUT_PATH=<GCS directory for predictions>
JOB_NAME="predict_$(date +%Y%m%d%H%M%S)"

gcloud ai-platform jobs submit prediction $JOB_NAME \
--model $MODEL_NAME \
--input-paths $INPUT_PATHS \
--output-path $OUTPUT_PATH \
--region us-east1 \
--data-format TF_RECORD
```
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2019 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature management for data preprocessing."""

import tensorflow as tf
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import dataset_schema


BQ_FEATURES = [
'fullVisitorId', 'totals.visits', 'totals.hits',
'totals.pageviews', 'device.deviceCategory',
'geoNetwork.continent', 'geoNetwork.subContinent', 'socialEngagementType',
'channelGrouping']

CATEGORICAL_COLUMNS = [
'deviceCategory',
'continent',
'subContinent',
'socialEngagementType',
'channelGrouping',
]

METADATA_COLUMNS = [
'fullVisitorId'
]

NUMERIC_COLUMNS = [
'visits',
'hits',
'pageviews',
'duration'
]

LABEL_ARRAY_COLUMN = 'labelArray'

BOOLEAN_COLUMNS = []

LABEL_COLUMNS = [
'start_date',
'end_date',
'active',
'duration',
'labelArray',
'label'
]


LABEL_VALUES = ['0-2M', '2-4M', '4-6M', '6-8M', '8M+']
LABEL_CEILINGS = [60, 120, 180, 240] # number of days for ceiling of each class


def get_raw_feature_spec():
"""Returns TF feature spec for preprocessing."""

features = {}
features.update(
{key: tf.FixedLenFeature([], dtype=tf.string)
for key in CATEGORICAL_COLUMNS}
)
features.update(
{key: tf.FixedLenFeature([], dtype=tf.float32)
for key in NUMERIC_COLUMNS}
)
features.update(
{key: tf.FixedLenFeature([], dtype=tf.int64)
for key in BOOLEAN_COLUMNS}
)
features[LABEL_ARRAY_COLUMN] = tf.FixedLenFeature(
[2*len(LABEL_CEILINGS)], tf.float32)

return features


RAW_FEATURE_SPEC = get_raw_feature_spec()


def get_raw_dataset_metadata():
return dataset_metadata.DatasetMetadata(
dataset_schema.from_feature_spec(RAW_FEATURE_SPEC))


def preprocess_fn(inputs):
"""TensorFlow transform preprocessing function.

Args:
inputs: Dict of key to Tensor.
Returns:
Dict of key to transformed Tensor.
"""
outputs = inputs.copy()
# For all categorical columns except the label column, we generate a
# vocabulary but do not modify the feature. This vocabulary is instead
# used in the trainer, by means of a feature column, to convert the feature
# from a string to an integer id.
for key in CATEGORICAL_COLUMNS:
tft.vocabulary(inputs[key], vocab_filename=key)
return outputs
Loading