Skip to content

Commit 22609f7

Browse files
xingyousongcopybara-github
authored andcommitted
Aesthetic + small bug fixes to Vizier service
PiperOrigin-RevId: 646600873
1 parent 086ab0a commit 22609f7

File tree

1 file changed

+30
-45
lines changed

1 file changed

+30
-45
lines changed

vizier/_src/service/vizier_service.py

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
"""RPC functions implemented from vizier_service.proto."""
18+
1819
import collections
1920
import datetime
2021
import threading
@@ -24,7 +25,6 @@
2425
import grpc
2526
import numpy as np
2627
import sqlalchemy as sqla
27-
2828
from vizier import pythia
2929
from vizier import pyvizier as vz
3030
from vizier._src.service import constants
@@ -55,6 +55,10 @@ def _get_current_time() -> timestamp_pb2.Timestamp:
5555
return now
5656

5757

58+
StudyResource = resources.StudyResource
59+
TrialResource = resources.TrialResource
60+
61+
5862
# TODO: remove context = None
5963
# TODO: remove context = None
6064
class VizierServicer(vizier_service_pb2_grpc.VizierServiceServicer):
@@ -196,7 +200,7 @@ def CreateStudy(
196200
study_id = study.display_name
197201

198202
# Finally create study in database and return it.
199-
study.name = resources.StudyResource(owner_id, study_id).name
203+
study.name = StudyResource(owner_id, study_id).name
200204
self.datastore.create_study(study)
201205
return study
202206

@@ -214,8 +218,8 @@ def ListStudies(
214218
context: Optional[grpc.ServicerContext] = None,
215219
) -> vizier_service_pb2.ListStudiesResponse:
216220
"""Lists all the studies in a region for an associated project."""
217-
list_of_studies = self.datastore.list_studies(request.parent)
218-
return vizier_service_pb2.ListStudiesResponse(studies=list_of_studies)
221+
studies = self.datastore.list_studies(request.parent)
222+
return vizier_service_pb2.ListStudiesResponse(studies=studies)
219223

220224
def DeleteStudy(
221225
self,
@@ -283,7 +287,7 @@ def SuggestTrials(
283287
)
284288
grpc_util.handle_exception(e, context)
285289

286-
study_resource = resources.StudyResource.from_name(study_name)
290+
study_resource = StudyResource.from_name(study_name)
287291
study_id = study_resource.study_id
288292
owner_id = study_resource.owner_id
289293

@@ -306,14 +310,12 @@ def SuggestTrials(
306310
start_time = _get_current_time()
307311
# Create a new Op if there aren't any active (not done) ops.
308312
try:
309-
new_op_number = (
310-
self.datastore.max_suggestion_operation_number(
311-
study_name, request.client_id
312-
)
313-
+ 1
313+
old_op_number = self.datastore.max_suggestion_operation_number(
314+
study_name, request.client_id
314315
)
315316
except custom_errors.NotFoundError:
316-
new_op_number = 1
317+
old_op_number = 0
318+
new_op_number = old_op_number + 1
317319
new_op_name = resources.SuggestionOperationResource(
318320
owner_id, study_id, request.client_id, new_op_number
319321
).name
@@ -441,9 +443,7 @@ def SuggestTrials(
441443
new_trial = new_trials.pop()
442444
trial_id = self.datastore.max_trial_id(request.parent) + 1
443445
new_trial.id = str(trial_id)
444-
new_trial.name = resources.TrialResource(
445-
owner_id, study_id, trial_id
446-
).name
446+
new_trial.name = TrialResource(owner_id, study_id, trial_id).name
447447
new_trial.state = study_pb2.Trial.State.ACTIVE
448448
new_trial.start_time.CopyFrom(start_time)
449449
new_trial.client_id = request.client_id
@@ -455,14 +455,12 @@ def SuggestTrials(
455455
).SerializeToString()
456456

457457
# Store remaining trials as REQUESTED if Pythia over-delivered.
458-
for remaining_trial in new_trials:
458+
for remain_trial in new_trials:
459459
trial_id = self.datastore.max_trial_id(request.parent) + 1
460-
remaining_trial.id = str(trial_id)
461-
remaining_trial.name = resources.TrialResource(
462-
owner_id, study_id, trial_id
463-
).name
464-
remaining_trial.state = study_pb2.Trial.State.REQUESTED
465-
self.datastore.create_trial(new_trial)
460+
remain_trial.id = str(trial_id)
461+
remain_trial.name = TrialResource(owner_id, study_id, trial_id).name
462+
remain_trial.state = study_pb2.Trial.State.REQUESTED
463+
self.datastore.create_trial(remain_trial)
466464

467465
output_op.done = True
468466
self.datastore.update_suggestion_operation(output_op)
@@ -491,11 +489,8 @@ def CreateTrial(
491489
trial = request.trial
492490
with self._study_name_to_lock[request.parent]:
493491
trial.id = str(self.datastore.max_trial_id(request.parent) + 1)
494-
trial.name = (
495-
resources.StudyResource.from_name(request.parent).trial_resource(
496-
trial_id=trial.id
497-
)
498-
).name
492+
study_resource = StudyResource.from_name(request.parent)
493+
trial.name = (study_resource.trial_resource(trial.id)).name
499494

500495
if trial.state != study_pb2.Trial.State.SUCCEEDED:
501496
trial.state = study_pb2.Trial.State.REQUESTED
@@ -543,9 +538,7 @@ def AddTrialMeasurement(
543538
ImmutableStudyError: If study was already immutable.
544539
ImmutableTrialError: If the trial cannot be modified.
545540
"""
546-
study_name = resources.TrialResource.from_name(
547-
request.trial_name
548-
).study_resource.name
541+
study_name = TrialResource.from_name(request.trial_name).study_resource.name
549542
if self._study_is_immutable(study_name):
550543
e = custom_errors.ImmutableStudyError(
551544
'Study {} is immutable. Cannot add measurement.'.format(study_name)
@@ -577,9 +570,7 @@ def CompleteTrial(
577570
context: Optional[grpc.ServicerContext] = None,
578571
) -> study_pb2.Trial:
579572
"""Marks a Trial as complete."""
580-
study_name = resources.TrialResource.from_name(
581-
request.name
582-
).study_resource.name
573+
study_name = TrialResource.from_name(request.name).study_resource.name
583574
if self._study_is_immutable(study_name):
584575
e = custom_errors.ImmutableStudyError(
585576
'Study {} is immutable. Cannot complete trial.'.format(study_name)
@@ -625,9 +616,7 @@ def DeleteTrial(
625616
context: Optional[grpc.ServicerContext] = None,
626617
) -> empty_pb2.Empty:
627618
"""Deletes a Trial."""
628-
study_name = resources.TrialResource.from_name(
629-
request.name
630-
).study_resource.name
619+
study_name = TrialResource.from_name(request.name).study_resource.name
631620
if self._study_is_immutable(study_name):
632621
e = custom_errors.ImmutableStudyError(
633622
'Study {} is immutable. Cannot delete trial.'.format(study_name)
@@ -679,7 +668,7 @@ def CheckTrialEarlyStoppingState(
679668
ImmutableStudyError: If study was already immutable.
680669
ImmutableTrialError: If the trial cannot be modified.
681670
"""
682-
trial_resource = resources.TrialResource.from_name(request.trial_name)
671+
trial_resource = TrialResource.from_name(request.trial_name)
683672
study_name = trial_resource.study_resource.name
684673
if self._study_is_immutable(study_name):
685674
e = custom_errors.ImmutableStudyError(
@@ -841,9 +830,7 @@ def StopTrial(
841830
ImmutableStudyError: If study was already immutable.
842831
ImmutableTrialError: If the trial cannot be modified.
843832
"""
844-
study_name = resources.TrialResource.from_name(
845-
request.name
846-
).study_resource.name
833+
study_name = TrialResource.from_name(request.name).study_resource.name
847834
if self._study_is_immutable(study_name):
848835
e = custom_errors.ImmutableStudyError(
849836
'Study {} is immutable. Cannot stop trial.'.format(study_name)
@@ -926,12 +913,10 @@ def ListOptimalTrials(
926913
# Find Pareto optimal trials.
927914
ys = np.array(considered_trial_objective_vectors)
928915
n = ys.shape[0]
929-
dominated = np.asarray(
930-
[
931-
[np.all(ys[i] <= ys[j]) & np.any(ys[j] > ys[i]) for i in range(n)]
932-
for j in range(n)
933-
]
934-
)
916+
dominated = np.asarray([
917+
[np.all(ys[i] <= ys[j]) & np.any(ys[j] > ys[i]) for i in range(n)]
918+
for j in range(n)
919+
])
935920
optimal_booleans = np.logical_not(np.any(dominated, axis=0))
936921
optimal_trials = []
937922
for i, boolean in enumerate(list(optimal_booleans)):

0 commit comments

Comments
 (0)