Skip to content

Commit 1fd0082

Browse files
oops
1 parent 6f3967e commit 1fd0082

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

scripts/train.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def train(
4141
epoch_size = train_sampler.itersize(config['batch_size_s'] * config['backprop_every'], config['batch_size_h'])
4242
collator = Collator(pad_value=-1, device=device, allow_self_loops=config['allow_self_loops'])
4343

44-
model = Trainer(config['model_config']).to(device(device)).to(dtype)
44+
model = Trainer(config['model_config']).to(device).to(dtype)
4545
optimizer = AdamW(params=model.parameters(), lr=1, weight_decay=1e-02)
4646
schedule = make_schedule(warmup_steps=config['warmup_epochs'] * epoch_size,
4747
warmdown_steps=config['warmdown_epochs'] * epoch_size,
@@ -86,19 +86,17 @@ def parse_args():
8686
default='../data/model.pt')
8787
parser.add_argument('--log_path', type=str, help='Where to log results',
8888
default='../data/log.txt')
89-
parser.add_argument('--use_half', type=bool, help='Whether to use half-precision floats',
90-
default=False)
9189
return parser.parse_args()
9290

9391

9492
if __name__ == '__main__':
9593
args = parse_args()
96-
train_cfg = json.load(open(args.config_path, 'r'))
94+
train_cfg: TrainCfg = json.load(open(args.config_path, 'r'))
9795
train(
9896
config=train_cfg,
9997
data_path=args.data_path,
10098
store_path=args.store_path,
10199
log_path=args.log_path,
102100
device='cuda',
103-
dtype=torch.half if args.use_half else torch.float
101+
dtype=torch.half if train_cfg['half_precision'] else torch.float
104102
)

src/Name/nn/training.py

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class TrainCfg(TypedDict):
2929
dev_files: list[str]
3030
test_files: list[str]
3131
allow_self_loops: bool
32+
half_precision: bool
3233

3334

3435
@dataclass

0 commit comments

Comments
 (0)