Skip to content

Commit

Permalink
Bug fix: Return data associated with new runs created by FileStore.cr…
Browse files Browse the repository at this point in the history
…eate_run() (mlflow#1328)
  • Loading branch information
dbczumar authored and aarondav committed May 29, 2019
1 parent 727764b commit 63d0406
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 31 deletions.
10 changes: 8 additions & 2 deletions mlflow/entities/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from mlflow.entities._mlflow_object import _MLflowObject
from mlflow.entities.run_data import RunData
from mlflow.entities.run_info import RunInfo
from mlflow.exceptions import MlflowException
from mlflow.protos.service_pb2 import Run as ProtoRun


Expand All @@ -11,7 +12,7 @@ class Run(_MLflowObject):

def __init__(self, run_info, run_data):
if run_info is None:
raise Exception("run_info cannot be None")
raise MlflowException("run_info cannot be None")
self._info = run_info
self._data = run_data

Expand Down Expand Up @@ -45,4 +46,9 @@ def from_proto(cls, proto):
return cls(RunInfo.from_proto(proto.info), RunData.from_proto(proto.data))

def to_dictionary(self):
return {"info": dict(self.info), "data": self.data.to_dictionary()}
run_dict = {
"info": dict(self.info),
}
if self.data:
run_dict["data"] = self.data.to_dictionary()
return run_dict
2 changes: 1 addition & 1 deletion mlflow/store/file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def create_run(self, experiment_id, user_id, start_time, tags):
mkdir(run_dir, FileStore.ARTIFACTS_FOLDER_NAME)
for tag in tags:
self.set_tag(run_uuid, tag)
return Run(run_info=run_info, run_data=None)
return self.get_run(run_id=run_uuid)

def get_run(self, run_id):
"""
Expand Down
54 changes: 39 additions & 15 deletions tests/entities/test_run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pytest

from mlflow.entities import Run, Metric, RunData, RunStatus, RunInfo, LifecycleStage
from mlflow.exceptions import MlflowException
from tests.entities.test_run_data import TestRunData
from tests.entities.test_run_info import TestRunInfo

Expand All @@ -19,26 +22,41 @@ def test_creation_and_hydration(self):

self._check_run(run1, run_info, metrics, params, tags)

as_dict = {"info": {"run_uuid": run_id,
"run_id": run_id,
"experiment_id": experiment_id,
"user_id": user_id,
"status": status,
"start_time": start_time,
"end_time": end_time,
"lifecycle_stage": lifecycle_stage,
"artifact_uri": artifact_uri,
},
"data": {"metrics": {m.key: m.value for m in metrics},
"params": {p.key: p.value for p in params},
"tags": {t.key: t.value for t in tags}}
}
self.assertEqual(run1.to_dictionary(), as_dict)
expected_info_dict = {
"run_uuid": run_id,
"run_id": run_id,
"experiment_id": experiment_id,
"user_id": user_id,
"status": status,
"start_time": start_time,
"end_time": end_time,
"lifecycle_stage": lifecycle_stage,
"artifact_uri": artifact_uri,
}
self.assertEqual(
run1.to_dictionary(),
{
"info": expected_info_dict,
"data": {
"metrics": {m.key: m.value for m in metrics},
"params": {p.key: p.value for p in params},
"tags": {t.key: t.value for t in tags},
}
}
)

proto = run1.to_proto()
run2 = Run.from_proto(proto)
self._check_run(run2, run_info, metrics, params, tags)

run3 = Run(run_info, None)
self.assertEqual(
run3.to_dictionary(),
{
"info": expected_info_dict,
}
)

def test_string_repr(self):
run_info = RunInfo(
run_uuid="hi", run_id="hi", experiment_id=0,
Expand All @@ -53,3 +71,9 @@ def test_string_repr(self):
"lifecycle_stage='active', run_id='hi', run_uuid='hi', "
"start_time=0, status=4, user_id='user-id'>>")
assert str(run1) == expected

def test_creating_run_with_absent_info_throws_exception(self):
run_data = TestRunData._create()[0]
with pytest.raises(MlflowException) as no_info_exc:
Run(None, run_data)
assert "run_info cannot be None" in str(no_info_exc)
24 changes: 23 additions & 1 deletion tests/store/test_file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import mock
import pytest

from mlflow.entities import Metric, Param, RunTag, ViewType, LifecycleStage, RunStatus
from mlflow.entities import Metric, Param, RunTag, ViewType, LifecycleStage, RunStatus, RunData
from mlflow.exceptions import MlflowException, MissingConfigException
from mlflow.store import SEARCH_MAX_RESULTS_DEFAULT
from mlflow.store.file_store import FileStore
Expand Down Expand Up @@ -270,6 +270,28 @@ def test_create_run_in_deleted_experiment(self):
with pytest.raises(Exception):
fs.create_run(exp_id, 'user', 0, [])

def test_create_run_returns_expected_run_data(self):
fs = FileStore(self.test_root)
no_tags_run = fs.create_run(
experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, user_id='user', start_time=0, tags=[])
assert isinstance(no_tags_run.data, RunData)
assert len(no_tags_run.data.tags) == 0

tags_dict = {
"my_first_tag": "first",
"my-second-tag": "2nd",
}
tags_entities = [
RunTag(key, value) for key, value in tags_dict.items()
]
tags_run = fs.create_run(
experiment_id=FileStore.DEFAULT_EXPERIMENT_ID,
user_id='user',
start_time=0,
tags=tags_entities)
assert isinstance(tags_run.data, RunData)
assert tags_run.data.tags == tags_dict

def _experiment_id_edit_func(self, old_dict):
old_dict["experiment_id"] = int(old_dict["experiment_id"])
return old_dict
Expand Down
25 changes: 13 additions & 12 deletions tests/tracking/test_rest_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,18 +265,19 @@ def test_create_run_all_args(mlflow_client, parent_run_id_kwarg):
created_run = mlflow_client.create_run(experiment_id, **create_run_kwargs)
run_id = created_run.info.run_id
print("Run id=%s" % run_id)
run = mlflow_client.get_run(run_id)
assert run.info.run_id == run_id
assert run.info.run_uuid == run_id
assert run.info.experiment_id == experiment_id
assert run.info.user_id == user
assert run.info.start_time == create_run_kwargs["start_time"]
for tag in create_run_kwargs["tags"]:
assert tag in run.data.tags
assert run.data.tags.get(MLFLOW_USER) == user
assert run.data.tags.get(MLFLOW_RUN_NAME) == "my name"
assert run.data.tags.get(MLFLOW_PARENT_RUN_ID) == parent_run_id_kwarg or "7"
assert mlflow_client.list_run_infos(experiment_id) == [run.info]
fetched_run = mlflow_client.get_run(run_id)
for run in [created_run, fetched_run]:
assert run.info.run_id == run_id
assert run.info.run_uuid == run_id
assert run.info.experiment_id == experiment_id
assert run.info.user_id == user
assert run.info.start_time == create_run_kwargs["start_time"]
for tag in create_run_kwargs["tags"]:
assert tag in run.data.tags
assert run.data.tags.get(MLFLOW_USER) == user
assert run.data.tags.get(MLFLOW_RUN_NAME) == "my name"
assert run.data.tags.get(MLFLOW_PARENT_RUN_ID) == parent_run_id_kwarg or "7"
assert mlflow_client.list_run_infos(experiment_id) == [run.info]


def test_create_run_defaults(mlflow_client):
Expand Down

0 comments on commit 63d0406

Please sign in to comment.