Skip to content

Commit 0cd5bb0

Browse files
[ML] DF Analytics: create classification jobs results view (#52584)
* wip: create classification results page + table and evaluate panel * enable view link for classification jobs * wip: fetch classification eval data * wip: display confusion matrix in datagrid * evaluate panel: add heatmap for cells and doc count * Update use of loadEvalData in expanded row component * Add metric type for evaluate endpoint and fix localization error * handle no incorrect prediction classes case for confusion matrix. remove unused translation * setCellProps needs to be called from a lifecycle method - wrap in useEffect * TypeScript improvements * fix datagrid column resize affecting results table * allow custom prediction field for classification jobs * ensure values are rounded correctly and add tooltip * temp workaroun for datagrid width issues
1 parent 79a8528 commit 0cd5bb0

File tree

25 files changed

+1514
-116
lines changed

25 files changed

+1514
-116
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@import 'pages/analytics_exploration/components/exploration/index';
22
@import 'pages/analytics_exploration/components/regression_exploration/index';
3+
@import 'pages/analytics_exploration/components/classification_exploration/index';
34
@import 'pages/analytics_management/components/analytics_list/index';
45
@import 'pages/analytics_management/components/create_analytics_form/index';
56
@import 'pages/analytics_management/components/create_analytics_flyout/index';

x-pack/legacy/plugins/ml/public/application/data_frame_analytics/common/analytics.ts

Lines changed: 156 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ interface ClassificationAnalysis {
3535
dependent_variable: string;
3636
training_percent?: number;
3737
num_top_classes?: string;
38+
prediction_field_name?: string;
3839
};
3940
}
4041

@@ -74,13 +75,33 @@ export interface RegressionEvaluateResponse {
7475
};
7576
}
7677

78+
export interface PredictedClass {
79+
predicted_class: string;
80+
count: number;
81+
}
82+
83+
export interface ConfusionMatrix {
84+
actual_class: string;
85+
actual_class_doc_count: number;
86+
predicted_classes: PredictedClass[];
87+
other_predicted_class_doc_count: number;
88+
}
89+
90+
export interface ClassificationEvaluateResponse {
91+
classification: {
92+
multiclass_confusion_matrix: {
93+
confusion_matrix: ConfusionMatrix[];
94+
};
95+
};
96+
}
97+
7798
interface GenericAnalysis {
7899
[key: string]: Record<string, any>;
79100
}
80101

81102
interface LoadEvaluateResult {
82103
success: boolean;
83-
eval: RegressionEvaluateResponse | null;
104+
eval: RegressionEvaluateResponse | ClassificationEvaluateResponse | null;
84105
error: string | null;
85106
}
86107

@@ -109,6 +130,7 @@ export const getAnalysisType = (analysis: AnalysisConfig) => {
109130

110131
export const getDependentVar = (analysis: AnalysisConfig) => {
111132
let depVar = '';
133+
112134
if (isRegressionAnalysis(analysis)) {
113135
depVar = analysis.regression.dependent_variable;
114136
}
@@ -124,17 +146,26 @@ export const getPredictionFieldName = (analysis: AnalysisConfig) => {
124146
let predictionFieldName;
125147
if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) {
126148
predictionFieldName = analysis.regression.prediction_field_name;
149+
} else if (
150+
isClassificationAnalysis(analysis) &&
151+
analysis.classification.prediction_field_name !== undefined
152+
) {
153+
predictionFieldName = analysis.classification.prediction_field_name;
127154
}
128155
return predictionFieldName;
129156
};
130157

131-
export const getPredictedFieldName = (resultsField: string, analysis: AnalysisConfig) => {
158+
export const getPredictedFieldName = (
159+
resultsField: string,
160+
analysis: AnalysisConfig,
161+
forSort?: boolean
162+
) => {
132163
// default is 'ml'
133164
const predictionFieldName = getPredictionFieldName(analysis);
134165
const defaultPredictionField = `${getDependentVar(analysis)}_prediction`;
135166
const predictedField = `${resultsField}.${
136167
predictionFieldName ? predictionFieldName : defaultPredictionField
137-
}`;
168+
}${isClassificationAnalysis(analysis) && !forSort ? '.keyword' : ''}`;
138169
return predictedField;
139170
};
140171

@@ -153,13 +184,32 @@ export const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysi
153184
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION;
154185
};
155186

156-
export const isRegressionResultsSearchBoolQuery = (
157-
arg: any
158-
): arg is RegressionResultsSearchBoolQuery => {
187+
export const isResultsSearchBoolQuery = (arg: any): arg is ResultsSearchBoolQuery => {
159188
const keys = Object.keys(arg);
160189
return keys.length === 1 && keys[0] === 'bool';
161190
};
162191

192+
export const isRegressionEvaluateResponse = (arg: any): arg is RegressionEvaluateResponse => {
193+
const keys = Object.keys(arg);
194+
return (
195+
keys.length === 1 &&
196+
keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION &&
197+
arg?.regression?.mean_squared_error !== undefined &&
198+
arg?.regression?.r_squared !== undefined
199+
);
200+
};
201+
202+
export const isClassificationEvaluateResponse = (
203+
arg: any
204+
): arg is ClassificationEvaluateResponse => {
205+
const keys = Object.keys(arg);
206+
return (
207+
keys.length === 1 &&
208+
keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION &&
209+
arg?.classification?.multiclass_confusion_matrix !== undefined
210+
);
211+
};
212+
163213
export interface DataFrameAnalyticsConfig {
164214
id: DataFrameAnalyticsId;
165215
// Description attribute is not supported yet
@@ -254,17 +304,14 @@ export function getValuesFromResponse(response: RegressionEvaluateResponse) {
254304

255305
return { meanSquaredError, rSquared };
256306
}
257-
interface RegressionResultsSearchBoolQuery {
307+
interface ResultsSearchBoolQuery {
258308
bool: Dictionary<any>;
259309
}
260-
interface RegressionResultsSearchTermQuery {
310+
interface ResultsSearchTermQuery {
261311
term: Dictionary<any>;
262312
}
263313

264-
export type RegressionResultsSearchQuery =
265-
| RegressionResultsSearchBoolQuery
266-
| RegressionResultsSearchTermQuery
267-
| SavedSearchQuery;
314+
export type ResultsSearchQuery = ResultsSearchBoolQuery | ResultsSearchTermQuery | SavedSearchQuery;
268315

269316
export function getEvalQueryBody({
270317
resultsField,
@@ -274,23 +321,44 @@ export function getEvalQueryBody({
274321
}: {
275322
resultsField: string;
276323
isTraining: boolean;
277-
searchQuery?: RegressionResultsSearchQuery;
324+
searchQuery?: ResultsSearchQuery;
278325
ignoreDefaultQuery?: boolean;
279326
}) {
280-
let query: RegressionResultsSearchQuery = {
327+
let query: ResultsSearchQuery = {
281328
term: { [`${resultsField}.is_training`]: { value: isTraining } },
282329
};
283330

284331
if (searchQuery !== undefined && ignoreDefaultQuery === true) {
285332
query = searchQuery;
286-
} else if (searchQuery !== undefined && isRegressionResultsSearchBoolQuery(searchQuery)) {
333+
} else if (searchQuery !== undefined && isResultsSearchBoolQuery(searchQuery)) {
287334
const searchQueryClone = cloneDeep(searchQuery);
288335
searchQueryClone.bool.must.push(query);
289336
query = searchQueryClone;
290337
}
291338
return query;
292339
}
293340

341+
interface EvaluateMetrics {
342+
classification: {
343+
multiclass_confusion_matrix: object;
344+
};
345+
regression: {
346+
r_squared: object;
347+
mean_squared_error: object;
348+
};
349+
}
350+
351+
interface LoadEvalDataConfig {
352+
isTraining: boolean;
353+
index: string;
354+
dependentVariable: string;
355+
resultsField: string;
356+
predictionFieldName?: string;
357+
searchQuery?: ResultsSearchQuery;
358+
ignoreDefaultQuery?: boolean;
359+
jobType: ANALYSIS_CONFIG_TYPE;
360+
}
361+
294362
export const loadEvalData = async ({
295363
isTraining,
296364
index,
@@ -299,34 +367,38 @@ export const loadEvalData = async ({
299367
predictionFieldName,
300368
searchQuery,
301369
ignoreDefaultQuery,
302-
}: {
303-
isTraining: boolean;
304-
index: string;
305-
dependentVariable: string;
306-
resultsField: string;
307-
predictionFieldName?: string;
308-
searchQuery?: RegressionResultsSearchQuery;
309-
ignoreDefaultQuery?: boolean;
310-
}) => {
370+
jobType,
371+
}: LoadEvalDataConfig) => {
311372
const results: LoadEvaluateResult = { success: false, eval: null, error: null };
312373
const defaultPredictionField = `${dependentVariable}_prediction`;
313-
const predictedField = `${resultsField}.${
374+
let predictedField = `${resultsField}.${
314375
predictionFieldName ? predictionFieldName : defaultPredictionField
315376
}`;
316377

378+
if (jobType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION) {
379+
predictedField = `${predictedField}.keyword`;
380+
}
381+
317382
const query = getEvalQueryBody({ resultsField, isTraining, searchQuery, ignoreDefaultQuery });
318383

384+
const metrics: EvaluateMetrics = {
385+
classification: {
386+
multiclass_confusion_matrix: {},
387+
},
388+
regression: {
389+
r_squared: {},
390+
mean_squared_error: {},
391+
},
392+
};
393+
319394
const config = {
320395
index,
321396
query,
322397
evaluation: {
323-
regression: {
398+
[jobType]: {
324399
actual_field: dependentVariable,
325400
predicted_field: predictedField,
326-
metrics: {
327-
r_squared: {},
328-
mean_squared_error: {},
329-
},
401+
metrics: metrics[jobType as keyof EvaluateMetrics],
330402
},
331403
},
332404
};
@@ -341,3 +413,57 @@ export const loadEvalData = async ({
341413
return results;
342414
}
343415
};
416+
417+
interface TrackTotalHitsSearchResponse {
418+
hits: {
419+
total: {
420+
value: number;
421+
relation: string;
422+
};
423+
hits: any[];
424+
};
425+
}
426+
427+
interface LoadDocsCountConfig {
428+
ignoreDefaultQuery?: boolean;
429+
isTraining: boolean;
430+
searchQuery: SavedSearchQuery;
431+
resultsField: string;
432+
destIndex: string;
433+
}
434+
435+
interface LoadDocsCountResponse {
436+
docsCount: number | null;
437+
success: boolean;
438+
}
439+
440+
export const loadDocsCount = async ({
441+
ignoreDefaultQuery = true,
442+
isTraining,
443+
searchQuery,
444+
resultsField,
445+
destIndex,
446+
}: LoadDocsCountConfig): Promise<LoadDocsCountResponse> => {
447+
const query = getEvalQueryBody({ resultsField, isTraining, ignoreDefaultQuery, searchQuery });
448+
449+
try {
450+
const body: SearchQuery = {
451+
track_total_hits: true,
452+
query,
453+
};
454+
455+
const resp: TrackTotalHitsSearchResponse = await ml.esSearch({
456+
index: destIndex,
457+
size: 0,
458+
body,
459+
});
460+
461+
const docsCount = resp.hits.total && resp.hits.total.value;
462+
return { docsCount, success: docsCount !== undefined };
463+
} catch (e) {
464+
return {
465+
docsCount: null,
466+
success: false,
467+
};
468+
}
469+
};

x-pack/legacy/plugins/ml/public/application/data_frame_analytics/common/fields.ts

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ export const sortRegressionResultsFields = (
7777
) => {
7878
const dependentVariable = getDependentVar(jobConfig.analysis);
7979
const resultsField = jobConfig.dest.results_field;
80-
const predictedField = getPredictedFieldName(resultsField, jobConfig.analysis);
80+
const predictedField = getPredictedFieldName(resultsField, jobConfig.analysis, true);
8181
if (a === `${resultsField}.is_training`) {
8282
return -1;
8383
}
@@ -96,6 +96,14 @@ export const sortRegressionResultsFields = (
9696
if (b === dependentVariable) {
9797
return 1;
9898
}
99+
100+
if (a === `${resultsField}.prediction_probability`) {
101+
return -1;
102+
}
103+
if (b === `${resultsField}.prediction_probability`) {
104+
return 1;
105+
}
106+
99107
return a.localeCompare(b);
100108
};
101109

@@ -107,7 +115,7 @@ export const sortRegressionResultsColumns = (
107115
) => (a: string, b: string) => {
108116
const dependentVariable = getDependentVar(jobConfig.analysis);
109117
const resultsField = jobConfig.dest.results_field;
110-
const predictedField = getPredictedFieldName(resultsField, jobConfig.analysis);
118+
const predictedField = getPredictedFieldName(resultsField, jobConfig.analysis, true);
111119

112120
const typeofA = typeof obj[a];
113121
const typeofB = typeof obj[b];
@@ -136,6 +144,14 @@ export const sortRegressionResultsColumns = (
136144
return 1;
137145
}
138146

147+
if (a === `${resultsField}.prediction_probability`) {
148+
return -1;
149+
}
150+
151+
if (b === `${resultsField}.prediction_probability`) {
152+
return 1;
153+
}
154+
139155
if (typeofA !== 'string' && typeofB === 'string') {
140156
return 1;
141157
}
@@ -184,6 +200,43 @@ export function getFlattenedFields(obj: EsDocSource, resultsField: string): EsFi
184200
return flatDocFields.filter(f => f !== ML__ID_COPY);
185201
}
186202

203+
export const getDefaultClassificationFields = (
204+
docs: EsDoc[],
205+
jobConfig: DataFrameAnalyticsConfig
206+
): EsFieldName[] => {
207+
if (docs.length === 0) {
208+
return [];
209+
}
210+
const resultsField = jobConfig.dest.results_field;
211+
const newDocFields = getFlattenedFields(docs[0]._source, resultsField);
212+
return newDocFields
213+
.filter(k => {
214+
if (k === `${resultsField}.is_training`) {
215+
return true;
216+
}
217+
// predicted value of dependent variable
218+
if (k === getPredictedFieldName(resultsField, jobConfig.analysis, true)) {
219+
return true;
220+
}
221+
// actual value of dependent variable
222+
if (k === getDependentVar(jobConfig.analysis)) {
223+
return true;
224+
}
225+
226+
if (k === `${resultsField}.prediction_probability`) {
227+
return true;
228+
}
229+
230+
if (k.split('.')[0] === resultsField) {
231+
return false;
232+
}
233+
234+
return docs.some(row => row._source[k] !== null);
235+
})
236+
.sort((a, b) => sortRegressionResultsFields(a, b, jobConfig))
237+
.slice(0, DEFAULT_REGRESSION_COLUMNS);
238+
};
239+
187240
export const getDefaultRegressionFields = (
188241
docs: EsDoc[],
189242
jobConfig: DataFrameAnalyticsConfig

0 commit comments

Comments
 (0)