Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev add spectral_norm #9674

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open

dev add spectral_norm #9674

wants to merge 13 commits into from

Conversation

hhhfccz
Copy link
Contributor

@hhhfccz hhhfccz commented Jan 1, 2023

@BBuf 修复了一些 spectral_norm 实现过程中遇到的bug

  • 修复 dot 在 cpu 下不支持 int32 与 int64 计算的 bug (因为matmul)
  • 增加 spectral_norm 的基本功能
  • 修复 kaiming_uniform_ 和 kaiming_normal_ 在输入0size tensor 的时候的除0 bug
  • 新增 oneflow.linalg.multi_dot()
  • spectral_norm 的 load_state_dict 测试, load 与 hook
  • global test
  • spectral_norm 和 multi_dot 的文档

@ehuaa
Copy link

ehuaa commented Mar 10, 2023

您好,之前看pytorch中的commit记录 module中增加相应模块的原因是为了 maintain BC
1678415477411

1678415531305
因为pytorch的这个接口是在线上迭代更新所以才引入了version,是不是oneflow中不用增加这一部分
pytorch/pytorch@2cd912b
pytorch/pytorch@ac994f2

@hhhfccz
Copy link
Contributor Author

hhhfccz commented Mar 10, 2023

您好,之前看pytorch中的commit记录 module中增加相应模块的原因是为了 maintain BC 1678415477411

1678415531305 因为pytorch的这个接口是在线上迭代更新所以才引入了version,是不是oneflow中不用增加这一部分 pytorch/pytorch@2cd912b pytorch/pytorch@ac994f2

目前version是默认为1,所以是希望对标上这个hook;您的意思是对标version=0嘛,目前来看weight_norm似乎对标的是version=0,

@hhhfccz
Copy link
Contributor Author

hhhfccz commented Mar 10, 2023

写SN目的还有一个是为了写SR的一些网络,我去实验看看SR那边需不需要保留version=1
@ehuaa 如果有新增代码或者docstring的话可以直接提交在这个PR

@ehuaa
Copy link

ehuaa commented Mar 10, 2023

写SN目的还有一个是为了写SR的一些网络,我去实验看看SR那边需不需要保留version=1 @ehuaa 如果有新增代码或者docstring的话可以直接提交在这个PR

好的~

@ehuaa
Copy link

ehuaa commented Mar 10, 2023

您好,之前看pytorch中的commit记录 module中增加相应模块的原因是为了 maintain BC 1678415477411
1678415531305 因为pytorch的这个接口是在线上迭代更新所以才引入了version,是不是oneflow中不用增加这一部分 pytorch/pytorch@2cd912b pytorch/pytorch@ac994f2

目前version是默认为1,所以是希望对标上这个hook;您的意思是对标version=0嘛,目前来看weight_norm似乎对标的是version=0,

其实也不是 就是 可以看下pytorch下面的两个hook 他们针对的是之前版本version=0的一个fix
1678417837589
我的理解是针对version=0 或者None的module的一种hotfix,但是oneflow的spectral norm的module已经是基于稳定版pytorch对齐的了,version一直是1,其实上面两个hook里面的逻辑都不会执行。

@hhhfccz
Copy link
Contributor Author

hhhfccz commented Mar 10, 2023

您好,之前看pytorch中的commit记录 module中增加相应模块的原因是为了 maintain BC 1678415477411
1678415531305 因为pytorch的这个接口是在线上迭代更新所以才引入了version,是不是oneflow中不用增加这一部分 pytorch/pytorch@2cd912b pytorch/pytorch@ac994f2

目前version是默认为1,所以是希望对标上这个hook;您的意思是对标version=0嘛,目前来看weight_norm似乎对标的是version=0,

其实也不是 就是 可以看下pytorch下面的两个hook 他们针对的是之前版本version=0的一个fix 1678417837589 我的理解是针对version=0 或者None的module的一种hotfix,但是oneflow的spectral norm的module已经是基于稳定版pytorch对齐的了,version一直是1,其实上面两个hook里面的逻辑都不会执行。

嗯嗯,我刚刚发现了我觉得是OK的,现在在看测试,目前测试会跟oneflow的Graph有一点冲突正在解决,稍后看麻烦review一下看看

@ehuaa
Copy link

ehuaa commented Mar 10, 2023

您好,之前看pytorch中的commit记录 module中增加相应模块的原因是为了 maintain BC 1678415477411
1678415531305 因为pytorch的这个接口是在线上迭代更新所以才引入了version,是不是oneflow中不用增加这一部分 pytorch/pytorch@2cd912b pytorch/pytorch@ac994f2

目前version是默认为1,所以是希望对标上这个hook;您的意思是对标version=0嘛,目前来看weight_norm似乎对标的是version=0,

嗯嗯 另外他后面version-1的in-place操作相关的fix主要是针对pytorch dataparallel master-slave模式带来的一些问题,oneflow的DistributedDataParallel和他的实现其实不一样。 我还没有跑SN的网络测试,不过感觉大概oneflow的模式大概能天然fix?单侧test里面感觉可以多加些ddp的~多交流

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.3ms (= 14126.5ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 144.8ms (= 14481.3ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.03 (= 144.8ms / 141.3ms)

OneFlow resnet50 time: 83.1ms (= 8307.6ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 87.0ms (= 8704.6ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.05 (= 87.0ms / 83.1ms)

OneFlow resnet50 time: 51.2ms (= 10236.9ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 61.3ms (= 12258.4ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.20 (= 61.3ms / 51.2ms)

OneFlow resnet50 time: 33.8ms (= 6751.3ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 45.9ms (= 9182.7ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.36 (= 45.9ms / 33.8ms)

OneFlow resnet50 time: 25.7ms (= 5136.0ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 37.4ms (= 7483.0ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.46 (= 37.4ms / 25.7ms)

OneFlow swin dataloader time: 0.235s (= 47.022s / 200, num_workers=1)
PyTorch swin dataloader time: 0.151s (= 30.234s / 200, num_workers=1)
Relative speed: 0.643 (= 0.151s / 0.235s)

OneFlow swin dataloader time: 0.069s (= 13.875s / 200, num_workers=4)
PyTorch swin dataloader time: 0.042s (= 8.399s / 200, num_workers=4)
Relative speed: 0.605 (= 0.042s / 0.069s)

OneFlow swin dataloader time: 0.041s (= 8.199s / 200, num_workers=8)
PyTorch swin dataloader time: 0.023s (= 4.532s / 200, num_workers=8)
Relative speed: 0.553 (= 0.023s / 0.041s)

❌ OneFlow resnet50 time: 153.6ms (= 15358.2ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 166.3ms (= 16630.7ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.08 (= 166.3ms / 153.6ms)

OneFlow resnet50 time: 93.8ms (= 9375.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 104.5ms (= 10453.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.11 (= 104.5ms / 93.8ms)

OneFlow resnet50 time: 61.2ms (= 12243.2ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 81.0ms (= 16190.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.32 (= 81.0ms / 61.2ms)

OneFlow resnet50 time: 43.2ms (= 8645.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 75.0ms (= 15005.0ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.74 (= 75.0ms / 43.2ms)

OneFlow resnet50 time: 35.1ms (= 7022.5ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.7ms (= 13534.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.93 (= 67.7ms / 35.1ms)

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9674/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants