Skip to content

Commit 0e03ec1

Browse files
committed
convert sdmetrics quant metrics to classic key:value pair
1 parent 8a0adbe commit 0e03ec1

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

netshare/pre_post_processors/netshare/choose_best_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55
import numpy as np
66

7-
from .util import create_sdmetrics_config
7+
from .util import create_sdmetrics_config, convert_sdmetricsConfigQuant_to_fieldValueDict
88
from sdmetrics.reports.timeseries import QualityReport
99

1010

@@ -18,7 +18,9 @@ def compare_rawdf_syndfs(
1818
comparison_type='quantitative')
1919
report = QualityReport(config_dict=sdmetrics_config['config'])
2020
report.generate(raw_df, syn_dfs[0], sdmetrics_config['metadata'])
21-
print("\n\n\n", report.dict_metric_scores)
21+
metricValueDict = convert_sdmetricsConfigQuant_to_fieldValueDict(
22+
report.dict_metric_scores)
23+
print(metricValueDict)
2224

2325
()+1
2426

netshare/pre_post_processors/netshare/util.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pickle
44
import math
55
import json
6+
import ast
67
import socket
78
import struct
89
import ipaddress
@@ -25,6 +26,26 @@
2526
from ...model_managers.netshare_manager.netshare_util import get_configid_from_kv
2627

2728

29+
def convert_sdmetricsConfigQuant_to_fieldValueDict(
30+
sdmetricsConfigQuant
31+
):
32+
'''Convert the sdmetricsConfigQuant to fieldValueDict
33+
Args:
34+
sdmetricsConfigQuant (dict): returned by create_sdmetrics_config(..., comparison_type='quantitative')
35+
Returns:
36+
fieldValueDict (dict): {field_name: value}
37+
'''
38+
39+
fieldValueDict = {}
40+
for metric_type, metrics in sdmetricsConfigQuant.items():
41+
for metric_class_name, metric_class in metrics.items():
42+
for field_name, field_value in metric_class.items():
43+
fieldValueDict[ast.literal_eval(
44+
field_name)[0]] = field_value[0][0]
45+
46+
return fieldValueDict
47+
48+
2849
def create_sdmetrics_config(
2950
config_pre_post_processor,
3051
comparison_type='both'

0 commit comments

Comments
 (0)