|
4 | 4 | import pandas as pd
|
5 | 5 | import torch
|
6 | 6 |
|
7 |
| -from fastai.tabular import tabular_data_from_df, get_tabular_learner |
8 | 7 | from fastai.core import partition
|
9 |
| -from fastai.torch_core import tensor |
| 8 | +from fastai.layer_optimizer import LayerOptimizer |
10 | 9 |
|
11 | 10 | class TestFastAI(unittest.TestCase):
|
12 | 11 | def test_partition(self):
|
13 | 12 | result = partition([1,2,3,4,5], 2)
|
14 | 13 |
|
15 | 14 | self.assertEqual(3, len(result))
|
16 | 15 |
|
17 |
| - def test_has_version(self): |
18 |
| - self.assertGreater(len(fastai.__version__), 1) |
19 |
| - |
20 |
| - # based on https://github.com/fastai/fastai/blob/master/tests/test_torch_core.py#L17 |
21 |
| - def test_torch_tensor(self): |
22 |
| - a = tensor([1, 2, 3]) |
23 |
| - b = torch.tensor([1, 2, 3]) |
| 16 | + # based on https://github.com/fastai/fastai/blob/0.7.0/tests/test_layer_optimizer.py |
| 17 | + def test_layer_optimizer(self): |
| 18 | + lo = LayerOptimizer(FakeOpt, fastai_params_('A', 'B', 'C'), 1e-2, 1e-4) |
| 19 | + fast_check_optimizer_(lo.opt, [(nm, 1e-2, 1e-4) for nm in 'ABC']) |
24 | 20 |
|
25 |
| - self.assertTrue(torch.all(a == b)) |
26 | 21 |
|
27 |
| - def test_tabular(self): |
28 |
| - df = pd.read_csv("/input/tests/data/train.csv") |
| 22 | +class Par(object): |
| 23 | + def __init__(self, x, grad=True): |
| 24 | + self.x = x |
| 25 | + self.requires_grad = grad |
| 26 | + def parameters(self): return [self] |
29 | 27 |
|
30 |
| - train_df, valid_df = df[:-5].copy(),df[-5:].copy() |
31 |
| - dep_var = "label" |
32 |
| - cont_names = [] |
33 |
| - for i in range(784): |
34 |
| - cont_names.append("pixel" + str(i)) |
35 | 28 |
|
36 |
| - data = tabular_data_from_df("", train_df, valid_df, dep_var, cont_names=cont_names, cat_names=[]) |
37 |
| - learn = get_tabular_learner(data, layers=[200, 100]) |
38 |
| - learn.fit(epochs=1) |
| 29 | +class FakeOpt(object): |
| 30 | + def __init__(self, params): self.param_groups = params |
| 31 | + |
| 32 | + |
| 33 | +def fastai_params_(*names): return [Par(nm) for nm in names] |
| 34 | + |
| 35 | +def fast_check_optimizer_(opt, expected): |
| 36 | + actual = opt.param_groups |
| 37 | + assert len(actual) == len(expected) |
| 38 | + for (a, e) in zip(actual, expected): fastai_check_param_(a, *e) |
| 39 | + |
| 40 | +def fastai_check_param_(par, nm, lr, wd): |
| 41 | + assert par['params'][0].x == nm |
| 42 | + assert par['lr'] == lr |
| 43 | + assert par['weight_decay'] == wd |
0 commit comments