Skip to content

Commit

Permalink
update readme, args
Browse files Browse the repository at this point in the history
  • Loading branch information
vardaan123 committed Oct 22, 2023
1 parent b1f2312 commit 317491b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ python -u dump_preproc_data.py --dataset-path data/FB15K-237/ \
```
python -u main.py --dataset-path data/FB15K-237/ --cuda \
--save-dir ckpts/CKPT_DIR/ --sampling-type minerva \
--embed-dim 320 --n-attn-heads 8 --n-bert-layers 3 \
--lr 1e-2 --warmup 0.1 --batch-size 512 \
--n-epochs 300 --patience 20 \
--seed 12548 > ckpts/CKPT_DIR/log.txt 2>&1
Expand All @@ -69,7 +68,6 @@ python -u main.py --dataset-path data/FB15K-237/ --cuda \
```
python -u main.py --dataset-path data/WN18RR/ --cuda \
--save-dir ckpts/CKPT_DIR/ --sampling-type minerva \
--embed-dim 320 --n-attn-heads 8 --n-bert-layers 3 \
--lr 0.00175 --label-smoothing 0.1 --warmup 0.1 \
--batch-size 256 --n-epochs 500 \
--patience 100 --beam-size 40 --add-segment-embed --add-inverse-rels \
Expand All @@ -83,8 +81,8 @@ python -u main.py --dataset-path data/WN18RR/ --cuda \
python eval.py --dataset-path <DATA_PATH> --cuda \
--ckpt-path ckpts/CKPT_DIR/model.pt \
--split <valid/test> --sampling-type minerva \
--graph-connection type_1 --embed-dim 320 --n-attn-heads 8 \
--n-bert-layers 3 [--beam-size <>] [--add-segment-embed] [--add-inverse-rels]
--graph-connection type_1 \
[--beam-size <>] [--add-segment-embed] [--add-inverse-rels]
```


Expand Down
7 changes: 4 additions & 3 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def main():
parser.add_argument('--model-type', type=str, choices=['self-attn', 'cross-attn'], default='cross-attn')
parser.add_argument('--ckpt-path', type=str, required=True)
parser.add_argument('--sample-size', type=int, default=20, help='sample size in terms of no. of edges')
parser.add_argument('--embed-dim', type=int, default=768, help='embedding dim.')
parser.add_argument('--embed-dim', type=int, default=320, help='embedding dim.')
parser.add_argument('--n-attn-heads', type=int, default=8)
parser.add_argument('--n-bert-layers', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--beta', type=float, default=0.999, help='beta_2 of adam')
parser.add_argument('--cuda', action='store_true')
Expand All @@ -43,8 +45,7 @@ def main():
parser.add_argument('--attention-probs-dropout-prob', type=float, default=0.1)

# Bert model args
parser.add_argument('--n-attn-heads', type=int, default=2)
parser.add_argument('--n-bert-layers', type=int, default=2)


args = parser.parse_args()

Expand Down
8 changes: 3 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def main():
parser.add_argument('--save-dir', type=str, required=True)
parser.add_argument('--ckpt-path', type=str, default=None)
parser.add_argument('--sample-size', type=int, default=20, help='sample size in terms of no. of edges')
parser.add_argument('--embed-dim', type=int, default=768, help='embedding dim.')
parser.add_argument('--embed-dim', type=int, default=320, help='embedding dim.')
parser.add_argument('--n-attn-heads', type=int, default=8)
parser.add_argument('--n-bert-layers', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--warmup', type=float, default=0.1, help='percentage of steps for warmup')
parser.add_argument('--beta', type=float, default=0.999, help='beta_2 of adam')
Expand Down Expand Up @@ -79,10 +81,6 @@ def main():
parser.add_argument('--shuffle-batches', action='store_true')
parser.add_argument('--num-workers', type=int, default=8, help='no. of workers for dumping data in db')

# Bert model args
parser.add_argument('--n-attn-heads', type=int, default=8)
parser.add_argument('--n-bert-layers', type=int, default=3)

args = parser.parse_args()

print(args)
Expand Down

0 comments on commit 317491b

Please sign in to comment.