Skip to content

Commit

Permalink
add 'Generative AI' submenu (#971)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq authored Aug 29, 2024
1 parent 29e2a47 commit 79158ab
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ export function SendButton(props: SendButtonProps): JSX.Element {
if (activeCell.exists) {
props.onSend({
type: 'cell',
source: activeCell.manager.getContent(false).source
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
source: activeCell.manager.getContent(false)!.source
});
closeMenu();
return;
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/src/contexts/active-cell-context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ export class ActiveCellManager {
* `ActiveCellContentWithError` object that describes both the active cell and
* the error output.
*/
getContent(withError: false): CellContent;
getContent(withError: false): CellContent | null;
getContent(withError: true): CellWithErrorContent | null;
getContent(withError = false): CellContent | CellWithErrorContent | null {
const sharedModel = this._activeCell?.model.sharedModel;
Expand Down
17 changes: 13 additions & 4 deletions packages/jupyter-ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ import { ChatHandler } from './chat_handler';
import { buildErrorWidget } from './widgets/chat-error';
import { completionPlugin } from './completions';
import { statusItemPlugin } from './status';
import { IJaiCompletionProvider, IJaiMessageFooter } from './tokens';
import { IJaiCompletionProvider, IJaiCore, IJaiMessageFooter } from './tokens';
import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
import { ActiveCellManager } from './contexts/active-cell-context';
import { Signal } from '@lumino/signaling';
import { menuPlugin } from './plugins/menu-plugin';

export type DocumentTracker = IWidgetTracker<IDocumentWidget>;

Expand All @@ -35,17 +36,18 @@ export namespace CommandIDs {
/**
* Initialization data for the jupyter_ai extension.
*/
const plugin: JupyterFrontEndPlugin<void> = {
const plugin: JupyterFrontEndPlugin<IJaiCore> = {
id: '@jupyter-ai/core:plugin',
autoStart: true,
requires: [IRenderMimeRegistry],
optional: [
IGlobalAwareness,
ILayoutRestorer,
IThemeManager,
IJaiCompletionProvider,
IJaiMessageFooter
],
requires: [IRenderMimeRegistry],
provides: IJaiCore,
activate: async (
app: JupyterFrontEnd,
rmRegistry: IRenderMimeRegistry,
Expand Down Expand Up @@ -114,7 +116,14 @@ const plugin: JupyterFrontEndPlugin<void> = {
},
label: 'Focus the jupyter-ai chat'
});

return {
activeCellManager,
chatHandler,
chatWidget,
selectionWatcher
};
}
};

export default [plugin, statusItemPlugin, completionPlugin];
export default [plugin, statusItemPlugin, completionPlugin, menuPlugin];
158 changes: 158 additions & 0 deletions packages/jupyter-ai/src/plugins/menu-plugin.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import {
JupyterFrontEnd,
JupyterFrontEndPlugin
} from '@jupyterlab/application';

import { IJaiCore } from '../tokens';
import { AiService } from '../handler';
import { Menu } from '@lumino/widgets';
import { CommandRegistry } from '@lumino/commands';

export namespace CommandIDs {
export const explain = 'jupyter-ai:explain';
export const fix = 'jupyter-ai:fix';
export const optimize = 'jupyter-ai:optimize';
export const refactor = 'jupyter-ai:refactor';
}

/**
* Optional plugin that adds a "Generative AI" submenu to the context menu.
* These implement UI shortcuts that explain, fix, refactor, or optimize code in
* a notebook or file.
*
* **This plugin is experimental and may be removed in a future release.**
*/
export const menuPlugin: JupyterFrontEndPlugin<void> = {
id: '@jupyter-ai/core:menu-plugin',
autoStart: true,
requires: [IJaiCore],
activate: (app: JupyterFrontEnd, jaiCore: IJaiCore) => {
const { activeCellManager, chatHandler, chatWidget, selectionWatcher } =
jaiCore;

function activateChatSidebar() {
app.shell.activateById(chatWidget.id);
}

function getSelection(): AiService.Selection | null {
const textSelection = selectionWatcher.selection;
const activeCell = activeCellManager.getContent(false);
const selection: AiService.Selection | null = textSelection
? { type: 'text', source: textSelection.text }
: activeCell
? { type: 'cell', source: activeCell.source }
: null;

return selection;
}

function buildLabelFactory(baseLabel: string): () => string {
return () => {
const textSelection = selectionWatcher.selection;
const activeCell = activeCellManager.getContent(false);

return textSelection
? `${baseLabel} (${textSelection.numLines} lines selected)`
: activeCell
? `${baseLabel} (1 active cell)`
: baseLabel;
};
}

// register commands
const menuCommands = new CommandRegistry();
menuCommands.addCommand(CommandIDs.explain, {
execute: () => {
const selection = getSelection();
if (!selection) {
return;
}

activateChatSidebar();
chatHandler.sendMessage({
prompt: 'Explain the code below.',
selection
});
},
label: buildLabelFactory('Explain code'),
isEnabled: () => !!getSelection()
});
menuCommands.addCommand(CommandIDs.fix, {
execute: () => {
const activeCellWithError = activeCellManager.getContent(true);
if (!activeCellWithError) {
return;
}

chatHandler.sendMessage({
prompt: '/fix',
selection: {
type: 'cell-with-error',
error: activeCellWithError.error,
source: activeCellWithError.source
}
});
},
label: () => {
const activeCellWithError = activeCellManager.getContent(true);
return activeCellWithError
? 'Fix code cell (1 error cell)'
: 'Fix code cell (no error cell)';
},
isEnabled: () => {
const activeCellWithError = activeCellManager.getContent(true);
return !!activeCellWithError;
}
});
menuCommands.addCommand(CommandIDs.optimize, {
execute: () => {
const selection = getSelection();
if (!selection) {
return;
}

activateChatSidebar();
chatHandler.sendMessage({
prompt: 'Optimize the code below.',
selection
});
},
label: buildLabelFactory('Optimize code'),
isEnabled: () => !!getSelection()
});
menuCommands.addCommand(CommandIDs.refactor, {
execute: () => {
const selection = getSelection();
if (!selection) {
return;
}

activateChatSidebar();
chatHandler.sendMessage({
prompt: 'Refactor the code below.',
selection
});
},
label: buildLabelFactory('Refactor code'),
isEnabled: () => !!getSelection()
});

// add commands as a context menu item containing a "Generative AI" submenu
const submenu = new Menu({
commands: menuCommands
});
submenu.id = 'jupyter-ai:submenu';
submenu.title.label = 'Generative AI';
submenu.addItem({ command: CommandIDs.explain });
submenu.addItem({ command: CommandIDs.fix });
submenu.addItem({ command: CommandIDs.optimize });
submenu.addItem({ command: CommandIDs.refactor });

app.contextMenu.addItem({
type: 'submenu',
selector: '.jp-Editor',
rank: 1,
submenu
});
}
};
9 changes: 9 additions & 0 deletions packages/jupyter-ai/src/selection-watcher.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ function getTextSelection(widget: Widget | null): Selection | null {
start,
end,
text,
numLines: text.split('\n').length,
widgetId: widget.id,
...(cellId && {
cellId
Expand All @@ -88,6 +89,10 @@ export type Selection = CodeEditor.ITextSelection & {
* The text within the selection as a string.
*/
text: string;
/**
* Number of lines contained by the text selection.
*/
numLines: number;
/**
* The ID of the document widget in which the selection was made.
*/
Expand All @@ -109,6 +114,10 @@ export class SelectionWatcher {
setInterval(this._poll.bind(this), 200);
}

get selection(): Selection | null {
return this._selection;
}

get selectionChanged(): Signal<this, Selection | null> {
return this._selectionChanged;
}
Expand Down
23 changes: 22 additions & 1 deletion packages/jupyter-ai/src/tokens.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import React from 'react';
import { Token } from '@lumino/coreutils';
import { ISignal } from '@lumino/signaling';
import type { IRankedMenu } from '@jupyterlab/ui-components';
import type { IRankedMenu, ReactWidget } from '@jupyterlab/ui-components';

import { AiService } from './handler';
import { ChatHandler } from './chat_handler';
import { ActiveCellManager } from './contexts/active-cell-context';
import { SelectionWatcher } from './selection-watcher';

export interface IJaiStatusItem {
addItem(item: IRankedMenu.IItemOptions): void;
Expand Down Expand Up @@ -46,3 +50,20 @@ export const IJaiMessageFooter = new Token<IJaiMessageFooter>(
'jupyter_ai:IJaiMessageFooter',
'Optional component that is used to render a footer on each Jupyter AI chat message, when provided.'
);

export interface IJaiCore {
chatWidget: ReactWidget;
chatHandler: ChatHandler;
activeCellManager: ActiveCellManager;
selectionWatcher: SelectionWatcher;
}

/**
* The Jupyter AI core provider token. Frontend plugins that want to extend the
* Jupyter AI frontend by adding features which send messages or observe the
* current text selection & active cell should require this plugin.
*/
export const IJaiCore = new Token<IJaiCore>(
'jupyter_ai:core',
'The core implementation of the frontend.'
);

0 comments on commit 79158ab

Please sign in to comment.