Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ax/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import json
from abc import abstractmethod
from collections.abc import Iterable
from copy import deepcopy
from functools import reduce
from hashlib import md5
from io import StringIO
Expand Down Expand Up @@ -539,6 +540,10 @@ def from_multiple_data(

return data_out

def clone(self) -> Data:
"""Returns a new Data object with the same underlying dataframe."""
return Data(df=deepcopy(self.df), description=self.description)


def set_single_trial(data: Data) -> Data:
"""Returns a new Data object where we set all rows to have the same
Expand Down
4 changes: 3 additions & 1 deletion ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,13 +1648,15 @@ def clone_with(
if isinstance(trial, BatchTrial) or isinstance(trial, Trial):
trial.clone_to(cloned_experiment)
trial_data, timestamp = self.lookup_data_for_trial(trial_index)
# Clone the data to avoid overwriting the original in the DB.
trial_data = trial_data.clone()
if timestamp != -1:
data_by_trial[trial_index] = OrderedDict([(timestamp, trial_data)])
else:
raise NotImplementedError(f"Cloning of {type(trial)} is not supported.")
if data is not None:
# If user passed in data, use it.
cloned_experiment.attach_data(data)
cloned_experiment.attach_data(data.clone())
else:
# Otherwise, attach the data extracted from the original experiment.
cloned_experiment._data_by_trial = data_by_trial
Expand Down
17 changes: 15 additions & 2 deletions ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
from __future__ import annotations

from collections.abc import Iterable, Sequence

from copy import deepcopy
from logging import Logger
from typing import Any, Generic, Optional, TypeVar

import numpy as np

import pandas as pd
from ax.core.data import Data
from ax.core.types import TMapTrialEvaluation
Expand Down Expand Up @@ -69,6 +68,10 @@ def default_value(self) -> T:
def value_type(self) -> type:
return type(self._default_value)

def clone(self) -> MapKeyInfo[T]:
"""Return a copy of this MapKeyInfo."""
return MapKeyInfo(key=self.key, default_value=deepcopy(self.default_value))


class MapData(Data):
"""Class storing mapping-like results for an experiment.
Expand Down Expand Up @@ -326,6 +329,16 @@ def deserialize_init_args(
]
return super().deserialize_init_args(args=args)

def clone(self) -> MapData:
"""Returns a new ``MapData`` object with the same underlying dataframe
and map key infos.
"""
return MapData(
df=deepcopy(self.map_df),
map_key_infos=[mki.clone() for mki in self.map_key_infos],
description=self.description,
)

def subsample(
self,
map_key: Optional[str] = None,
Expand Down
12 changes: 12 additions & 0 deletions ax/core/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ def test_Data(self) -> None:
0.5,
)

def test_clone(self) -> None:
data = Data(df=self.df, description="test")
data._db_id = 1234
data_clone = data.clone()
# Check equality of the objects.
self.assertTrue(data.df.equals(data_clone.df))
self.assertEqual(data.description, data_clone.description)
# Make sure it's not the original object or df.
self.assertIsNot(data, data_clone)
self.assertIsNot(data.df, data_clone.df)
self.assertIsNone(data_clone._db_id)

def test_BadData(self) -> None:
df = pd.DataFrame([{"bad_field": "0_0", "bad_field_2": {"x": 0, "y": "a"}}])
with self.assertRaises(ValueError):
Expand Down
12 changes: 12 additions & 0 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
from ax.runners.synthetic import SyntheticRunner
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
from ax.storage.sqa_store.db import init_test_engine_and_session_factory
from ax.storage.sqa_store.load import load_experiment
from ax.storage.sqa_store.save import save_experiment
from ax.utils.common.constants import EXPERIMENT_IS_TEST_WARNING, Keys
from ax.utils.common.random import set_rng_seed
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -1018,12 +1021,15 @@ def test_is_test_warning(self) -> None:
)

def test_clone_with(self) -> None:
init_test_engine_and_session_factory(force_init=True)
experiment = get_branin_experiment(
with_batch=True,
with_completed_trial=True,
with_status_quo=True,
with_choice_parameter=True,
)
# Save the experiment to set db_ids.
save_experiment(experiment)

larger_search_space = SearchSpace(
parameters=[
Expand Down Expand Up @@ -1086,6 +1092,12 @@ def test_clone_with(self) -> None:
checked_cast(Arm, experiment.status_quo).parameters, {"x1": 0.0, "x2": 0.0}
)

# Save the cloned experiment to db and make sure the original
# experiment is unchanged in the db.
save_experiment(cloned_experiment)
reloaded_experiment = load_experiment(experiment.name)
self.assertEqual(experiment, reloaded_experiment)

# clone specific trials and new data only
df = pd.DataFrame(
{
Expand Down
13 changes: 13 additions & 0 deletions ax/core/tests/test_map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,19 @@ def test_properties(self) -> None:
self.assertEqual(self.mmd.map_keys, ["epoch"])
self.assertEqual(self.mmd.map_key_to_type, {"epoch": int})

def test_clone(self) -> None:
self.mmd._db_id = 1234
clone = self.mmd.clone()
# Make sure the two objects are equal.
self.assertTrue(clone.map_df.equals(self.mmd.map_df))
self.assertTrue(clone.df.equals(self.mmd.df))
self.assertEqual(clone.map_key_infos, self.mmd.map_key_infos)
self.assertEqual(clone.description, self.mmd.description)
# Make sure it's not the original object or df.
self.assertIsNot(clone, self.mmd)
self.assertIsNot(clone.map_df, self.mmd.map_df)
self.assertIsNone(clone._db_id)

def test_combine(self) -> None:
data = MapData.from_multiple_map_data([])
self.assertEqual(data.map_df.size, 0)
Expand Down