Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vision Model Overview E2E Tests - Refactor #2185

Merged
merged 12 commits into from
Jul 24, 2023
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import {
describeModelOverview,
modelAssessmentDatasets
} from "@responsible-ai/e2e";
const datasetShape =
modelAssessmentDatasets.FridgeImageClassificationModelDebugging;
describeModelOverview(datasetShape, "FridgeImageClassificationModelDebugging");
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import {
describeModelOverview,
modelAssessmentDatasets
} from "@responsible-ai/e2e";
const datasetShape = modelAssessmentDatasets.FridgeMultilabelModelDebugging;
describeModelOverview(datasetShape, "FridgeMultilabelModelDebugging");
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import {
describeModelOverview,
modelAssessmentDatasets
} from "@responsible-ai/e2e";
const datasetShape =
modelAssessmentDatasets.FridgeObjectDetectionModelDebugging;
describeModelOverview(datasetShape, "FridgeObjectDetectionModelDebugging");
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ export const FridgeImageClassificationModelDebugging = {
name: "All data",
sampleSize: "134"
}
]
],
newCohort: {
metrics: {
accuracy: "0.9",
macroF1: "0.9",
macroPrecision: "0.9",
macroRecall: "0.9",
microF1: "0.9",
microPrecision: "0.9",
microRecall: "0.9"
},
name: "CohortCreateE2E-image-classification",
sampleSize: "5"
}
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ export const FridgeMultilabelModelDebugging = {
name: "All data",
sampleSize: "10"
}
]
],
newCohort: {
metrics: {
exactMatchRatio: "1",
hammingScore: "1"
},
name: "CohortCreateE2E-multilabel",
sampleSize: "3"
}
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ export const FridgeObjectDetectionModelDebugging = {
name: "All data",
sampleSize: "5"
}
]
],
newCohort: {
metrics: {
averagePrecision: "1",
averageRecall: "1",
meanAveragePrecision: "1"
},
name: "CohortCreateE2E-object-detection",
sampleSize: "2"
}
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,22 @@ export function describeModelOverview(
isNotebookTest = true
): void {
describe(testName, () => {
const isVision =
Advitya17 marked this conversation as resolved.
Show resolved Hide resolved
datasetShape.isObjectDetection ||
datasetShape.isMultiLabel ||
datasetShape.isImageClassification
? true
: false;
if (isNotebookTest) {
before(() => {
visit(name);
});
} else {
before(() => {
cy.visit(`#/modelAssessment/${name}/light/english/Version-2`);
const dashboardName = isVision
? "modelAssessmentVision"
: "modelAssessment";
cy.visit(`#/${dashboardName}/${name}/light/english/Version-2`);
});
}

Expand All @@ -38,7 +47,8 @@ export function describeModelOverview(
ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
datasetShape,
false,
isNotebookTest
isNotebookTest,
isVision
);
});

Expand All @@ -57,7 +67,8 @@ export function describeModelOverview(
);
ensureAllModelOverviewFeatureCohortsViewElementsAfterSelectionArePresent(
datasetShape,
1
1,
isVision
);
});

Expand All @@ -69,16 +80,19 @@ export function describeModelOverview(
);
ensureAllModelOverviewFeatureCohortsViewElementsAfterSelectionArePresent(
datasetShape,
2
2,
isVision
);
});

it("should show new cohorts in charts", () => {
ensureNewCohortsShowUpInCharts(datasetShape, isNotebookTest);
ensureNewCohortsShowUpInCharts(datasetShape, isNotebookTest, isVision);
});

it("should pivot between charts when clicking", () => {
ensureChartsPivot(datasetShape, isNotebookTest, true);
if (!isVision) {
ensureChartsPivot(datasetShape, isNotebookTest, true);
}
});
} else {
it("should not have 'Model overview' component", () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ import { getNumberOfCohorts } from "./numberOfCohorts";
export function ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
datasetShape: IModelAssessmentData,
includeNewCohort: boolean,
isNotebookTest: boolean
isNotebookTest: boolean,
isVision: boolean
): void {
const data = datasetShape.modelOverviewData;
const initialCohorts = data?.initialCohorts;
Expand All @@ -23,7 +24,10 @@ export function ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
"not.exist"
);
if (isNotebookTest) {
if (getNumberOfCohorts(datasetShape, includeNewCohort) <= 1) {
if (
getNumberOfCohorts(datasetShape, includeNewCohort) <= 1 ||
datasetShape.isObjectDetection
) {
cy.get(Locators.ModelOverviewHeatmapVisualDisplayToggle).should(
"not.exist"
);
Expand All @@ -45,6 +49,24 @@ export function ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
"meanSquaredError",
"meanPrediction"
);
} else if (datasetShape.isImageClassification) {
metricsOrder.push(
"accuracy",
"f1Score",
"precisionScore",
"recallScore",
"falsePositiveRate",
"falseNegativeRate",
"selectionRate"
);
} else if (datasetShape.isMultiLabel) {
metricsOrder.push("exactMatchRatio", "hammingScore");
} else if (datasetShape.isObjectDetection) {
metricsOrder.push(
"meanAveragePrecision",
"averagePrecision",
"averageRecall"
);
} else {
metricsOrder.push("accuracy");
if (!datasetShape.isMulticlass) {
Expand All @@ -69,35 +91,39 @@ export function ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
});
});

if (isNotebookTest) {
cy.get(Locators.ModelOverviewHeatmapCells)
.should("have.length", (cohorts?.length || 0) * (metricsOrder.length + 1))
.each(($cell) => {
// somehow the cell string is one invisible character longer, trim
expect($cell.text().slice(0, $cell.text().length - 1)).to.be.oneOf(
heatmapCellContents
);
});
}

cy.get(
Locators.ModelOverviewDisaggregatedAnalysisBaseCohortDisclaimer
).should("not.exist");
cy.get(Locators.ModelOverviewDisaggregatedAnalysisBaseCohortWarning).should(
"not.exist"
);

const defaultVisibleChart = getDefaultVisibleChart(
datasetShape.isRegression,
datasetShape.isBinary
);
assertChartVisibility(datasetShape, defaultVisibleChart);

if (defaultVisibleChart === Locators.ModelOverviewMetricChart) {
ensureNotebookModelOverviewMetricChartIsCorrect(
isNotebookTest,
datasetShape,
includeNewCohort
if (!isVision) {
if (isNotebookTest) {
cy.get(Locators.ModelOverviewHeatmapCells)
.should(
"have.length",
(cohorts?.length || 0) * (metricsOrder.length + 1)
)
.each(($cell) => {
// somehow the cell string is one invisible character longer, trim
expect($cell.text().slice(0, $cell.text().length - 1)).to.be.oneOf(
heatmapCellContents
);
});
}
const defaultVisibleChart = getDefaultVisibleChart(
datasetShape.isRegression,
datasetShape.isBinary
);
assertChartVisibility(datasetShape, defaultVisibleChart);

if (defaultVisibleChart === Locators.ModelOverviewMetricChart) {
ensureNotebookModelOverviewMetricChartIsCorrect(
isNotebookTest,
datasetShape,
includeNewCohort
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,31 @@ import {

export function ensureAllModelOverviewFeatureCohortsViewElementsAfterSelectionArePresent(
datasetShape: IModelAssessmentData,
selectedFeatures: number
selectedFeatures: number,
isVision: boolean
): void {
cy.get(Locators.ModelOverviewFeatureSelection).should("exist");
cy.get(Locators.ModelOverviewFeatureConfigurationActionButton).should(
"exist"
);
cy.get(Locators.ModelOverviewHeatmapVisualDisplayToggle).should("exist");
cy.get(Locators.ModelOverviewDatasetCohortStatsTable).should("not.exist");
cy.get(Locators.ModelOverviewDisaggregatedAnalysisTable).should("exist");

const defaultVisibleChart = getDefaultVisibleChart(
datasetShape.isRegression,
datasetShape.isBinary
);
assertChartVisibility(datasetShape, defaultVisibleChart);
if (!isVision) {
cy.get(Locators.ModelOverviewHeatmapVisualDisplayToggle).should("exist"); // TODO: check!
cy.get(Locators.ModelOverviewDisaggregatedAnalysisTable).should("exist");

assertNumberOfChartRowsEqual(
datasetShape,
selectedFeatures,
defaultVisibleChart
);
const defaultVisibleChart = getDefaultVisibleChart(
datasetShape.isRegression,
datasetShape.isBinary
);
assertChartVisibility(datasetShape, defaultVisibleChart);

assertNumberOfChartRowsEqual(
datasetShape,
selectedFeatures,
defaultVisibleChart
);
}
}

function assertNumberOfChartRowsEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@ import { ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent } from

export function ensureNewCohortsShowUpInCharts(
datasetShape: IModelAssessmentData,
isNotebookTest: boolean
isNotebookTest: boolean,
isVision: boolean
): void {
cy.get(Locators.ModelOverviewCohortViewDatasetCohortViewButton).click();
ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
datasetShape,
false,
isNotebookTest
isNotebookTest,
isVision
);
createCohort(datasetShape.modelOverviewData?.newCohort?.name);
ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
datasetShape,
true,
isNotebookTest
isNotebookTest,
isVision
);
}
Loading