diff --git a/package-lock.json b/package-lock.json index b869350b75b..bfdf9e0d774 100644 --- a/package-lock.json +++ b/package-lock.json @@ -8092,8 +8092,7 @@ "node_modules/@mongodb-js/mongodb-constants": { "version": "0.2.2", "resolved": "https://registry.npmjs.org/@mongodb-js/mongodb-constants/-/mongodb-constants-0.2.2.tgz", - "integrity": "sha512-vm1G+/WRWmXGyE9ZnhDv9toe+LRu1x0F/lGEwqWESfBiUUUuVZhj25fS2o4IL7H4pJ31sFxr7/gu+ER8OkmtzA==", - "dev": true + "integrity": "sha512-vm1G+/WRWmXGyE9ZnhDv9toe+LRu1x0F/lGEwqWESfBiUUUuVZhj25fS2o4IL7H4pJ31sFxr7/gu+ER8OkmtzA==" }, "node_modules/@mongodb-js/mongodb-downloader": { "version": "0.2.6", @@ -41227,9 +41226,22 @@ "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.2.0.tgz", "integrity": "sha512-0zTyLGyDJYd/MBxG1AhJkKa6fpEBds4OQO2ut0w7OYG+ZGhGea09lijvzsqegYSik88zc7cUtIlnnO+/BvD6gQ==", "dependencies": { + "@mongodb-js/atlas-service": "^0.2.1", + "@mongodb-js/compass-components": "^1.11.0", + "@mongodb-js/compass-crud": "^13.11.1", + "@mongodb-js/compass-editor": "^0.10.0", + "@mongodb-js/compass-logging": "^1.1.7", + "@mongodb-js/compass-utils": "^0.3.3", + "@mongodb-js/explain-plan-helper": "^1.1.0", + "@mongodb-js/mongodb-constants": "^0.6.0", + "@mongodb-js/mongodb-redux-common": "^2.0.9", "@types/json-schema": "^7.0.8", "ajv": "^6.12.5", - "ajv-keywords": "^3.5.2" + "ajv-keywords": "^3.5.2", + "bson": "^5.2.0", + "compass-preferences-model": "^2.11.1", + "hadron-document": "^8.3.0", + "hadron-type-checker": "^7.0.3" }, "engines": { "node": ">= 10.13.0" @@ -41237,6 +41249,22 @@ "funding": { "type": "opencollective", "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "@mongodb-js/atlas-service": "^0.2.1", + "@mongodb-js/compass-components": "^1.11.0", + "@mongodb-js/compass-crud": "^13.11.1", + "@mongodb-js/compass-editor": "^0.10.0", + "@mongodb-js/compass-logging": "^1.1.7", + "@mongodb-js/compass-utils": "^0.3.3", + "@mongodb-js/explain-plan-helper": "^1.1.0", + "@mongodb-js/mongodb-constants": "^0.6.0", + "@mongodb-js/mongodb-redux-common": "^2.0.9", + "bson": "^5.2.0", + "compass-preferences-model": "^2.11.1", + "hadron-document": "^8.3.0", + "hadron-type-checker": "^7.0.3", + "react": "^17.0.2" } }, "node_modules/terser-webpack-plugin/node_modules/supports-color": { @@ -44319,6 +44347,7 @@ "version": "9.12.1", "license": "SSPL", "dependencies": { + "@mongodb-js/atlas-service": "^0.2.1", "@mongodb-js/compass-components": "^1.11.0", "@mongodb-js/compass-crud": "^13.11.1", "@mongodb-js/compass-editor": "^0.10.0", @@ -44363,6 +44392,7 @@ "mongodb-data-service": "^22.9.1", "mongodb-ns": "^2.4.0", "mongodb-query-parser": "^2.5.0", + "mongodb-schema": "^11.2.1", "nyc": "^15.1.0", "prettier": "^2.7.1", "prop-types": "^15.7.2", @@ -44377,6 +44407,7 @@ "xvfb-maybe": "^0.2.1" }, "peerDependencies": { + "@mongodb-js/atlas-service": "^0.2.1", "@mongodb-js/compass-components": "^1.11.0", "@mongodb-js/compass-crud": "^13.11.1", "@mongodb-js/compass-editor": "^0.10.0", @@ -58251,6 +58282,7 @@ "@dnd-kit/sortable": "^7.0.2", "@dnd-kit/utilities": "^3.2.1", "@electron/remote": "^2.0.10", + "@mongodb-js/atlas-service": "^0.2.1", "@mongodb-js/compass-components": "^1.11.0", "@mongodb-js/compass-crud": "^13.11.1", "@mongodb-js/compass-editor": "^0.10.0", @@ -58286,6 +58318,7 @@ "mongodb-data-service": "^22.9.1", "mongodb-ns": "^2.4.0", "mongodb-query-parser": "^2.5.0", + "mongodb-schema": "^11.2.1", "nyc": "^15.1.0", "prettier": "^2.7.1", "prop-types": "^15.7.2", @@ -61972,8 +62005,7 @@ "@mongodb-js/mongodb-constants": { "version": "0.2.2", "resolved": "https://registry.npmjs.org/@mongodb-js/mongodb-constants/-/mongodb-constants-0.2.2.tgz", - "integrity": "sha512-vm1G+/WRWmXGyE9ZnhDv9toe+LRu1x0F/lGEwqWESfBiUUUuVZhj25fS2o4IL7H4pJ31sFxr7/gu+ER8OkmtzA==", - "dev": true + "integrity": "sha512-vm1G+/WRWmXGyE9ZnhDv9toe+LRu1x0F/lGEwqWESfBiUUUuVZhj25fS2o4IL7H4pJ31sFxr7/gu+ER8OkmtzA==" }, "@mongodb-js/mongodb-downloader": { "version": "0.2.6", @@ -92392,9 +92424,22 @@ "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.2.0.tgz", "integrity": "sha512-0zTyLGyDJYd/MBxG1AhJkKa6fpEBds4OQO2ut0w7OYG+ZGhGea09lijvzsqegYSik88zc7cUtIlnnO+/BvD6gQ==", "requires": { + "@mongodb-js/atlas-service": "^0.2.1", + "@mongodb-js/compass-components": "^1.11.0", + "@mongodb-js/compass-crud": "^13.11.1", + "@mongodb-js/compass-editor": "^0.10.0", + "@mongodb-js/compass-logging": "^1.1.7", + "@mongodb-js/compass-utils": "^0.3.3", + "@mongodb-js/explain-plan-helper": "^1.1.0", + "@mongodb-js/mongodb-constants": "^0.6.0", + "@mongodb-js/mongodb-redux-common": "^2.0.9", "@types/json-schema": "^7.0.8", "ajv": "^6.12.5", - "ajv-keywords": "^3.5.2" + "ajv-keywords": "^3.5.2", + "bson": "^5.2.0", + "compass-preferences-model": "^2.11.1", + "hadron-document": "^8.3.0", + "hadron-type-checker": "^7.0.3" } }, "supports-color": { diff --git a/packages/atlas-service/src/main.spec.ts b/packages/atlas-service/src/main.spec.ts index 8e287d67f88..d268255fd18 100644 --- a/packages/atlas-service/src/main.spec.ts +++ b/packages/atlas-service/src/main.spec.ts @@ -13,6 +13,20 @@ function getListenerCount(emitter: EventEmitter) { }, 0); } +const atlasAIServiceTests: { + functionName: 'getQueryFromUserInput' | 'getAggregationFromUserInput'; + aiEndpoint: string; +}[] = [ + { + functionName: 'getQueryFromUserInput', + aiEndpoint: 'mql-query', + }, + { + functionName: 'getAggregationFromUserInput', + aiEndpoint: 'mql-aggregation', + }, +]; + describe('AtlasServiceMain', function () { const sandbox = Sinon.createSandbox(); @@ -199,170 +213,172 @@ describe('AtlasServiceMain', function () { }); }); - describe('getQueryFromUserInput', function () { - it('makes a post request with the user input to the endpoint in the environment', async function () { - AtlasService['fetch'] = sandbox.stub().resolves({ - ok: true, - json() { - return Promise.resolve({ - content: { query: { find: { test: 'pineapple' } } }, - }); - }, - }) as any; - - const res = await AtlasService.getQueryFromUserInput({ - userInput: 'test', - signal: new AbortController().signal, - collectionName: 'jam', - databaseName: 'peanut', - schema: { _id: { types: [{ bsonType: 'ObjectId' }] } }, - sampleDocuments: [{ _id: 1234 }], - }); - - const { args } = ( - AtlasService['fetch'] as unknown as Sinon.SinonStub - ).getCall(0); - - expect(AtlasService['fetch']).to.have.been.calledOnce; - expect(args[0]).to.eq('http://example.com/ai/api/v1/mql-query'); - expect(args[1].body).to.eq( - '{"userInput":"test","collectionName":"jam","databaseName":"peanut","schema":{"_id":{"types":[{"bsonType":"ObjectId"}]}},"sampleDocuments":[{"_id":1234}]}' - ); - expect(res).to.have.nested.property( - 'content.query.find.test', - 'pineapple' - ); - }); + for (const { functionName, aiEndpoint } of atlasAIServiceTests) { + describe(functionName, function () { + it('makes a post request with the user input to the endpoint in the environment', async function () { + AtlasService['fetch'] = sandbox.stub().resolves({ + ok: true, + json() { + return Promise.resolve({ + content: { query: { filter: "{ test: 'pineapple' }" } }, + }); + }, + }) as any; - it('uses the abort signal in the fetch request', async function () { - const c = new AbortController(); - c.abort(); - try { - await AtlasService.getQueryFromUserInput({ - signal: c.signal, + const res = await AtlasService[functionName]({ userInput: 'test', - collectionName: 'test', + signal: new AbortController().signal, + collectionName: 'jam', databaseName: 'peanut', + schema: { _id: { types: [{ bsonType: 'ObjectId' }] } }, + sampleDocuments: [{ _id: 1234 }], }); - expect.fail('Expected getQueryFromUserInput to throw'); - } catch (err) { - expect(err).to.have.property('message', 'This operation was aborted'); - } - }); - it('throws if the request would be too much for the ai', async function () { - try { - await AtlasService.getQueryFromUserInput({ - userInput: 'test', - collectionName: 'test', - databaseName: 'peanut', - sampleDocuments: [{ test: '4'.repeat(60000) }], - }); - expect.fail('Expected getQueryFromUserInput to throw'); - } catch (err) { - expect(err).to.have.property( - 'message', - 'Error: too large of a request to send to the ai. Please use a smaller prompt or collection with smaller documents.' - ); - } - }); - - it('passes fewer documents if the request would be too much for the ai with all of the documents', async function () { - AtlasService['fetch'] = sandbox.stub().resolves({ - ok: true, - json() { - return Promise.resolve({}); - }, - }) as any; + const { args } = ( + AtlasService['fetch'] as unknown as Sinon.SinonStub + ).getCall(0); - await AtlasService.getQueryFromUserInput({ - userInput: 'test', - collectionName: 'test.test', - databaseName: 'peanut', - sampleDocuments: [ - { a: '1' }, - { a: '2' }, - { a: '3' }, - { a: '4'.repeat(50000) }, - ], + expect(AtlasService['fetch']).to.have.been.calledOnce; + expect(args[0]).to.eq(`http://example.com/ai/api/v1/${aiEndpoint}`); + expect(args[1].body).to.eq( + '{"userInput":"test","collectionName":"jam","databaseName":"peanut","schema":{"_id":{"types":[{"bsonType":"ObjectId"}]}},"sampleDocuments":[{"_id":1234}]}' + ); + expect(res).to.have.nested.property( + 'content.query.filter', + "{ test: 'pineapple' }" + ); }); - const { args } = ( - AtlasService['fetch'] as unknown as Sinon.SinonStub - ).getCall(0); + it('uses the abort signal in the fetch request', async function () { + const c = new AbortController(); + c.abort(); + try { + await AtlasService[functionName]({ + signal: c.signal, + userInput: 'test', + collectionName: 'test', + databaseName: 'peanut', + }); + expect.fail(`Expected ${functionName} to throw`); + } catch (err) { + expect(err).to.have.property('message', 'This operation was aborted'); + } + }); - expect(AtlasService['fetch']).to.have.been.calledOnce; - expect(args[1].body).to.eq( - '{"userInput":"test","collectionName":"test.test","databaseName":"peanut","sampleDocuments":[{"a":"1"}]}' - ); - }); + it('throws if the request would be too much for the ai', async function () { + try { + await AtlasService[functionName]({ + userInput: 'test', + collectionName: 'test', + databaseName: 'peanut', + sampleDocuments: [{ test: '4'.repeat(60000) }], + }); + expect.fail(`Expected ${functionName} to throw`); + } catch (err) { + expect(err).to.have.property( + 'message', + 'Error: too large of a request to send to the ai. Please use a smaller prompt or collection with smaller documents.' + ); + } + }); - it('throws the error', async function () { - AtlasService['fetch'] = sandbox.stub().resolves({ - ok: false, - status: 500, - statusText: 'Internal Server Error', - }) as any; + it('passes fewer documents if the request would be too much for the ai with all of the documents', async function () { + AtlasService['fetch'] = sandbox.stub().resolves({ + ok: true, + json() { + return Promise.resolve({}); + }, + }) as any; - try { - await AtlasService.getQueryFromUserInput({ + await AtlasService[functionName]({ userInput: 'test', collectionName: 'test.test', databaseName: 'peanut', + sampleDocuments: [ + { a: '1' }, + { a: '2' }, + { a: '3' }, + { a: '4'.repeat(50000) }, + ], }); - expect.fail('Expected getQueryFromUserInput to throw'); - } catch (err) { - expect(err).to.have.property('message', '500 Internal Server Error'); - } - }); - it('should throw if COMPASS_ATLAS_SERVICE_BASE_URL is not set', async function () { - delete process.env.COMPASS_ATLAS_SERVICE_BASE_URL; + const { args } = ( + AtlasService['fetch'] as unknown as Sinon.SinonStub + ).getCall(0); - try { - await AtlasService.getQueryFromUserInput({ - userInput: 'test', - collectionName: 'test.test', - databaseName: 'peanut', - }); - expect.fail('Expected AtlasService.signIn() to throw'); - } catch (err) { - expect(err).to.have.property( - 'message', - 'No AI Query endpoint to fetch. Please set the environment variable `COMPASS_ATLAS_SERVICE_BASE_URL`' + expect(AtlasService['fetch']).to.have.been.calledOnce; + expect(args[1].body).to.eq( + '{"userInput":"test","collectionName":"test.test","databaseName":"peanut","sampleDocuments":[{"a":"1"}]}' ); - } - }); + }); - it('should wait for token refresh if called when expired', async function () { - AtlasService['fetch'] = sandbox.stub().resolves({ - ok: true, - json() { - return Promise.resolve({ test: 1 }); - }, - }) as any; - AtlasService['oidcPluginLogger'].emit( - 'mongodb-oidc-plugin:refresh-started' - ); - const [query] = await Promise.all([ - AtlasService.getQueryFromUserInput({ - userInput: 'test', - collectionName: 'test', - databaseName: 'test', - sampleDocuments: [], - }), - (() => { - AtlasService['oidcPluginLogger'].emit( - 'mongodb-oidc-plugin:refresh-succeeded' - ); - AtlasService['oidcPluginLogger'].emit( - 'mongodb-oidc-plugin:state-updated' + it('throws the error', async function () { + AtlasService['fetch'] = sandbox.stub().resolves({ + ok: false, + status: 500, + statusText: 'Internal Server Error', + }) as any; + + try { + await AtlasService[functionName]({ + userInput: 'test', + collectionName: 'test.test', + databaseName: 'peanut', + }); + expect.fail(`Expected ${functionName} to throw`); + } catch (err) { + expect(err).to.have.property('message', '500 Internal Server Error'); + } + }); + + it('should throw if COMPASS_ATLAS_SERVICE_BASE_URL is not set', async function () { + delete process.env.COMPASS_ATLAS_SERVICE_BASE_URL; + + try { + await AtlasService[functionName]({ + userInput: 'test', + collectionName: 'test.test', + databaseName: 'peanut', + }); + expect.fail('Expected AtlasService.signIn() to throw'); + } catch (err) { + expect(err).to.have.property( + 'message', + 'No AI Query endpoint to fetch. Please set the environment variable `COMPASS_ATLAS_SERVICE_BASE_URL`' ); - })(), - ]); - expect(query).to.deep.eq({ test: 1 }); + } + }); + + it('should wait for token refresh if called when expired', async function () { + AtlasService['fetch'] = sandbox.stub().resolves({ + ok: true, + json() { + return Promise.resolve({ test: 1 }); + }, + }) as any; + AtlasService['oidcPluginLogger'].emit( + 'mongodb-oidc-plugin:refresh-started' + ); + const [query] = await Promise.all([ + AtlasService[functionName]({ + userInput: 'test', + collectionName: 'test', + databaseName: 'test', + sampleDocuments: [], + }), + (() => { + AtlasService['oidcPluginLogger'].emit( + 'mongodb-oidc-plugin:refresh-succeeded' + ); + AtlasService['oidcPluginLogger'].emit( + 'mongodb-oidc-plugin:state-updated' + ); + })(), + ]); + expect(query).to.deep.eq({ test: 1 }); + }); }); - }); + } describe('throwIfNotOk', function () { it('should not throw if res is ok', async function () { diff --git a/packages/atlas-service/src/main.ts b/packages/atlas-service/src/main.ts index 66c40d75f30..7fcef7346d8 100644 --- a/packages/atlas-service/src/main.ts +++ b/packages/atlas-service/src/main.ts @@ -19,7 +19,13 @@ import type { Response } from 'node-fetch'; import fetch from 'node-fetch'; import type { SimplifiedSchema } from 'mongodb-schema'; import type { Document } from 'mongodb'; -import type { AIQuery, IntrospectInfo, Token, UserInfo } from './util'; +import type { + AIAggregation, + AIQuery, + IntrospectInfo, + Token, + UserInfo, +} from './util'; import { broadcast, getStoragePaths, @@ -262,6 +268,7 @@ export class AtlasService { 'introspect', 'isAuthenticated', 'signIn', + 'getAggregationFromUserInput', 'getQueryFromUserInput', ], this.ipcMain @@ -604,6 +611,69 @@ export class AtlasService { return res.json() as Promise; } + static async getAggregationFromUserInput({ + signal, + userInput, + collectionName, + databaseName, + schema, + sampleDocuments, + }: { + userInput: string; + collectionName: string; + databaseName: string; + schema?: SimplifiedSchema; + sampleDocuments?: Document[]; + signal?: AbortSignal; + }) { + throwIfAborted(signal); + + let msgBody = JSON.stringify({ + userInput, + collectionName, + databaseName, + schema, + sampleDocuments, + }); + if (msgBody.length > AI_MAX_REQUEST_SIZE) { + // When the message body is over the max size, we try + // to see if with fewer sample documents we can still perform the request. + // If that fails we throw an error indicating this collection's + // documents are too large to send to the ai. + msgBody = JSON.stringify({ + userInput, + collectionName, + databaseName, + schema, + sampleDocuments: sampleDocuments?.slice(0, AI_MIN_SAMPLE_DOCUMENTS), + }); + if (msgBody.length > AI_MAX_REQUEST_SIZE) { + throw new Error( + 'Error: too large of a request to send to the ai. Please use a smaller prompt or collection with smaller documents.' + ); + } + } + + await this.maybeWaitForToken({ signal }); + + const res = await this.fetch( + `${this.apiBaseUrl}/ai/api/v1/mql-aggregation`, + { + signal: signal as NodeFetchAbortSignal | undefined, + method: 'POST', + headers: { + Authorization: `Bearer ${this.token?.accessToken ?? ''}`, + 'Content-Type': 'application/json', + }, + body: msgBody, + } + ); + + await throwIfNotOk(res); + + return res.json() as Promise; + } + static async getQueryFromUserInput({ signal, userInput, @@ -640,7 +710,6 @@ export class AtlasService { schema, sampleDocuments: sampleDocuments?.slice(0, AI_MIN_SAMPLE_DOCUMENTS), }); - // Why this is not happening on the backend? if (msgBody.length > AI_MAX_REQUEST_SIZE) { throw new Error( 'Error: too large of a request to send to the ai. Please use a smaller prompt or collection with smaller documents.' diff --git a/packages/atlas-service/src/renderer.ts b/packages/atlas-service/src/renderer.ts index 8c2ba2cdb83..20578ed9a6a 100644 --- a/packages/atlas-service/src/renderer.ts +++ b/packages/atlas-service/src/renderer.ts @@ -16,18 +16,21 @@ export class AtlasService { | 'introspect' | 'isAuthenticated' | 'signIn' + | 'getAggregationFromUserInput' | 'getQueryFromUserInput' >('AtlasService', [ 'getUserInfo', 'introspect', 'isAuthenticated', 'signIn', + 'getAggregationFromUserInput', 'getQueryFromUserInput', ]); getUserInfo = this.ipc.getUserInfo; introspect = this.ipc.introspect; isAuthenticated = this.ipc.isAuthenticated; + getAggregationFromUserInput = this.ipc.getAggregationFromUserInput; getQueryFromUserInput = this.ipc.getQueryFromUserInput; async signIn( diff --git a/packages/atlas-service/src/util.ts b/packages/atlas-service/src/util.ts index 7a5613722e5..afcd1b7d128 100644 --- a/packages/atlas-service/src/util.ts +++ b/packages/atlas-service/src/util.ts @@ -11,6 +11,14 @@ export type IntrospectInfo = { active: boolean }; export type Token = plugin.IdPServerResponse; +export type AIAggregation = { + content?: { + aggregation?: { + pipeline?: unknown; + }; + }; +}; + export type AIQuery = { content?: { query?: unknown; diff --git a/packages/compass-aggregations/package.json b/packages/compass-aggregations/package.json index 33d1abbeab6..fb600be6402 100644 --- a/packages/compass-aggregations/package.json +++ b/packages/compass-aggregations/package.json @@ -38,6 +38,7 @@ }, "license": "SSPL", "peerDependencies": { + "@mongodb-js/atlas-service": "^0.2.1", "@mongodb-js/compass-components": "^1.11.0", "@mongodb-js/compass-crud": "^13.11.1", "@mongodb-js/compass-editor": "^0.10.0", @@ -83,6 +84,7 @@ "mongodb-data-service": "^22.9.1", "mongodb-ns": "^2.4.0", "mongodb-query-parser": "^2.5.0", + "mongodb-schema": "^11.2.1", "nyc": "^15.1.0", "prettier": "^2.7.1", "prop-types": "^15.7.2", @@ -97,6 +99,7 @@ "xvfb-maybe": "^0.2.1" }, "dependencies": { + "@mongodb-js/atlas-service": "^0.2.1", "@mongodb-js/compass-components": "^1.11.0", "@mongodb-js/compass-crud": "^13.11.1", "@mongodb-js/compass-editor": "^0.10.0", diff --git a/packages/compass-aggregations/src/components/pipeline-toolbar/index.tsx b/packages/compass-aggregations/src/components/pipeline-toolbar/index.tsx index 6bca6afde33..93cea0b3e62 100644 --- a/packages/compass-aggregations/src/components/pipeline-toolbar/index.tsx +++ b/packages/compass-aggregations/src/components/pipeline-toolbar/index.tsx @@ -7,10 +7,13 @@ import { useDarkMode, } from '@mongodb-js/compass-components'; import { connect } from 'react-redux'; +import { usePreference } from 'compass-preferences-model'; import PipelineHeader from './pipeline-header'; import PipelineOptions from './pipeline-options'; import PipelineSettings from './pipeline-settings'; +import { PipelineAI } from './pipeline-ai'; +import { hideInput as hideAIInput } from '../../modules/pipeline-builder/pipeline-ai'; import type { RootState } from '../../modules'; import PipelineResultsHeader from '../pipeline-results-workspace/pipeline-results-header'; @@ -52,23 +55,28 @@ const optionsStyles = css({ }); type PipelineToolbarProps = { + isAIInputVisible?: boolean; isBuilderView: boolean; showRunButton: boolean; showExportButton: boolean; showExplainButton: boolean; onChangePipelineOutputOption: (val: PipelineOutputOption) => void; + onHideAIInputClick?: () => void; pipelineOutputOption: PipelineOutputOption; }; export const PipelineToolbar: React.FunctionComponent = ({ + isAIInputVisible = false, isBuilderView, showRunButton, showExportButton, showExplainButton, onChangePipelineOutputOption, + onHideAIInputClick, pipelineOutputOption, }) => { const darkMode = useDarkMode(); + const enableAIExperience = usePreference('enableAIExperience', React); const [isOptionsVisible, setIsOptionsVisible] = useState(false); return (
= ({
)} + {enableAIExperience && ( + { + onHideAIInputClick?.(); + }} + show={isAIInputVisible} + /> + )} {isBuilderView ? (
@@ -110,7 +126,10 @@ export const PipelineToolbar: React.FunctionComponent = ({ ); }; -const mapState = ({ workspace }: RootState) => ({ - isBuilderView: workspace === 'builder', +const mapState = (state: RootState) => ({ + isBuilderView: state.workspace === 'builder', + isAIInputVisible: state.pipelineBuilder.aiPipeline.isInputVisible, }); -export default connect(mapState)(PipelineToolbar); +export default connect(mapState, { + onHideAIInputClick: hideAIInput, +})(PipelineToolbar); diff --git a/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-ai.spec.tsx b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-ai.spec.tsx new file mode 100644 index 00000000000..0840c87a796 --- /dev/null +++ b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-ai.spec.tsx @@ -0,0 +1,146 @@ +import React from 'react'; +import type { ComponentProps } from 'react'; +import { + cleanup, + fireEvent, + render, + screen, + waitFor, +} from '@testing-library/react'; +import { expect } from 'chai'; +import sinon from 'sinon'; +import type { SinonSpy } from 'sinon'; +import { Provider } from 'react-redux'; +import preferencesAccess from 'compass-preferences-model'; + +import { PipelineAI } from './pipeline-ai'; +import configureStore from '../../../test/configure-store'; +import { + AIPipelineActionTypes, + changeAIPromptText, +} from '../../modules/pipeline-builder/pipeline-ai'; + +const noop = () => { + /* no op */ +}; + +const renderPipelineAI = ({ + ...props +}: Partial> = {}) => { + const store = configureStore(); + + render( + + + + ); + return store; +}; + +const feedbackPopoverTextAreaId = 'feedback-popover-textarea'; + +describe('PipelineAI Component', function () { + let store: ReturnType; + afterEach(cleanup); + + describe('when rendered', function () { + let onCloseSpy: SinonSpy; + beforeEach(function () { + onCloseSpy = sinon.spy(); + store = renderPipelineAI({ + onClose: onCloseSpy, + }); + }); + + it('calls to close robot button is clicked', function () { + expect(onCloseSpy.called).to.be.false; + const closeButton = screen.getByTestId('close-ai-button'); + expect(closeButton).to.be.visible; + closeButton.click(); + expect(onCloseSpy.calledOnce).to.be.true; + }); + }); + + describe('when rendered with text', function () { + beforeEach(function () { + store = renderPipelineAI(); + store.dispatch(changeAIPromptText('test')); + }); + + it('calls to clear the text when the X is clicked', function () { + expect(store.getState().pipelineBuilder.aiPipeline.aiPromptText).to.equal( + 'test' + ); + + const clearTextButton = screen.getByTestId('ai-text-clear-prompt'); + expect(clearTextButton).to.be.visible; + clearTextButton.click(); + + expect(store.getState().pipelineBuilder.aiPipeline.aiPromptText).to.equal( + '' + ); + }); + }); + + describe('Pipeline AI Feedback', function () { + let trackUsageStatistics: boolean | undefined; + + beforeEach(async function () { + store = renderPipelineAI(); + trackUsageStatistics = + preferencesAccess.getPreferences().trackUsageStatistics; + // 'compass:track' will only emit if tracking is enabled. + await preferencesAccess.savePreferences({ trackUsageStatistics: true }); + }); + + afterEach(async function () { + await preferencesAccess.savePreferences({ trackUsageStatistics }); + }); + + it('should log a telemetry event with the entered text on submit', async function () { + // Note: This is coupling this test with internals of the logger and telemetry. + // We're doing this as this is a unique case where we're using telemetry + // for feedback. Avoid repeating this elsewhere. + const trackingLogs: any[] = []; + process.on('compass:track', (event) => trackingLogs.push(event)); + + // No feedback popover is shown yet. + expect(screen.queryByTestId(feedbackPopoverTextAreaId)).to.not.exist; + expect(screen.queryByTestId('ai-feedback-thumbs-up')).to.not.exist; + + store.dispatch({ + type: AIPipelineActionTypes.AIPipelineSucceeded, + }); + + expect(screen.queryByTestId(feedbackPopoverTextAreaId)).to.not.exist; + const thumbsUpButton = screen.getByTestId('ai-feedback-thumbs-up'); + expect(thumbsUpButton).to.be.visible; + thumbsUpButton.click(); + + const textArea = screen.getByTestId(feedbackPopoverTextAreaId); + expect(textArea).to.be.visible; + fireEvent.change(textArea, { + target: { value: 'this is the pipeline I was looking for' }, + }); + + screen.getByText('Submit').click(); + + await waitFor( + () => { + // No feedback popover is shown. + expect(screen.queryByTestId(feedbackPopoverTextAreaId)).to.not.exist; + expect(trackingLogs).to.deep.equal([ + { + event: 'PipelineAI Feedback', + properties: { + feedback: 'positive', + text: 'this is the pipeline I was looking for', + }, + }, + ]); + }, + { interval: 10 } + ); + }); + }); +}); diff --git a/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-ai.tsx b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-ai.tsx new file mode 100644 index 00000000000..bee673d99f7 --- /dev/null +++ b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-ai.tsx @@ -0,0 +1,58 @@ +import React from 'react'; +import { GenerativeAIInput } from '@mongodb-js/compass-components'; +import { connect } from 'react-redux'; +import createLoggerAndTelemetry from '@mongodb-js/compass-logging'; + +import { + changeAIPromptText, + cancelAIPipelineGeneration, + runAIPipelineGeneration, +} from '../../modules/pipeline-builder/pipeline-ai'; +import type { RootState } from '../../modules'; + +const { log, mongoLogId, track } = createLoggerAndTelemetry('AI-PIPELINE-UI'); + +const onSubmitFeedback = (feedback: 'positive' | 'negative', text: string) => { + log.info(mongoLogId(1_001_000_232), 'PipelineAI', 'AI pipeline feedback', { + feedback, + text, + }); + + track('PipelineAI Feedback', () => ({ + feedback, + text, + })); +}; + +type PipelineAIProps = Omit< + React.ComponentProps, + 'onSubmitFeedback' +>; + +function PipelineAI(props: PipelineAIProps) { + return ( + + ); +} + +const ConnectedPipelineAI = connect( + (state: RootState) => { + return { + aiPromptText: state.pipelineBuilder.aiPipeline.aiPromptText, + isFetching: state.pipelineBuilder.aiPipeline.status === 'fetching', + didSucceed: state.pipelineBuilder.aiPipeline.status === 'success', + errorMessage: state.pipelineBuilder.aiPipeline.errorMessage, + }; + }, + { + onChangeAIPromptText: changeAIPromptText, + onCancelRequest: cancelAIPipelineGeneration, + onSubmitText: runAIPipelineGeneration, + } +)(PipelineAI); + +export { ConnectedPipelineAI as PipelineAI }; diff --git a/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/index.tsx b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/index.tsx index 93b3c305a4e..0ee2a8f8e0a 100644 --- a/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/index.tsx +++ b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/index.tsx @@ -117,27 +117,29 @@ export const PipelineHeader: React.FunctionComponent = ({ isOpenPipelineVisible, }) => { return ( -
- {isOpenPipelineVisible && ( -
- Pipeline - +
+
+ {isOpenPipelineVisible && ( +
+ Pipeline + +
+ )} +
+ +
+
+
- )} -
- -
-
-
); diff --git a/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-actions.spec.tsx b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-actions.spec.tsx index 992052da62f..53b51c1a1ed 100644 --- a/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-actions.spec.tsx +++ b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-actions.spec.tsx @@ -27,6 +27,7 @@ describe('PipelineActions', function () { render( {}} onCollectionScanInsightActionButtonClick={() => {}} + onShowAIInputClick={() => {}} /> ); }); @@ -93,6 +95,7 @@ describe('PipelineActions', function () { render( {}} isAtlasDeployed={false} onCollectionScanInsightActionButtonClick={() => {}} + onShowAIInputClick={() => {}} /> ); }); @@ -130,6 +134,7 @@ describe('PipelineActions', function () { render( {}} isAtlasDeployed={true} onCollectionScanInsightActionButtonClick={() => {}} + onShowAIInputClick={() => {}} /> ); }); @@ -165,6 +171,7 @@ describe('PipelineActions', function () { isExportButtonDisabled={true} isRunButtonDisabled={true} isOptionsVisible={true} + showAIEntry={false} showRunButton={true} showExportButton={true} showExplainButton={true} @@ -174,6 +181,7 @@ describe('PipelineActions', function () { onExplainAggregation={onExplainAggregationSpy} onUpdateView={() => {}} onCollectionScanInsightActionButtonClick={() => {}} + onShowAIInputClick={() => {}} /> ); }); diff --git a/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-actions.tsx b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-actions.tsx index d7e3c48cf47..9b824cfe404 100644 --- a/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-actions.tsx +++ b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-actions.tsx @@ -1,6 +1,7 @@ import React from 'react'; import { connect } from 'react-redux'; import { + AIExperienceEntry, Button, MoreOptionsToggle, PerformanceSignals, @@ -22,6 +23,7 @@ import { import { isOutputStage } from '../../../utils/stage'; import { openCreateIndexModal } from '../../../modules/insights'; import { usePreference } from 'compass-preferences-model'; +import { showInput as showAIInput } from '../../../modules/pipeline-builder/pipeline-ai'; const containerStyles = css({ display: 'flex', @@ -51,6 +53,9 @@ type PipelineActionsProps = { isAtlasDeployed?: boolean; + showAIEntry: boolean; + onShowAIInputClick: () => void; + showCollectionScanInsight?: boolean; onCollectionScanInsightActionButtonClick: () => void; }; @@ -65,6 +70,8 @@ export const PipelineActions: React.FunctionComponent = ({ isUpdateViewButtonDisabled, isExplainButtonDisabled, showExplainButton, + showAIEntry, + onShowAIInputClick, onUpdateView, onRunAggregation, onToggleOptions, @@ -75,8 +82,13 @@ export const PipelineActions: React.FunctionComponent = ({ onCollectionScanInsightActionButtonClick, }) => { const showInsights = usePreference('showInsights', React); + const enableAIExperience = usePreference('enableAIExperience', React); + return (
+ {enableAIExperience && showAIEntry && ( + + )} {showInsights && showCollectionScanInsight && (
{ isRunButtonDisabled: hasSyntaxErrors, isExplainButtonDisabled: hasSyntaxErrors, isExportButtonDisabled: isMergeOrOutPipeline || hasSyntaxErrors, + showAIEntry: + !state.pipelineBuilder.aiPipeline.isInputVisible && + resultPipeline.length > 0, showUpdateViewButton: Boolean(state.editViewName), isUpdateViewButtonDisabled: !state.isModified || hasSyntaxErrors, isAtlasDeployed: state.isAtlasDeployed, @@ -172,6 +187,7 @@ const mapDispatch = { onExportAggregationResults: exportAggregationResults, onExplainAggregation: explainAggregation, onCollectionScanInsightActionButtonClick: openCreateIndexModal, + onShowAIInputClick: showAIInput, }; export default connect(mapState, mapDispatch)(React.memo(PipelineActions)); diff --git a/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-stages.spec.tsx b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-stages.spec.tsx index a15e756b5f5..00e92306164 100644 --- a/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-stages.spec.tsx +++ b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-stages.spec.tsx @@ -14,10 +14,12 @@ const renderPipelineStages = ( render( {}} onEditPipelineClick={() => {}} + onShowAIInputClick={() => {}} {...props} /> ); diff --git a/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-stages.tsx b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-stages.tsx index ef5f260970e..b15ac69ab55 100644 --- a/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-stages.tsx +++ b/packages/compass-aggregations/src/components/pipeline-toolbar/pipeline-header/pipeline-stages.tsx @@ -1,22 +1,24 @@ import React from 'react'; import { connect } from 'react-redux'; import { + AIExperienceEntry, Pipeline, Stage, Description, Link, css, - cx, spacing, Button, Icon, } from '@mongodb-js/compass-components'; +import { usePreference } from 'compass-preferences-model'; import type { RootState } from '../../../modules'; import { editPipeline } from '../../../modules/workspace'; import type { Workspace } from '../../../modules/workspace'; import { getPipelineStageOperatorsFromBuilderState } from '../../../modules/pipeline-builder/builder-helpers'; import { addStage } from '../../../modules/pipeline-builder/stage-editor'; +import { showInput as showAIInput } from '../../../modules/pipeline-builder/pipeline-ai'; const containerStyles = css({ display: 'flex', @@ -42,8 +44,10 @@ type PipelineStagesProps = { isResultsMode: boolean; stages: string[]; showAddNewStage: boolean; + showAIEntry: boolean; onAddStageClick: () => void; onEditPipelineClick: (workspace: Workspace) => void; + onShowAIInputClick: () => void; }; const nbsp = '\u00a0'; @@ -52,26 +56,42 @@ export const PipelineStages: React.FunctionComponent = ({ isResultsMode, stages, showAddNewStage, + showAIEntry, onAddStageClick, onEditPipelineClick, + onShowAIInputClick, }) => { + const enableAIExperience = usePreference('enableAIExperience', React); + return (
{stages.length === 0 ? ( - + Your pipeline is currently empty. {showAddNewStage && ( <> - {nbsp}To get started add the{nbsp} - onAddStageClick()} - hideExternalIcon - data-testid="pipeline-toolbar-add-stage-button" - > - first stage. - + {enableAIExperience && showAIEntry ? ( + <>{nbsp}Need help getting started? + ) : ( + <> + {nbsp}To get started add the{nbsp} + onAddStageClick()} + hideExternalIcon + data-testid="pipeline-toolbar-add-stage-button" + > + first stage. + + + )} + + )} + {enableAIExperience && showAIEntry && ( + <> + {nbsp} + )} @@ -102,8 +122,13 @@ const mapState = (state: RootState) => { const isResultsMode = state.workspace === 'results'; const isStageMode = state.pipelineBuilder.pipelineMode === 'builder-ui'; return { + showAIEntry: !state.pipelineBuilder.aiPipeline.isInputVisible, stages: stages.filter(Boolean) as string[], - showAddNewStage: !isResultsMode && isStageMode && stages.length === 0, + showAddNewStage: + !state.pipelineBuilder.aiPipeline.isInputVisible && + !isResultsMode && + isStageMode && + stages.length === 0, isResultsMode, }; }; @@ -111,5 +136,6 @@ const mapState = (state: RootState) => { const mapDispatch = { onAddStageClick: addStage, onEditPipelineClick: editPipeline, + onShowAIInputClick: showAIInput, }; export default connect(mapState, mapDispatch)(React.memo(PipelineStages)); diff --git a/packages/compass-aggregations/src/modules/comments.js b/packages/compass-aggregations/src/modules/comments.js index bd2ba51162e..ff2ea6ae1cb 100644 --- a/packages/compass-aggregations/src/modules/comments.js +++ b/packages/compass-aggregations/src/modules/comments.js @@ -1,4 +1,5 @@ import { ActionTypes as ConfirmNewPipelineActions } from './is-new-pipeline-confirm'; +import { AIPipelineActionTypes } from './pipeline-builder/pipeline-ai'; import { RESTORE_PIPELINE } from './saved-pipeline'; import { APPLY_SETTINGS } from './settings'; @@ -19,7 +20,10 @@ export default function reducer(state = INITIAL_STATE, action) { if (action.type === APPLY_SETTINGS) { return action.settings.isCommentMode ?? state; } - if (action.type === ConfirmNewPipelineActions.NewPipelineConfirmed) { + if ( + action.type === ConfirmNewPipelineActions.NewPipelineConfirmed || + action.type === AIPipelineActionTypes.LoadGeneratedPipeline + ) { return INITIAL_STATE; } if (action.type === RESTORE_PIPELINE) { diff --git a/packages/compass-aggregations/src/modules/id.js b/packages/compass-aggregations/src/modules/id.js index a42465f7c92..0b5cb8ab886 100644 --- a/packages/compass-aggregations/src/modules/id.js +++ b/packages/compass-aggregations/src/modules/id.js @@ -2,6 +2,7 @@ import { ObjectId } from 'bson'; import { CLONE_PIPELINE } from './clone-pipeline'; import { ActionTypes as ConfirmNewPipelineActions } from '././is-new-pipeline-confirm'; import { RESTORE_PIPELINE } from './saved-pipeline'; +import { AIPipelineActionTypes } from './pipeline-builder/pipeline-ai'; /** * Id create action. @@ -25,6 +26,7 @@ export default function reducer(state = INITIAL_STATE, action) { if ( action.type === CREATE_ID || action.type === CLONE_PIPELINE || + action.type === AIPipelineActionTypes.LoadGeneratedPipeline || action.type === ConfirmNewPipelineActions.NewPipelineConfirmed ) { return new ObjectId().toHexString(); diff --git a/packages/compass-aggregations/src/modules/index.ts b/packages/compass-aggregations/src/modules/index.ts index 455a7f04d3e..c1f967a890b 100644 --- a/packages/compass-aggregations/src/modules/index.ts +++ b/packages/compass-aggregations/src/modules/index.ts @@ -1,5 +1,6 @@ import type { Action, AnyAction } from 'redux'; import { combineReducers } from 'redux'; +import type { AtlasService } from '@mongodb-js/atlas-service/renderer'; import dataService from './data-service'; import fields from './fields'; import editViewName from './edit-view-name'; @@ -89,6 +90,7 @@ export type RootState = ReturnType; export type PipelineBuilderExtraArgs = { pipelineBuilder: PipelineBuilder; pipelineStorage: PipelineStorage; + atlasService: AtlasService; }; export type PipelineBuilderThunkDispatch = diff --git a/packages/compass-aggregations/src/modules/insights.ts b/packages/compass-aggregations/src/modules/insights.ts index 4b6a68be0d6..c76da2bd247 100644 --- a/packages/compass-aggregations/src/modules/insights.ts +++ b/packages/compass-aggregations/src/modules/insights.ts @@ -6,6 +6,7 @@ import { cancellableWait } from '@mongodb-js/compass-utils'; import type { PipelineBuilderThunkAction } from '.'; import { ActionTypes as ConfirmNewPipelineActions } from './is-new-pipeline-confirm'; import { RESTORE_PIPELINE } from './saved-pipeline'; +import { AIPipelineActionTypes } from './pipeline-builder/pipeline-ai'; const FETCH_EXPLAIN_PLAN_SUCCESS = 'compass-aggregations/FETCH_EXPLAIN_PLAN_SUCCESS'; @@ -22,10 +23,11 @@ const reducer: Reducer<{ isCollectionScan: boolean }> = ( isCollectionScan: action.explainPlan.isCollectionScan, }; } - if (action.type === ConfirmNewPipelineActions.NewPipelineConfirmed) { - return { ...INITIAL_STATE }; - } - if (action.type === RESTORE_PIPELINE) { + if ( + action.type === ConfirmNewPipelineActions.NewPipelineConfirmed || + action.type === AIPipelineActionTypes.LoadGeneratedPipeline || + action.type === RESTORE_PIPELINE + ) { return { ...INITIAL_STATE }; } return state; diff --git a/packages/compass-aggregations/src/modules/name.spec.js b/packages/compass-aggregations/src/modules/name.spec.ts similarity index 100% rename from packages/compass-aggregations/src/modules/name.spec.js rename to packages/compass-aggregations/src/modules/name.spec.ts diff --git a/packages/compass-aggregations/src/modules/name.js b/packages/compass-aggregations/src/modules/name.ts similarity index 51% rename from packages/compass-aggregations/src/modules/name.js rename to packages/compass-aggregations/src/modules/name.ts index 6b43082a299..08bbebbb2f2 100644 --- a/packages/compass-aggregations/src/modules/name.js +++ b/packages/compass-aggregations/src/modules/name.ts @@ -1,29 +1,31 @@ +import type { AnyAction, Reducer } from 'redux'; + import { ActionTypes as ConfirmNewPipelineActions } from './is-new-pipeline-confirm'; +import { AIPipelineActionTypes } from './pipeline-builder/pipeline-ai'; import { RESTORE_PIPELINE } from './saved-pipeline'; import { SAVING_PIPELINE_APPLY } from './saving-pipeline'; -/** - * The initial state. - */ -export const INITIAL_STATE = ''; +type State = string; + +export const INITIAL_STATE: State = ''; /** * Reducer function for handle state changes to name. - * - * @param {String} state - The name state. - * @param {Object} action - The action. - * - * @returns {String} The new state. */ -export default function reducer(state = INITIAL_STATE, action) { +const reducer: Reducer = (state = INITIAL_STATE, action) => { if (action.type === SAVING_PIPELINE_APPLY) { return action.name; } - if (action.type === ConfirmNewPipelineActions.NewPipelineConfirmed) { + if ( + action.type === ConfirmNewPipelineActions.NewPipelineConfirmed || + action.type === AIPipelineActionTypes.LoadGeneratedPipeline + ) { return INITIAL_STATE; } if (action.type === RESTORE_PIPELINE) { return action.storedOptions.name; } return state; -} +}; + +export default reducer; diff --git a/packages/compass-aggregations/src/modules/namespace.js b/packages/compass-aggregations/src/modules/namespace.ts similarity index 100% rename from packages/compass-aggregations/src/modules/namespace.js rename to packages/compass-aggregations/src/modules/namespace.ts diff --git a/packages/compass-aggregations/src/modules/pipeline-builder/builder-helpers.spec.ts b/packages/compass-aggregations/src/modules/pipeline-builder/builder-helpers.spec.ts index 56e13fc179b..2df92e0e00e 100644 --- a/packages/compass-aggregations/src/modules/pipeline-builder/builder-helpers.spec.ts +++ b/packages/compass-aggregations/src/modules/pipeline-builder/builder-helpers.spec.ts @@ -2,6 +2,7 @@ import { expect } from 'chai'; import { applyMiddleware, createStore as createReduxStore } from 'redux'; import type { DataService } from 'mongodb-data-service'; import thunk from 'redux-thunk'; +import { AtlasService } from '@mongodb-js/atlas-service/renderer'; import reducer from '..'; import { getPipelineStageOperatorsFromBuilderState } from './builder-helpers'; @@ -32,6 +33,7 @@ function createStore(pipelineSource = `[{$match: {_id: 1}}, {$limit: 10}]`) { }, applyMiddleware( thunk.withExtraArgument({ + atlasService: new AtlasService(), pipelineBuilder, pipelineStorage: new PipelineStorage(), }) diff --git a/packages/compass-aggregations/src/modules/pipeline-builder/index.ts b/packages/compass-aggregations/src/modules/pipeline-builder/index.ts index ac4323386d6..62e93223845 100644 --- a/packages/compass-aggregations/src/modules/pipeline-builder/index.ts +++ b/packages/compass-aggregations/src/modules/pipeline-builder/index.ts @@ -1,10 +1,12 @@ import { combineReducers } from 'redux'; +import aiPipeline from './pipeline-ai'; import stageEditor from './stage-editor'; import pipeline from './text-editor-pipeline'; import pipelineMode from './pipeline-mode'; import outputStage from './text-editor-output-stage'; const reducer = combineReducers({ + aiPipeline, pipelineMode, stageEditor, textEditor: combineReducers({ diff --git a/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-ai.spec.ts b/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-ai.spec.ts new file mode 100644 index 00000000000..75e8a2dd35d --- /dev/null +++ b/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-ai.spec.ts @@ -0,0 +1,156 @@ +import { expect } from 'chai'; +import Sinon from 'sinon'; + +import configureStore from '../../../test/configure-store'; +import { + AIPipelineActionTypes, + cancelAIPipelineGeneration, + runAIPipelineGeneration, +} from './pipeline-ai'; +import type { ConfigureStoreOptions } from '../../stores/store'; +import { toggleAutoPreview } from '../auto-preview'; + +describe('AIPipelineReducer', function () { + const sandbox = Sinon.createSandbox(); + + afterEach(function () { + sandbox.reset(); + }); + + function createStore(opts: Partial = {}) { + return configureStore({ + ...opts, + }); + } + + describe('runAIPipelineGeneration', function () { + describe('with a successful server response', function () { + it('should succeed', async function () { + const mockAtlasService = { + getAggregationFromUserInput: sandbox.stub().resolves({ + content: { aggregation: { pipeline: [{ $match: { _id: 1 } }] } }, + }), + }; + + const mockDataService = { + sample: sandbox.stub().resolves([{ _id: 42 }]), + getConnectionString: sandbox.stub().returns({ hosts: [] }), + }; + + const store = createStore({ + namespace: 'database.collection', + dataProvider: { + dataProvider: mockDataService as any, + }, + atlasService: mockAtlasService as any, + }); + // Set autoPreview false so that it doesn't start the + // follow up async preview doc requests. + store.dispatch(toggleAutoPreview(false)); + expect(store.getState().pipelineBuilder.aiPipeline.status).to.equal( + 'ready' + ); + + await store.dispatch(runAIPipelineGeneration('testing prompt')); + + expect(mockAtlasService.getAggregationFromUserInput).to.have.been + .calledOnce; + expect( + mockAtlasService.getAggregationFromUserInput.getCall(0) + ).to.have.nested.property('args[0].userInput', 'testing prompt'); + expect( + mockAtlasService.getAggregationFromUserInput.getCall(0) + ).to.have.nested.property('args[0].collectionName', 'collection'); + expect( + mockAtlasService.getAggregationFromUserInput.getCall(0) + ).to.have.nested.property('args[0].databaseName', 'database'); + // Sample documents are currently disabled. + expect( + mockAtlasService.getAggregationFromUserInput.getCall(0) + ).to.not.have.nested.property('args[0].sampleDocuments'); + + expect( + store.getState().pipelineBuilder.aiPipeline.aiPipelineFetchId + ).to.equal(-1); + expect( + store.getState().pipelineBuilder.aiPipeline.errorMessage + ).to.equal(undefined); + expect(store.getState().pipelineBuilder.aiPipeline.status).to.equal( + 'success' + ); + }); + }); + + describe('when there is an error', function () { + it('sets the error on the store', async function () { + const mockAtlasService = { + getAggregationFromUserInput: sandbox + .stub() + .rejects(new Error('500 Internal Server Error')), + }; + + const store = createStore({ atlasService: mockAtlasService as any }); + expect( + store.getState().pipelineBuilder.aiPipeline.errorMessage + ).to.equal(undefined); + await store.dispatch(runAIPipelineGeneration('testing prompt') as any); + expect( + store.getState().pipelineBuilder.aiPipeline.aiPipelineFetchId + ).to.equal(-1); + expect( + store.getState().pipelineBuilder.aiPipeline.errorMessage + ).to.equal('500 Internal Server Error'); + expect(store.getState().pipelineBuilder.aiPipeline.status).to.equal( + 'ready' + ); + }); + + it('resets the store if errs was caused by user being unauthorized', async function () { + const authError = new Error('Unauthorized'); + (authError as any).statusCode = 401; + const mockAtlasService = { + getAggregationFromUserInput: sandbox.stub().rejects(authError), + }; + const store = createStore({ atlasService: mockAtlasService as any }); + await store.dispatch(runAIPipelineGeneration('testing prompt') as any); + expect(store.getState().pipelineBuilder.aiPipeline).to.deep.eq({ + status: 'ready', + aiPromptText: '', + errorMessage: undefined, + isInputVisible: false, + aiPipelineFetchId: -1, + }); + }); + }); + }); + + describe('cancelAIPipelineGeneration', function () { + it('should unset the fetching id and set the status on the store', function () { + const store = createStore(); + expect( + store.getState().pipelineBuilder.aiPipeline.aiPipelineFetchId + ).to.equal(-1); + + store.dispatch({ + type: AIPipelineActionTypes.AIPipelineStarted, + fetchId: 1, + }); + + expect(store.getState().pipelineBuilder.aiPipeline.status).to.equal( + 'fetching' + ); + expect( + store.getState().pipelineBuilder.aiPipeline.aiPipelineFetchId + ).to.equal(1); + + store.dispatch(cancelAIPipelineGeneration()); + + expect( + store.getState().pipelineBuilder.aiPipeline.aiPipelineFetchId + ).to.equal(-1); + expect(store.getState().pipelineBuilder.aiPipeline.status).to.equal( + 'ready' + ); + }); + }); +}); diff --git a/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-ai.ts b/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-ai.ts new file mode 100644 index 00000000000..91cb43ac303 --- /dev/null +++ b/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-ai.ts @@ -0,0 +1,408 @@ +import type { Reducer } from 'redux'; +import createLoggerAndTelemetry from '@mongodb-js/compass-logging'; +import { getSimplifiedSchema } from 'mongodb-schema'; +import toNS from 'mongodb-ns'; +import preferences from 'compass-preferences-model'; +import { openToast } from '@mongodb-js/compass-components'; +import type { Document } from 'mongodb'; + +import type { PipelineBuilderThunkAction } from '../'; +import { isAction } from '../../utils/is-action'; +import type { PipelineParserError } from './pipeline-parser/utils'; +import type Stage from './stage'; +import { updatePipelinePreview } from './builder-helpers'; + +const { log, mongoLogId } = createLoggerAndTelemetry('AI-PIPELINE-UI'); + +const emptyPipelineError = + 'No pipeline was returned. Please try again with a different prompt.'; + +type AIPipelineStatus = 'ready' | 'fetching' | 'success'; + +export type AIPipelineState = { + errorMessage: string | undefined; + isInputVisible: boolean; + aiPromptText: string; + status: AIPipelineStatus; + aiPipelineFetchId: number; // Maps to the AbortController of the current fetch (or -1). +}; + +export const initialState: AIPipelineState = { + status: 'ready', + aiPromptText: '', + errorMessage: undefined, + isInputVisible: false, + aiPipelineFetchId: -1, +}; + +export const enum AIPipelineActionTypes { + AIPipelineStarted = 'compass-aggregations/pipeline-builder/pipeline-ai/AIPipelineStarted', + AIPipelineCancelled = 'compass-aggregations/pipeline-builder/pipeline-ai/AIPipelineCancelled', + AIPipelineFailed = 'compass-aggregations/pipeline-builder/pipeline-ai/AIPipelineFailed', + AIPipelineSucceeded = 'compass-aggregations/pipeline-builder/pipeline-ai/AIPipelineSucceeded', + CancelAIPipelineGeneration = 'compass-aggregations/pipeline-builder/pipeline-ai/CancelAIPipelineGeneration', + ShowInput = 'compass-aggregations/pipeline-builder/pipeline-ai/ShowInput', + HideInput = 'compass-aggregations/pipeline-builder/pipeline-ai/HideInput', + ChangeAIPromptText = 'compass-aggregations/pipeline-builder/pipeline-ai/ChangeAIPromptText', + LoadGeneratedPipeline = 'compass-aggregations/LoadGeneratedPipeline', +} + +const NUM_DOCUMENTS_TO_SAMPLE = 4; + +const AIPipelineAbortControllerMap = new Map(); + +let aiPipelineFetchId = 0; + +function getAbortSignal() { + const id = ++aiPipelineFetchId; + const controller = new AbortController(); + AIPipelineAbortControllerMap.set(id, controller); + return { id, signal: controller.signal }; +} + +function abort(id: number) { + const controller = AIPipelineAbortControllerMap.get(id); + controller?.abort(); + return AIPipelineAbortControllerMap.delete(id); +} + +function cleanupAbortSignal(id: number) { + return AIPipelineAbortControllerMap.delete(id); +} + +type ShowInputAction = { + type: AIPipelineActionTypes.ShowInput; +}; + +type HideInputAction = { + type: AIPipelineActionTypes.HideInput; +}; + +type ChangeAIPromptTextAction = { + type: AIPipelineActionTypes.ChangeAIPromptText; + text: string; +}; + +export const changeAIPromptText = (text: string): ChangeAIPromptTextAction => ({ + type: AIPipelineActionTypes.ChangeAIPromptText, + text, +}); + +export type LoadGeneratedPipelineAction = { + type: AIPipelineActionTypes.LoadGeneratedPipeline; + pipelineText: string; + pipeline: Document[] | null; + syntaxErrors: PipelineParserError[]; + stages: Stage[]; +}; + +type AIPipelineStartedAction = { + type: AIPipelineActionTypes.AIPipelineStarted; + fetchId: number; +}; + +type AIPipelineFailedAction = { + type: AIPipelineActionTypes.AIPipelineFailed; + errorMessage: string; + networkErrorCode?: number; +}; + +export type AIPipelineSucceededAction = { + type: AIPipelineActionTypes.AIPipelineSucceeded; +}; + +function logFailed(errorMessage: string) { + log.info( + mongoLogId(1_001_000_230), + 'AIPipeline', + 'AI pipeline request failed', + { + errorMessage, + } + ); +} + +export const runAIPipelineGeneration = ( + userInput: string +): PipelineBuilderThunkAction< + Promise, + | AIPipelineStartedAction + | AIPipelineFailedAction + | AIPipelineSucceededAction + | LoadGeneratedPipelineAction +> => { + return async (dispatch, getState, { atlasService, pipelineBuilder }) => { + const { + pipelineBuilder: { + aiPipeline: { aiPipelineFetchId: existingFetchId }, + }, + namespace, + dataService: { dataService }, + } = getState(); + + if (aiPipelineFetchId !== -1) { + // Cancel the active request as this one will override. + abort(existingFetchId); + } + + const abortController = new AbortController(); + const { id: fetchId, signal } = getAbortSignal(); + + dispatch({ + type: AIPipelineActionTypes.AIPipelineStarted, + fetchId, + }); + + let jsonResponse; + try { + const sampleDocuments = + (await dataService?.sample?.( + namespace, + { + query: {}, + size: NUM_DOCUMENTS_TO_SAMPLE, + }, + { + maxTimeMS: preferences.getPreferences().maxTimeMS, + promoteValues: false, + }, + { + abortSignal: signal, + } + )) || []; + const schema = await getSimplifiedSchema(sampleDocuments); + + const { collection: collectionName, database: databaseName } = + toNS(namespace); + jsonResponse = await atlasService.getAggregationFromUserInput({ + signal: abortController.signal, + userInput, + collectionName, + databaseName, + schema, + // sampleDocuments, // For now we are not passing sample documents to the ai. + }); + } catch (err: any) { + if (signal.aborted) { + // If we already aborted so we ignore the error. + return; + } + logFailed(err?.message); + // We're going to reset input state with this error, show the error in the + // toast instead + if (err.statusCode === 401) { + openToast('ai-unauthorized', { + variant: 'important', + title: 'Network Error', + description: 'Unauthorized', + timeout: 5000, + }); + } + dispatch({ + type: AIPipelineActionTypes.AIPipelineFailed, + errorMessage: err?.message, + networkErrorCode: err.statusCode, + }); + return; + } finally { + // Remove the AbortController from the Map as we either finished + // waiting for the fetch or cancelled at this point. + cleanupAbortSignal(fetchId); + } + + if (signal.aborted) { + log.info( + mongoLogId(1_001_000_231), + 'AIPipeline', + 'Cancelled ai pipeline request' + ); + return; + } + + let pipelineText; + try { + // Error when the response is empty or there is nothing to map. + if (!jsonResponse?.content?.aggregation?.pipeline) { + throw new Error(emptyPipelineError); + } + + pipelineText = String(jsonResponse?.content?.aggregation?.pipeline); + + if (!pipelineText || !pipelineText?.length) { + throw new Error(emptyPipelineError); + } + } catch (err: any) { + logFailed(err?.message); + dispatch({ + type: AIPipelineActionTypes.AIPipelineFailed, + errorMessage: err?.message, + }); + return; + } + + log.info( + mongoLogId(1_001_000_228), + 'AIPipeline', + 'AI pipeline request succeeded', + { + pipelineText, + } + ); + + dispatch({ + type: AIPipelineActionTypes.AIPipelineSucceeded, + }); + + pipelineBuilder.reset(pipelineText); + + dispatch({ + type: AIPipelineActionTypes.LoadGeneratedPipeline, + stages: pipelineBuilder.stages, + pipelineText: pipelineBuilder.source, + pipeline: pipelineBuilder.pipeline, + syntaxErrors: pipelineBuilder.syntaxError, + }); + + dispatch(updatePipelinePreview()); + }; +}; + +type CancelAIPipelineGenerationAction = { + type: AIPipelineActionTypes.CancelAIPipelineGeneration; +}; + +export const cancelAIPipelineGeneration = (): PipelineBuilderThunkAction< + void, + CancelAIPipelineGenerationAction +> => { + return (dispatch, getState) => { + // Abort any ongoing op. + abort(getState().pipelineBuilder.aiPipeline.aiPipelineFetchId); + + dispatch({ + type: AIPipelineActionTypes.CancelAIPipelineGeneration, + }); + }; +}; + +export const showInput = (): PipelineBuilderThunkAction> => { + return async (dispatch, _getState, { atlasService }) => { + try { + if (process.env.COMPASS_E2E_SKIP_ATLAS_SIGNIN !== 'true') { + await atlasService.signIn({ promptType: 'ai-promo-modal' }); + } + dispatch({ type: AIPipelineActionTypes.ShowInput }); + } catch { + // if sign in failed / user canceled we just don't show the input + } + }; +}; + +export const hideInput = (): PipelineBuilderThunkAction< + void, + HideInputAction +> => { + return (dispatch) => { + // Cancel any ongoing op when we hide. + dispatch(cancelAIPipelineGeneration()); + dispatch({ type: AIPipelineActionTypes.HideInput }); + }; +}; + +const aiPipelineReducer: Reducer = ( + state = initialState, + action +) => { + if ( + isAction( + action, + AIPipelineActionTypes.AIPipelineStarted + ) + ) { + return { + ...state, + status: 'fetching', + errorMessage: undefined, + aiPipelineFetchId: action.fetchId, + }; + } + + if ( + isAction( + action, + AIPipelineActionTypes.AIPipelineFailed + ) + ) { + // If fetching query failed due to authentication error, reset the state to + // hide the input and show the "Ask AI" button again: this should start the + // sign in flow for the user when clicked + if (action.networkErrorCode === 401) { + return { ...initialState }; + } + + return { + ...state, + status: 'ready', + aiPipelineFetchId: -1, + errorMessage: action.errorMessage, + }; + } + + if ( + isAction( + action, + AIPipelineActionTypes.AIPipelineSucceeded + ) + ) { + return { + ...state, + status: 'success', + aiPipelineFetchId: -1, + }; + } + + if ( + isAction( + action, + AIPipelineActionTypes.CancelAIPipelineGeneration + ) + ) { + return { + ...state, + status: 'ready', + aiPipelineFetchId: -1, + }; + } + + if (isAction(action, AIPipelineActionTypes.ShowInput)) { + return { + ...state, + isInputVisible: true, + }; + } + + if (isAction(action, AIPipelineActionTypes.HideInput)) { + return { + ...state, + isInputVisible: false, + }; + } + + if ( + isAction( + action, + AIPipelineActionTypes.ChangeAIPromptText + ) + ) { + return { + ...state, + // Reset the status after a successful run when the user change's the text. + status: state.status === 'success' ? 'ready' : state.status, + aiPromptText: action.text, + }; + } + + return state; +}; + +export default aiPipelineReducer; diff --git a/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-mode.ts b/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-mode.ts index eff684c80f0..00486b047cb 100644 --- a/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-mode.ts +++ b/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-mode.ts @@ -11,6 +11,8 @@ import type Stage from './stage'; import type { PipelineParserError } from './pipeline-parser/utils'; import { createLoggerAndTelemetry } from '@mongodb-js/compass-logging'; import { RESTORE_PIPELINE } from '../saved-pipeline'; +import { AIPipelineActionTypes } from './pipeline-ai'; +import type { LoadGeneratedPipelineAction } from './pipeline-ai'; const { track } = createLoggerAndTelemetry('COMPASS-AGGREGATIONS-UI'); @@ -39,7 +41,13 @@ const reducer: Reducer = (state = INITIAL_STATE, action) => { ) { return action.mode; } - if (action.type === RESTORE_PIPELINE) { + if ( + action.type === RESTORE_PIPELINE || + isAction( + action, + AIPipelineActionTypes.LoadGeneratedPipeline + ) + ) { // Force as-text editor mode if loaded pipeline contains syntax errors if (action.syntaxErrors.length > 0) { return 'as-text'; diff --git a/packages/compass-aggregations/src/modules/pipeline-builder/stage-editor.spec.ts b/packages/compass-aggregations/src/modules/pipeline-builder/stage-editor.spec.ts index c706451d769..9bfa785e94b 100644 --- a/packages/compass-aggregations/src/modules/pipeline-builder/stage-editor.spec.ts +++ b/packages/compass-aggregations/src/modules/pipeline-builder/stage-editor.spec.ts @@ -2,6 +2,7 @@ import { expect } from 'chai'; import type { DataService } from 'mongodb-data-service'; import { applyMiddleware, createStore as createReduxStore } from 'redux'; import thunk from 'redux-thunk'; +import { AtlasService } from '@mongodb-js/atlas-service/renderer'; import { PipelineBuilder } from './pipeline-builder'; import { changeStageOperator, @@ -109,6 +110,7 @@ function createStore({ }, applyMiddleware( thunk.withExtraArgument({ + atlasService: new AtlasService(), pipelineBuilder, pipelineStorage: new PipelineStorage(), }) diff --git a/packages/compass-aggregations/src/modules/pipeline-builder/stage-editor.ts b/packages/compass-aggregations/src/modules/pipeline-builder/stage-editor.ts index 14d3ceb6eaa..4f0f8324fe0 100644 --- a/packages/compass-aggregations/src/modules/pipeline-builder/stage-editor.ts +++ b/packages/compass-aggregations/src/modules/pipeline-builder/stage-editor.ts @@ -25,6 +25,8 @@ import { isOutputStage } from '../../utils/stage'; import { mapPipelineModeToEditorViewType } from './builder-helpers'; import { getId } from './stage-ids'; import { fetchExplainForPipeline } from '../insights'; +import { AIPipelineActionTypes } from './pipeline-ai'; +import type { LoadGeneratedPipelineAction } from './pipeline-ai'; const { track } = createLoggerAndTelemetry('COMPASS-AGGREGATIONS-UI'); export const enum StageEditorActionTypes { @@ -866,6 +868,10 @@ const reducer: Reducer = ( isAction( action, PipelineModeActionTypes.PipelineModeToggled + ) || + isAction( + action, + AIPipelineActionTypes.LoadGeneratedPipeline ) ) { const stages = action.stages.map((stage: Stage, idx: number) => { diff --git a/packages/compass-aggregations/src/modules/pipeline-builder/text-editor-output-stage.ts b/packages/compass-aggregations/src/modules/pipeline-builder/text-editor-output-stage.ts index b21728a426a..8352c8e860e 100644 --- a/packages/compass-aggregations/src/modules/pipeline-builder/text-editor-output-stage.ts +++ b/packages/compass-aggregations/src/modules/pipeline-builder/text-editor-output-stage.ts @@ -12,6 +12,8 @@ import { aggregatePipeline } from '../../utils/cancellable-aggregation'; import { gotoOutResults } from '../out-results-fn'; import type { PipelineModeToggledAction } from './pipeline-mode'; import { ActionTypes as PipelineModeActionTypes } from './pipeline-mode'; +import { AIPipelineActionTypes } from './pipeline-ai'; +import type { LoadGeneratedPipelineAction } from './pipeline-ai'; const enum OutputStageActionTypes { FetchStarted = 'compass-aggregations/pipeline-builder/text-editor-output-stage/FetchStarted', @@ -54,6 +56,10 @@ const reducer: Reducer = (state = INITIAL_STATE, action) => { action, PipelineModeActionTypes.PipelineModeToggled ) || + isAction( + action, + AIPipelineActionTypes.LoadGeneratedPipeline + ) || action.type === RESTORE_PIPELINE || action.type === ConfirmNewPipelineActions.NewPipelineConfirmed ) { diff --git a/packages/compass-aggregations/src/modules/pipeline-builder/text-editor-pipeline.spec.ts b/packages/compass-aggregations/src/modules/pipeline-builder/text-editor-pipeline.spec.ts index 9bff22ad9f2..5f6b3633355 100644 --- a/packages/compass-aggregations/src/modules/pipeline-builder/text-editor-pipeline.spec.ts +++ b/packages/compass-aggregations/src/modules/pipeline-builder/text-editor-pipeline.spec.ts @@ -1,6 +1,7 @@ import { expect } from 'chai'; import { applyMiddleware, createStore as createReduxStore } from 'redux'; import thunk from 'redux-thunk'; +import { AtlasService } from '@mongodb-js/atlas-service/renderer'; import { PipelineBuilder } from './pipeline-builder'; import { changeEditorValue, @@ -43,6 +44,7 @@ function createStore( }, applyMiddleware( thunk.withExtraArgument({ + atlasService: new AtlasService(), pipelineBuilder, pipelineStorage: new PipelineStorage(), }) diff --git a/packages/compass-aggregations/src/modules/pipeline-builder/text-editor-pipeline.ts b/packages/compass-aggregations/src/modules/pipeline-builder/text-editor-pipeline.ts index 793728e30c8..a3bc1307bf4 100644 --- a/packages/compass-aggregations/src/modules/pipeline-builder/text-editor-pipeline.ts +++ b/packages/compass-aggregations/src/modules/pipeline-builder/text-editor-pipeline.ts @@ -15,6 +15,8 @@ import { ActionTypes as ConfirmNewPipelineActions } from '../is-new-pipeline-con import { RESTORE_PIPELINE } from '../saved-pipeline'; import { capMaxTimeMSAtPreferenceLimit } from 'compass-preferences-model'; import { fetchExplainForPipeline } from '../insights'; +import { AIPipelineActionTypes } from './pipeline-ai'; +import type { LoadGeneratedPipelineAction } from './pipeline-ai'; export const enum EditorActionTypes { EditorPreviewFetch = 'compass-aggregations/pipeline-builder/text-editor-pipeline/TextEditorPreviewFetch', @@ -78,6 +80,10 @@ const reducer: Reducer = (state = INITIAL_STATE, action) => { action, PipelineModeActionTypes.PipelineModeToggled ) || + isAction( + action, + AIPipelineActionTypes.LoadGeneratedPipeline + ) || action.type === RESTORE_PIPELINE || action.type === ConfirmNewPipelineActions.NewPipelineConfirmed ) { diff --git a/packages/compass-aggregations/src/stores/store.ts b/packages/compass-aggregations/src/stores/store.ts index ce335f4f633..a366f861bff 100644 --- a/packages/compass-aggregations/src/stores/store.ts +++ b/packages/compass-aggregations/src/stores/store.ts @@ -3,6 +3,7 @@ import { createStore, applyMiddleware } from 'redux'; import thunk from 'redux-thunk'; import toNS from 'mongodb-ns'; import { toJSString } from 'mongodb-query-parser'; +import { AtlasService } from '@mongodb-js/atlas-service/renderer'; import reducer from '../modules'; import { fieldsChanged } from '../modules/fields'; import { refreshInputDocuments } from '../modules/input-documents'; @@ -122,6 +123,10 @@ export type ConfigureStoreOptions = { * the stage wizard to populate the dropdown for $lookup use-case. */ collections: CollectionInfo[]; + /** + * Service for making ai requests. + */ + atlasService: AtlasService; }>; const configureStore = (options: ConfigureStoreOptions) => { @@ -218,6 +223,7 @@ const configureStore = (options: ConfigureStoreOptions) => { thunk.withExtraArgument({ pipelineBuilder, pipelineStorage, + atlasService: options.atlasService ?? new AtlasService(), }) ) ); diff --git a/packages/compass-components/src/components/generative-ai/ai-experience-entry.ts b/packages/compass-components/src/components/generative-ai/ai-experience-entry.tsx similarity index 80% rename from packages/compass-components/src/components/generative-ai/ai-experience-entry.ts rename to packages/compass-components/src/components/generative-ai/ai-experience-entry.tsx index fefeecfef42..185b2e12d8c 100644 --- a/packages/compass-components/src/components/generative-ai/ai-experience-entry.ts +++ b/packages/compass-components/src/components/generative-ai/ai-experience-entry.tsx @@ -1,16 +1,19 @@ +import React from 'react'; import { palette } from '@leafygreen-ui/palette'; import { css, cx } from '@leafygreen-ui/emotion'; import { spacing } from '@leafygreen-ui/tokens'; import { + RobotSVG, getRobotSVGString, robotSVGDarkModeStyles, robotSVGLightModeStyles, robotSVGStyles, } from './robot-svg'; import { focusRing } from '../../hooks/use-focus-ring'; +import { useDarkMode } from '../../hooks/use-theme'; -const aiQueryEntryStyles = css( +const aiEntryStyles = css( { // Reset button styles. border: 'none', @@ -45,7 +48,7 @@ const aiQueryEntryStyles = css( robotSVGStyles ); -const aiQueryEntryDarkModeStyles = css( +const aiEntryDarkModeStyles = css( { color: palette.green.dark1, '&:hover': { @@ -58,7 +61,7 @@ const aiQueryEntryDarkModeStyles = css( robotSVGDarkModeStyles ); -const aiQueryEntryLightModeStyles = css( +const aiEntryLightModeStyles = css( { color: palette.green.dark2, '&:hover': { @@ -71,6 +74,23 @@ const aiQueryEntryLightModeStyles = css( robotSVGLightModeStyles ); +function AIExperienceEntry({ onClick }: { onClick: () => void }) { + const darkMode = useDarkMode(); + + return ( + + ); +} + // We build the AI Placeholder with html elements as our // codemirror placeholder extension accepts `HTMLElement`s. function createAIPlaceholderHTMLPlaceholder({ @@ -105,8 +125,8 @@ function createAIPlaceholderHTMLPlaceholder({ }); aiButtonEl.className = cx( - aiQueryEntryStyles, - darkMode ? aiQueryEntryDarkModeStyles : aiQueryEntryLightModeStyles + aiEntryStyles, + darkMode ? aiEntryDarkModeStyles : aiEntryLightModeStyles ); const robotIconSVG = `Ask AI @@ -118,4 +138,4 @@ ${getRobotSVGString()}`; return containerEl; } -export { createAIPlaceholderHTMLPlaceholder }; +export { AIExperienceEntry, createAIPlaceholderHTMLPlaceholder }; diff --git a/packages/compass-components/src/components/generative-ai/ai-feedback.tsx b/packages/compass-components/src/components/generative-ai/ai-feedback.tsx index cd05fbf2bd7..471f242dc62 100644 --- a/packages/compass-components/src/components/generative-ai/ai-feedback.tsx +++ b/packages/compass-components/src/components/generative-ai/ai-feedback.tsx @@ -134,7 +134,7 @@ function AIFeedback({ onSubmitFeedback }: AIFeedbackProps) { )} onClick={() => setChosenFeedbackOption('positive')} size="small" - data-testid="ai-query-feedback-thumbs-up" + data-testid="ai-feedback-thumbs-up" ref={feedbackPositiveButtonRef} > diff --git a/packages/compass-components/src/components/generative-ai/ai-text-input.spec.tsx b/packages/compass-components/src/components/generative-ai/generative-ai-input.spec.tsx similarity index 95% rename from packages/compass-components/src/components/generative-ai/ai-text-input.spec.tsx rename to packages/compass-components/src/components/generative-ai/generative-ai-input.spec.tsx index bbdb617a14c..ef3cd14826d 100644 --- a/packages/compass-components/src/components/generative-ai/ai-text-input.spec.tsx +++ b/packages/compass-components/src/components/generative-ai/generative-ai-input.spec.tsx @@ -52,7 +52,7 @@ describe('GenerativeAIInput Component', function () { it('calls to close robot button is clicked', function () { expect(onCloseSpy.called).to.be.false; - const closeButton = screen.getByTestId('close-ai-query-button'); + const closeButton = screen.getByTestId('close-ai-button'); expect(closeButton).to.be.visible; closeButton.click(); expect(onCloseSpy.calledOnce).to.be.true; @@ -96,7 +96,7 @@ describe('GenerativeAIInput Component', function () { // No feedback popover is shown yet. expect(screen.queryByTestId(feedbackPopoverTextAreaId)).to.not.exist; - const thumbsUpButton = screen.getByTestId('ai-query-feedback-thumbs-up'); + const thumbsUpButton = screen.getByTestId('ai-feedback-thumbs-up'); expect(thumbsUpButton).to.be.visible; thumbsUpButton.click(); diff --git a/packages/compass-components/src/components/generative-ai/generative-ai-input.tsx b/packages/compass-components/src/components/generative-ai/generative-ai-input.tsx index cf2b440be04..62c74c52793 100644 --- a/packages/compass-components/src/components/generative-ai/generative-ai-input.tsx +++ b/packages/compass-components/src/components/generative-ai/generative-ai-input.tsx @@ -16,8 +16,6 @@ const containerStyles = css({ display: 'flex', flexDirection: 'column', gap: spacing[1], - margin: `0px ${spacing[2]}px`, - marginTop: '2px', }); const inputBarContainerStyles = css({ @@ -119,7 +117,7 @@ const closeAIButtonStyles = css( focusRing ); -const closeText = 'Close AI Query'; +const closeText = 'Close AI Helper'; const SubmitArrowSVG = ({ darkMode }: { darkMode?: boolean }) => ( void; onChangeAIPromptText: (text: string) => void; @@ -162,6 +161,7 @@ function GenerativeAIInput({ didSucceed, errorMessage, isFetching, + placeholder = 'Tell Compass what documents to find (e.g. which movies were released in 2000)', show, onCancelRequest, onClose, @@ -223,9 +223,9 @@ function GenerativeAIInput({ className={textInputStyles} ref={promptTextInputRef} sizeVariant="small" - data-testid="ai-query-user-text-input" + data-testid="ai-user-text-input" aria-label="Enter a plain text query that the AI will translate into MongoDB query language." - placeholder="Tell Compass what documents to find (e.g. which movies were released in 2000)" + placeholder={placeholder} value={aiPromptText} onChange={(evt: React.ChangeEvent) => onChangeAIPromptText(evt.currentTarget.value) @@ -235,7 +235,7 @@ function GenerativeAIInput({
{aiPromptText && ( onChangeAIPromptText('')} data-testid="ai-text-clear-prompt" > @@ -248,7 +248,7 @@ function GenerativeAIInput({ generateButtonStyles, !darkMode && generateButtonLightModeStyles )} - data-testid="ai-query-generate-button" + data-testid="ai-generate-button" onClick={() => isFetching ? onCancelRequest() : onSubmitText(aiPromptText) } @@ -292,7 +292,7 @@ function GenerativeAIInput({ ) : (
{errorMessage && (
- + {errorMessage}
diff --git a/packages/compass-components/src/components/generative-ai/index.ts b/packages/compass-components/src/components/generative-ai/index.ts index 255f0a7b6d4..f2f97da4ac7 100644 --- a/packages/compass-components/src/components/generative-ai/index.ts +++ b/packages/compass-components/src/components/generative-ai/index.ts @@ -1,2 +1,5 @@ export { GenerativeAIInput } from './generative-ai-input'; -export { createAIPlaceholderHTMLPlaceholder } from './ai-experience-entry'; +export { + AIExperienceEntry, + createAIPlaceholderHTMLPlaceholder, +} from './ai-experience-entry'; diff --git a/packages/compass-components/src/index.ts b/packages/compass-components/src/index.ts index 19219036752..651be49cf27 100644 --- a/packages/compass-components/src/index.ts +++ b/packages/compass-components/src/index.ts @@ -51,6 +51,7 @@ export { DocumentIcon } from './components/icons/document-icon'; export { FavoriteIcon } from './components/icons/favorite-icon'; export { NoSavedItemsIcon } from './components/icons/no-saved-items-icon'; export { + AIExperienceEntry, GenerativeAIInput, createAIPlaceholderHTMLPlaceholder, } from './components/generative-ai'; diff --git a/packages/compass-e2e-tests/helpers/selectors.ts b/packages/compass-e2e-tests/helpers/selectors.ts index 91d48e09de0..38151f82351 100644 --- a/packages/compass-e2e-tests/helpers/selectors.ts +++ b/packages/compass-e2e-tests/helpers/selectors.ts @@ -1034,11 +1034,10 @@ export const queryBarExportToLanguageButton = (tabName: string): string => { }; export const QueryBarAskAIButton = '[data-testid="open-ai-query-ask-ai-button"]'; -export const QueryBarAITextInput = '[data-testid="ai-query-user-text-input"]'; +export const QueryBarAITextInput = '[data-testid="ai-user-text-input"]'; export const QueryBarAIGenerateQueryButton = - '[data-testid="ai-query-generate-button"]'; -export const QueryBarAIErrorMessageBanner = - '[data-testid="ai-query-error-msg"]'; + '[data-testid="ai-generate-button"]'; +export const QueryBarAIErrorMessageBanner = '[data-testid="ai-error-msg"]'; // Workspace tabs at the top export const SelectedWorkspaceTabButton = diff --git a/packages/compass-e2e-tests/tests/collection-ai-query.test.ts b/packages/compass-e2e-tests/tests/collection-ai-query.test.ts index 0d3336c4a56..c958184b87b 100644 --- a/packages/compass-e2e-tests/tests/collection-ai-query.test.ts +++ b/packages/compass-e2e-tests/tests/collection-ai-query.test.ts @@ -76,11 +76,7 @@ describe('Collection ai query', function () { body: { content: { query: { - filter: { - i: { - $gt: 50, - }, - }, + filter: '{i: {$gt: 50}}', }, }, }, diff --git a/packages/compass-query-bar/src/components/query-ai.spec.tsx b/packages/compass-query-bar/src/components/query-ai.spec.tsx index cc84daf87f5..46e9663e187 100644 --- a/packages/compass-query-bar/src/components/query-ai.spec.tsx +++ b/packages/compass-query-bar/src/components/query-ai.spec.tsx @@ -56,7 +56,7 @@ describe('QueryAI Component', function () { it('calls to close robot button is clicked', function () { expect(onCloseSpy.called).to.be.false; - const closeButton = screen.getByTestId('close-ai-query-button'); + const closeButton = screen.getByTestId('close-ai-button'); expect(closeButton).to.be.visible; closeButton.click(); expect(onCloseSpy.calledOnce).to.be.true; @@ -104,7 +104,7 @@ describe('QueryAI Component', function () { // No feedback popover is shown yet. expect(screen.queryByTestId(feedbackPopoverTextAreaId)).to.not.exist; - expect(screen.queryByTestId('ai-query-feedback-thumbs-up')).to.not.exist; + expect(screen.queryByTestId('ai-feedback-thumbs-up')).to.not.exist; store.dispatch({ type: AIQueryActionTypes.AIQuerySucceeded, @@ -112,7 +112,7 @@ describe('QueryAI Component', function () { }); expect(screen.queryByTestId(feedbackPopoverTextAreaId)).to.not.exist; - const thumbsUpButton = screen.getByTestId('ai-query-feedback-thumbs-up'); + const thumbsUpButton = screen.getByTestId('ai-feedback-thumbs-up'); expect(thumbsUpButton).to.be.visible; thumbsUpButton.click(); diff --git a/packages/compass-query-bar/src/components/query-bar.tsx b/packages/compass-query-bar/src/components/query-bar.tsx index 54fd23e1654..1145a7fdb63 100644 --- a/packages/compass-query-bar/src/components/query-bar.tsx +++ b/packages/compass-query-bar/src/components/query-bar.tsx @@ -92,6 +92,11 @@ const queryOptionsContainerStyles = css({ gap: spacing[2], }); +const queryAIContainerStyles = css({ + margin: `0px ${spacing[2]}px`, + marginTop: '2px', +}); + const queryBarDocumentationLink = 'https://docs.mongodb.com/compass/current/query/filter/'; @@ -307,12 +312,14 @@ export const QueryBar: React.FunctionComponent = ({
)} {enableAIQuery && ( - { - onHideAIInputClick?.(); - }} - show={isAIInputVisible} - /> +
+ { + onHideAIInputClick?.(); + }} + show={isAIInputVisible} + /> +
)} ); diff --git a/packages/compass-query-bar/src/stores/ai-query-reducer.ts b/packages/compass-query-bar/src/stores/ai-query-reducer.ts index 3405afec829..84053c590ae 100644 --- a/packages/compass-query-bar/src/stores/ai-query-reducer.ts +++ b/packages/compass-query-bar/src/stores/ai-query-reducer.ts @@ -3,11 +3,13 @@ import createLoggerAndTelemetry from '@mongodb-js/compass-logging'; import { getSimplifiedSchema } from 'mongodb-schema'; import toNS from 'mongodb-ns'; import preferences from 'compass-preferences-model'; -import { EJSON } from 'bson'; import type { QueryBarThunkAction } from './query-bar-store'; import { isAction } from '../utils'; -import { mapQueryToFormFields } from '../utils/query'; +import { + mapQueryToFormFields, + parseQueryAttributesToFormFields, +} from '../utils/query'; import type { QueryFormFields } from '../constants/query-properties'; import { DEFAULT_FIELD_VALUES } from '../constants/query-bar-store'; import { openToast } from '@mongodb-js/compass-components'; @@ -209,12 +211,12 @@ export const runAIQuery = ( ); } - const query = EJSON.deserialize(jsonResponse?.content?.query); + const query = jsonResponse?.content?.query; - fields = mapQueryToFormFields({ - ...DEFAULT_FIELD_VALUES, - ...(query ?? {}), - }); + fields = { + ...mapQueryToFormFields(DEFAULT_FIELD_VALUES), + ...parseQueryAttributesToFormFields(query ?? {}), + }; } catch (err: any) { logFailed(err?.message); dispatch({ diff --git a/packages/compass-query-bar/src/utils/query.ts b/packages/compass-query-bar/src/utils/query.ts index 5160a8e6bdf..a89f22562a7 100644 --- a/packages/compass-query-bar/src/utils/query.ts +++ b/packages/compass-query-bar/src/utils/query.ts @@ -51,6 +51,35 @@ export function doesQueryHaveExtraOptionsSet(fields?: QueryFormFields) { return false; } +export function parseQueryAttributesToFormFields(query?: { + filter?: string; + project?: string; + collation?: string; + sort?: string; + skip?: string; + limit?: string; + maxTimeMS?: string; +}): QueryFormFields { + return Object.fromEntries( + Object.entries(query ?? {}) + .map(([key, valueString]) => { + if (!isQueryProperty(key) || typeof valueString === 'undefined') { + return null; + } + + const value = validateField(key, valueString); + const valid: boolean = value !== false; + return [ + key, + { string: valueString, value: valid ? value : null, valid }, + ] as const; + }) + .filter((value) => { + return value !== null; + }) as [string, unknown][] + ) as QueryFormFields; +} + /** * Map query document to the query fields state only preserving valid values */