-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
estimator.py
80 lines (61 loc) · 2.28 KB
/
estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python
# coding: utf8
""" Utility functions for creating estimator. """
from pathlib import Path
from os.path import join
from tempfile import gettempdir
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.contrib import predictor
# pylint: enable=import-error
from ..model import model_fn, InputProviderFactory
from ..model.provider import get_default_model_provider
# Default exporting directory for predictor.
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
def get_default_model_dir(model_dir):
"""
Transforms a string like 'spleeter:2stems' into an actual path.
:param model_dir:
:return:
"""
model_provider = get_default_model_provider()
return model_provider.get(model_dir)
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionary of parameters for building the model
Returns:
a tensorflow estimator
"""
# Load model.
params['model_dir'] = get_default_model_dir(params['model_dir'])
params['MWF'] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=config
)
return estimator
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
""" Exports given estimator as predictor into the given directory
and returns associated tf.predictor instance.
:param estimator: Estimator to export.
:param directory: (Optional) path to write exported model into.
"""
input_provider = InputProviderFactory.get(estimator.params)
def receiver():
features = input_provider.get_input_dict_placeholders()
return tf.estimator.export.ServingInputReceiver(features, features)
estimator.export_saved_model(directory, receiver)
versions = [
model for model in Path(directory).iterdir()
if model.is_dir() and 'temp' not in str(model)]
latest = str(sorted(versions)[-1])
return predictor.from_saved_model(latest)