Skip to content

Commit 8f3c1e6

Browse files
authored
Merge branch 'linus_dev' into sentry
2 parents cff1f4c + 02bfc83 commit 8f3c1e6

File tree

6 files changed

+78
-8
lines changed

6 files changed

+78
-8
lines changed

mlchain/cli/run.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import click
33
import importlib
44
import sys
5+
import copy
56
import GPUtil
67
from mlchain import logger
78
from mlchain.server import MLServer
@@ -92,25 +93,33 @@ def run_command(entry_file, host, port, bind, wrapper, server, workers, config,
9293
name, mode, api_format, ngrok, kws):
9394
kws = list(kws)
9495
if isinstance(entry_file, str) and not os.path.exists(entry_file):
95-
kws = [entry_file] + kws
96+
kws = [f'--entry_file={entry_file}'] + kws
9697
entry_file = None
9798
from mlchain import config as mlconfig
9899
default_config = False
100+
99101
if config is None:
100102
default_config = True
101103
config = 'mlconfig.yaml'
102104

103-
if os.path.isfile(config):
104-
config = mlconfig.load_file(config)
105+
config_path = copy.deepcopy(config)
106+
if os.path.isfile(config_path) and os.path.exists(config_path):
107+
config = mlconfig.load_file(config_path)
105108
if config is None:
106-
raise AssertionError("Not support file config {0}".format(config))
109+
raise SystemExit("Config file {0} are not supported".format(config_path))
107110
else:
108111
if not default_config:
109-
raise FileNotFoundError("Not found file {0}".format(config))
110-
config = {}
112+
raise SystemExit("Can't find config file {0}".format(config_path))
113+
else:
114+
raise SystemExit("Can't find mlchain config file. Please double check your current working directory. Or use `mlchain init` to initialize a new ones here.")
111115
if 'mode' in config and 'env' in config['mode']:
112116
if mode in config['mode']['env']:
113117
config['mode']['default'] = mode
118+
elif mode is not None:
119+
available_mode = list(config['mode']['env'].keys())
120+
available_mode = [each for each in available_mode if each != 'default']
121+
raise SystemExit(
122+
f"No {mode} mode are available. Found these mode in config file: {available_mode}")
114123
mlconfig.load_config(config)
115124
for kw in kws:
116125
if kw.startswith('--'):
@@ -124,6 +133,8 @@ def run_command(entry_file, host, port, bind, wrapper, server, workers, config,
124133
raise AssertionError("Unexpected param {0}".format(kw))
125134
model_id = mlconfig.get_value(None, config, 'model_id', None)
126135
entry_file = mlconfig.get_value(entry_file, config, 'entry_file', 'server.py')
136+
if not os.path.exists(entry_file):
137+
raise SystemExit(f"Entry file {entry_file} not found in current working directory.")
127138
host = mlconfig.get_value(host, config, 'host', 'localhost')
128139
port = mlconfig.get_value(port, config, 'port', 5000)
129140
server = mlconfig.get_value(server, config, 'server', 'flask')

mlchain/config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ def load_config(self, path, mode=None):
102102
for mode in ['default', default]:
103103
if mode in data['mode']['env']:
104104
for k, v in data['mode']['env'][mode].items():
105-
environ[k] = str(v)
105+
if k in environ:
106+
data['mode']['env'][mode][k] = environ[k]
107+
else:
108+
environ[k] = str(v)
106109
self.update(data['mode']['env'][mode])
107110

108111
def get_client_config(self, name):
@@ -150,6 +153,11 @@ def load_config(data):
150153
if 'env' in data['mode']:
151154
for mode in ['default', default]:
152155
if mode in data['mode']['env']:
156+
for k, v in data['mode']['env'][mode].items():
157+
if k in environ:
158+
data['mode']['env'][mode][k] = environ[k]
159+
else:
160+
environ[k] = str(v)
153161
mlconfig.update(data['mode']['env'][mode])
154162

155163
if mlconfig.MLCHAIN_SENTRY_DSN is not None and data.get('wrapper', None) != 'gunicorn':

tests/dummy_server/server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@ def __init__(self):
1010
pass
1111

1212

13-
def predict(self, image: np.ndarray):
13+
def predict(self, image: np.ndarray = None):
1414
"""
1515
Resize input to 100 by 100.
1616
Args:
1717
image (numpy.ndarray): An input image.
1818
Returns:
1919
The image (np.ndarray) at 100 by 100.
2020
"""
21+
if image is None:
22+
return 'Hihi'
2123
image = cv2.resize(image, (100, 100))
2224
return image
2325

tests/test_limiter.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,30 @@ def test_limiter_2(self):
3030
total_time = time.time() - start_time
3131
assert total_time >= 3
3232

33+
def test_limiter_fail(self):
34+
try:
35+
limiter = RateLimiter(max_calls=1, period=0)
36+
except ValueError:
37+
pass
38+
39+
try:
40+
limiter = RateLimiter(max_calls=0, period=1)
41+
except ValueError:
42+
pass
43+
44+
def test_limiter_with_callback(self):
45+
start_time = time.time()
46+
global abc
47+
abc = 0
48+
def callback(i):
49+
global abc
50+
abc += 1
51+
limiter = RateLimiter(max_calls=3, period=1, callback=callback)
52+
for i in range(10):
53+
with limiter:
54+
pass
55+
total_time = time.time() - start_time
56+
assert total_time >= 3
57+
3358
if __name__ == '__main__':
3459
unittest.main()

tests/test_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import logging
2+
import unittest
3+
4+
from mlchain.base.utils import *
5+
logger = logging.getLogger()
6+
7+
class TestUtils(unittest.TestCase):
8+
def __init__(self, *args, **kwargs):
9+
unittest.TestCase.__init__(self, *args, **kwargs)
10+
logger.info("Running utils test")
11+
12+
def test_nothing(self):
13+
pass
14+
15+
if __name__ == '__main__':
16+
unittest.main()

tests/test_workflow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ def dummy_task():
104104
logger.info(x)
105105
background.stop()
106106

107+
try:
108+
background = Background(task, interval=0.01).run(pass_fail_job=False)
109+
time.sleep(0.02)
110+
logger.info(x)
111+
background.stop()
112+
except:
113+
pass
114+
107115
def test_mlchain_async_task(self):
108116
async def dummy_task(n):
109117
return n+1

0 commit comments

Comments
 (0)