1515from __future__ import annotations
1616
1717"""RPC functions implemented from vizier_service.proto."""
18+
1819import collections
1920import datetime
2021import threading
2425import grpc
2526import numpy as np
2627import sqlalchemy as sqla
27-
2828from vizier import pythia
2929from vizier import pyvizier as vz
3030from vizier ._src .service import constants
@@ -55,6 +55,10 @@ def _get_current_time() -> timestamp_pb2.Timestamp:
5555 return now
5656
5757
58+ StudyResource = resources .StudyResource
59+ TrialResource = resources .TrialResource
60+
61+
5862# TODO: remove context = None
5963# TODO: remove context = None
6064class VizierServicer (vizier_service_pb2_grpc .VizierServiceServicer ):
@@ -196,7 +200,7 @@ def CreateStudy(
196200 study_id = study .display_name
197201
198202 # Finally create study in database and return it.
199- study .name = resources . StudyResource (owner_id , study_id ).name
203+ study .name = StudyResource (owner_id , study_id ).name
200204 self .datastore .create_study (study )
201205 return study
202206
@@ -214,8 +218,8 @@ def ListStudies(
214218 context : Optional [grpc .ServicerContext ] = None ,
215219 ) -> vizier_service_pb2 .ListStudiesResponse :
216220 """Lists all the studies in a region for an associated project."""
217- list_of_studies = self .datastore .list_studies (request .parent )
218- return vizier_service_pb2 .ListStudiesResponse (studies = list_of_studies )
221+ studies = self .datastore .list_studies (request .parent )
222+ return vizier_service_pb2 .ListStudiesResponse (studies = studies )
219223
220224 def DeleteStudy (
221225 self ,
@@ -283,7 +287,7 @@ def SuggestTrials(
283287 )
284288 grpc_util .handle_exception (e , context )
285289
286- study_resource = resources . StudyResource .from_name (study_name )
290+ study_resource = StudyResource .from_name (study_name )
287291 study_id = study_resource .study_id
288292 owner_id = study_resource .owner_id
289293
@@ -306,14 +310,12 @@ def SuggestTrials(
306310 start_time = _get_current_time ()
307311 # Create a new Op if there aren't any active (not done) ops.
308312 try :
309- new_op_number = (
310- self .datastore .max_suggestion_operation_number (
311- study_name , request .client_id
312- )
313- + 1
313+ old_op_number = self .datastore .max_suggestion_operation_number (
314+ study_name , request .client_id
314315 )
315316 except custom_errors .NotFoundError :
316- new_op_number = 1
317+ old_op_number = 0
318+ new_op_number = old_op_number + 1
317319 new_op_name = resources .SuggestionOperationResource (
318320 owner_id , study_id , request .client_id , new_op_number
319321 ).name
@@ -441,9 +443,7 @@ def SuggestTrials(
441443 new_trial = new_trials .pop ()
442444 trial_id = self .datastore .max_trial_id (request .parent ) + 1
443445 new_trial .id = str (trial_id )
444- new_trial .name = resources .TrialResource (
445- owner_id , study_id , trial_id
446- ).name
446+ new_trial .name = TrialResource (owner_id , study_id , trial_id ).name
447447 new_trial .state = study_pb2 .Trial .State .ACTIVE
448448 new_trial .start_time .CopyFrom (start_time )
449449 new_trial .client_id = request .client_id
@@ -455,14 +455,12 @@ def SuggestTrials(
455455 ).SerializeToString ()
456456
457457 # Store remaining trials as REQUESTED if Pythia over-delivered.
458- for remaining_trial in new_trials :
458+ for remain_trial in new_trials :
459459 trial_id = self .datastore .max_trial_id (request .parent ) + 1
460- remaining_trial .id = str (trial_id )
461- remaining_trial .name = resources .TrialResource (
462- owner_id , study_id , trial_id
463- ).name
464- remaining_trial .state = study_pb2 .Trial .State .REQUESTED
465- self .datastore .create_trial (new_trial )
460+ remain_trial .id = str (trial_id )
461+ remain_trial .name = TrialResource (owner_id , study_id , trial_id ).name
462+ remain_trial .state = study_pb2 .Trial .State .REQUESTED
463+ self .datastore .create_trial (remain_trial )
466464
467465 output_op .done = True
468466 self .datastore .update_suggestion_operation (output_op )
@@ -491,11 +489,8 @@ def CreateTrial(
491489 trial = request .trial
492490 with self ._study_name_to_lock [request .parent ]:
493491 trial .id = str (self .datastore .max_trial_id (request .parent ) + 1 )
494- trial .name = (
495- resources .StudyResource .from_name (request .parent ).trial_resource (
496- trial_id = trial .id
497- )
498- ).name
492+ study_resource = StudyResource .from_name (request .parent )
493+ trial .name = (study_resource .trial_resource (trial .id )).name
499494
500495 if trial .state != study_pb2 .Trial .State .SUCCEEDED :
501496 trial .state = study_pb2 .Trial .State .REQUESTED
@@ -543,9 +538,7 @@ def AddTrialMeasurement(
543538 ImmutableStudyError: If study was already immutable.
544539 ImmutableTrialError: If the trial cannot be modified.
545540 """
546- study_name = resources .TrialResource .from_name (
547- request .trial_name
548- ).study_resource .name
541+ study_name = TrialResource .from_name (request .trial_name ).study_resource .name
549542 if self ._study_is_immutable (study_name ):
550543 e = custom_errors .ImmutableStudyError (
551544 'Study {} is immutable. Cannot add measurement.' .format (study_name )
@@ -577,9 +570,7 @@ def CompleteTrial(
577570 context : Optional [grpc .ServicerContext ] = None ,
578571 ) -> study_pb2 .Trial :
579572 """Marks a Trial as complete."""
580- study_name = resources .TrialResource .from_name (
581- request .name
582- ).study_resource .name
573+ study_name = TrialResource .from_name (request .name ).study_resource .name
583574 if self ._study_is_immutable (study_name ):
584575 e = custom_errors .ImmutableStudyError (
585576 'Study {} is immutable. Cannot complete trial.' .format (study_name )
@@ -625,9 +616,7 @@ def DeleteTrial(
625616 context : Optional [grpc .ServicerContext ] = None ,
626617 ) -> empty_pb2 .Empty :
627618 """Deletes a Trial."""
628- study_name = resources .TrialResource .from_name (
629- request .name
630- ).study_resource .name
619+ study_name = TrialResource .from_name (request .name ).study_resource .name
631620 if self ._study_is_immutable (study_name ):
632621 e = custom_errors .ImmutableStudyError (
633622 'Study {} is immutable. Cannot delete trial.' .format (study_name )
@@ -679,7 +668,7 @@ def CheckTrialEarlyStoppingState(
679668 ImmutableStudyError: If study was already immutable.
680669 ImmutableTrialError: If the trial cannot be modified.
681670 """
682- trial_resource = resources . TrialResource .from_name (request .trial_name )
671+ trial_resource = TrialResource .from_name (request .trial_name )
683672 study_name = trial_resource .study_resource .name
684673 if self ._study_is_immutable (study_name ):
685674 e = custom_errors .ImmutableStudyError (
@@ -841,9 +830,7 @@ def StopTrial(
841830 ImmutableStudyError: If study was already immutable.
842831 ImmutableTrialError: If the trial cannot be modified.
843832 """
844- study_name = resources .TrialResource .from_name (
845- request .name
846- ).study_resource .name
833+ study_name = TrialResource .from_name (request .name ).study_resource .name
847834 if self ._study_is_immutable (study_name ):
848835 e = custom_errors .ImmutableStudyError (
849836 'Study {} is immutable. Cannot stop trial.' .format (study_name )
@@ -926,12 +913,10 @@ def ListOptimalTrials(
926913 # Find Pareto optimal trials.
927914 ys = np .array (considered_trial_objective_vectors )
928915 n = ys .shape [0 ]
929- dominated = np .asarray (
930- [
931- [np .all (ys [i ] <= ys [j ]) & np .any (ys [j ] > ys [i ]) for i in range (n )]
932- for j in range (n )
933- ]
934- )
916+ dominated = np .asarray ([
917+ [np .all (ys [i ] <= ys [j ]) & np .any (ys [j ] > ys [i ]) for i in range (n )]
918+ for j in range (n )
919+ ])
935920 optimal_booleans = np .logical_not (np .any (dominated , axis = 0 ))
936921 optimal_trials = []
937922 for i , boolean in enumerate (list (optimal_booleans )):
0 commit comments