Skip to content

Commit

Permalink
chore(gen-ai): remove strict validation of query response for aggrega…
Browse files Browse the repository at this point in the history
…tion generation (#5858)
  • Loading branch information
Anemy authored Jun 4, 2024
1 parent 7c89ff2 commit b319a42
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ import util from 'util';
import { execFile as callbackExecFile } from 'child_process';
import decomment from 'decomment';

import {
validateAIQueryResponse,
validateAIAggregationResponse,
} from '../../src/atlas-ai-service';
import { loadFixturesToDB } from './fixtures';
import type { Fixtures } from './fixtures';
import { AtlasAPI } from './ai-backend';
Expand Down Expand Up @@ -229,6 +233,10 @@ const runOnce = async (
if (assertResult) {
let cursor;

type === 'query'
? validateAIQueryResponse(response)
: validateAIAggregationResponse(response);

if (
type === 'aggregation' ||
(type === 'query' &&
Expand Down
201 changes: 103 additions & 98 deletions packages/compass-generative-ai/src/atlas-ai-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,107 @@ function buildQueryOrAggregationMessageBody(
return msgBody;
}

function hasExtraneousKeys(obj: any, expectedKeys: string[]) {
return Object.keys(obj).some((key) => !expectedKeys.includes(key));
}

export function validateAIQueryResponse(
response: any
): asserts response is AIQuery {
const { content } = response ?? {};

if (typeof content !== 'object' || content === null) {
throw new Error('Unexpected response: expected content to be an object');
}

if (hasExtraneousKeys(content, ['query', 'aggregation'])) {
throw new Error(
'Unexpected keys in response: expected query and aggregation'
);
}

const { query, aggregation } = content;

if (!query && !aggregation) {
throw new Error(
'Unexpected response: expected query or aggregation, got none'
);
}

if (query && typeof query !== 'object') {
throw new Error('Unexpected response: expected query to be an object');
}

if (
hasExtraneousKeys(query, [
'filter',
'project',
'collation',
'sort',
'skip',
'limit',
])
) {
throw new Error(
'Unexpected keys in response: expected filter, project, collation, sort, skip, limit, aggregation'
);
}

for (const field of [
'filter',
'project',
'collation',
'sort',
'skip',
'limit',
]) {
if (query[field] && typeof query[field] !== 'string') {
throw new Error(
`Unexpected response: expected field ${field} to be a string, got ${JSON.stringify(
query[field],
null,
2
)}`
);
}
}

if (aggregation && typeof aggregation.pipeline !== 'string') {
throw new Error(
`Unexpected response: expected aggregation pipeline to be a string, got ${JSON.stringify(
aggregation,
null,
2
)}`
);
}
}

export function validateAIAggregationResponse(
response: any
): asserts response is AIAggregation {
const { content } = response;

if (typeof content !== 'object' || content === null) {
throw new Error('Unexpected response: expected content to be an object');
}

if (hasExtraneousKeys(content, ['aggregation'])) {
throw new Error('Unexpected keys in response: expected aggregation');
}

if (content.aggregation && typeof content.aggregation.pipeline !== 'string') {
// Compared to queries where we will always get the `query` field, for
// aggregations backend deletes the whole `aggregation` key if pipeline is
// empty, so we only validate `pipeline` key if `aggregation` key is present
throw new Error(
`Unexpected response: expected aggregation to be a string, got ${String(
content.aggregation.pipeline
)}`
);
}
}

export class AtlasAiService {
private initPromise: Promise<void> | null = null;

Expand Down Expand Up @@ -240,110 +341,18 @@ export class AtlasAiService {
return this.getQueryOrAggregationFromUserInput(
AGGREGATION_URI,
input,
this.validateAIAggregationResponse.bind(this)
validateAIAggregationResponse
);
}

async getQueryFromUserInput(input: GenerativeAiInput) {
return this.getQueryOrAggregationFromUserInput(
QUERY_URI,
input,
this.validateAIQueryResponse.bind(this)
validateAIQueryResponse
);
}

private validateAIQueryResponse(response: any): asserts response is AIQuery {
const { content } = response ?? {};

if (typeof content !== 'object' || content === null) {
throw new Error('Unexpected response: expected content to be an object');
}

if (this.hasExtraneousKeys(content, ['query', 'aggregation'])) {
throw new Error(
'Unexpected keys in response: expected query and aggregation'
);
}

const { query, aggregation } = content;

if (typeof query !== 'object' || query === null) {
throw new Error('Unexpected response: expected query to be an object');
}

if (
this.hasExtraneousKeys(query, [
'filter',
'project',
'collation',
'sort',
'skip',
'limit',
])
) {
throw new Error(
'Unexpected keys in response: expected filter, project, collation, sort, skip, limit, aggregation'
);
}

for (const field of [
'filter',
'project',
'collation',
'sort',
'skip',
'limit',
]) {
if (query[field] && typeof query[field] !== 'string') {
throw new Error(
`Unexpected response: expected field ${field} to be a string, got ${JSON.stringify(
query[field],
null,
2
)}`
);
}
}

if (aggregation && typeof aggregation.pipeline !== 'string') {
throw new Error(
`Unexpected response: expected aggregation pipeline to be a string, got ${JSON.stringify(
aggregation,
null,
2
)}`
);
}
}

private validateAIAggregationResponse(
response: any
): asserts response is AIAggregation {
const { content } = response;

if (typeof content !== 'object' || content === null) {
throw new Error('Unexpected response: expected content to be an object');
}

if (this.hasExtraneousKeys(content, ['aggregation'])) {
throw new Error('Unexpected keys in response: expected aggregation');
}

if (
content.aggregation &&
typeof content.aggregation.pipeline !== 'string'
) {
// Compared to queries where we will always get the `query` field, for
// aggregations backend deletes the whole `aggregation` key if pipeline is
// empty, so we only validate `pipeline` key if `aggregation` key is present
throw new Error(
`Unexpected response: expected aggregation to be a string, got ${String(
content.aggregation.pipeline
)}`
);
}
}

private validateAIFeatureEnablementResponse(
response: any
): asserts response is AIFeatureEnablement {
Expand All @@ -352,8 +361,4 @@ export class AtlasAiService {
throw new Error('Unexpected response: expected features to be an object');
}
}

private hasExtraneousKeys(obj: any, expectedKeys: string[]) {
return Object.keys(obj).some((key) => !expectedKeys.includes(key));
}
}

0 comments on commit b319a42

Please sign in to comment.