Skip to content

Commit

Permalink
update k_confidence test
Browse files Browse the repository at this point in the history
  • Loading branch information
ASSANDHOLE committed Dec 6, 2022
1 parent 11c501b commit c6c9221
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 14 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ __pycache__/
*$py.class
*.pkl
temp*.py
nas/config.json
nas/config.json
record.txt
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,11 @@ def parse_value(_val):

def main_benchmark(problem_name: str):
_seeds = 20
_n_proc = 1
_n_proc = 20
init_seed = 42
_estimate_gram = 2.6
_estimate_gram = 3.5
usage_check(_n_proc)
gpu_ids = [0]
gpu_ids = [0, 1, 2, 3, 4, 5, 6, 7]
_res = benchmark_for_seeds(main,
post_mean_std,
seeds=_seeds,
Expand Down Expand Up @@ -300,4 +300,4 @@ def fast_seed(seed: int) -> None:
'd7': 'DTLZ7b',
})
# main(problems.d7, do_plot=True, print_progress=True, do_train=True)
main_benchmark(problems.d7)
main_benchmark(problems.d1)
6 changes: 3 additions & 3 deletions problem_config/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ def get_args():
args = NamedDict()
args.problem_dim = (10, 3)
args.train_test = (300, 1)
args.epoch = 50
args.sgd_epoch = 50
args.sgd_select_n = 50
args.epoch = 10
args.sgd_epoch = 10
args.sgd_select_n = 100
args.update_lr = 0.0025
args.meta_lr = 0.001
args.fine_tune_lr = 0.005
Expand Down
12 changes: 6 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def set_ipython_exception_hook():
sys.excepthook = _IPythonExceptionHook()


def test():
def test(_n, _m):
import numpy as np
tot = 900
each = 50
tot = _n
each = _m
tot_arr = np.arange(tot, dtype=np.int32)
sel = np.zeros_like(tot_arr, dtype=bool)
i = 0
Expand Down Expand Up @@ -84,7 +84,7 @@ def draw_curve(n, m):
import matplotlib.pyplot as plt
k = list(range(10, 200))
y = [((n ** ki - (n - m) ** ki) / n ** ki) ** n for ki in k]
arr = [test() for _ in range(2000)]
arr = [test(n, m) for _ in range(2000)]
plt.hist(arr, bins=50, density=True)
div = np.diff(y) / np.diff(k)
div = np.concatenate(([div[0]], div))
Expand All @@ -94,7 +94,7 @@ def draw_curve(n, m):


if __name__ == '__main__':
_n, _m = 900, 50
draw_curve(_n, _m)
_n, _m = 900, 100
# draw_curve(_n, _m)
v = calculate_confidence_k(_n, _m)
print(v)

0 comments on commit c6c9221

Please sign in to comment.