Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resnet npu #412

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Vision/classification/image/resnet50/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def parse_args(ignore_unknown_args=False):
parser = argparse.ArgumentParser(
description="OneFlow ResNet50 Arguments", allow_abbrev=False
)
parser.add_argument("--device", type=str, default="cuda", help="device: cpu, cuda...")
parser.add_argument(
"--save",
type=str,
Expand Down
46 changes: 46 additions & 0 deletions Vision/classification/image/resnet50/examples/npu_eager.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# set -aux

export PYTHONUNBUFFERED=1
echo PYTHONUNBUFFERED=$PYTHONUNBUFFERED

CHECKPOINT_SAVE_PATH="./graph_checkpoints"
if [ ! -d "$CHECKPOINT_SAVE_PATH" ]; then
mkdir $CHECKPOINT_SAVE_PATH
fi

#OFRECORD_PATH="./mini-imagenet/ofrecord"
OFRECORD_PATH="/data0/datasets/ImageNet/ofrecord/"

if [ ! -d "$OFRECORD_PATH" ]; then
wget https://oneflow-public.oss-cn-beijing.aliyuncs.com/online_document/dataset/imagenet/mini-imagenet.zip
unzip mini-imagenet.zip
fi

OFRECORD_PART_NUM=1
LEARNING_RATE=0.256
MOM=0.875
EPOCH=90
TRAIN_BATCH_SIZE=50
VAL_BATCH_SIZE=50

# SRC_DIR=/path/to/models/resnet50
SRC_DIR=$(realpath $(dirname $0)/..)

python3 $SRC_DIR/train.py \
--ofrecord-path $OFRECORD_PATH \
--ofrecord-part-num $OFRECORD_PART_NUM \
--num-devices-per-node 1 \
--lr $LEARNING_RATE \
--momentum $MOM \
--num-epochs $EPOCH \
--warmup-epochs 5 \
--train-batch-size $TRAIN_BATCH_SIZE \
--val-batch-size $VAL_BATCH_SIZE \
--save $CHECKPOINT_SAVE_PATH \
--scale-grad \
--print-interval 1 \
--load checkpoints/init \
--device npu
#--use-gpu-decode \
#--samples-per-epoch 50 \
#--val-samples-per-epoch 50 \
47 changes: 47 additions & 0 deletions Vision/classification/image/resnet50/examples/npu_graph.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# set -aux

export PYTHONUNBUFFERED=1
echo PYTHONUNBUFFERED=$PYTHONUNBUFFERED

CHECKPOINT_SAVE_PATH="./graph_checkpoints"
if [ ! -d "$CHECKPOINT_SAVE_PATH" ]; then
mkdir $CHECKPOINT_SAVE_PATH
fi

#OFRECORD_PATH="./mini-imagenet/ofrecord"
OFRECORD_PATH="/data0/datasets/ImageNet/ofrecord/"

if [ ! -d "$OFRECORD_PATH" ]; then
wget https://oneflow-public.oss-cn-beijing.aliyuncs.com/online_document/dataset/imagenet/mini-imagenet.zip
unzip mini-imagenet.zip
fi

OFRECORD_PART_NUM=1
LEARNING_RATE=0.256
MOM=0.875
EPOCH=90
TRAIN_BATCH_SIZE=50
VAL_BATCH_SIZE=50

# SRC_DIR=/path/to/models/resnet50
SRC_DIR=$(realpath $(dirname $0)/..)

python3 $SRC_DIR/train.py \
--ofrecord-path $OFRECORD_PATH \
--ofrecord-part-num $OFRECORD_PART_NUM \
--num-devices-per-node 1 \
--lr $LEARNING_RATE \
--momentum $MOM \
--num-epochs $EPOCH \
--warmup-epochs 5 \
--train-batch-size $TRAIN_BATCH_SIZE \
--val-batch-size $VAL_BATCH_SIZE \
--save $CHECKPOINT_SAVE_PATH \
--scale-grad \
--print-interval 1 \
--load checkpoints/init \
--graph \
--device npu
#--use-gpu-decode \
#--samples-per-epoch 50 \
#--val-samples-per-epoch 50 \
10 changes: 6 additions & 4 deletions Vision/classification/image/resnet50/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def __init__(
self.cross_entropy = cross_entropy
self.data_loader = data_loader
self.add_optimizer(optimizer, lr_sch=lr_scheduler)
self.device = args.device

def build(self):
image, label = self.data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
logits = self.model(image)
loss = self.cross_entropy(logits, label)
if self.return_pred_and_label:
Expand All @@ -79,11 +80,12 @@ def __init__(self, model, data_loader):

self.data_loader = data_loader
self.model = model
self.device = args.device

def build(self):
image, label = self.data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
logits = self.model(image)
pred = logits.softmax()
return pred, label
4 changes: 2 additions & 2 deletions Vision/classification/image/resnet50/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def main(args):
print("***** Model Init *****")
model = resnet50()
model.load_state_dict(flow.load(args.model_path))
model = model.to("cuda")
model = model.to(args.device)
model.eval()
end_t = time.perf_counter()
print(f"***** Model Init Finish, time escapled {end_t - start_t:.6f} s *****")
Expand All @@ -65,7 +65,7 @@ def main(args):

start_t = end_t
image = load_image(args.image_path)
image = flow.Tensor(image, device=flow.device("cuda"))
image = flow.Tensor(image, device=flow.device(args.device))
if args.graph:
pred = model_graph(image)
else:
Expand Down
17 changes: 12 additions & 5 deletions Vision/classification/image/resnet50/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def make_data_loader(args, mode, is_global=False, synthetic=False):
placement=placement,
sbp=sbp,
channel_last=args.channel_last,
device=args.device,
)
return data_loader.to("cuda")
return data_loader.to(args.device)

ofrecord_data_loader = OFRecordDataLoader(
ofrecord_dir=args.ofrecord_path,
Expand All @@ -45,6 +46,8 @@ def make_data_loader(args, mode, is_global=False, synthetic=False):
placement=placement,
sbp=sbp,
use_gpu_decode=args.use_gpu_decode,
device="cpu",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时用cpu解码

#device=args.device,
)
return ofrecord_data_loader

Expand All @@ -62,6 +65,7 @@ def __init__(
placement=None,
sbp=None,
use_gpu_decode=False,
device="cuda",
):
super().__init__()

Expand All @@ -71,6 +75,7 @@ def __init__(
self.total_batch_size = total_batch_size
self.dataset_size = dataset_size
self.mode = mode
self.device = device

random_shuffle = True if mode == "train" else False
shuffle_after_epoch = True if mode == "train" else False
Expand Down Expand Up @@ -159,11 +164,12 @@ def forward(self):
else:
image_raw_bytes = self.image_decoder(record)
image = self.resize(image_raw_bytes)[0]
image = image.to("cuda")

label = self.label_decoder(record)
flip_code = self.flip()
flip_code = flip_code.to("cuda")
if self.use_gpu_decode:
# todo NPU: image will down grade to cpu
flip_code = flip_code.to(self.device)
image = self.crop_mirror_norm(image, flip_code)
else:
record = self.ofrecord_reader()
Expand All @@ -184,6 +190,7 @@ def __init__(
placement=None,
sbp=None,
channel_last=False,
device="cuda",
):
super().__init__()

Expand Down Expand Up @@ -220,10 +227,10 @@ def __init__(
)
else:
self.image = flow.randint(
0, high=256, size=self.image_shape, dtype=flow.float32, device="cuda"
0, high=256, size=self.image_shape, dtype=flow.float32, device=device,
)
self.label = flow.randint(
0, high=self.num_classes, size=self.label_shape, device="cuda",
0, high=self.num_classes, size=self.label_shape, device=device,
).to(dtype=flow.int32)

def forward(self):
Expand Down
27 changes: 26 additions & 1 deletion Vision/classification/image/resnet50/models/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,30 @@ def forward(self, input, label):
# log_prob = input.softmax(dim=-1).log()
# onehot_label = flow.F.cast(onehot_label, log_prob.dtype)
# loss = flow.mul(log_prob * -1, onehot_label).sum(dim=-1).mean()
loss = flow._C.softmax_cross_entropy(input, onehot_label.to(dtype=input.dtype))
#loss = flow._C.softmax_cross_entropy(input, onehot_label.to(dtype=input.dtype))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前npu不支持

loss = flow._C.cross_entropy(input, onehot_label.to(dtype=input.dtype), reduction='none')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

待验证训练是否收敛。

return loss.mean()

class oldLabelSmoothLoss(flow.nn.Module):
Copy link
Contributor Author

@ShawnXuan ShawnXuan Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是flowvision里面的loss。
需要dim_gather

Copy link
Contributor

@Flowingsun007 Flowingsun007 Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

softmax_cross_entropy和dim_gather应该不难开发,我们可以列到npu开发计划里,后面等开发完成再试试

Copy link
Contributor

@Flowingsun007 Flowingsun007 Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看了下npu的dim_gather已经支持了:oneflow_npu/kernels/dim_gather_kernel.cpp,应该再开发一个softmax_cross_entropy就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

softmax_cross_entropy 我开发了一个,可能反向还有问题。不过softmax_cross_entropy没有和torch对应,我倾向于开发和torch兼容的算子,所以选了flowvision的方案,不用softmax_cross_entropy。

我回头试试 dim_gather

"""NLL Loss with label smoothing
"""

#def __init__(self, smoothing=0.1):
#super(LabelSmoothingCrossEntropy, self).__init__()
def __init__(self, num_classes=-1, smooth_rate=0.0):
super().__init__()
assert smooth_rate < 1.0
self.smoothing = smooth_rate
self.confidence = 1.0 - smooth_rate

def forward(self, x: flow.Tensor, target: flow.Tensor) -> flow.Tensor:
# TODO: register F.log_softmax() function and switch flow.log(flow.softmax()) to F.log_softmax()
logprobs = flow.log_softmax(x, dim=-1)
# TODO: fix gather bug when dim < 0
# FIXME: only support cls task now
nll_loss = -logprobs.gather(dim=1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()

25 changes: 17 additions & 8 deletions Vision/classification/image/resnet50/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time

import oneflow as flow
import oneflow_npu
from oneflow.nn.parallel import DistributedDataParallel as ddp

from config import get_args
Expand All @@ -26,6 +27,7 @@
class Trainer(object):
def __init__(self):
args = get_args()
self.device = args.device
for k, v in args.__dict__.items():
setattr(self, k, v)

Expand Down Expand Up @@ -56,7 +58,7 @@ def __init__(self):
self.cross_entropy = make_cross_entropy(args)

self.train_data_loader = make_data_loader(
args, "train", self.is_global, self.synthetic_data
args, "validation", self.is_global, self.synthetic_data
)
self.val_data_loader = make_data_loader(
args, "validation", self.is_global, self.synthetic_data
Expand Down Expand Up @@ -89,12 +91,12 @@ def init_model(self):
start_t = time.perf_counter()

if self.is_global:
placement = flow.env.all_device_placement("cuda")
placement = flow.env.all_device_placement(self.device)
self.model = self.model.to_global(
placement=placement, sbp=flow.sbp.broadcast
)
else:
self.model = self.model.to("cuda")
self.model = self.model.to(self.device)

if self.load_path is None:
self.legacy_init_parameters()
Expand Down Expand Up @@ -247,6 +249,13 @@ def train_one_epoch(self):
else:
loss, pred, label = self.train_eager()

print("loss")
print(loss)
print("pred")
print(pred)
print("label")
print(label)
exit()
self.cur_iter += 1

loss = tol(loss, self.metric_local)
Expand Down Expand Up @@ -276,7 +285,7 @@ def train_eager(self):
param.grad /= self.world_size
else:
loss.backward()
loss = loss / self.world_size
#loss = loss / self.world_size

self.optimizer.step()
self.optimizer.zero_grad()
Expand Down Expand Up @@ -311,8 +320,8 @@ def eval(self):

def forward(self):
image, label = self.train_data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
logits = self.model(image)
loss = self.cross_entropy(logits, label)
if self.metric_train_acc:
Expand All @@ -323,8 +332,8 @@ def forward(self):

def inference(self):
image, label = self.val_data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
with flow.no_grad():
logits = self.model(image)
pred = logits.softmax()
Expand Down