Skip to content

Commit 233dfd4

Browse files
committed
ver 0.15.14 added tb
1 parent a94fc50 commit 233dfd4

File tree

6 files changed

+100
-2
lines changed

6 files changed

+100
-2
lines changed

npy/utils/multirank/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def is_matrix(X):
2626
def is_multirank(X):
2727
if is_singleton(X):
2828
return False
29-
return is_matrix(X)
29+
return is_matrix(X[0])
3030

3131

3232
#########

ntc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .modules import *
33
from .miscs import *
44
from . import ex
5+
from . import tb
56
from . import transforms
67
from . import image
78
from . import loss

ntc/tb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .core import set_name, get_name, off, on, set_root_path
2+
from .add import *

ntc/tb/add.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from .core import get_writer
2+
3+
__all__ = ['add_scalars', 'add_scalar']
4+
5+
6+
##############################################
7+
8+
9+
def add_scalars(d, step=None, name=None, prefix=''):
10+
for k, v in d.items():
11+
add_scalar(k, v, step=step, name=name, prefix=prefix)
12+
13+
14+
def add_scalar(key, value, step=None, name=None, prefix=''):
15+
writer = get_writer(name)
16+
17+
if prefix:
18+
key = f'{prefix}/{key}'
19+
20+
writer.add_scalar(key, value, global_step=step)

ntc/tb/core.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from npy import ddict
2+
from npy.ns import sayi
3+
import datetime
4+
import os
5+
6+
_DEFAULT_TB_NAME = ''
7+
_USE_TB = True
8+
_ROOT_PATH = './tblog'
9+
10+
__all__ = ['set_name', 'get_name', 'get_writer', 'off', 'on', 'set_root_path']
11+
12+
13+
def off():
14+
global _USE_TB
15+
_USE_TB = False
16+
17+
18+
def on():
19+
global _USE_TB
20+
_USE_TB = True
21+
22+
23+
def get_tb_use():
24+
global _USE_TB
25+
return _USE_TB
26+
27+
28+
def set_name(name: str):
29+
global _DEFAULT_TB_NAME
30+
_DEFAULT_TB_NAME = name
31+
32+
33+
def get_name() -> str:
34+
global _DEFAULT_TB_NAME
35+
return _DEFAULT_TB_NAME
36+
37+
38+
def get_root_path() -> str:
39+
global _ROOT_PATH
40+
return _ROOT_PATH
41+
42+
43+
def set_root_path(v: str):
44+
global _ROOT_PATH
45+
_ROOT_PATH = v
46+
47+
48+
def get_log_path(name: str):
49+
time_str = datetime.datetime.now().strftime('%d %b %H:%M:%S')
50+
if name == '':
51+
exp_name = f'{time_str}'
52+
else:
53+
exp_name = f'{name} ({time_str})'
54+
55+
sayi(f'Tensorboard exp name is {exp_name}')
56+
57+
log_path = os.path.join(get_root_path(), exp_name)
58+
return log_path
59+
60+
61+
def filewriter_factory(name: str):
62+
from torch.utils.tensorboard import SummaryWriter
63+
log_path = get_log_path(name)
64+
return SummaryWriter(log_path)
65+
66+
67+
writers = ddict(filewriter_factory)
68+
69+
70+
def get_writer(name=None):
71+
if not _USE_TB:
72+
return
73+
if name is None:
74+
name = get_name()
75+
return writers[name]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setup(name='nuclear-python',
1212

13-
version='0.15.12',
13+
version='0.15.14',
1414

1515
url='https://github.com/nuclearboy95/nuclear-python',
1616

0 commit comments

Comments
 (0)