diff --git a/.gitignore b/.gitignore
index 5d32b237..783162de 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,3 +12,4 @@ system-test/*key.json
.DS_Store
package-lock.json
__pycache__
+.vscode
diff --git a/.kokoro/samples-test.sh b/.kokoro/samples-test.sh
index bab7ba4e..bdabc67b 100755
--- a/.kokoro/samples-test.sh
+++ b/.kokoro/samples-test.sh
@@ -15,12 +15,11 @@
# limitations under the License.
set -eo pipefail
-
export NPM_CONFIG_PREFIX=${HOME}/.npm-global
# Setup service account credentials.
export GOOGLE_APPLICATION_CREDENTIALS=${KOKORO_GFILE_DIR}/service-account.json
-export GCLOUD_PROJECT=long-door-651
+export GCLOUD_PROJECT=ucaip-sample-tests
cd $(dirname $0)/..
diff --git a/.kokoro/system-test.sh b/.kokoro/system-test.sh
index 8a084004..3103df45 100755
--- a/.kokoro/system-test.sh
+++ b/.kokoro/system-test.sh
@@ -20,7 +20,7 @@ export NPM_CONFIG_PREFIX=${HOME}/.npm-global
# Setup service account credentials.
export GOOGLE_APPLICATION_CREDENTIALS=${KOKORO_GFILE_DIR}/service-account.json
-export GCLOUD_PROJECT=long-door-651
+export GCLOUD_PROJECT=ucaip-sample-tests
cd $(dirname $0)/..
diff --git a/README.md b/README.md
index 8cf5a805..9b147a22 100644
--- a/README.md
+++ b/README.md
@@ -76,6 +76,9 @@ has instructions for running the samples.
| Sample | Source Code | Try it |
| --------------------------- | --------------------------------- | ------ |
+| Create-training-pipeline-image-classification | [source code](https://github.com/googleapis/nodejs-ai-platform/blob/master/samples/create-training-pipeline-image-classification.js) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/nodejs-ai-platform&page=editor&open_in_editor=samples/create-training-pipeline-image-classification.js,samples/README.md) |
+| List-endpoints | [source code](https://github.com/googleapis/nodejs-ai-platform/blob/master/samples/list-endpoints.js) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/nodejs-ai-platform&page=editor&open_in_editor=samples/list-endpoints.js,samples/README.md) |
+| Predict-image-classification | [source code](https://github.com/googleapis/nodejs-ai-platform/blob/master/samples/predict-image-classification.js) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/nodejs-ai-platform&page=editor&open_in_editor=samples/predict-image-classification.js,samples/README.md) |
| Quickstart | [source code](https://github.com/googleapis/nodejs-ai-platform/blob/master/samples/quickstart.js) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/nodejs-ai-platform&page=editor&open_in_editor=samples/quickstart.js,samples/README.md) |
diff --git a/linkinator.config.json b/linkinator.config.json
index 29a223b6..bdfc0dc1 100644
--- a/linkinator.config.json
+++ b/linkinator.config.json
@@ -3,7 +3,10 @@
"skip": [
"https://codecov.io/gh/googleapis/",
"www.googleapis.com",
- "img.shields.io"
+ "img.shields.io",
+ "https://github.com/googleapis/nodejs-ai-platform/blob/master/samples/list-endpoints.js",
+ "https://github.com/googleapis/nodejs-ai-platform/blob/master/samples/create-training-pipeline-image-classification.js",
+ "https://github.com/googleapis/nodejs-ai-platform/blob/master/samples/predict-image-classification.js"
],
"silent": true,
"concurrency": 10
diff --git a/package.json b/package.json
index 4281252b..6092e12c 100644
--- a/package.json
+++ b/package.json
@@ -46,7 +46,8 @@
"test": "c8 mocha build/test"
},
"dependencies": {
- "google-gax": "^2.9.2"
+ "google-gax": "^2.9.2",
+ "protobuf.js": "^1.1.2"
},
"devDependencies": {
"@types/mocha": "^8.0.3",
diff --git a/samples/README.md b/samples/README.md
index df485e1f..7ec4c2ac 100644
--- a/samples/README.md
+++ b/samples/README.md
@@ -1,18 +1,20 @@
-[//]: # "This README.md file is auto-generated, all changes to this file will be lost."
-[//]: # "To regenerate it, use `python -m synthtool`."
+[//]: # 'This README.md file is auto-generated, all changes to this file will be lost.'
+[//]: # 'To regenerate it, use `python -m synthtool`.'
+
# [AI Platform: Node.js Samples](https://github.com/googleapis/nodejs-ai-platform)
[![Open in Cloud Shell][shell_img]][shell_link]
-
-
## Table of Contents
-* [Before you begin](#before-you-begin)
-* [Samples](#samples)
- * [Quickstart](#quickstart)
+- [Before you begin](#before-you-begin)
+- [Samples](#samples)
+ - [Create-training-pipeline-image-classification](#create-training-pipeline-image-classification)
+ - [List-endpoints](#list-endpoints)
+ - [Predict-image-classification](#predict-image-classification)
+ - [Quickstart](#quickstart)
## Before you begin
@@ -27,23 +29,51 @@ Before running the samples, make sure you've followed the steps outlined in
## Samples
+### Create-training-pipeline-image-classification
+View the [source code](https://github.com/googleapis/nodejs-ai-platform/blob/master/samples/create-training-pipeline-image-classification.js).
-### Quickstart
+[![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/nodejs-ai-platform&page=editor&open_in_editor=samples/create-training-pipeline-image-classification.js,samples/README.md)
-View the [source code](https://github.com/googleapis/nodejs-ai-platform/blob/master/samples/quickstart.js).
+**Usage:**
-[![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/nodejs-ai-platform&page=editor&open_in_editor=samples/quickstart.js,samples/README.md)
+`node samples/create-training-pipeline-image-classification.js`
-__Usage:__
+---
+### List-endpoints
-`node samples/quickstart.js`
+View the [source code](https://github.com/googleapis/nodejs-ai-platform/blob/master/samples/list-endpoints.js).
+
+[![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/nodejs-ai-platform&page=editor&open_in_editor=samples/list-endpoints.js,samples/README.md)
+
+**Usage:**
+
+`node samples/list-endpoints.js`
+
+---
+### Predict-image-classification
+View the [source code](https://github.com/googleapis/nodejs-ai-platform/blob/master/samples/predict-image-classification.js).
+[![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/nodejs-ai-platform&page=editor&open_in_editor=samples/predict-image-classification.js,samples/README.md)
+**Usage:**
+`node samples/predict-image-classification.js`
+
+---
+
+### Quickstart
+
+View the [source code](https://github.com/googleapis/nodejs-ai-platform/blob/master/samples/quickstart.js).
+
+[![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/nodejs-ai-platform&page=editor&open_in_editor=samples/quickstart.js,samples/README.md)
+
+**Usage:**
+
+`node samples/quickstart.js`
[shell_img]: https://gstatic.com/cloudssh/images/open-btn.png
[shell_link]: https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/nodejs-ai-platform&page=editor&open_in_editor=samples/README.md
diff --git a/samples/create-training-pipeline-image-classification.js b/samples/create-training-pipeline-image-classification.js
new file mode 100644
index 00000000..a6c868ec
--- /dev/null
+++ b/samples/create-training-pipeline-image-classification.js
@@ -0,0 +1,135 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+'use strict';
+
+function main(
+ datasetId,
+ modelDisplayName,
+ trainingPipelineDisplayName,
+ project,
+ location = 'us-central1'
+) {
+ // [START aiplatform_create_training_pipeline_image_classification]
+ /**
+ * TODO(developer): Uncomment these variables before running the sample.
+ * (Not necessary if passing values as arguments)
+ */
+ /*
+ const datasetId = 'YOUR DATASET';
+ const modelDisplayName = 'NEW MODEL NAME;
+ const trainingPipelineDisplayName = 'NAME FOR TRAINING PIPELINE';
+ const project = 'YOUR PROJECT ID';
+ const location = 'us-central1';
+ */
+ // Imports the Google Cloud Pipeline Service Client library
+ const aiplatform = require('@google-cloud/aiplatform');
+
+ const {
+ definition,
+ } = aiplatform.protos.google.cloud.aiplatform.v1beta1.schema.trainingjob;
+ const ModelType = definition.AutoMlImageClassificationInputs.ModelType;
+
+ // Specifies the location of the api endpoint
+ const clientOptions = {
+ apiEndpoint: 'us-central1-aiplatform.googleapis.com',
+ };
+
+ // Instantiates a client
+ const pipelineServiceClient = new aiplatform.PipelineServiceClient(
+ clientOptions
+ );
+
+ async function createTrainingPipelineImageClassification() {
+ // Configure the parent resource
+ const parent = `projects/${project}/locations/${location}`;
+
+ // Values should match the input expected by your model.
+ const trainingTaskInputsMessage = new definition.AutoMlImageClassificationInputs(
+ {
+ multiLabel: true,
+ modelType: ModelType.CLOUD,
+ budgetMilliNodeHours: 8000,
+ disableEarlyStopping: false,
+ }
+ );
+
+ const trainingTaskInputs = trainingTaskInputsMessage.toValue();
+
+ const trainingTaskDefinition =
+ 'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml';
+
+ const modelToUpload = {displayName: modelDisplayName};
+ const inputDataConfig = {datasetId: datasetId};
+ const trainingPipeline = {
+ displayName: trainingPipelineDisplayName,
+ trainingTaskDefinition,
+ trainingTaskInputs,
+ inputDataConfig: inputDataConfig,
+ modelToUpload: modelToUpload,
+ };
+ const request = {
+ parent,
+ trainingPipeline,
+ };
+
+ // Create training pipeline request
+ const [response] = await pipelineServiceClient.createTrainingPipeline(
+ request
+ );
+
+ console.log('Create training pipeline image classification response');
+ console.log(`\tName : ${response.name}`);
+ console.log(`\tDisplay Name : ${response.displayName}`);
+ console.log(
+ `\tTraining task definition : ${response.trainingTaskDefinition}`
+ );
+ console.log(
+ `\tTraining task inputs : \
+ ${JSON.stringify(response.trainingTaskInputs)}`
+ );
+ console.log(
+ `\tTraining task metadata : \
+ ${JSON.stringify(response.trainingTaskMetadata)}`
+ );
+ console.log(`\tState ; ${response.state}`);
+ console.log(`\tCreate time : ${JSON.stringify(response.createTime)}`);
+ console.log(`\tStart time : ${JSON.stringify(response.startTime)}`);
+ console.log(`\tEnd time : ${JSON.stringify(response.endTime)}`);
+ console.log(`\tUpdate time : ${JSON.stringify(response.updateTime)}`);
+ console.log(`\tLabels : ${JSON.stringify(response.labels)}`);
+
+ const error = response.error;
+ console.log('\tError');
+ if (error === null) {
+ console.log('\t\tCode : {}');
+ console.log('\t\tMessage : {}');
+ } else {
+ console.log(`\t\tCode : ${error.code}`);
+ console.log(`\t\tMessage : ${error.message}`);
+ }
+ }
+
+ createTrainingPipelineImageClassification();
+ // [END aiplatform_create_training_pipeline_image_classification]
+}
+
+process.on('unhandledRejection', err => {
+ console.error(err.message);
+ process.exitCode = 1;
+});
+
+main(...process.argv.slice(2));
diff --git a/samples/list-endpoints.js b/samples/list-endpoints.js
new file mode 100644
index 00000000..14738adb
--- /dev/null
+++ b/samples/list-endpoints.js
@@ -0,0 +1,61 @@
+/**
+ * Copyright 2020, Google, LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+'use strict';
+
+function main(projectId, location = 'us-central1') {
+ // [START aiplatform_list_endpoints]
+ /**
+ * TODO(developer): Uncomment these variables before running the sample.
+ */
+ // const projectId = 'YOUR_PROJECT_ID';
+ // const location = 'YOUR_PROJECT_LOCATION';
+
+ const {EndpointServiceClient} = require('@google-cloud/aiplatform');
+
+ // Specifies the location of the api endpoint
+ const clientOptions = {
+ apiEndpoint: 'us-central1-aiplatform.googleapis.com',
+ };
+ const client = new EndpointServiceClient(clientOptions);
+
+ async function listEndpoints() {
+ // Configure the parent resource
+ const parent = `projects/${projectId}/locations/${location}`;
+ const request = {
+ parent,
+ };
+
+ // Get and print out a list of all the endpoints for this resource
+ const [result] = await client.listEndpoints(request);
+ for (const endpoint of result) {
+ console.log(`\nEndpoint name: ${endpoint.name}`);
+ console.log(`Display name: ${endpoint.displayName}`);
+ if (endpoint.deployedModels[0]) {
+ console.log(
+ `First deployed model: ${endpoint.deployedModels[0].model}`
+ );
+ }
+ }
+ }
+
+ listEndpoints();
+ // [END aiplatform_list_endpoints]
+}
+
+main(...process.argv.slice(2)).catch(err => {
+ console.error(err);
+ process.exitCode = 1;
+});
diff --git a/samples/package.json b/samples/package.json
index ac49bd1e..1c716f08 100644
--- a/samples/package.json
+++ b/samples/package.json
@@ -16,6 +16,8 @@
"@google-cloud/aiplatform": "^1.0.0"
},
"devDependencies": {
- "mocha": "^8.0.0"
+ "chai": "^4.2.0",
+ "mocha": "^8.0.0",
+ "uuid": "^8.3.1"
}
}
diff --git a/samples/predict-image-classification.js b/samples/predict-image-classification.js
new file mode 100644
index 00000000..1d83e80e
--- /dev/null
+++ b/samples/predict-image-classification.js
@@ -0,0 +1,99 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+'use strict';
+
+function main(filename, endpointId, project, location = 'us-central1') {
+ // [START aiplatform_predict_image_classification]
+ /**
+ * TODO(developer): Uncomment these variables before running the sample.\
+ * (Not necessary if passing values as arguments)
+ */
+
+ // const filename = "YOUR_PREDICTION_FILE_NAME";
+ // const endpointId = "YOUR_ENDPOINT_ID";
+ // const project = 'YOUR_PROJECT_ID';
+ // const location = 'YOUR_PROJECT_LOCATION';
+ const aiplatform = require('@google-cloud/aiplatform');
+ const {
+ instance,
+ params,
+ prediction,
+ } = aiplatform.protos.google.cloud.aiplatform.v1beta1.schema.predict;
+
+ // Imports the Google Cloud Prediction Service Client library
+ const {PredictionServiceClient} = aiplatform;
+
+ // Specifies the location of the api endpoint
+ const clientOptions = {
+ apiEndpoint: 'us-central1-prediction-aiplatform.googleapis.com',
+ };
+
+ // Instantiates a client
+ const predictionServiceClient = new PredictionServiceClient(clientOptions);
+
+ async function predictImageClassification() {
+ // Configure the endpoint resource
+ const endpoint = `projects/${project}/locations/${location}/endpoints/${endpointId}`;
+
+ const parametersObj = new params.ImageClassificationPredictionParams({
+ confidenceThreshold: 0.5,
+ maxPredictions: 5,
+ });
+ const parameters = parametersObj.toValue();
+
+ const fs = require('fs');
+ const image = fs.readFileSync(filename, 'base64');
+ const instanceObj = new instance.ImageClassificationPredictionInstance({
+ content: image,
+ });
+ const instanceValue = instanceObj.toValue();
+
+ const instances = [instanceValue];
+ const request = {
+ endpoint,
+ instances,
+ parameters,
+ };
+
+ // Predict request
+ const [response] = await predictionServiceClient.predict(request);
+
+ console.log('Predict image classification response');
+ console.log(`\tDeployed model id : ${response.deployedModelId}`);
+ const predictions = response.predictions;
+ console.log('\tPredictions :');
+ for (const predictionValue of predictions) {
+ const predictionResultObj = prediction.ClassificationPredictionResult.fromValue(
+ predictionValue
+ );
+ for (const [i, label] of predictionResultObj.displayNames.entries()) {
+ console.log(`\tDisplay name: ${label}`);
+ console.log(`\tConfidences: ${predictionResultObj.confidences[i]}`);
+ console.log(`\tIDs: ${predictionResultObj.ids[i]}\n\n`);
+ }
+ }
+ }
+ predictImageClassification();
+ // [END aiplatform_predict_image_classification]
+}
+
+process.on('unhandledRejection', err => {
+ console.error(err.message);
+ process.exitCode = 1;
+});
+
+main(...process.argv.slice(2));
diff --git a/samples/resources/daisy.jpg b/samples/resources/daisy.jpg
new file mode 100644
index 00000000..ae01cae9
Binary files /dev/null and b/samples/resources/daisy.jpg differ
diff --git a/samples/test/create-training-pipeline-image-classification.test.js b/samples/test/create-training-pipeline-image-classification.test.js
new file mode 100644
index 00000000..c6c506e0
--- /dev/null
+++ b/samples/test/create-training-pipeline-image-classification.test.js
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+'use strict';
+
+const {assert} = require('chai');
+const {after, describe, it} = require('mocha');
+
+const uuid = require('uuid').v4;
+const cp = require('child_process');
+const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
+
+const aiplatform = require('@google-cloud/aiplatform');
+const clientOptions = {
+ apiEndpoint: 'us-central1-aiplatform.googleapis.com',
+};
+
+const pipelineServiceClient = new aiplatform.PipelineServiceClient(
+ clientOptions
+);
+
+const datasetId = process.env.TRAINING_PIPELINE_IMAGE_CLASS_DATASET_ID;
+const modelDisplayName = `temp_create_training_pipeline_image_classification_model_test${uuid()}`;
+const trainingPipelineDisplayName = `temp_create_training_pipeline_image_classification_test_${uuid()}`;
+const project = process.env.CAIP_PROJECT_ID;
+const location = process.env.LOCATION;
+
+let trainingPipelineId;
+
+describe('AI platform create training pipeline image classification', () => {
+ it('should create a new image classification training pipeline', async () => {
+ const stdout = execSync(
+ `node ./create-training-pipeline-image-classification.js ${datasetId} ${modelDisplayName} ${trainingPipelineDisplayName} ${project} ${location}`
+ );
+ assert.match(stdout, /\/locations\/us-central1\/trainingPipelines\//);
+ trainingPipelineId = stdout
+ .split('/locations/us-central1/trainingPipelines/')[1]
+ .split('\n')[0];
+ });
+
+ after('should cancel the training pipeline and delete it', async () => {
+ const name = pipelineServiceClient.trainingPipelinePath(
+ project,
+ location,
+ trainingPipelineId
+ );
+
+ const cancelRequest = {
+ name,
+ };
+
+ pipelineServiceClient.cancelTrainingPipeline(cancelRequest).then(() => {
+ const deleteRequest = {
+ name,
+ };
+
+ return pipelineServiceClient.deleteTrainingPipeline(deleteRequest);
+ });
+ });
+});
diff --git a/samples/test/predict-image-classification.test.js b/samples/test/predict-image-classification.test.js
new file mode 100644
index 00000000..f2e8aa14
--- /dev/null
+++ b/samples/test/predict-image-classification.test.js
@@ -0,0 +1,41 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+'use strict';
+
+const path = require('path');
+const {assert} = require('chai');
+const {describe, it} = require('mocha');
+
+const cp = require('child_process');
+const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
+
+const filename = 'daisy.jpg';
+const local_file = path.resolve(
+ path.join(__dirname, `../resources/${filename}`)
+);
+const endpointId = process.env.PREDICT_IMAGE_CLASS_ENDPOINT_ID;
+const project = process.env.CAIP_PROJECT_ID;
+const location = process.env.LOCATION;
+
+describe('AI platform predict image classification', () => {
+ it('should make predictions using the image classification model', async () => {
+ const stdout = execSync(
+ `node ./predict-image-classification.js ${local_file} ${endpointId} ${project} ${location}`
+ );
+ assert.match(stdout, /Predict image classification response/);
+ });
+});
diff --git a/samples/test/quickstart.test.js b/samples/test/quickstart.test.js
index 07fb1e05..524df65f 100644
--- a/samples/test/quickstart.test.js
+++ b/samples/test/quickstart.test.js
@@ -23,7 +23,7 @@ const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
describe('quickstart', () => {
it('should have functional quickstart', async () => {
- const stdout = execSync('node quickstart');
+ const stdout = execSync('node quickstart.js');
assert(stdout.match(/DatasetServiceClient/));
});
});
diff --git a/src/decorator.ts b/src/decorator.ts
new file mode 100644
index 00000000..c99dfa89
--- /dev/null
+++ b/src/decorator.ts
@@ -0,0 +1,109 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+import * as _helpers from './helpers';
+import * as protos from '../protos/protos';
+
+const enhancedTypesJson = require('./enhanced-types.json');
+
+interface JsonNode {
+ [index: string]: string[] | JsonNode | JsonNode[];
+}
+
+interface PrototypedObject {
+ prototype: Function;
+}
+
+interface NestedNamespace {
+ [index: string]: PrototypedObject | NestedNamespace;
+}
+
+// Walk the tree of nested namespaces contained within the enhanced-types.json file
+function walkNamespaces(
+ jsonNode: JsonNode,
+ rootNamespace: NestedNamespace
+): void {
+ for (const namespaceName in jsonNode) {
+ if (Object.hasOwnProperty.call(jsonNode, namespaceName)) {
+ const namespace = rootNamespace[namespaceName];
+
+ // Get the namespace object from JSON
+ const namespaceJsonObject = jsonNode[namespaceName];
+
+ // Verify that this is an array node.
+ if (
+ namespace &&
+ namespaceJsonObject &&
+ Array.isArray(namespaceJsonObject)
+ ) {
+ // Assign the methods to this list of types.
+ assignMethodsToMessages(
+ namespace as NestedNamespace,
+ namespaceJsonObject as string[]
+ );
+
+ // Check if this is another node.
+ } else if (
+ namespace &&
+ namespaceJsonObject &&
+ typeof namespaceJsonObject === 'object' &&
+ !Array.isArray(namespaceJsonObject)
+ ) {
+ // Iterate over the next level of namespaces
+ walkNamespaces(namespaceJsonObject, namespace as NestedNamespace);
+ }
+ }
+ }
+}
+
+// Assign the toValue() and fromValue() helper methods to the enhanced message objects.
+function assignMethodsToMessages(
+ // tslint:disable-next-line no-any
+ namespace: NestedNamespace,
+ messages: string[]
+): void {
+ for (const message of messages) {
+ if (message in namespace) {
+ const enhancedMessage: PrototypedObject = namespace[
+ message
+ ] as PrototypedObject;
+ if (enhancedMessage) {
+ Object.assign(enhancedMessage.prototype, _helpers.addToValue());
+
+ // Capture reference to `enhancedMessage` class in closure below.
+ const _addFromValue = {
+ fromValue: (value: object): object | undefined => {
+ const messageType = (enhancedMessage as unknown) as protobuf.Type;
+ const message = messageType.create();
+ const convertedValue = _helpers.fromValue(value);
+ if (convertedValue !== undefined) {
+ Object.assign(message, convertedValue);
+ return message;
+ }
+ return undefined;
+ },
+ };
+ Object.assign(enhancedMessage, _addFromValue);
+ }
+ }
+ }
+}
+
+export function _enhance(apiVersion: string): void {
+ const schemaRoot = enhancedTypesJson['schema'];
+ const namespaceRoot = ((protos.google.cloud
+ .aiplatform as unknown) as NestedNamespace)[apiVersion] as NestedNamespace;
+ const namespaceSchemaRoot = namespaceRoot['schema'];
+ walkNamespaces(schemaRoot, namespaceSchemaRoot as NestedNamespace);
+}
diff --git a/src/enhanced-types.json b/src/enhanced-types.json
new file mode 100644
index 00000000..15e6b62f
--- /dev/null
+++ b/src/enhanced-types.json
@@ -0,0 +1,69 @@
+{
+ "schema": {
+ "predict": {
+ "instance": [
+ "ImageClassificationPredictionInstance",
+ "ImageObjectDetectionPredictionInstance",
+ "ImageSegmentationPredictionInstance",
+ "TextClassificationPredictionInstance",
+ "TextExtractionPredictionInstance",
+ "TextSentimentPredictionInstance",
+ "VideoActionRecognitionPredictionInstance",
+ "VideoClassificationPredictionInstance",
+ "VideoObjectTrackingPredictionInstance"
+ ],
+ "params": [
+ "ImageClassificationPredictionParams",
+ "ImageObjectDetectionPredictionParams",
+ "ImageSegmentationPredictionParams",
+ "VideoActionRecognitionPredictionParams",
+ "VideoClassificationPredictionParams",
+ "VideoObjectTrackingPredictionParams"
+ ],
+ "prediction": [
+ "ClassificationPredictionResult",
+ "ImageObjectDetectionPredictionResult",
+ "ImageSegmentationPredictionResult",
+ "TabularClassificationPredictionResult",
+ "TabularRegressionPredictionResult",
+ "TextExtractionPredictionResult",
+ "TextSentimentPredictionResult",
+ "TimeSeriesForecastingPredictionResult",
+ "VideoActionRecognitionPredictionResult",
+ "VideoClassificationPredictionResult",
+ "VideoObjectTrackingPredictionResult"
+ ]
+ },
+ "trainingjob": {
+ "definition": [
+ "AutoMlForecasting",
+ "AutoMlForecastingInputs",
+ "AutoMlForecastingMetadata",
+ "AutoMlImageClassification",
+ "AutoMlImageClassificationInputs",
+ "AutoMlImageClassificationMetadata",
+ "AutoMlImageObjectDetection",
+ "AutoMlImageObjectDetectionInputs",
+ "AutoMlImageObjectDetectionMetadata",
+ "AutoMlImageSegmentation",
+ "AutoMlImageSegmentationInputs",
+ "AutoMlImageSegmentationMetadata",
+ "AutoMlTables",
+ "AutoMlTablesInputs",
+ "AutoMlTablesMetadata",
+ "AutoMlTextClassification",
+ "AutoMlTextClassificationInputs",
+ "AutoMlTextExtraction",
+ "AutoMlTextExtractionInputs",
+ "AutoMlTextSentiment",
+ "AutoMlTextSentimentInputs",
+ "AutoMlVideoActionRecognition",
+ "AutoMlVideoActionRecognitionInputs",
+ "AutoMlVideoClassification",
+ "AutoMlVideoClassificationInputs",
+ "AutoMlVideoObjectTracking",
+ "AutoMlVideoObjectTrackingInputs"
+ ]
+ }
+ }
+}
diff --git a/src/helpers.ts b/src/helpers.ts
new file mode 100644
index 00000000..f0752768
--- /dev/null
+++ b/src/helpers.ts
@@ -0,0 +1,79 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+import {
+ googleProtobufValueFromObject,
+ googleProtobufValueToObject,
+ ValueType,
+} from './value-converter';
+
+interface ToValueFunction {
+ toValue(): null | object | undefined;
+ // Add these two members so that we can convert to Message objects more easily.
+ $type: unknown;
+ toJson(): string;
+}
+
+// Assigns the toValue() function as a member of an enhanced class.
+export function addToValue() {
+ const methods: ToValueFunction = ({} as unknown) as ToValueFunction;
+
+ methods.toValue = function () {
+ return toValue((this as unknown) as protobuf.Message);
+ };
+
+ return methods;
+}
+
+/**
+ * Converts a protobuf.Message to a protobuf.Value object.
+ * @param message Message to convert
+ * @returns a Value-formatted object
+ */
+export function toValue(
+ message: protobuf.Message
+): null | object | undefined | protobuf.common.IValue {
+ if (message === undefined) {
+ return undefined;
+ }
+
+ const value = googleProtobufValueFromObject(
+ (message as unknown) as ValueType,
+ (val: object) => {
+ return val;
+ }
+ );
+ return value;
+}
+
+/**
+ * Creates instance of class from a protobuf.Value object.
+ * @param value Value to convert
+ * @returns a Message
+ */
+export function fromValue(
+ value: protobuf.common.IValue
+): object | null | undefined | string | number | ValueType | boolean {
+ if (!value) {
+ return undefined;
+ }
+
+ if (!value.structValue || !value.structValue.fields) {
+ throw new Error(
+ 'ERROR: fromValue() was provided a malformed protobuf object'
+ );
+ }
+
+ return googleProtobufValueToObject(value);
+}
diff --git a/src/index.ts b/src/index.ts
index 9e4a16eb..898a9b39 100644
--- a/src/index.ts
+++ b/src/index.ts
@@ -57,3 +57,12 @@ export default {
};
import * as protos from '../protos/protos';
export {protos};
+
+import {fromValue, toValue} from './helpers';
+
+const helpers = {toValue, fromValue};
+
+export {helpers};
+
+import {_enhance} from './decorator';
+_enhance('v1beta1');
diff --git a/src/value-converter.ts b/src/value-converter.ts
new file mode 100644
index 00000000..7e0cd5b1
--- /dev/null
+++ b/src/value-converter.ts
@@ -0,0 +1,123 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// TODO(): Remove this file once https://github.com/protobufjs/protobuf.js/pull/1495 is submitted.
+export interface ValueType {
+ [index: string]:
+ | null
+ | boolean
+ | string
+ | number
+ | ValueType
+ | Array;
+}
+
+// INTERNAL ONLY. This function is not exposed to external callers.
+export function googleProtobufValueFromObject(
+ object: ValueType,
+ create: (result: object) => object
+): object | null | ValueType | protobuf.common.IValue {
+ if (object === null) {
+ return create({
+ kind: 'nullValue',
+ nullValue: 0,
+ });
+ }
+ if (typeof object === 'boolean') {
+ return create({
+ kind: 'boolValue',
+ boolValue: object,
+ });
+ }
+ if (typeof object === 'number') {
+ return create({
+ kind: 'numberValue',
+ numberValue: object,
+ });
+ }
+ if (typeof object === 'string') {
+ return create({
+ kind: 'stringValue',
+ stringValue: object,
+ });
+ }
+ if (Array.isArray(object)) {
+ const array = object.map(element => {
+ return googleProtobufValueFromObject(element, create);
+ });
+ return create({
+ kind: 'listValue',
+ listValue: {
+ values: array,
+ },
+ });
+ }
+ if (typeof object === 'object') {
+ // tslint:disable-next-line no-explicit-any
+ const fields: any = {},
+ names: string[] = Object.keys(object);
+ for (let i = 0; i < names.length; ++i) {
+ const fieldName = names[i];
+ fields[fieldName] = googleProtobufValueFromObject(
+ object[fieldName] as ValueType,
+ create
+ );
+ }
+ return create({
+ kind: 'structValue',
+ structValue: {
+ fields: fields,
+ },
+ });
+ }
+ return null;
+}
+
+// INTERNAL ONLY. This function not exposed to external callers.
+// recursive google.protobuf.Value to plain JS object
+export function googleProtobufValueToObject(
+ message: protobuf.common.IValue
+): object | null | undefined | boolean | number | string {
+ if (message.kind === 'boolValue') {
+ return message.boolValue;
+ }
+ if (message.kind === 'nullValue') {
+ return null;
+ }
+ if (message.kind === 'numberValue') {
+ return message.numberValue;
+ }
+ if (message.kind === 'stringValue') {
+ return message.stringValue;
+ }
+ if (message.kind === 'listValue') {
+ return message.listValue?.values?.map(googleProtobufValueToObject);
+ }
+ if (message.kind === 'structValue') {
+ if (!message.structValue?.fields) {
+ return {};
+ }
+ const names = Object.keys(message.structValue.fields),
+ // tslint:disable-next-line no-explicit-any
+ struct: any = {};
+ for (let i = 0; i < names.length; ++i) {
+ struct[names[i]] = googleProtobufValueToObject(
+ message.structValue['fields'][names[i]]
+ );
+ }
+ return struct;
+ }
+ return undefined;
+}
diff --git a/synth.metadata b/synth.metadata
index 55175367..e717c9c6 100644
--- a/synth.metadata
+++ b/synth.metadata
@@ -3,23 +3,23 @@
{
"git": {
"name": ".",
- "remote": "https://github.com/googleapis/nodejs-ai-platform.git",
- "sha": "67351fc5139bdf8db1b244abe884c30003cbc4c1"
+ "remote": "git@github.com:googleapis/nodejs-ai-platform.git",
+ "sha": "e97531fa5a374ec9215d21aa6b9c3434490a531f"
}
},
{
"git": {
"name": "googleapis",
"remote": "https://github.com/googleapis/googleapis.git",
- "sha": "16dd59787d6ce130ab66066c02eeea9dac0c8f0e",
- "internalRef": "345712055"
+ "sha": "6dae98144d466d4f985b926baec6208b01572f55",
+ "internalRef": "347459563"
}
},
{
"git": {
"name": "synthtool",
"remote": "https://github.com/googleapis/synthtool.git",
- "sha": "15013eff642a7e7e855aed5a29e6e83c39beba2a"
+ "sha": "996775eca5fd934edac3c2ae34b80ff0395b1717"
}
}
],
diff --git a/synth.py b/synth.py
index aeaac29f..d81670c8 100644
--- a/synth.py
+++ b/synth.py
@@ -21,23 +21,41 @@
logging.basicConfig(level=logging.DEBUG)
+# List of excludes for the enhanced library
+excludes = [
+ "package.json",
+ "README.md",
+ "src/decorator.ts",
+ "src/enhanced-types.json",
+ "src/helpers.ts",
+ "src/index.ts",
+ "src/value-converter.ts",
+ "test/helpers.test.ts",
+ "test/index.test.ts",
+ "tsconfig.json",
+]
+
# run the gapic generator
gapic = gcp.GAPICBazel()
-versions = ['v1beta1']
-name = 'aiplatform'
+versions = ["v1beta1"]
+name = "aiplatform"
for version in versions:
- library = gapic.node_library(name, version)
- s.copy(library, excludes=["package.json", "README.md"])
+ library = gapic.node_library(name, version)
+ s.copy(library, excludes=excludes)
# Copy common templates
common_templates = gcp.CommonTemplates()
templates = common_templates.node_library(
- source_location='build/src', versions=versions)
+ source_location="build/src", versions=versions
+)
# We override the default sample configuration with a custom
# environment file:
-s.copy(templates, excludes=[
- ".kokoro/continuous/node12/samples-test.cfg",
- ".kokoro/presubmit/node12/samples-test.cfg"
-])
+s.copy(
+ templates,
+ excludes=[
+ ".kokoro/continuous/node12/samples-test.cfg",
+ ".kokoro/presubmit/node12/samples-test.cfg",
+ ],
+)
node.postprocess_gapic_library()
diff --git a/test/helpers.test.ts b/test/helpers.test.ts
new file mode 100644
index 00000000..943861f3
--- /dev/null
+++ b/test/helpers.test.ts
@@ -0,0 +1,322 @@
+/*!
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import {describe, it} from 'mocha';
+import * as assert from 'assert';
+// eslint-disable-next-line @typescript-eslint/no-var-requires
+const aiplatform = require('../src');
+
+describe('AI Platform helper methods', () => {
+ const dataTypeObject = {
+ myBool: true,
+ myInt: 4,
+ myString: 'hello',
+ myNull: null,
+ myList: ['one', 'two'],
+ myObj: {
+ nested: 'obj',
+ },
+ };
+ const protobufTypeObject = {
+ kind: 'structValue',
+ structValue: {
+ fields: {
+ myBool: {
+ kind: 'boolValue',
+ boolValue: true,
+ },
+ myInt: {
+ kind: 'numberValue',
+ numberValue: 4,
+ },
+ myString: {
+ kind: 'stringValue',
+ stringValue: 'hello',
+ },
+ myNull: {
+ kind: 'nullValue',
+ nullValue: 0,
+ },
+ myList: {
+ kind: 'listValue',
+ listValue: {
+ values: [
+ {
+ kind: 'stringValue',
+ stringValue: 'one',
+ },
+ {
+ kind: 'stringValue',
+ stringValue: 'two',
+ },
+ ],
+ },
+ },
+ myObj: {
+ kind: 'structValue',
+ structValue: {
+ fields: {
+ nested: {
+ kind: 'stringValue',
+ stringValue: 'obj',
+ },
+ },
+ },
+ },
+ },
+ },
+ };
+
+ describe('toValue', () => {
+ it('exposes toValue() method', () => {
+ const {helpers} = aiplatform;
+ assert.ok(helpers.toValue);
+ });
+
+ it('converts protobuf data types', () => {
+ const {helpers} = aiplatform;
+
+ const actualToValueOuput = helpers.toValue(dataTypeObject);
+ const actualInnerStruct = actualToValueOuput.structValue;
+ assert.ok(actualInnerStruct);
+
+ const actualInnerFields = actualInnerStruct.fields;
+ assert.ok(actualInnerFields);
+
+ const actualBoolType = actualInnerFields.myBool;
+ assert.ok(actualBoolType);
+ assert.ok('boolValue' in actualBoolType);
+ assert.strictEqual(actualBoolType.boolValue, dataTypeObject.myBool);
+
+ const actualNumberType = actualInnerFields.myInt;
+ assert.ok(actualNumberType);
+ assert.ok('numberValue' in actualNumberType);
+ assert.strictEqual(actualNumberType.numberValue, dataTypeObject.myInt);
+
+ const actualStringType = actualInnerFields.myString;
+ assert.ok(actualStringType);
+ assert.ok('stringValue' in actualStringType);
+ assert.strictEqual(actualStringType.stringValue, dataTypeObject.myString);
+
+ const actualNullType = actualInnerFields.myNull;
+ assert.ok(actualNullType);
+ assert.ok('nullValue' in actualNullType);
+
+ const actualListType = actualInnerFields.myList;
+ assert.ok(actualListType);
+ assert.ok('listValue' in actualListType);
+ assert.ok('values' in actualListType.listValue);
+ assert.strictEqual(
+ actualListType.listValue.values.length,
+ dataTypeObject.myList.length
+ );
+
+ const actualStructType = actualInnerFields.myObj;
+ assert.ok(actualStructType);
+ assert.ok('structValue' in actualStructType);
+ assert.ok('fields' in actualStructType.structValue);
+ assert.ok('nested' in actualStructType.structValue.fields);
+ });
+
+ it('creates an empty protobuf structure when passed an empty object', () => {
+ const {helpers} = aiplatform;
+ const actualEmptyProtobufStruct = helpers.toValue({});
+
+ assert.ok(actualEmptyProtobufStruct);
+ assert.ok('structValue' in actualEmptyProtobufStruct);
+ assert.ok('fields' in actualEmptyProtobufStruct.structValue);
+ });
+
+ it('returns undefined if not passed an argument', () => {
+ const {helpers} = aiplatform;
+ const actualUndefinedResult = helpers.toValue();
+ assert.strictEqual(actualUndefinedResult, undefined);
+ });
+ });
+
+ describe('fromValue', () => {
+ it('exposes fromValue() method', () => {
+ const {helpers} = aiplatform;
+ assert.ok(helpers.fromValue);
+ });
+
+ it('converts protobuf object formatting to plain JavaScript objects', () => {
+ const {helpers} = aiplatform;
+ const actualConvertedObject = helpers.fromValue(protobufTypeObject);
+
+ assert.ok('myBool' in actualConvertedObject);
+ assert.ok('myInt' in actualConvertedObject);
+ assert.ok('myString' in actualConvertedObject);
+ assert.ok('myList' in actualConvertedObject);
+ assert.strictEqual(actualConvertedObject.myList.length, 2);
+ assert.ok('myObj' in actualConvertedObject);
+ assert.ok('nested' in actualConvertedObject.myObj);
+ });
+
+ it('throws an error if not provided a protobuf-formatted object', () => {
+ const {helpers} = aiplatform;
+ const malformedProtobufObject = {something: 'malformed'};
+
+ assert.throws(() => {
+ helpers.fromValue(malformedProtobufObject);
+ });
+ });
+
+ it('returns undefined if not passed an argument', () => {
+ const {helpers} = aiplatform;
+ const actualUndefinedResult = helpers.fromValue();
+ assert.strictEqual(actualUndefinedResult, undefined);
+ });
+ });
+
+ describe('dynamically assigned methods', () => {
+ const {
+ definition,
+ } = aiplatform.protos.google.cloud.aiplatform.v1beta1.schema.trainingjob;
+ const {
+ instance,
+ } = aiplatform.protos.google.cloud.aiplatform.v1beta1.schema.predict;
+
+ describe('toValue', () => {
+ const imageClassificationTrainingInputs = {
+ multiLabel: true,
+ modelType: definition.AutoMlImageClassificationInputs.ModelType.CLOUD,
+ budgetMilliNodeHours: 8000,
+ disableEarlyStopping: false,
+ };
+
+ const textClassificationPredictionInstance = {
+ content: 'this is some fake text',
+ };
+
+ it('exposes toValue() on instances of enhanced types', () => {
+ const actualTrainingTaskInputs = new definition.AutoMlImageClassificationInputs();
+ assert.ok(actualTrainingTaskInputs.toValue);
+ });
+
+ it('converts an enhanced type to a protobuf-formatted object', () => {
+ const trainingTaskInputs = new definition.AutoMlImageClassificationInputs(
+ imageClassificationTrainingInputs
+ );
+ const predictionInstance = new instance.TextClassificationPredictionInstance(
+ textClassificationPredictionInstance
+ );
+
+ const actualTrainingTaskInputValue = trainingTaskInputs.toValue();
+ const actualPredictionInstanceValue = predictionInstance.toValue();
+ const actualTrainingTaskFields =
+ actualTrainingTaskInputValue.structValue.fields;
+
+ assert.ok('multiLabel' in actualTrainingTaskFields);
+ assert.ok('budgetMilliNodeHours' in actualTrainingTaskFields);
+ assert.ok('modelType' in actualTrainingTaskFields);
+ assert.ok('disableEarlyStopping' in actualTrainingTaskFields);
+
+ assert.strictEqual(actualTrainingTaskFields.multiLabel.boolValue, true);
+ assert.strictEqual(
+ actualTrainingTaskFields.budgetMilliNodeHours.numberValue,
+ 8000
+ );
+ assert.strictEqual(
+ actualTrainingTaskFields.disableEarlyStopping.boolValue,
+ false
+ );
+ assert.strictEqual(
+ actualTrainingTaskFields.modelType.numberValue,
+ definition.AutoMlImageClassificationInputs.ModelType.CLOUD
+ );
+
+ const actualTextInstance =
+ actualPredictionInstanceValue.structValue.fields.content.stringValue;
+ assert.notStrictEqual(
+ actualTextInstance.indexOf(
+ textClassificationPredictionInstance.content
+ ),
+ -1
+ );
+ });
+ });
+
+ describe('fromValue', () => {
+ const imageClassificationTrainingInputs = {
+ kind: 'structValue',
+ structValue: {
+ fields: {
+ multiLabel: {
+ kind: 'boolValue',
+ boolValue: true,
+ },
+ modelType: {
+ kind: 'numberValue',
+ numberValue: 1,
+ },
+ budgetMilliNodeHours: {
+ kind: 'numberValue',
+ numberValue: 8000,
+ },
+ disableEarlyStopping: {
+ kind: 'boolValue',
+ boolValue: false,
+ },
+ },
+ },
+ };
+ const testText = 'this is some fake text';
+ const textClassificationPredictionInstance = {
+ kind: 'structValue',
+ structValue: {
+ fields: {
+ content: {
+ kind: 'stringValue',
+ stringValue: testText,
+ },
+ },
+ },
+ };
+
+ it('exposes fromValue() as a static method on enhanced types', () => {
+ assert.notStrictEqual(
+ definition.AutoMlImageClassificationInputs.fromValue,
+ undefined
+ );
+ });
+
+ it('converts protobuf-formatted objects into instanced of enhanced types', () => {
+ const actualTrainingTaskInputs = definition.AutoMlImageClassificationInputs.fromValue(
+ imageClassificationTrainingInputs
+ );
+ const actualPredictionInstance = instance.TextClassificationPredictionInstance.fromValue(
+ textClassificationPredictionInstance
+ );
+
+ assert.strictEqual(actualTrainingTaskInputs.budgetMilliNodeHours, 8000);
+ assert.strictEqual(
+ actualTrainingTaskInputs.modelType,
+ definition.AutoMlImageClassificationInputs.ModelType.CLOUD
+ );
+ assert.strictEqual(
+ actualTrainingTaskInputs.disableEarlyStopping,
+ false
+ );
+ assert.strictEqual(actualTrainingTaskInputs.multiLabel, true);
+ assert.notStrictEqual(
+ actualPredictionInstance.content.indexOf(testText),
+ -1
+ );
+ });
+ });
+ });
+});
diff --git a/test/index.test.ts b/test/index.test.ts
new file mode 100644
index 00000000..119a3a75
--- /dev/null
+++ b/test/index.test.ts
@@ -0,0 +1,62 @@
+/*!
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import {describe, it} from 'mocha';
+import * as assert from 'assert';
+// eslint-disable-next-line @typescript-eslint/no-var-requires
+const aiplatform = require('../src');
+const enhancedTypes = require('../src/enhanced-types.json');
+
+describe('AI Platform enhanced types', () => {
+ const definitionTypeNames = enhancedTypes.schema.trainingjob.definition;
+ const predictInstanceTypeNames = enhancedTypes.schema.predict.instance;
+ const predictParamsTypeNames = enhancedTypes.schema.predict.params;
+ const predictResultTypeNames = enhancedTypes.schema.predict.prediction;
+
+ const {
+ definition,
+ } = aiplatform.protos.google.cloud.aiplatform.v1beta1.schema.trainingjob;
+ const {
+ instance,
+ params,
+ prediction,
+ } = aiplatform.protos.google.cloud.aiplatform.v1beta1.schema.predict;
+
+ function testNamespaceAgainstArray(
+ namespace: Record,
+ arr: string[]
+ ) {
+ for (const name of arr) {
+ assert.ok(name in namespace);
+ }
+ }
+
+ it('adds training job definition types', () => {
+ testNamespaceAgainstArray(definition, definitionTypeNames);
+ });
+
+ it('adds prediction instance types', () => {
+ testNamespaceAgainstArray(instance, predictInstanceTypeNames);
+ });
+
+ it('adds prediction param types', () => {
+ testNamespaceAgainstArray(params, predictParamsTypeNames);
+ });
+
+ it('adds prediction result types', () => {
+ testNamespaceAgainstArray(prediction, predictResultTypeNames);
+ });
+});
diff --git a/tsconfig.json b/tsconfig.json
index c78f1c88..cf56e41f 100644
--- a/tsconfig.json
+++ b/tsconfig.json
@@ -14,6 +14,7 @@
"src/**/*.ts",
"test/*.ts",
"test/**/*.ts",
- "system-test/*.ts"
+ "system-test/*.ts",
+ "src/enhanced-types.json"
]
}