Skip to content

Commit

Permalink
Merge branch 'main' into mikealawrence-add-tft-model-support
Browse files Browse the repository at this point in the history
  • Loading branch information
Mlawrence95 authored Dec 4, 2022
2 parents 472768f + e693350 commit dde8ac0
Show file tree
Hide file tree
Showing 63 changed files with 1,236 additions and 107 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
syntax = "proto3";

package google.cloud.aiplatform.container.v1beta1;
package google.cloud.aiplatform.container.v1;

import "google/rpc/status.proto";

Expand All @@ -14,6 +14,10 @@ service MatchService {
// Returns the nearest neighbors for batch queries. If it is a sharded
// deployment, calls the other shards and aggregates the responses.
rpc BatchMatch(BatchMatchRequest) returns (BatchMatchResponse) {}

// Looks up the embeddings.
rpc BatchGetEmbeddings(BatchGetEmbeddingsRequest)
returns (BatchGetEmbeddingsResponse) {}
}

// Parameters for a match query.
Expand Down Expand Up @@ -56,6 +60,28 @@ message MatchRequest {
// not set or set to 0.0, query uses the default value specified in
// NearestNeighborSearchConfig.TreeAHConfig.leaf_nodes_to_search_percent.
int32 leaf_nodes_to_search_percent_override = 7;

// If set to true, besides the doc id, query result will also include the
// embedding. Set this value may impact the query performance (e.g, increase
// query latency, etc).
bool embedding_enabled = 8;
}

// Embedding on query result.
message Embedding {
// The id of the matched neighbor.
string id = 1;

// The embedding values.
repeated float float_val = 2;

// The list of restricts.
repeated Namespace restricts = 3;

// The attribute value used for crowding. The maximum number of neighbors
// to return per crowding attribute value
// (per_crowding_attribute_num_neighbors) is configured per-query.
int64 crowding_attribute = 4;
}

// Response of a match query.
Expand All @@ -66,9 +92,32 @@ message MatchResponse {

// The distances of the matches.
double distance = 2;

// If crowding is enabled, the crowding attribute of this neighbor will
// be stored here.
int64 crowding_attribute = 3;
}
// All its neighbors.
repeated Neighbor neighbor = 1;

// Embedding values for all returned neighbors.
// This is only set when query.embedding_enabled is set to true.
repeated Embedding embeddings = 2;
}

// Request of a Batch Get Embeddings query.
message BatchGetEmbeddingsRequest {
// The ID of the DeploydIndex that will serve the request.
string deployed_index_id = 1;

// The ids to be looked up.
repeated string id = 2;
}

// Response of a Batch Get Embeddings query.
message BatchGetEmbeddingsResponse {
// Embedding values for all ids in the query request.
repeated Embedding embeddings = 1;
}

// Parameters for a batch match query.
Expand Down
6 changes: 6 additions & 0 deletions google/cloud/aiplatform_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@
from .types.featurestore_online_service import ReadFeatureValuesRequest
from .types.featurestore_online_service import ReadFeatureValuesResponse
from .types.featurestore_online_service import StreamingReadFeatureValuesRequest
from .types.featurestore_online_service import WriteFeatureValuesPayload
from .types.featurestore_online_service import WriteFeatureValuesRequest
from .types.featurestore_online_service import WriteFeatureValuesResponse
from .types.featurestore_service import BatchCreateFeaturesOperationMetadata
from .types.featurestore_service import BatchCreateFeaturesRequest
from .types.featurestore_service import BatchCreateFeaturesResponse
Expand Down Expand Up @@ -986,6 +989,9 @@
"Value",
"VizierServiceClient",
"WorkerPoolSpec",
"WriteFeatureValuesPayload",
"WriteFeatureValuesRequest",
"WriteFeatureValuesResponse",
"WriteTensorboardExperimentDataRequest",
"WriteTensorboardExperimentDataResponse",
"WriteTensorboardRunDataRequest",
Expand Down
10 changes: 10 additions & 0 deletions google/cloud/aiplatform_v1/gapic_metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@
"methods": [
"streaming_read_feature_values"
]
},
"WriteFeatureValues": {
"methods": [
"write_feature_values"
]
}
}
},
Expand All @@ -242,6 +247,11 @@
"methods": [
"streaming_read_feature_values"
]
},
"WriteFeatureValues": {
"methods": [
"write_feature_values"
]
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,136 @@ async def sample_streaming_read_feature_values():
# Done; return the response.
return response

async def write_feature_values(
self,
request: Union[
featurestore_online_service.WriteFeatureValuesRequest, dict
] = None,
*,
entity_type: str = None,
payloads: Sequence[
featurestore_online_service.WriteFeatureValuesPayload
] = None,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> featurestore_online_service.WriteFeatureValuesResponse:
r"""Writes Feature values of one or more entities of an
EntityType.
The Feature values are merged into existing entities if
any. The Feature values to be written must have
timestamp within the online storage retention.
.. code-block:: python
# This snippet has been automatically generated and should be regarded as a
# code template only.
# It will require modifications to work:
# - It may require correct/in-range values for request initialization.
# - It may require specifying regional endpoints when creating the service
# client as shown in:
# https://googleapis.dev/python/google-api-core/latest/client_options.html
from google.cloud import aiplatform_v1
async def sample_write_feature_values():
# Create a client
client = aiplatform_v1.FeaturestoreOnlineServingServiceAsyncClient()
# Initialize request argument(s)
payloads = aiplatform_v1.WriteFeatureValuesPayload()
payloads.entity_id = "entity_id_value"
request = aiplatform_v1.WriteFeatureValuesRequest(
entity_type="entity_type_value",
payloads=payloads,
)
# Make the request
response = await client.write_feature_values(request=request)
# Handle the response
print(response)
Args:
request (Union[google.cloud.aiplatform_v1.types.WriteFeatureValuesRequest, dict]):
The request object. Request message for
[FeaturestoreOnlineServingService.WriteFeatureValues][google.cloud.aiplatform.v1.FeaturestoreOnlineServingService.WriteFeatureValues].
entity_type (:class:`str`):
Required. The resource name of the EntityType for the
entities being written. Value format:
``projects/{project}/locations/{location}/featurestores/ {featurestore}/entityTypes/{entityType}``.
For example, for a machine learning model predicting
user clicks on a website, an EntityType ID could be
``user``.
This corresponds to the ``entity_type`` field
on the ``request`` instance; if ``request`` is provided, this
should not be set.
payloads (:class:`Sequence[google.cloud.aiplatform_v1.types.WriteFeatureValuesPayload]`):
Required. The entities to be written. Up to 100,000
feature values can be written across all ``payloads``.
This corresponds to the ``payloads`` field
on the ``request`` instance; if ``request`` is provided, this
should not be set.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
Returns:
google.cloud.aiplatform_v1.types.WriteFeatureValuesResponse:
Response message for
[FeaturestoreOnlineServingService.WriteFeatureValues][google.cloud.aiplatform.v1.FeaturestoreOnlineServingService.WriteFeatureValues].
"""
# Create or coerce a protobuf request object.
# Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([entity_type, payloads])
if request is not None and has_flattened_params:
raise ValueError(
"If the `request` argument is set, then none of "
"the individual field arguments should be set."
)

request = featurestore_online_service.WriteFeatureValuesRequest(request)

# If we have keyword arguments corresponding to fields on the
# request, apply these.
if entity_type is not None:
request.entity_type = entity_type
if payloads:
request.payloads.extend(payloads)

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.write_feature_values,
default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
# add these here.
metadata = tuple(metadata) + (
gapic_v1.routing_header.to_grpc_metadata(
(("entity_type", request.entity_type),)
),
)

# Send the request.
response = await rpc(
request,
retry=retry,
timeout=timeout,
metadata=metadata,
)

# Done; return the response.
return response

async def list_operations(
self,
request: operations_pb2.ListOperationsRequest = None,
Expand Down
Loading

0 comments on commit dde8ac0

Please sign in to comment.