@@ -35,6 +35,7 @@ interface ClassificationAnalysis {
3535 dependent_variable : string ;
3636 training_percent ?: number ;
3737 num_top_classes ?: string ;
38+ prediction_field_name ?: string ;
3839 } ;
3940}
4041
@@ -74,13 +75,33 @@ export interface RegressionEvaluateResponse {
7475 } ;
7576}
7677
78+ export interface PredictedClass {
79+ predicted_class : string ;
80+ count : number ;
81+ }
82+
83+ export interface ConfusionMatrix {
84+ actual_class : string ;
85+ actual_class_doc_count : number ;
86+ predicted_classes : PredictedClass [ ] ;
87+ other_predicted_class_doc_count : number ;
88+ }
89+
90+ export interface ClassificationEvaluateResponse {
91+ classification : {
92+ multiclass_confusion_matrix : {
93+ confusion_matrix : ConfusionMatrix [ ] ;
94+ } ;
95+ } ;
96+ }
97+
7798interface GenericAnalysis {
7899 [ key : string ] : Record < string , any > ;
79100}
80101
81102interface LoadEvaluateResult {
82103 success : boolean ;
83- eval : RegressionEvaluateResponse | null ;
104+ eval : RegressionEvaluateResponse | ClassificationEvaluateResponse | null ;
84105 error : string | null ;
85106}
86107
@@ -109,6 +130,7 @@ export const getAnalysisType = (analysis: AnalysisConfig) => {
109130
110131export const getDependentVar = ( analysis : AnalysisConfig ) => {
111132 let depVar = '' ;
133+
112134 if ( isRegressionAnalysis ( analysis ) ) {
113135 depVar = analysis . regression . dependent_variable ;
114136 }
@@ -124,17 +146,26 @@ export const getPredictionFieldName = (analysis: AnalysisConfig) => {
124146 let predictionFieldName ;
125147 if ( isRegressionAnalysis ( analysis ) && analysis . regression . prediction_field_name !== undefined ) {
126148 predictionFieldName = analysis . regression . prediction_field_name ;
149+ } else if (
150+ isClassificationAnalysis ( analysis ) &&
151+ analysis . classification . prediction_field_name !== undefined
152+ ) {
153+ predictionFieldName = analysis . classification . prediction_field_name ;
127154 }
128155 return predictionFieldName ;
129156} ;
130157
131- export const getPredictedFieldName = ( resultsField : string , analysis : AnalysisConfig ) => {
158+ export const getPredictedFieldName = (
159+ resultsField : string ,
160+ analysis : AnalysisConfig ,
161+ forSort ?: boolean
162+ ) => {
132163 // default is 'ml'
133164 const predictionFieldName = getPredictionFieldName ( analysis ) ;
134165 const defaultPredictionField = `${ getDependentVar ( analysis ) } _prediction` ;
135166 const predictedField = `${ resultsField } .${
136167 predictionFieldName ? predictionFieldName : defaultPredictionField
137- } `;
168+ } ${ isClassificationAnalysis ( analysis ) && ! forSort ? '.keyword' : '' } `;
138169 return predictedField ;
139170} ;
140171
@@ -153,13 +184,32 @@ export const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysi
153184 return keys . length === 1 && keys [ 0 ] === ANALYSIS_CONFIG_TYPE . CLASSIFICATION ;
154185} ;
155186
156- export const isRegressionResultsSearchBoolQuery = (
157- arg : any
158- ) : arg is RegressionResultsSearchBoolQuery => {
187+ export const isResultsSearchBoolQuery = ( arg : any ) : arg is ResultsSearchBoolQuery => {
159188 const keys = Object . keys ( arg ) ;
160189 return keys . length === 1 && keys [ 0 ] === 'bool' ;
161190} ;
162191
192+ export const isRegressionEvaluateResponse = ( arg : any ) : arg is RegressionEvaluateResponse => {
193+ const keys = Object . keys ( arg ) ;
194+ return (
195+ keys . length === 1 &&
196+ keys [ 0 ] === ANALYSIS_CONFIG_TYPE . REGRESSION &&
197+ arg ?. regression ?. mean_squared_error !== undefined &&
198+ arg ?. regression ?. r_squared !== undefined
199+ ) ;
200+ } ;
201+
202+ export const isClassificationEvaluateResponse = (
203+ arg : any
204+ ) : arg is ClassificationEvaluateResponse => {
205+ const keys = Object . keys ( arg ) ;
206+ return (
207+ keys . length === 1 &&
208+ keys [ 0 ] === ANALYSIS_CONFIG_TYPE . CLASSIFICATION &&
209+ arg ?. classification ?. multiclass_confusion_matrix !== undefined
210+ ) ;
211+ } ;
212+
163213export interface DataFrameAnalyticsConfig {
164214 id : DataFrameAnalyticsId ;
165215 // Description attribute is not supported yet
@@ -254,17 +304,14 @@ export function getValuesFromResponse(response: RegressionEvaluateResponse) {
254304
255305 return { meanSquaredError, rSquared } ;
256306}
257- interface RegressionResultsSearchBoolQuery {
307+ interface ResultsSearchBoolQuery {
258308 bool : Dictionary < any > ;
259309}
260- interface RegressionResultsSearchTermQuery {
310+ interface ResultsSearchTermQuery {
261311 term : Dictionary < any > ;
262312}
263313
264- export type RegressionResultsSearchQuery =
265- | RegressionResultsSearchBoolQuery
266- | RegressionResultsSearchTermQuery
267- | SavedSearchQuery ;
314+ export type ResultsSearchQuery = ResultsSearchBoolQuery | ResultsSearchTermQuery | SavedSearchQuery ;
268315
269316export function getEvalQueryBody ( {
270317 resultsField,
@@ -274,23 +321,44 @@ export function getEvalQueryBody({
274321} : {
275322 resultsField : string ;
276323 isTraining : boolean ;
277- searchQuery ?: RegressionResultsSearchQuery ;
324+ searchQuery ?: ResultsSearchQuery ;
278325 ignoreDefaultQuery ?: boolean ;
279326} ) {
280- let query : RegressionResultsSearchQuery = {
327+ let query : ResultsSearchQuery = {
281328 term : { [ `${ resultsField } .is_training` ] : { value : isTraining } } ,
282329 } ;
283330
284331 if ( searchQuery !== undefined && ignoreDefaultQuery === true ) {
285332 query = searchQuery ;
286- } else if ( searchQuery !== undefined && isRegressionResultsSearchBoolQuery ( searchQuery ) ) {
333+ } else if ( searchQuery !== undefined && isResultsSearchBoolQuery ( searchQuery ) ) {
287334 const searchQueryClone = cloneDeep ( searchQuery ) ;
288335 searchQueryClone . bool . must . push ( query ) ;
289336 query = searchQueryClone ;
290337 }
291338 return query ;
292339}
293340
341+ interface EvaluateMetrics {
342+ classification : {
343+ multiclass_confusion_matrix : object ;
344+ } ;
345+ regression : {
346+ r_squared : object ;
347+ mean_squared_error : object ;
348+ } ;
349+ }
350+
351+ interface LoadEvalDataConfig {
352+ isTraining : boolean ;
353+ index : string ;
354+ dependentVariable : string ;
355+ resultsField : string ;
356+ predictionFieldName ?: string ;
357+ searchQuery ?: ResultsSearchQuery ;
358+ ignoreDefaultQuery ?: boolean ;
359+ jobType : ANALYSIS_CONFIG_TYPE ;
360+ }
361+
294362export const loadEvalData = async ( {
295363 isTraining,
296364 index,
@@ -299,34 +367,38 @@ export const loadEvalData = async ({
299367 predictionFieldName,
300368 searchQuery,
301369 ignoreDefaultQuery,
302- } : {
303- isTraining : boolean ;
304- index : string ;
305- dependentVariable : string ;
306- resultsField : string ;
307- predictionFieldName ?: string ;
308- searchQuery ?: RegressionResultsSearchQuery ;
309- ignoreDefaultQuery ?: boolean ;
310- } ) => {
370+ jobType,
371+ } : LoadEvalDataConfig ) => {
311372 const results : LoadEvaluateResult = { success : false , eval : null , error : null } ;
312373 const defaultPredictionField = `${ dependentVariable } _prediction` ;
313- const predictedField = `${ resultsField } .${
374+ let predictedField = `${ resultsField } .${
314375 predictionFieldName ? predictionFieldName : defaultPredictionField
315376 } `;
316377
378+ if ( jobType === ANALYSIS_CONFIG_TYPE . CLASSIFICATION ) {
379+ predictedField = `${ predictedField } .keyword` ;
380+ }
381+
317382 const query = getEvalQueryBody ( { resultsField, isTraining, searchQuery, ignoreDefaultQuery } ) ;
318383
384+ const metrics : EvaluateMetrics = {
385+ classification : {
386+ multiclass_confusion_matrix : { } ,
387+ } ,
388+ regression : {
389+ r_squared : { } ,
390+ mean_squared_error : { } ,
391+ } ,
392+ } ;
393+
319394 const config = {
320395 index,
321396 query,
322397 evaluation : {
323- regression : {
398+ [ jobType ] : {
324399 actual_field : dependentVariable ,
325400 predicted_field : predictedField ,
326- metrics : {
327- r_squared : { } ,
328- mean_squared_error : { } ,
329- } ,
401+ metrics : metrics [ jobType as keyof EvaluateMetrics ] ,
330402 } ,
331403 } ,
332404 } ;
@@ -341,3 +413,57 @@ export const loadEvalData = async ({
341413 return results ;
342414 }
343415} ;
416+
417+ interface TrackTotalHitsSearchResponse {
418+ hits : {
419+ total : {
420+ value : number ;
421+ relation : string ;
422+ } ;
423+ hits : any [ ] ;
424+ } ;
425+ }
426+
427+ interface LoadDocsCountConfig {
428+ ignoreDefaultQuery ?: boolean ;
429+ isTraining : boolean ;
430+ searchQuery : SavedSearchQuery ;
431+ resultsField : string ;
432+ destIndex : string ;
433+ }
434+
435+ interface LoadDocsCountResponse {
436+ docsCount : number | null ;
437+ success : boolean ;
438+ }
439+
440+ export const loadDocsCount = async ( {
441+ ignoreDefaultQuery = true ,
442+ isTraining,
443+ searchQuery,
444+ resultsField,
445+ destIndex,
446+ } : LoadDocsCountConfig ) : Promise < LoadDocsCountResponse > => {
447+ const query = getEvalQueryBody ( { resultsField, isTraining, ignoreDefaultQuery, searchQuery } ) ;
448+
449+ try {
450+ const body : SearchQuery = {
451+ track_total_hits : true ,
452+ query,
453+ } ;
454+
455+ const resp : TrackTotalHitsSearchResponse = await ml . esSearch ( {
456+ index : destIndex ,
457+ size : 0 ,
458+ body,
459+ } ) ;
460+
461+ const docsCount = resp . hits . total && resp . hits . total . value ;
462+ return { docsCount, success : docsCount !== undefined } ;
463+ } catch ( e ) {
464+ return {
465+ docsCount : null ,
466+ success : false ,
467+ } ;
468+ }
469+ } ;
0 commit comments