Skip to content

Commit

Permalink
[Optimizer] Support AdamW optimizer for EmbeddingVariable. (DeepRec-A…
Browse files Browse the repository at this point in the history
  • Loading branch information
fuhailin authored Aug 5, 2022
1 parent 6ef311f commit 986498f
Show file tree
Hide file tree
Showing 18 changed files with 1,106 additions and 3 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ This is the first release of DeepRec. DeepRec has super large-scale distributed
### **Optimizer**
- AdamAsync Optimizer
- AdagradDecay Optimizer
- AdamW Optimizer

### **Op & Hardware Acceleration**
- Operators Optimization: Unique, Gather, DynamicStitch, BiasAdd, Select, Transpose, SparseSegmentReduction, where, DynamicPartition, SparseConcat tens of ops' CPU/GPU optimization.
Expand Down
55 changes: 55 additions & 0 deletions docs/AdamW-Optimizer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# AdamW Optimizer
## 介绍
AdamW优化器支持Embedding Variable,相对于Adam优化器增加了了权重衰减的功能。

这是Loshch ilov & Hutter (https://arxiv.org/abs/1711.05101)的 "Decoupled Weight Decay Regularization"中描述的AdamW优化器的一个实现。


## 用户接口
训练时只需要定义`tf.train.AdamWOptimizer`即可,和其他TF原生Optimizer使用方式相同。具体定义如下:
```python
class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
def __init__(self,
weight_decay,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
use_locking=False,
name="AdamW"):

# 调用方法:
optimizer = tf.train.AdamWOptimizer(
weight_decay=weight_decay_new
learning_rate=learning_rate_new,
beta1=0.9,
beta2=0.999,
epsilon=1e-8)
```
## 使用示例
```python
import tensorflow as tf

var = tf.get_variable("var_0", shape=[10,16],
initializer=tf.ones_initializer(tf.float32))

emb = tf.nn.embedding_lookup(var, tf.cast([0,1,2,5,6,7], tf.int64))
fun = tf.multiply(emb, 2.0, name='multiply')
loss = tf.reduce_sum(fun, name='reduce_sum')

gs= tf.train.get_or_create_global_step()
opt = tf.train.AdamWOptimizer(weight_decay=0.01, learning_rate=0.1)

g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)

init = tf.global_variables_initializer()

sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
with tf.Session(config=sess_config) as sess:
sess.run([init])
print(sess.run([emb, train_op, loss]))
print(sess.run([emb, train_op, loss]))
print(sess.run([emb, train_op, loss]))
```

158 changes: 158 additions & 0 deletions tensorflow/core/kernels/training_ali_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2452,4 +2452,162 @@ TF_CALL_float(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#undef REGISTER_KERNELS

template <typename Device, typename T, typename Tindex>
class KvSparseApplyAdamWOp : public OpKernel {
public:
explicit KvSparseApplyAdamWOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
}

void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockEmbeddingVariableInputMutexesInOrder<Tindex, T>(ctx, use_exclusive_lock_,
{0, 1, 2});
EmbeddingVar<Tindex, T>* var = nullptr;
OP_REQUIRES_OK(ctx, GetInputEmbeddingVar(ctx, 0, &var));
core::ScopedUnref unref_var(var);

EmbeddingVar<Tindex, T>* m = nullptr;
OP_REQUIRES_OK(ctx, GetInputEmbeddingVar(ctx, 1, &m));
core::ScopedUnref unref_m(m);

EmbeddingVar<Tindex, T>* v = nullptr;
OP_REQUIRES_OK(ctx, GetInputEmbeddingVar(ctx, 2, &v));
core::ScopedUnref unref_v(v);

const Tensor& beta1_power = ctx->input(3);
const Tensor& beta2_power = ctx->input(4);
const Tensor& lr = ctx->input(5);
const Tensor& beta1 = ctx->input(6);
const Tensor& beta2 = ctx->input(7);
const Tensor& epsilon = ctx->input(8);
const Tensor& grad = ctx->input(9);
const Tensor& indices = ctx->input(10);
const Tensor& global_step = ctx->input(11);
const Tensor& weight_decay = ctx->input(12);

OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(beta1_power.shape()),
errors::InvalidArgument("beta1_power is not a scalar: ",
beta1_power.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(beta2_power.shape()),
errors::InvalidArgument("beta2_power is not a scalar: ",
beta2_power.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(beta1.shape()),
errors::InvalidArgument("beta1 is not a scalar: ",
beta1.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(beta2.shape()),
errors::InvalidArgument("beta2 is not a scalar: ",
beta2.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
errors::InvalidArgument("epsilon is not a scalar: ",
epsilon.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsVector(indices.shape()),
errors::InvalidArgument("indices must be one-dimensional"));

OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(weight_decay.shape()),
errors::InvalidArgument("weight_decay is not a scalar: ",
weight_decay.shape().DebugString()));

int64 inner_dim = 1;
TensorShape var_shape({var->ValueLen()});
for (int d = 0; d < var_shape.dims(); d++) {
OP_REQUIRES(ctx, var_shape.dim_size(d) == grad.dim_size(d + 1),
errors::InvalidArgument(strings::StrCat(
"var and grad must match in dimension ", d + 1)));
inner_dim *= grad.dim_size(d + 1);
}
OP_REQUIRES(
ctx, inner_dim > 0,
errors::InvalidArgument(
"Inner dimension should be greater than zero."));

OP_REQUIRES(
ctx, IsLegacyScalar(global_step.shape()),
errors::InvalidArgument(
"global_step is not a scalar: ", global_step.shape().DebugString()));

const int64 N = indices.dim_size(0);
OP_REQUIRES(
ctx, grad.dim_size(0) == N,
errors::InvalidArgument(
"grad must be the same size as indices in the first dimension."));

if (N > 0) {
T beta1_power_scalar = beta1_power.scalar<T>()();
T beta2_power_scalar = beta2_power.scalar<T>()();
T lr_scalar = lr.scalar<T>()();
T beta1_scalar = beta1.scalar<T>()();
T beta2_scalar = beta2.scalar<T>()();
T epsilon_scalar = epsilon.scalar<T>()();
T weight_decay_scalar = weight_decay.scalar<T>()();
const T alpha = lr_scalar *
Eigen::numext::sqrt(static_cast<T>(1) - beta2_power_scalar) /
(static_cast<T>(1) - beta1_power_scalar);

auto DoWork = [this, ctx, inner_dim, &var, &m, &v, &grad, &indices,
&beta1_power_scalar, &beta2_power_scalar, &lr_scalar, &beta1_scalar,
&beta2_scalar, &epsilon_scalar, &alpha, &global_step,
&weight_decay_scalar] (int64 start_i, int64 limit_i) {
if (inner_dim > 0) {
auto grad_flat = grad.flat_outer_dims<T>();
auto indices_vec = indices.vec<Tindex>();

int64 gs = global_step.scalar<int64>()();

for (int64 i = start_i; i < limit_i; i++) {
const Tindex index = indices_vec(i);
ValuePtr<T>* value_ptr = nullptr;
bool is_filter =false;
OP_REQUIRES_OK(ctx, var->LookupOrCreateKey(index, &value_ptr, &is_filter));
var->UpdateVersion(value_ptr, gs);
if (is_filter) {
auto var_i = var->flat(value_ptr);
auto m_a = m->flat(value_ptr);
auto v_a = v->flat(value_ptr);
auto g = grad_flat.template chip<0>(i);
// m_a = beta1 * m + (1 - beta1) * g
m_a += (g - m_a) * (static_cast<T>(1) - beta1_scalar);
// v_a = beta2 * v + (1 - beta2) * (g * g)
v_a += (g.square() - v_a) * (static_cast<T>(1) - beta2_scalar);
var_i -= (m_a * alpha) / (v_a.sqrt() + epsilon_scalar) + weight_decay_scalar * var_i;
var->Commit(index, value_ptr);
}
}
}
};

const int64 cost = 1000;
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers, N, cost, DoWork);
}
}

private:
bool use_exclusive_lock_;
};

#define REGISTER_KERNELS(T, Tindices) \
REGISTER_KERNEL_BUILDER(Name("KvResourceSparseApplyAdamW") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices"), \
KvSparseApplyAdamWOp<CPUDevice, T, Tindices>);
#define REGISTER_CPU_KERNELS(T) \
REGISTER_KERNELS(T, int32); \
REGISTER_KERNELS(T, int64);

TF_CALL_float(REGISTER_CPU_KERNELS);

#undef REGISTER_CPU_KERNELS
#undef REGISTER_KERNELS

} // namespace tensorflow
24 changes: 24 additions & 0 deletions tensorflow/core/ops/training_ali_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,4 +507,28 @@ REGISTER_OP("KvResourceSparseApplyGradientDescent")
.Attr("use_locking: bool = false")
.SetShapeFn(KvApplyGradientDescentShapeFn);

REGISTER_OP("KvResourceSparseApplyAdamW")
.Input("var: resource")
.Input("m: resource")
.Input("v: resource")
.Input("beta1_power: T")
.Input("beta2_power: T")
.Input("lr: T")
.Input("beta1: T")
.Input("beta2: T")
.Input("epsilon: T")
.Input("grad: T")
.Input("indices: Tindices")
.Input("global_step: Tstep")
.Input("weight_decay: T")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64, string}")
.Attr("Tstep: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return KvResourceApplyAdamShapeFn(c, true /* sparse */);
})
.Doc(R"doc(
)doc");

} // namespace tensorflow
17 changes: 17 additions & 0 deletions tensorflow/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7379,3 +7379,20 @@ cuda_py_test(
],
xla_enable_strict_auto_jit = True,
)

cuda_py_test(
name = "weight_decay_optimizers_test",
size = "small",
srcs = ["training/weight_decay_optimizers_test.py"],
additional_deps = [
":array_ops",
":framework",
":math_ops",
":platform",
":training",
":platform_test",
":client_testlib",
"//third_party/py/numpy",
],
xla_enable_strict_auto_jit = True,
)
Loading

0 comments on commit 986498f

Please sign in to comment.