From c6c92218e0d8b4a6644504e72e94b77ac48a5a46 Mon Sep 17 00:00:00 2001 From: An Guangyan Date: Tue, 29 Nov 2022 19:58:28 +0800 Subject: [PATCH] update k_confidence test --- .gitignore | 3 ++- main.py | 8 ++++---- problem_config/example.py | 6 +++--- utils.py | 12 ++++++------ 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index eb0780f..3e9f28a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ __pycache__/ *$py.class *.pkl temp*.py -nas/config.json \ No newline at end of file +nas/config.json +record.txt diff --git a/main.py b/main.py index 289c786..fbe715c 100644 --- a/main.py +++ b/main.py @@ -249,11 +249,11 @@ def parse_value(_val): def main_benchmark(problem_name: str): _seeds = 20 - _n_proc = 1 + _n_proc = 20 init_seed = 42 - _estimate_gram = 2.6 + _estimate_gram = 3.5 usage_check(_n_proc) - gpu_ids = [0] + gpu_ids = [0, 1, 2, 3, 4, 5, 6, 7] _res = benchmark_for_seeds(main, post_mean_std, seeds=_seeds, @@ -300,4 +300,4 @@ def fast_seed(seed: int) -> None: 'd7': 'DTLZ7b', }) # main(problems.d7, do_plot=True, print_progress=True, do_train=True) - main_benchmark(problems.d7) + main_benchmark(problems.d1) diff --git a/problem_config/example.py b/problem_config/example.py index 0cfa9cf..3a70340 100644 --- a/problem_config/example.py +++ b/problem_config/example.py @@ -13,9 +13,9 @@ def get_args(): args = NamedDict() args.problem_dim = (10, 3) args.train_test = (300, 1) - args.epoch = 50 - args.sgd_epoch = 50 - args.sgd_select_n = 50 + args.epoch = 10 + args.sgd_epoch = 10 + args.sgd_select_n = 100 args.update_lr = 0.0025 args.meta_lr = 0.001 args.fine_tune_lr = 0.005 diff --git a/utils.py b/utils.py index d7bea03..c9dcefb 100644 --- a/utils.py +++ b/utils.py @@ -32,10 +32,10 @@ def set_ipython_exception_hook(): sys.excepthook = _IPythonExceptionHook() -def test(): +def test(_n, _m): import numpy as np - tot = 900 - each = 50 + tot = _n + each = _m tot_arr = np.arange(tot, dtype=np.int32) sel = np.zeros_like(tot_arr, dtype=bool) i = 0 @@ -84,7 +84,7 @@ def draw_curve(n, m): import matplotlib.pyplot as plt k = list(range(10, 200)) y = [((n ** ki - (n - m) ** ki) / n ** ki) ** n for ki in k] - arr = [test() for _ in range(2000)] + arr = [test(n, m) for _ in range(2000)] plt.hist(arr, bins=50, density=True) div = np.diff(y) / np.diff(k) div = np.concatenate(([div[0]], div)) @@ -94,7 +94,7 @@ def draw_curve(n, m): if __name__ == '__main__': - _n, _m = 900, 50 - draw_curve(_n, _m) + _n, _m = 900, 100 + # draw_curve(_n, _m) v = calculate_confidence_k(_n, _m) print(v)