Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat support upload external model register #293

Open
wants to merge 9 commits into
base: feature/model-registry
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: update base review
Signed-off-by: xyinshen <xyinshen@amazon.com>
  • Loading branch information
xyinshen committed Dec 6, 2023
commit bbc935955a76c71f7c9a9384c974fa5047181b28
1 change: 1 addition & 0 deletions public/components/common/forms/model_file_format.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export const ModelFileFormatSelect = ({ readOnly = false }: Props) => {
});

const { ref: fileFormatInputRef, ...fileFormatField } = modelFileFormatController.field;

const selectedFileFormatOption = useMemo(() => {
if (fileFormatField.value) {
return FILE_FORMAT_OPTIONS.find((fmt) => fmt.value === fileFormatField.value);
Expand Down
1 change: 1 addition & 0 deletions public/components/monitoring/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export const Monitoring = () => {
} = useMonitoring();
const [previewModel, setPreviewModel] = useState<ModelDeploymentItem | null>(null);
const searchInputRef = useRef<HTMLInputElement | null>();

const setInputRef = useCallback((node: HTMLInputElement | null) => {
searchInputRef.current = node;
}, []);
Expand Down
1 change: 1 addition & 0 deletions public/components/monitoring/model_connector_filter.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export const ModelConnectorFilter = ({
),
[internalConnectorsResult?.data, allExternalConnectors]
);

return (
<OptionsFilter
id="modelConnectorNameFilter"
Expand Down
51 changes: 51 additions & 0 deletions public/components/register_model/model_deployment.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import React, { useCallback, useState } from 'react';
import { EuiCheckbox, EuiText, EuiFormRow } from '@elastic/eui';
import { useController, useFormContext } from 'react-hook-form';
import { useSearchParams } from '../../hooks/use_search_params';
export const ModelDeployment = () => {
const searchParams = useSearchParams();
const typeParams = searchParams.get('type');
const [checked, setChecked] = useState(false);
const { control } = useFormContext<{ deployment: boolean }>();
const modelDeploymentController = useController({
name: 'deployment',
control,
});

const { ref: deploymentInputRef, ...deploymentField } = modelDeploymentController.field;
const onDeploymentChange = useCallback(
(e) => {
setChecked(e.target.checked);
deploymentField.onChange(checked);
},
[deploymentField, checked]
);
return (
<EuiFormRow label={typeParams === 'external' ? 'Activation' : 'Deployment'}>
<div>
{<EuiText size="xs">Needs a description</EuiText>}
{(typeParams === 'upload' || typeParams === 'import') && (
<EuiCheckbox
id="deployment"
label="Start deployment automatically"
checked={checked}
onChange={onDeploymentChange}
/>
)}
{typeParams === 'external' && (
<EuiCheckbox
id="activation"
label="Activate on registration"
checked={checked}
onChange={onDeploymentChange}
/>
)}
</div>
</EuiFormRow>
);
};
10 changes: 5 additions & 5 deletions public/components/register_model/model_source.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import { useMonitoring } from '../monitoring/use_monitoring';

export const ModelSource = () => {
const { allExternalConnectors } = useMonitoring();
const CONNECTOR_OPTIONS = allExternalConnectors?.map((item) => {
return Object.assign({}, { label: item.name, value: item.description });
const connectorOptions = allExternalConnectors?.map((item) => {
return Object.assign({}, { label: item.name, value: item.id });
});
const { control } = useFormContext<{ modelConnector: string }>();

Expand All @@ -36,9 +36,9 @@ export const ModelSource = () => {
const { ref: fileFormatInputRef, ...fileFormatField } = modelConnectorController.field;
const selectedConnectorOption = useMemo(() => {
if (fileFormatField.value) {
return CONNECTOR_OPTIONS?.find((connector) => connector.value === fileFormatField.value);
return connectorOptions?.find((connector) => connector.value === fileFormatField.value);
}
}, [fileFormatField, CONNECTOR_OPTIONS]);
}, [fileFormatField, connectorOptions]);

const onConnectorChange = useCallback(
(options: Array<EuiComboBoxOptionOption<string>>) => {
Expand All @@ -65,7 +65,7 @@ export const ModelSource = () => {
<EuiFormRow label="Model connector">
<EuiComboBox
inputRef={fileFormatInputRef}
options={CONNECTOR_OPTIONS}
options={connectorOptions}
singleSelection={{ asPlainText: true }}
selectedOptions={selectedConnectorOption ? [selectedConnectorOption] : []}
placeholder="Select a connector"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ interface IItem {
checked?: 'on' | undefined;
description: string;
}
interface Props {
getPreSelected: (val: boolean) => void;
}
const renderModelOption = (option: IItem, searchValue: string) => {
return (
<>
Expand All @@ -37,7 +34,7 @@ const renderModelOption = (option: IItem, searchValue: string) => {
</>
);
};
export const PreTrainedModelSelect = ({ getPreSelected }: Props) => {
export const PreTrainedModelSelect = () => {
useEffect(() => {
const subscribe = modelRepositoryManager.getPreTrainedModels$().subscribe((models) => {
setModelRepoSelection(
Expand All @@ -52,30 +49,25 @@ export const PreTrainedModelSelect = ({ getPreSelected }: Props) => {
subscribe.unsubscribe();
};
}, []);
const ShowRest = useCallback(
(selected: boolean) => {
getPreSelected(selected);
},
[getPreSelected]
);
const [modelRepoSelection, setModelRepoSelection] = useState<Array<EuiSelectableOption<IItem>>>(
[]
);
const history = useHistory();
const onChange = useCallback(
(modelSelection: Array<EuiSelectableOption<IItem>>) => {
setModelRepoSelection(modelSelection);
ShowRest(true);
// ShowRest(true);
},
[ShowRest]
// [ShowRest]
[]
);
useEffect(() => {
const selectedOption = modelRepoSelection.find((option) => option.checked === 'on');
if (selectedOption?.label) {
history.push(
`${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import&name=${
selectedOption?.label
}&version=${selectedOption?.label}`
}`
);
}
}, [modelRepoSelection, history]);
Expand Down
133 changes: 54 additions & 79 deletions public/components/register_model/register_model.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ import {
EuiFlexItem,
EuiTextColor,
EuiLink,
EuiFormRow,
EuiCheckbox,
EuiPageContent,
} from '@elastic/eui';
import useObservable from 'react-use/lib/useObservable';
Expand All @@ -32,7 +30,7 @@ import { mountReactNode } from '../../../../../src/core/public/utils';
import { routerPaths } from '../../../common/router_paths';
import { ErrorCallOut } from '../../components/common';
import { modelRepositoryManager } from '../../utils/model_repository_manager';
import { PreTrainedModelSelect } from './pretrainedmodel_select';
import { PreTrainedModelSelect } from './pretrained_model_select';
import { modelTaskManager } from './model_task_manager';
import { ModelVersionNotesPanel } from './model_version_notes';
import { modelFileUploadManager } from './model_file_upload_manager';
Expand All @@ -44,7 +42,7 @@ import { ConfigurationPanel } from './model_configuration';
import { ModelTagsPanel } from './model_tags';
import { submitModelWithFile, submitModelWithURL } from './register_model_api';
import { ModelSource } from './model_source';

import { ModelDeployment } from './model_deployment';
const DEFAULT_VALUES = {
name: '',
description: '',
Expand Down Expand Up @@ -89,36 +87,42 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo
services: { chrome, notifications },
} = useOpenSearchDashboards();
const isLocked = useObservable(chrome?.getIsNavDrawerLocked$() ?? from([false]));
const [preSelected, setPreSelect] = useState(false);
const getPreSelected = (val: boolean) => {
setPreSelect(val);
};
const formType = isValidModelRegisterFormType(typeParams) ? typeParams : 'upload';
const partials =
formType === 'import'
? [
PreTrainedModelSelect,
...(!preSelected ? [] : [ModelDetailsPanel]),
...(!preSelected ? [] : [ModelTagsPanel]),
...(!preSelected ? [] : [ModelVersionNotesPanel]),
]
: formType === 'upload'
? [
...(registerToModelId ? [] : [ModelOverviewTitle]),
...(registerToModelId ? [] : [ModelDetailsPanel]),
...(registerToModelId ? [] : [ModelTagsPanel]),
...(registerToModelId ? [] : [FileAndVersionTitle]),
ArtifactPanel,
ConfigurationPanel,
...(registerToModelId ? [ModelTagsPanel] : []),
ModelVersionNotesPanel,
]
: [
...(registerToModelId ? [] : [ModelDetailsPanel]),
...(registerToModelId ? [] : [ModelTagsPanel]),
ModelSource,
ModelVersionNotesPanel,
];
const partials = (() => {
if (formType === 'import') {
if (!nameParams) {
return [PreTrainedModelSelect];
}
return [
PreTrainedModelSelect,
ModelDetailsPanel,
ModelTagsPanel,
ModelVersionNotesPanel,
ModelDeployment,
];
}
if (formType === 'external') {
return [
...(registerToModelId ? [] : [ModelDetailsPanel]),
...(registerToModelId ? [] : [ModelTagsPanel]),
ModelSource,
ModelVersionNotesPanel,
ModelDeployment,
];
}
return [
...(registerToModelId ? [] : [ModelOverviewTitle]),
...(registerToModelId ? [] : [ModelDetailsPanel]),
...(registerToModelId ? [] : [ModelTagsPanel]),
...(registerToModelId ? [] : [FileAndVersionTitle]),
ArtifactPanel,
ConfigurationPanel,
...(registerToModelId ? [ModelTagsPanel] : []),
ModelVersionNotesPanel,
ModelDeployment,
];
})();

const form = useForm<ModelFileFormData | ModelUrlFormData>({
mode: 'onChange',
defaultValues,
Expand Down Expand Up @@ -275,21 +279,23 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo
// eslint-disable-next-line no-console
console.log(errors);
}, []);

const getPageTitle = () => {
if (registerToModelId) {
return 'Register version';
}
switch (formType) {
case 'external':
return 'Register external model';
case 'external':
return 'Register pre-trained model';
default:
return 'Register your own model';
}
};
const errorCount = Object.keys(form.formState.errors).length;
const formHeader = (
<>
<EuiPageHeader
pageTitle={
registerToModelId
? 'Register version'
: formType === 'external'
? 'Register external model'
: formType === 'import'
? 'Register pre-trained model'
: 'Register your own model'
}
/>
<EuiPageHeader pageTitle={getPageTitle} />
<EuiText style={{ maxWidth: 725 }}>
<small>
{registerToModelId && (
Expand All @@ -314,33 +320,6 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo
</EuiText>
</>
);
const [checked, setChecked] = useState(false);
const onChange = (e: any) => {
setChecked(e.target.checked);
};
const formFooter = (
<EuiFormRow label={formType === 'external' ? 'Activation' : 'Deployment'}>
<div>
{<EuiText size="xs">Needs a description</EuiText>}
{(formType === 'upload' || formType === 'import') && (
<EuiCheckbox
id="deployment"
label="Start deployment automatically"
checked={checked}
onChange={(e) => onChange(e)}
/>
)}
{formType === 'external' && !registerToModelId && (
<EuiCheckbox
id="activation"
label="Activate on registration"
checked={checked}
onChange={(e) => onChange(e)}
/>
)}
</div>
</EuiFormRow>
);
return (
<EuiPageContent
verticalPosition="center"
Expand All @@ -366,19 +345,15 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo
)}
{partials.map((FormPartial, i) => (
<React.Fragment key={i}>
{FormPartial === PreTrainedModelSelect ? (
<FormPartial getPreSelected={getPreSelected} />
) : (
<FormPartial />
)}
{FormPartial === PreTrainedModelSelect ? <FormPartial /> : <FormPartial />}
{FormPartial === ModelOverviewTitle || FormPartial === FileAndVersionTitle ? (
<EuiSpacer size="s" />
) : (
<EuiSpacer size="xl" />
)}
</React.Fragment>
))}
{formType === 'import' ? preSelected && formFooter : formFooter}
{/* {formType === 'import' ? nameParams && formFooter : formFooter} */}
</EuiPanel>
<EuiSpacer size="xxl" />
<EuiSpacer size="xxl" />
Expand All @@ -398,7 +373,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo
</EuiFlexItem>
)}
<EuiFlexItem grow={false}>
<EuiButton onClick={() => setIsSubmitted(true)} iconType="cross" color="ghost">
<EuiButton onClick={() => setIsSubmitted(false)} iconType="cross" color="ghost">
Cancel
</EuiButton>
</EuiFlexItem>
Expand Down
1 change: 1 addition & 0 deletions public/components/register_model/register_model.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export interface ModelFormBase {
tags?: Tag[];
versionNotes?: string;
type?: 'import' | 'upload' | 'external';
deployment: boolean;
}

/**
Expand Down
Loading