Skip to content

Commit dcbf94a

Browse files
committed
Ready for publication.
1 parent 89dee2e commit dcbf94a

File tree

9 files changed

+130
-145
lines changed

9 files changed

+130
-145
lines changed

README.md

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,6 @@
22

33
Code for _DeepXML: A Deep Extreme Multi-Label Learning Framework Applied to Short Text Documents_
44

5-
## Requirements
6-
7-
---
8-
9-
* Pyxclib
10-
* NumPy
11-
* PyTorch
12-
* Numba
13-
* Scikit-learn
14-
155
---
166

177
## Architectures and algorithms
@@ -32,54 +22,65 @@ DeepXML supports multiple feature architectures such as Bag-of-embedding/Astec,
3222

3323
---
3424

35-
## Example use cases
25+
## Setting up
3626

3727
---
3828

39-
### A single learner with DeepXML framework
29+
### Expected directory structure
4030

41-
The DeepXML framework can be utilized as follows. A json file is used to specify architecture and other arguments.
31+
```txt
32+
+-- <work_dir>
33+
| +-- programs
34+
| | +-- deepxml
35+
| | +-- deepxml
36+
| +-- data
37+
| +-- <dataset>
38+
| +-- models
39+
| +-- results
4240
43-
```bash
44-
./run_main.sh 0 DeepXML EURLex-4K 0 108
4541
```
4642

47-
### An ensemble of multiple learners with DeepXML framework
43+
### Download data for Astec
4844

49-
An ensemble can be trained as follows. A json file is used to specify architecture and other arguments.
45+
```txt
46+
* Download the (zipped file) BoW features from XML repository.
47+
* Extract the zipped file into data directory.
48+
* The following files should be available in <work_dir>/data/<dataset>
49+
- train.txt
50+
- test.txt
51+
- fasttextB_embeddings_300d.npy or fasttextB_embeddings_512d.npy
52+
```
5053

51-
```bash
52-
./run_main.sh 0 DeepXML EURLex-4K 0 108,666,786
54+
### Convert to new data format
55+
56+
```perl
57+
# A perl script is provided (in deepxml/tools) to convert the data into new format as expected by Astec
58+
# Either set the $data_dir variable to the data directory of a particular dataset or replace it with the path
59+
perl convert_format.pl $data_dir/train.txt $data_dir/trn_X_Xf.txt $data_dir/trn_X_Y.txt
60+
perl convert_format.pl $data_dir/test.txt $data_dir/tst_X_Xf.txt $data_dir/tst_X_Y.txt
5361
```
5462

55-
## Full documentation
63+
## Example use cases
5664

5765
---
5866

59-
### Expected directory structure
67+
### A single learner with DeepXML framework
6068

61-
```txt
62-
+-- work_dir
63-
| +-- programs
64-
| | +-- deepxml
65-
| | +-- deepxml
66-
| +-- data_dir
67-
| +-- dataset
68-
| +-- model_dir
69-
| +-- results_dir
69+
The DeepXML framework can be utilized as follows. A json file is used to specify architecture and other arguments. Please refer to the full documentation below for more details.
7070

71+
```bash
72+
./run_main.sh 0 DeepXML EURLex-4K 0 108
7173
```
7274

73-
### Convert the data to new format
75+
### An ensemble of multiple learners with DeepXML framework
7476

75-
```perl
76-
# A perl script is provided (deepxml/tools) to convert the data into new format as expected by DeepXML
77-
perl convert_format.pl <data_dir>/train.txt <data_dir>/trn_X_Xf.txt <data_dir>/trn_X_Y.txt
77+
An ensemble can be trained as follows. A json file is used to specify architecture and other arguments.
7878

79-
perl convert_format.pl <data_dir>/test.txt <data_dir>/tst_X_Xf.txt <data_dir>/tst_X_Y.txt
79+
```bash
80+
./run_main.sh 0 DeepXML EURLex-4K 0 108,666,786
8081
```
8182

82-
### Run details
83+
## Full Documentation
8384

8485
```txt
8586
./run_main.sh <gpu_id> <framework> <dataset> <version> <seed>
@@ -93,12 +94,13 @@ perl convert_format.pl <data_dir>/test.txt <data_dir>/tst_X_Xf.txt <data_dir>/ts
9394
9495
* dataset
9596
- Name of the dataset.
96-
- Expected files in work_dir/data/<dataset>
97+
- Astec expects the following files in <work_dir>/data/<dataset>
9798
- trn_X_Xf.txt
9899
- trn_X_Y.txt
99100
- tst_X_Xf.txt
100101
- tst_X_Y.txt
101-
- fasttextB_embeddings_300d.npy or fasttextB_embeddings_512d.npy
102+
- fasttextB_embeddings_300d.npy or fasttextB_embeddings_512d.npy
103+
- You can set the 'embedding_dims' in config file to switch between 300d and 512d embeddings.
102104
103105
* version
104106
- different runs could be managed by version and seed.

deepxml/configs/DeepXML/AmazonTitles-2.5M.json

Lines changed: 0 additions & 85 deletions
This file was deleted.

deepxml/configs/DeepXML/AmazonTitles-3M.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"embedding_dims": 300,
1313
"beta": 0.10,
1414
"top_k": 300,
15+
"save_top_k": 100,
1516
"save_predictions": true,
1617
"trn_label_fname": "trn_X_Y.txt",
1718
"val_label_fname": "tst_X_Y.txt",
@@ -36,7 +37,7 @@
3637
"extreme": {
3738
"num_epochs": 15,
3839
"dlr_factor": 0.5,
39-
"learning_rate": 0.002,
40+
"learning_rate": 0.0005,
4041
"batch_size": 255,
4142
"dlr_step": 14,
4243
"ns_method": "ensemble",

deepxml/configs/DeepXML/AmazonTitles-670K.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"dlr_factor": 0.5,
2626
"learning_rate": 0.02,
2727
"batch_size": 255,
28-
"dlr_step": 14,
28+
"dlr_step": 10,
2929
"normalize": true,
3030
"optim": "Adam",
3131
"init": "token_embeddings",

deepxml/run_scripts/Identity.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
2+
"representation_dims": "#ARGS.embedding_dims;",
23
"transform_coarse": {
34
"order": ["_identity"],
45
"_identity": {}

deepxml/runner.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def create_surrogate_mapping(data_dir, g_config, seed):
3636
return data_stats, mapping
3737

3838

39-
def evaluate(g_config, data_dir, pred_fname, betas=-1, n_learners=1):
39+
def evaluate(g_config, data_dir, pred_fname, filter_fname=None, betas=-1, n_learners=1):
4040
if n_learners == 1:
4141
func = evalaute_one.main
4242
else:
@@ -46,14 +46,20 @@ def evaluate(g_config, data_dir, pred_fname, betas=-1, n_learners=1):
4646
data_dir = os.path.join(data_dir, dataset)
4747
A = g_config['A']
4848
B = g_config['B']
49+
if 'save_top_k' in g_config:
50+
top_k = g_config['save_top_k']
51+
else:
52+
top_k = g_config['top_k']
4953
ans = func(
5054
tst_label_fname=os.path.join(
5155
data_dir, g_config["tst_label_fname"]),
5256
trn_label_fname=os.path.join(
5357
data_dir, g_config["trn_label_fname"]),
5458
pred_fname=pred_fname,
5559
A=A,
56-
B=B,
60+
B=B,
61+
top_k=top_k,
62+
filter_fname=filter_fname,
5763
betas=betas,
5864
save=g_config["save_predictions"])
5965
return ans
@@ -86,6 +92,11 @@ def run_deepxml(work_dir, version, seed, config):
8692

8793
# Directory and filenames
8894
data_dir = os.path.join(work_dir, 'data')
95+
96+
filter_fname = os.path.join(data_dir, dataset, 'filter_labels_test.txt')
97+
if not os.path.isfile(filter_fname):
98+
filter_fname = None
99+
89100
result_dir = os.path.join(
90101
work_dir, 'results', 'DeepXML', arch, dataset, f'v_{version}')
91102
model_dir = os.path.join(
@@ -158,6 +169,7 @@ def run_deepxml(work_dir, version, seed, config):
158169
g_config=g_config,
159170
data_dir=data_dir,
160171
pred_fname=pred_fname,
172+
filter_fname=filter_fname,
161173
betas=[0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.75, 0.90])
162174
f_rstats = os.path.join(result_dir, 'log_eval.txt')
163175
with open(f_rstats, "w") as fp:
@@ -220,6 +232,7 @@ def run_deepxml(work_dir, version, seed, config):
220232
ans = evaluate(
221233
g_config=g_config,
222234
data_dir=data_dir,
235+
filter_fname=filter_fname,
223236
pred_fname=pred_fname,
224237
betas=[0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.75, 0.90])
225238
with open(f_rstats, 'a') as fp:
@@ -241,6 +254,9 @@ def run_deepxml_ova(work_dir, version, seed, config):
241254
work_dir, 'results', 'DeepXML-OVA', arch, dataset, f'v_{version}')
242255
model_dir = os.path.join(
243256
work_dir, 'models', 'DeepXML-OVA', arch, dataset, f'v_{version}')
257+
filter_fname = os.path.join(data_dir, dataset, 'filter_labels_test.txt')
258+
if not os.path.isfile(filter_fname):
259+
filter_fname = None
244260

245261
_args = parameters.Parameters("Parameters")
246262
_args.parse_args()
@@ -279,6 +295,7 @@ def run_deepxml_ova(work_dir, version, seed, config):
279295
pred_fname = os.path.join(result_dir, 'tst_predictions')
280296
ans = evaluate(
281297
g_config=g_config,
298+
filter_fname=filter_fname,
282299
data_dir=data_dir,
283300
pred_fname=pred_fname)
284301
f_rstats = os.path.join(result_dir, 'log_eval.txt')
@@ -300,6 +317,9 @@ def run_deepxml_ann(work_dir, version, seed, config):
300317
work_dir, 'results', 'DeepXML-ANNS', arch, dataset, f'v_{version}')
301318
model_dir = os.path.join(
302319
work_dir, 'models', 'DeepXML-ANNS', arch, dataset, f'v_{version}')
320+
filter_fname = os.path.join(data_dir, dataset, 'filter_labels_test.txt')
321+
if not os.path.isfile(filter_fname):
322+
filter_fname = None
303323
_args = parameters.Parameters("Parameters")
304324
_args.parse_args()
305325
_args.update(config['global'])
@@ -339,6 +359,7 @@ def run_deepxml_ann(work_dir, version, seed, config):
339359
pred_fname = os.path.join(result_dir, 'tst_predictions')
340360
ans = evaluate(
341361
g_config=g_config,
362+
filter_fname=filter_fname,
342363
data_dir=data_dir,
343364
pred_fname=pred_fname,
344365
betas=[0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.75, 0.90])

0 commit comments

Comments
 (0)