Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

show rank charts in compare datasets node #6584

Merged
merged 4 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import { isEmpty, cloneDeep } from 'lodash';
import { mean, stddev } from '@/utils/stats';
import { isEmpty } from 'lodash';
import { Dataset, InterventionPolicy, ModelConfiguration } from '@/types/Types';
import { WorkflowNode, WorkflowPortStatus } from '@/types/workflow';
import { renameFnGenerator } from '@/components/workflow/ops/calibrate-ciemss/calibrate-utils';
import { Ref } from 'vue';
import { createRankingInterventionsChart } from '@/services/charts';
import { DATASET_VAR_NAME_PREFIX, getDatasetResultCSV, mergeResults, getDataset } from '@/services/dataset';
import {
DataArray,
Expand All @@ -14,8 +12,8 @@ import {
} from '@/services/models/simulation-service';
import { getInterventionPolicyById } from '@/services/intervention-policy';
import { getModelConfigurationById } from '@/services/model-configurations';
import { ChartData, getInterventionColorAndScoreMaps } from '@/composables/useCharts';
import { PlotValue, TimepointOption, RankOption, CompareDatasetsState } from './compare-datasets-operation';
import { ChartData } from '@/composables/useCharts';
import { PlotValue, CompareDatasetsState } from './compare-datasets-operation';

interface DataResults {
results: DataArray[];
Expand Down Expand Up @@ -184,141 +182,6 @@ export function buildChartData(
};
}

export function generateRankingCharts(
rankingCriteriaCharts,
rankingResultsChart,
node: WorkflowNode<CompareDatasetsState>,
chartData,
datasets: Ref<Dataset[]>,
modelConfigurations,
interventionPolicies
) {
// Reset charts
rankingCriteriaCharts.value = [];
rankingResultsChart.value = null;

const allRankedCriteriaValues: { score: number; policyName: string; configName: string }[][] = [];

const { interventionNameColorMap, interventionNameScoresMap } = getInterventionColorAndScoreMaps(
datasets,
modelConfigurations,
interventionPolicies
);

node.state.criteriaOfInterestCards.forEach((card) => {
if (!chartData.value || !card.selectedVariable) return;

const variableKey = `${chartData.value.pyciemssMap[card.selectedVariable]}_mean`;
let pointOfComparison: Record<string, number> = {};

if (card.timepoint === TimepointOption.OVERALL) {
const resultSummary = cloneDeep(chartData.value.resultSummary); // Must clone to avoid modifying the original data

// Note that the reduce function here only compares the variable of interest
// so only those key/value pairs will be relevant in the pointOfComparison object.
// Other keys like timepoint_id (that we aren't using) will be in pointOfComparison
// but they won't coincide with the value of the variable of interest.
pointOfComparison = resultSummary.reduce((acc, val) =>
Object.keys(val).reduce((acc2, key) => {
if (key.includes(variableKey)) {
acc2[key] = Math.max(acc[key], val[key]);
}
return acc2;
}, acc)
);
} else if (card.timepoint === TimepointOption.FIRST) {
pointOfComparison = chartData.value.resultSummary[0];
} else if (card.timepoint === TimepointOption.LAST) {
pointOfComparison = chartData.value.resultSummary[chartData.value.resultSummary.length - 1];
}

const rankingCriteriaValues: { score: number; policyName: string; configName: string }[] = [];

datasets.value.forEach((dataset, index: number) => {
const { metadata } = dataset;
const modelConfiguration: ModelConfiguration = modelConfigurations.value.find(
({ id }) => id === metadata.simulationAttributes?.modelConfigurationId
);
const policy: InterventionPolicy = interventionPolicies.value.find(
({ id }) => id === metadata.simulationAttributes?.interventionPolicyId
);

const policyName = policy?.name ?? 'no policy';

if (!modelConfiguration?.name) {
return;
}

rankingCriteriaValues.push({
score: pointOfComparison[`${variableKey}:${index}`] ?? 0,
policyName,
configName: modelConfiguration.name
});
});

const sortedRankingCriteriaValues =
card.rank === RankOption.MAXIMUM
? rankingCriteriaValues.sort((a, b) => b.score - a.score)
: rankingCriteriaValues.sort((a, b) => a.score - b.score);

rankingCriteriaCharts.value.push(
createRankingInterventionsChart(
sortedRankingCriteriaValues,
interventionNameColorMap,
card.name,
card.selectedVariable
)
);
allRankedCriteriaValues.push(sortedRankingCriteriaValues);
});

// For each criteria
allRankedCriteriaValues.forEach((criteriaValues, index) => {
const rankMutliplier = node.state.criteriaOfInterestCards[index].rank === RankOption.MINIMUM ? -1 : 1;

// Calculate mean and stdev for this criteria
const values = criteriaValues.map((val) => val.score);
const meanValue = mean(values);
let stdevValue = stddev(values);
if (stdevValue === 0) stdevValue = 1;

// For each policy
Object.keys(interventionNameScoresMap).forEach((policyName) => {
// For each value of the criteria
criteriaValues.forEach((criteriaValue) => {
if (criteriaValue.policyName !== policyName) return; // Skip criteria values that don't belong to this policy
const scoredPolicyCriteria = rankMutliplier * ((criteriaValue.score - meanValue) / stdevValue);
interventionNameScoresMap[policyName].push(scoredPolicyCriteria);
});
});
});

const scoredPolicies = Object.keys(interventionNameScoresMap)
.map((policyName) => ({
score: mean(interventionNameScoresMap[policyName]),
policyName,
configName: ''
})) // Sort from highest to lowest value
.sort((a, b) => b.score - a.score);

rankingResultsChart.value = createRankingInterventionsChart(scoredPolicies, interventionNameColorMap);
}

export async function generateImpactCharts(
chartData,
datasets: Ref<Dataset[]>,
datasetResults,
baselineDatasetIndex,
selectedPlotType
) {
chartData.value = buildChartData(
datasets.value,
datasetResults.value,
baselineDatasetIndex.value,
selectedPlotType.value
);
}

// TODO: this should probably be split up into smaller functions but for now it's at least not duplicated in the node and drilldown
// TODO: Please type the function params in this file for a later pass
export async function initialize(
Expand All @@ -333,15 +196,14 @@ export async function initialize(
} | null>,
modelConfigIdToInterventionPolicyIdMap: Ref<Record<string, string[]>>,
impactChartData: Ref<ChartData | null>,
rankingChartData,
baselineDatasetIndex,
selectedPlotType,
rankingChartData: Ref<ChartData | null>,
baselineDatasetIndex: Ref<number>,
selectedPlotType: Ref<PlotValue>,
modelConfigurations: Ref<ModelConfiguration[]>,
interventionPolicies: Ref<InterventionPolicy[]>,
rankingCriteriaCharts: Ref<any>,
rankingResultsChart: Ref<any>
interventionPolicies: Ref<InterventionPolicy[]>
) {
const { inputs } = node;
datasets.value = [];
const datasetInputs = inputs.filter(
(input) => input.type === 'datasetId' && input.status === WorkflowPortStatus.CONNECTED
);
Expand Down Expand Up @@ -385,7 +247,12 @@ export async function initialize(
datasetResults.value = await fetchDatasetResults(datasets.value);
isFetchingDatasets.value = false;

await generateImpactCharts(impactChartData, datasets, datasetResults, baselineDatasetIndex, selectedPlotType);
impactChartData.value = buildChartData(
datasets.value,
datasetResults.value,
baselineDatasetIndex.value,
selectedPlotType.value
);

const modelConfigurationIds = Object.keys(modelConfigIdToInterventionPolicyIdMap.value);
if (isEmpty(modelConfigurationIds)) return;
Expand All @@ -409,14 +276,4 @@ export async function initialize(
baselineDatasetIndex.value,
PlotValue.VALUE
);

generateRankingCharts(
rankingCriteriaCharts,
rankingResultsChart,
node,
rankingChartData,
datasets,
modelConfigurations,
interventionPolicies
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
>
<template #content>
<tera-drilldown-section class="px-3">
<Button class="ml-auto" size="small" label="Run" @click="onRun" />
<Button
v-if="knobs.selectedCompareOption !== CompareValue.RANK"
class="ml-auto"
size="small"
label="Run"
@click="onRun"
/>
<label>What do you want to compare?</label>
<Dropdown
v-model="knobs.selectedCompareOption"
Expand Down Expand Up @@ -54,7 +60,7 @@
<div class="flex flex-column gap-2" v-else-if="knobs.selectedCompareOption === CompareValue.RANK">
<label>Specify criteria of interest:</label>
<tera-criteria-of-interest-card
v-for="(card, i) in node.state.criteriaOfInterestCards"
v-for="(card, i) in knobs.criteriaOfInterestCards"
:key="i"
:card="card"
:variables="variableNames"
Expand Down Expand Up @@ -206,11 +212,16 @@
the scenario outcome meets the criteria of interest more. The dark line (zero) is the mean outcome
across all scenarios.
</p>
<vega-chart :visualization-spec="rankingResultsChart" :are-embed-actions-visible="false" expandable />
<vega-chart
v-if="rankingCharts.rankingResultsChart"
:visualization-spec="rankingCharts.rankingResultsChart"
:are-embed-actions-visible="false"
expandable
/>
</AccordionTab>
<AccordionTab header="Ranking criteria">
<vega-chart
v-for="(rankingCriteriaChart, index) in rankingCriteriaCharts"
v-for="(rankingCriteriaChart, index) in rankingCharts.rankingCriteriaCharts"
:visualization-spec="rankingCriteriaChart"
:are-embed-actions-visible="false"
:key="index"
Expand Down Expand Up @@ -380,7 +391,7 @@ import { WorkflowNode } from '@/types/workflow';
import TeraSliderPanel from '@/components/widgets/tera-slider-panel.vue';
import TeraDrilldownSection from '@/components/drilldown/tera-drilldown-section.vue';
import { DrilldownTabs, ChartSettingType } from '@/types/common';
import { onMounted, ref, watch, computed } from 'vue';
import { onMounted, ref, computed, onBeforeUnmount } from 'vue';
import Button from 'primevue/button';
import Accordion from 'primevue/accordion';
import AccordionTab from 'primevue/accordiontab';
Expand Down Expand Up @@ -412,7 +423,7 @@ import {
PlotValue,
type CompareDatasetsMap
} from './compare-datasets-operation';
import { generateRankingCharts, generateImpactCharts, initialize } from './compare-datasets-utils';
import { initialize, buildChartData } from './compare-datasets-utils';

const props = defineProps<{
node: WorkflowNode<CompareDatasetsState>;
Expand Down Expand Up @@ -463,23 +474,18 @@ const isFetchingDatasets = ref(false);
const areSimulationsFromSameModel = ref(true);

const onRun = () => {
if (knobs.value.selectedCompareOption === CompareValue.RANK) {
generateRankingCharts(
rankingCriteriaCharts,
rankingResultsChart,
props.node,
rankingChartData,
datasets,
modelConfigurations,
interventionPolicies
);
} else if (knobs.value.selectedCompareOption === CompareValue.SCENARIO) {
if (knobs.value.selectedCompareOption === CompareValue.SCENARIO) {
constructATETable();
}
};

function onChangeImpactComparison() {
generateImpactCharts(impactChartData, datasets, datasetResults, baselineDatasetIndex, selectedPlotType);
impactChartData.value = buildChartData(
datasets.value,
datasetResults.value,
baselineDatasetIndex.value,
selectedPlotType.value
);
}

interface BasicKnobs {
Expand Down Expand Up @@ -533,21 +539,21 @@ const chartSize = useDrilldownChartSize(outputPanel);

const impactChartData = ref<ChartData | null>(null);
const rankingChartData = ref<ChartData | null>(null);
const rankingResultsChart = ref<any>(null);
const rankingCriteriaCharts = ref<any>([]);

const chartData = computed(() => {
if (knobs.value.selectedCompareOption === CompareValue.RANK) {
return rankingChartData.value;
}
return impactChartData.value;
});

const variableNames = ref<string[]>([]);
const mappingOptions = ref<Record<string, string[]>>({});

const { generateAnnotation, getChartAnnotationsByChartId, useCompareDatasetCharts } = useCharts(
props.node.id,
null,
null,
impactChartData,
chartSize,
null,
null
);
const criteriaOfInterestCards = computed(() => knobs.value.criteriaOfInterestCards);

const { generateAnnotation, getChartAnnotationsByChartId, useCompareDatasetCharts, useInterventionRankingCharts } =
useCharts(props.node.id, null, null, chartData, chartSize, null, null);
const selectedPlotType = computed(() => knobs.value.selectedPlotType);
const baselineDatasetIndex = computed(() =>
datasets.value.findIndex((dataset) => dataset.id === knobs.value.selectedBaselineDatasetId)
Expand All @@ -560,6 +566,14 @@ const variableCharts = useCompareDatasetCharts(
modelConfigurations,
interventionPolicies
);

const rankingCharts = useInterventionRankingCharts(
criteriaOfInterestCards,
datasets,
modelConfigurations,
interventionPolicies
);

const groundTruthDatasetIndex = computed(() =>
datasets.value.findIndex((dataset) => dataset.id === knobs.value.selectedGroundTruthDatasetId)
);
Expand Down Expand Up @@ -772,9 +786,7 @@ onMounted(async () => {
baselineDatasetIndex,
selectedPlotType,
modelConfigurations,
interventionPolicies,
rankingCriteriaCharts,
rankingResultsChart
interventionPolicies
);

// Prepare variable dropdowns
Expand Down Expand Up @@ -817,20 +829,18 @@ onMounted(async () => {
constructWisTable();
});

watch(
() => knobs.value,
() => {
const state = cloneDeep(props.node.state);
state.criteriaOfInterestCards = knobs.value.criteriaOfInterestCards;
state.selectedCompareOption = knobs.value.selectedCompareOption;
state.selectedBaselineDatasetId = knobs.value.selectedBaselineDatasetId;
state.selectedGroundTruthDatasetId = knobs.value.selectedGroundTruthDatasetId;
state.selectedPlotType = knobs.value.selectedPlotType;
state.mapping = knobs.value.mapping;
emit('update-state', state);
},
{ deep: true }
);
onBeforeUnmount(async () => {
// flush changes
const clonedState = cloneDeep(props.node.state);
clonedState.criteriaOfInterestCards = knobs.value.criteriaOfInterestCards;
clonedState.selectedCompareOption = knobs.value.selectedCompareOption;
clonedState.selectedBaselineDatasetId = knobs.value.selectedBaselineDatasetId;
clonedState.selectedGroundTruthDatasetId = knobs.value.selectedGroundTruthDatasetId;
clonedState.selectedPlotType = knobs.value.selectedPlotType;
clonedState.mapping = knobs.value.mapping;

emit('update-state', clonedState);
});
</script>

<style scoped>
Expand Down
Loading
Loading