Description
Currently, the inner gradient accumulation method in Embedding and take is not based on safe accumulation, which means that we will lose precision in the fp16 case. Here's the example that amplified the issue:
import mxnet as mx
import numpy as np
mx.npx.set_np()
ctx = mx.gpu()
vocab_size = 8
embedding_dim = 1
index_num = 100000
dat = mx.np.random.randint(0, vocab_size, size=(index_num,), ctx=ctx)
for dtype in [np.float16, np.float32]:
weight = mx.np.random.normal(0, 1, size=(vocab_size, embedding_dim), ctx=ctx, dtype=dtype)
weight.attach_grad(grad_req='add')
weight.grad[:] = 1.0
with mx.autograd.record():
out = mx.npx.embedding(dat, weight, input_dim=vocab_size, output_dim=embedding_dim) * 0.01
out.backward()
print('dtype=', dtype)
print(weight.grad)
Output:
dtype= <class 'numpy.float16'>
[[32.]
[32.]
[32.]
[32.]
[32.]
[32.]
[32.]
[32.]] @gpu(0)
dtype= <class 'numpy.float32'>
[[126.748665]
[127.53883 ]
[125.30836 ]
[125.36837 ]
[126.278564]
[127.05873 ]
[124.74824 ]
[125.018295]] @gpu(0)
Also, the same happens for take
import mxnet as mx
import numpy as np
mx.npx.set_np()
ctx = mx.gpu()
vocab_size = 8
embedding_dim = 1
index_num = 100000
dat = mx.np.random.randint(0, vocab_size, size=(index_num,), ctx=ctx)
weight = mx.np.random.normal(0, 1, size=(vocab_size, embedding_dim), ctx=ctx, dtype=np.float16)
weight.attach_grad(grad_req='add')
weight.grad[:] = 1.0
with mx.autograd.record():
out = mx.np.take(weight, dat, axis=0) * 0.01
out.backward()
print(weight.grad)
Output:
dtype= <class 'numpy.float16'>
[[32.]
[32.]
[32.]
[32.]
[32.]
[32.]
[32.]
[32.]] @gpu(0)
dtype= <class 'numpy.float32'>
[[125.44839 ]
[126.68865 ]
[126.62864 ]
[125.44839 ]
[127.028725]
[126.38859 ]
[125.108315]
[125.32836 ]] @gpu(0)