Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT2 #205

Merged
merged 13 commits into from
Sep 15, 2023
Prev Previous commit
Next Next commit
refine README and python files
  • Loading branch information
KungYork committed Aug 21, 2023
commit df0b1fede37a0b306a81482089c9b4521cfdf354
12 changes: 12 additions & 0 deletions distilbert_finetune.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件有用吗?无用的话请删除此文件

Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits

predicted_class_id = logits.argmax().item()
model.config.id2label[predicted_class_id]
9 changes: 7 additions & 2 deletions training/benchmarks/gpt2/pytorch/config/_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Required parameters

vendor: str = None
shh2000 marked this conversation as resolved.
Show resolved Hide resolved
data_dir: str = None
name: str = "GPT2"
cudnn_benchmark: bool = False
cudnn_deterministic: bool = True

use_env: bool = True
log_freq: int = 1
dist_backend: str = None
device: str = None

# =========================================================
Expand All @@ -24,7 +29,6 @@
# data
# =========================================================

data_dir: str = None
train_data_prefix: str = "lambada_train_text_document"
test_data_prefix: str = "lambada_test.json"
init_checkpoint: str = "model_optim_rng.pt"
Expand Down Expand Up @@ -113,6 +117,7 @@
# distributed parallel
# =========================================================

dist_backend: str = None
DDP_impl: str = "native"
gradient_accumulation_fusion: bool = False
use_contiguous_buffers_in_local_ddp: bool = False
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def build_data_loader(dataset, train_batch_size, num_workers, drop_last,

return data_loader


class MegatronPretrainingSampler:

def __init__(self, total_samples, consumed_samples, train_batch_size,
Expand Down Expand Up @@ -112,4 +111,3 @@ def __iter__(self):
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from dataloaders.indexed_dataset import make_dataset as make_indexed_dataset
from dataloaders.data_samplers import build_pretraining_data_loader, build_data_loader
from dataloaders.dataloader import build_pretraining_data_loader, build_data_loader
from dataloaders import get_tokenizer

import config
Expand Down
10 changes: 7 additions & 3 deletions training/benchmarks/gpt2/pytorch/model/losses/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@


def cross_entropy(outputs, target):
#para: outputs, [b, s, vocab_size]
# target, [b, s]
#return: loss, [b, s]
"""
Compute the cross entropy loss of output and target.

para: outputs, [b, s, vocab_size]
target, [b, s]
return: loss, [b, s]
"""

logits = outputs.clone()
# logits = outputs
Expand Down
9 changes: 4 additions & 5 deletions training/benchmarks/gpt2/pytorch/run_pretraining.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""GPT2 Pretraining"""
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""GPT2 Pretraining"""

import argparse
import os
Expand All @@ -29,7 +29,6 @@

def main():
import config
from config import mutable_params
global logger

if config.use_env and 'LOCAL_RANK' in os.environ:
Expand Down
4 changes: 4 additions & 0 deletions training/benchmarks/gpt2/pytorch/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")

import math
import time
import os
Expand Down
4 changes: 4 additions & 0 deletions training/benchmarks/gpt2/pytorch/train/trainer_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")

import os
from typing import Tuple

Expand Down
4 changes: 4 additions & 0 deletions training/benchmarks/gpt2/pytorch/train/training_state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")

from dataclasses import dataclass

import torch
Expand Down
4 changes: 2 additions & 2 deletions training/nvidia/gpt2-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
| 指标名称 | 指标值 | 特殊说明 |
| -------------- | ----------------------- | ------------------------------------------- |
| 任务类别 | 自然语言编码 | |
| 模型 | bert-large-uncased | |
| 模型 | megatron-gpt2-345m | |
| 数据集 | Wikipedia | |
| 数据精度 | precision,见“性能指标” | 可选fp32/amp/fp16 |
| 超参修改 | fix_hp,见“性能指标” | 跑满硬件设备评测吞吐量所需特殊超参 |
Expand All @@ -38,4 +38,4 @@

| 配置 | precision | fix_hp | e2e_time | p_whole | p_train | p_core | lambada_acc | mem |
| ------------------- | --------- | ---------------- | -------- | ------- | ------- | ------ | ------- | --------- |
| A100单机8卡(1x8) | fp32 | bs=32,lr=0.00015 | | 2.30 | 88.36 | 89.57 | | 33.7/40.0 |
| A100单机8卡(1x8) | fp32 | bs=32,lr=0.00015 | 853.75 | 2.30 | 88.36 | 89.57 | 0.7001 | 33.7/40.0 |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p_whole如此低,是否考虑略微减少一些eval的频次,并增加一些train_steps,保证总时间相对不变的情况下多训练些步数

2 changes: 0 additions & 2 deletions training/nvidia/gpt2-pytorch/config/config_A100x1x8.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
from config_common import *

dist_backend = "nccl"
3 changes: 2 additions & 1 deletion training/nvidia/gpt2-pytorch/config/config_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
vendor = 'kunlunxin'
vendor = 'nvidia'

# disable fp16
fp16 = False

dist_backend = "nccl"