Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions x-pack/plugins/ml/common/constants/data_frame_analytics.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

export const DEFAULT_RESULTS_FIELD = 'ml';
6 changes: 6 additions & 0 deletions x-pack/plugins/ml/common/types/data_frame_analytics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,9 @@ export interface DataFrameAnalyticsConfig {
version: string;
allow_lazy_start?: boolean;
}

export enum ANALYSIS_CONFIG_TYPE {
OUTLIER_DETECTION = 'outlier_detection',
REGRESSION = 'regression',
CLASSIFICATION = 'classification',
}
23 changes: 23 additions & 0 deletions x-pack/plugins/ml/common/types/feature_importance.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

export interface ClassFeatureImportance {
class_name: string | boolean;
importance: number;
}
export interface FeatureImportance {
feature_name: string;
importance?: number;
classes?: ClassFeatureImportance[];
}

export interface TopClass {
class_name: string;
class_probability: number;
class_score: number;
}

export type TopClasses = TopClass[];
79 changes: 79 additions & 0 deletions x-pack/plugins/ml/common/util/analytics_utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

import {
AnalysisConfig,
ClassificationAnalysis,
OutlierAnalysis,
RegressionAnalysis,
ANALYSIS_CONFIG_TYPE,
} from '../types/data_frame_analytics';

export const isOutlierAnalysis = (arg: any): arg is OutlierAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION;
};

export const isRegressionAnalysis = (arg: any): arg is RegressionAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION;
};

export const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION;
};

export const getDependentVar = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['dependent_variable']
| ClassificationAnalysis['classification']['dependent_variable'] => {
let depVar = '';

if (isRegressionAnalysis(analysis)) {
depVar = analysis.regression.dependent_variable;
}

if (isClassificationAnalysis(analysis)) {
depVar = analysis.classification.dependent_variable;
}
return depVar;
};

export const getPredictionFieldName = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['prediction_field_name']
| ClassificationAnalysis['classification']['prediction_field_name'] => {
// If undefined will be defaulted to dependent_variable when config is created
let predictionFieldName;
if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) {
predictionFieldName = analysis.regression.prediction_field_name;
} else if (
isClassificationAnalysis(analysis) &&
analysis.classification.prediction_field_name !== undefined
) {
predictionFieldName = analysis.classification.prediction_field_name;
}
return predictionFieldName;
};

export const getDefaultPredictionFieldName = (analysis: AnalysisConfig) => {
return `${getDependentVar(analysis)}_prediction`;
};
export const getPredictedFieldName = (
resultsField: string,
analysis: AnalysisConfig,
forSort?: boolean
) => {
// default is 'ml'
const predictionFieldName = getPredictionFieldName(analysis);
const predictedField = `${resultsField}.${
predictionFieldName ? predictionFieldName : getDefaultPredictionFieldName(analysis)
}`;
return predictedField;
};
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,14 @@ export const getDataGridSchemasFromFieldTypes = (fieldTypes: FieldTypes, results
schema = 'numeric';
}

if (
field.includes(`${resultsField}.${FEATURE_IMPORTANCE}`) ||
field.includes(`${resultsField}.${TOP_CLASSES}`)
) {
if (field.includes(`${resultsField}.${TOP_CLASSES}`)) {
schema = 'json';
}

if (field.includes(`${resultsField}.${FEATURE_IMPORTANCE}`)) {
schema = 'featureImportance';
}

return { id: field, schema, isSortable };
});
};
Expand Down Expand Up @@ -250,10 +251,6 @@ export const useRenderCellValue = (
return cellValue ? 'true' : 'false';
}

if (typeof cellValue === 'object' && cellValue !== null) {
return JSON.stringify(cellValue);
}

return cellValue;
};
}, [indexPattern?.fields, pagination.pageIndex, pagination.pageSize, tableItems]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
*/

import { isEqual } from 'lodash';
import React, { memo, useEffect, FC } from 'react';

import React, { memo, useEffect, FC, useMemo } from 'react';
import { i18n } from '@kbn/i18n';

import {
Expand All @@ -24,13 +23,16 @@ import {
} from '@elastic/eui';

import { CoreSetup } from 'src/core/public';

import { DEFAULT_SAMPLER_SHARD_SIZE } from '../../../../common/constants/field_histograms';

import { INDEX_STATUS } from '../../data_frame_analytics/common';
import { ANALYSIS_CONFIG_TYPE, INDEX_STATUS } from '../../data_frame_analytics/common';

import { euiDataGridStyle, euiDataGridToolbarSettings } from './common';
import { UseIndexDataReturnType } from './types';
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
import { TopClasses } from '../../../../common/types/feature_importance';
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';

// TODO Fix row hovering + bar highlighting
// import { hoveredRow$ } from './column_chart';

Expand All @@ -41,6 +43,9 @@ export const DataGridTitle: FC<{ title: string }> = ({ title }) => (
);

interface PropsWithoutHeader extends UseIndexDataReturnType {
baseline?: number;
analysisType?: ANALYSIS_CONFIG_TYPE;
resultsField?: string;
dataTestSubj: string;
toastNotifications: CoreSetup['notifications']['toasts'];
}
Expand All @@ -60,6 +65,7 @@ type Props = PropsWithHeader | PropsWithoutHeader;
export const DataGrid: FC<Props> = memo(
(props) => {
const {
baseline,
chartsVisible,
chartsButtonVisible,
columnsWithCharts,
Expand All @@ -80,8 +86,10 @@ export const DataGrid: FC<Props> = memo(
toastNotifications,
toggleChartVisibility,
visibleColumns,
predictionFieldName,
resultsField,
analysisType,
} = props;

// TODO Fix row hovering + bar highlighting
// const getRowProps = (item: any) => {
// return {
Expand All @@ -90,6 +98,45 @@ export const DataGrid: FC<Props> = memo(
// };
// };

const popOverContent = useMemo(() => {
return analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION ||
analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION
? {
featureImportance: ({ children }: { cellContentsElement: any; children: any }) => {
const rowIndex = children?.props?.visibleRowIndex;
const row = data[rowIndex];
if (!row) return <div />;
// if resultsField for some reason is not available then use ml
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;
const parsedFIArray = row[mlResultsField].feature_importance;
let predictedValue: string | number | undefined;
let topClasses: TopClasses = [];
if (
predictionFieldName !== undefined &&
row &&
row[mlResultsField][predictionFieldName] !== undefined
) {
predictedValue = row[mlResultsField][predictionFieldName];
topClasses = row[mlResultsField].top_classes;
}

return (
<DecisionPathPopover
analysisType={analysisType}
predictedValue={predictedValue}
baseline={baseline}
featureImportance={parsedFIArray}
topClasses={topClasses}
predictionFieldName={
predictionFieldName ? predictionFieldName.replace('_prediction', '') : undefined
}
/>
);
},
}
: undefined;
}, [baseline, data]);

useEffect(() => {
if (invalidSortingColumnns.length > 0) {
invalidSortingColumnns.forEach((columnId) => {
Expand Down Expand Up @@ -225,6 +272,7 @@ export const DataGrid: FC<Props> = memo(
}
: {}),
}}
popoverContents={popOverContent}
pagination={{
...pagination,
pageSizeOptions: [5, 10, 25],
Expand Down
Loading