Skip to content

Commit 139d8bd

Browse files
committed
ver 0.15.4
miscs
1 parent 968a758 commit 139d8bd

File tree

4 files changed

+154
-18
lines changed

4 files changed

+154
-18
lines changed
Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,42 @@
11
import re
22

3-
__all__ = ['str2float', 'str2int', 'str2int_l', 'countchar', 'rmchar', 'rmchars', 'rmws']
3+
__all__ = ['str2float', 'str2int', 'int_or_str2int', 'str2int_l',
4+
'countchar', 'rmchar', 'rmchars', 'rmws', 'include_exclude']
45

56

6-
def str2float(v: str) -> float:
7-
v = v.strip()
8-
if v.startswith('(') and v.endswith(')'):
9-
return -str2int(v[1:-1])
7+
def str2float(s: str) -> float:
8+
s = s.strip()
9+
if s.startswith('(') and s.endswith(')'):
10+
return -str2int(s[1:-1])
1011

11-
v = ''.join(re.findall(r'[\d.\-]+', v))
12-
if not v:
12+
s = ''.join(re.findall(r'[\d.\-]+', s))
13+
if not s:
1314
return 0.
14-
elif v == '-':
15+
elif s == '-':
1516
return 0
1617
else:
17-
return float(v)
18+
return float(s)
1819

1920

20-
def str2int(v: str) -> int:
21-
v = v.strip()
22-
if v.startswith('(') and v.endswith(')'):
23-
return -str2int(v[1:-1])
21+
def str2int(s: str) -> int:
22+
s = s.strip()
23+
if s.startswith('(') and s.endswith(')'):
24+
return -str2int(s[1:-1])
2425

25-
v = ''.join(re.findall(r'[\d\-]+', v))
26-
if not v:
26+
s = ''.join(re.findall(r'[\d\-]+', s))
27+
if not s:
2728
return 0
28-
elif v == '-':
29+
elif s == '-':
2930
return 0
3031
else:
31-
return int(v)
32+
return int(s)
33+
34+
35+
def int_or_str2int(v) -> int:
36+
if isinstance(v, str):
37+
return str2int(v)
38+
else:
39+
return v
3240

3341

3442
def str2int_l(l: list) -> list:
@@ -51,3 +59,15 @@ def rmchars(s: str, cs: list) -> str:
5159

5260
def rmws(s: str) -> str:
5361
return rmchars(s, ['\n', ' ', '\t'])
62+
63+
64+
def include_exclude(s: str, includes: list = None, excludes: list = None) -> bool:
65+
if includes is not None:
66+
if any([v not in s for v in includes]):
67+
return False
68+
69+
if excludes is not None:
70+
if any([v in s for v in excludes]):
71+
return False
72+
73+
return True

ntc/ex/mnist_acc.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from ntc.ns import *
2+
3+
4+
__all__ = ['main']
5+
6+
7+
xent = nn.CrossEntropyLoss()
8+
9+
10+
class Net(nn.Module):
11+
def __init__(self):
12+
super(Net, self).__init__()
13+
self.conv1 = nn.Conv2d(1, 32, 5, 2)
14+
self.conv2 = nn.Conv2d(32, 32, 3, 1)
15+
self.conv3 = nn.Conv2d(32, 32, 3, 1)
16+
self.fc1 = nn.Linear(2048, 128)
17+
self.fc2 = nn.Linear(128, 10)
18+
19+
def forward(self, x):
20+
x = self.conv1(x)
21+
x = F.leaky_relu(x, 0.1)
22+
23+
x = self.conv2(x)
24+
x = F.leaky_relu(x, 0.1)
25+
26+
x = self.conv3(x)
27+
x = F.leaky_relu(x, 0.1)
28+
29+
x = torch.flatten(x, 1)
30+
31+
x = self.fc1(x)
32+
x = F.leaky_relu(x, 0.1)
33+
34+
x = self.fc2(x)
35+
return x
36+
37+
38+
def forward(model, xs, ys, device):
39+
xs, ys = xs.to(device), ys.to(device)
40+
logits = model(xs)
41+
loss = xent(logits, ys)
42+
43+
pred = logits.argmax(dim=1, keepdim=True)
44+
correct = pred.eq(ys.view_as(pred)).sum().item()
45+
return loss, correct
46+
47+
48+
def train(model, device, train_loader, opt, epoch):
49+
met = AverageMeters()
50+
model.train()
51+
acc_count = 0
52+
ACC_NUM = 16
53+
opt.zero_grad()
54+
55+
for i, (xs, ys) in enumerate(train_loader):
56+
57+
loss, correct = forward(model, xs, ys, device)
58+
loss.backward()
59+
acc_count += 1
60+
61+
if acc_count % ACC_NUM == 0:
62+
acc_count = 0
63+
opt.step()
64+
opt.zero_grad()
65+
66+
met.update('loss', loss.item())
67+
met.update('acc', correct, ys.size(0))
68+
else:
69+
if acc_count != 0:
70+
opt.step()
71+
72+
sayi(f'Epoch {epoch:03d}] Train Loss: {met.avg("loss"):.3f}, Acc: {met.avg("acc"):.3f}')
73+
74+
75+
def test(model, device, test_loader):
76+
model.eval()
77+
met = AverageMeters()
78+
with torch.no_grad():
79+
for xs, ys in test_loader:
80+
loss, correct = forward(model, xs, ys, device)
81+
82+
met.update('loss', loss.item())
83+
met.update('acc', correct, ys.size(0))
84+
85+
sayi(f'Test Loss: {met.avg("loss"):.3f}, Acc: {met.avg("acc"):.3f}')
86+
87+
88+
def get_dataloader():
89+
import torchvision
90+
from torchvision.transforms import ToTensor
91+
d_train = torchvision.datasets.MNIST('./mnist/', train=True, transform=ToTensor(), download=True)
92+
d_test = torchvision.datasets.MNIST('./mnist/', train=False, transform=ToTensor(), download=True)
93+
94+
l_train = DataLoader(d_train, batch_size=4, shuffle=True, pin_memory=True)
95+
l_test = DataLoader(d_test, batch_size=256, shuffle=False, pin_memory=True)
96+
97+
return l_train, l_test
98+
99+
100+
def main():
101+
device = torch.device("cuda")
102+
model = Net().to(device)
103+
opt = optim.Adam(model.parameters(), lr=1e-4)
104+
105+
train_loader, test_loader = get_dataloader()
106+
107+
sayi('Start training')
108+
for epoch in range(1, 100 + 1):
109+
train(model, device, train_loader, opt, epoch)
110+
test(model, device, test_loader)
111+
112+
113+
if __name__ == '__main__':
114+
set_cuda(0)
115+
main()

ntc/miscs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import numpy as np
23

34
__all__ = ['abstain_loss', 'calc_correct', 'to_device', 'to_numpy']
45

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setup(name='nuclear-python',
1212

13-
version='0.15.3.1',
13+
version='0.15.4',
1414

1515
url='https://github.com/nuclearboy95/nuclear-python',
1616

0 commit comments

Comments
 (0)