Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Need to use safe accumulation for calculating the gradient of Embedding + Take #17703

@sxjscience

Description

@sxjscience

Description

Currently, the inner gradient accumulation method in Embedding and take is not based on safe accumulation, which means that we will lose precision in the fp16 case. Here's the example that amplified the issue:

import mxnet as mx
import numpy as np
mx.npx.set_np()

ctx = mx.gpu()
vocab_size = 8
embedding_dim = 1
index_num = 100000

dat = mx.np.random.randint(0, vocab_size, size=(index_num,), ctx=ctx)

for dtype in [np.float16, np.float32]:
    weight = mx.np.random.normal(0, 1, size=(vocab_size, embedding_dim), ctx=ctx, dtype=dtype)

    weight.attach_grad(grad_req='add')
    weight.grad[:] = 1.0
    with mx.autograd.record():
        out = mx.npx.embedding(dat, weight, input_dim=vocab_size, output_dim=embedding_dim) * 0.01
        out.backward()
    print('dtype=', dtype)
    print(weight.grad)

Output:

dtype= <class 'numpy.float16'>
[[32.]
 [32.]
 [32.]
 [32.]
 [32.]
 [32.]
 [32.]
 [32.]] @gpu(0)
dtype= <class 'numpy.float32'>
[[126.748665]
 [127.53883 ]
 [125.30836 ]
 [125.36837 ]
 [126.278564]
 [127.05873 ]
 [124.74824 ]
 [125.018295]] @gpu(0)

Also, the same happens for take

import mxnet as mx
import numpy as np
mx.npx.set_np()

ctx = mx.gpu()
vocab_size = 8
embedding_dim = 1
index_num = 100000

dat = mx.np.random.randint(0, vocab_size, size=(index_num,), ctx=ctx)
weight = mx.np.random.normal(0, 1, size=(vocab_size, embedding_dim), ctx=ctx, dtype=np.float16)

weight.attach_grad(grad_req='add')
weight.grad[:] = 1.0
with mx.autograd.record():
    out = mx.np.take(weight, dat, axis=0) * 0.01
    out.backward()
print(weight.grad)

Output:

dtype= <class 'numpy.float16'>
[[32.]
 [32.]
 [32.]
 [32.]
 [32.]
 [32.]
 [32.]
 [32.]] @gpu(0)
dtype= <class 'numpy.float32'>
[[125.44839 ]
 [126.68865 ]
 [126.62864 ]
 [125.44839 ]
 [127.028725]
 [126.38859 ]
 [125.108315]
 [125.32836 ]] @gpu(0)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions