Skip to content

Commit

Permalink
Support group lasso for pruning (intel#28)
Browse files Browse the repository at this point in the history
* Support group lasso for pruning
* Fixed pylint error
* Fixed dataset select in autodistillation
* Fixed import mixtrue error
* Fixed wandb login error
* fixed examples error
  • Loading branch information
PenghuiCheng authored Apr 8, 2022
1 parent 157d22e commit e22cf8c
Show file tree
Hide file tree
Showing 37 changed files with 4,364 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
# check_min_version("4.10.0")
os.environ["WANDB_DISABLED"] = "true"

logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.12.0")

os.environ["WANDB_DISABLED"] = "true"

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.12.0")

os.environ["WANDB_DISABLED"] = "true"

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.12.0")

os.environ["WANDB_DISABLED"] = "true"

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@

logger = logging.getLogger(__name__)

os.environ["WANDB_DISABLED"] = "true"


@dataclass
class ModelArguments:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@

logger = logging.getLogger(__name__)

os.environ["WANDB_DISABLED"] = "true"


@dataclass
class ModelArguments:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@

logger = logging.getLogger(__name__)

os.environ["WANDB_DISABLED"] = "true"


@dataclass
class ModelArguments:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
Step-by-Step
============

This document is used to list steps of reproducing PyTorch BERT pruning result.

# Prerequisite

### 1. Installation

#### Python First

Recommend python 3.7 or higher version.

#### Install [nlp-toolkit]()
```
pip install nlp-toolkit
```

#### Install PyTorch

Install pytorch-gpu, visit [pytorch.org](https://pytorch.org/).
```bash
# Install pytorch
pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
```

#### Install BERT dependency

```bash
cd examples/pytorch/eager/language_translation/BERT_sparse
pip3 install -r requirements.txt --ignore-installed PyYAML
```
```bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
> **Note**
>
> If no CUDA runtime is found, please export CUDA_HOME='/usr/local/cuda'.
### 2. Prepare Dataset

* For SQuAD task, you should download SQuAD dataset from [SQuAD dataset link](https://rajpurkar.github.io/SQuAD-explorer/).
### 3. Prepare pretrained model
* Please download BERT large pretrained model from [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/models/bert_pyt_ckpt_large_pretraining_amp_lamb/files?version=20.03.0).
```bash
# wget cmd
wget https://api.ngc.nvidia.com/v2/models/nvidia/bert_pyt_ckpt_large_pretraining_amp_lamb/versions/20.03.0/files/bert_large_pretrained_amp.pt

# curl cmd
curl -LO https://api.ngc.nvidia.com/v2/models/nvidia/bert_pyt_ckpt_large_pretraining_amp_lamb/versions/20.03.0/files/bert_large_pretrained_amp.pt
```
# Run
Enter your created conda env, then run the script.
```bash
bash scripts/run_squad_sparse.sh /path/to/model.pt 2.0 16 5e-5 tf32 /path/to/data /path/to/outdir prune_bert.yaml
```
The default parameters are as follows:
```shell
init_checkpoint=${1:-"/path/to/ckpt_8601.pt"}
epochs=${2:-"2.0"}
batch_size=${3:-"4"}
learning_rate=${4:-"3e-5"}
precision=${5:-"tf32"}
BERT_PREP_WORKING_DIR=${6:-'/path/to/bert_data'}
OUT_DIR=${7:-"./results/SQuAD"}
prune_config=${8:-"prune_bert.yaml"}
```
# Original BERT README

Please refer [BERT README](https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/README.md)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"max_position_embeddings": 512,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"type_vocab_size": 2,
"vocab_size": 30522
}
Loading

0 comments on commit e22cf8c

Please sign in to comment.