Skip to content

Commit

Permalink
feat: add Lion optimizer (#4331)
Browse files Browse the repository at this point in the history
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
enneamer and loadams authored Oct 5, 2023
1 parent d72edb3 commit 8e64c3b
Show file tree
Hide file tree
Showing 21 changed files with 1,439 additions and 4 deletions.
43 changes: 43 additions & 0 deletions csrc/cpu/lion/fused_lion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include "cpu_lion.h"

// C++ interface

void multi_tensor_lion(int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, /*gpmv*/
const float lr,
const float beta1,
const float beta2,
const int step,
const int mode,
const float weight_decay)
{
static bool initialized = false;
if (!initialized) {
create_lion_optimizer(0);
initialized = true;
}
for (int i = 0; i < tensor_lists[0].size(); i++) {
ds_lion_step(0,
step,
lr,
beta1,
beta2,
weight_decay,
tensor_lists[1][i],
tensor_lists[0][i],
tensor_lists[2][i]);
}
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("multi_tensor_lion",
&multi_tensor_lion,
"Compute and apply gradient update to parameters for Lion optimizer");
}
233 changes: 233 additions & 0 deletions csrc/includes/cpu_lion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#pragma once

#define NOMINMAX // Windows idiosyncrasy
// https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c

#include <stdio.h>
#include <torch/extension.h>
#include <cassert>
#include "simd.h"

#if defined(__ENABLE_CUDA__)
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include "cuda.h"
#include "custom_cuda_layers.h"
typedef __half ds_half_precision_t;
#else
#include <cmath>
typedef unsigned short ds_half_precision_t;
#endif

#define STEP(SPAN) \
void Step_##SPAN(float* _params, \
float* grads, \
float* _exp_avg, \
size_t _param_size, \
ds_half_precision_t* dev_param = nullptr, \
bool half_precision = false);

class Lion_Optimizer {
public:
Lion_Optimizer(float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float weight_decay = 0)
: _alpha(alpha), _betta1(betta1), _betta2(betta2), _weight_decay(weight_decay), _step(0)
{
#if defined(__ENABLE_CUDA__)
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));

_streams[0] = TrainingContext::Instance().GetCurrentStream();
_streams[1] = TrainingContext::Instance().GetNewStream();
_buf_index = false;
#endif
}
~Lion_Optimizer()
{
#if defined(__ENABLE_CUDA__)
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
#endif
}

#if defined(__AVX512__) or defined(__AVX256__)
template <int span>
void Step_AVX(size_t* rounded_size,
float* _params,
float* grads,
float* _exp_avg,
size_t param_size,
ds_half_precision_t* dev_param = nullptr,
bool half_precision = false);
#endif
STEP(1)
STEP(4)
STEP(8)
#if defined(__ENABLE_CUDA__)
inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
}
#endif
inline void IncrementStep(size_t step, float beta1, float beta2)
{
_step++;
if (_step != step || beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
}
}
inline void update_state(float lr, float weight_decay)
{
_alpha = lr;
_weight_decay = weight_decay;
}

private:
float _alpha;
float _betta1;
float _betta2;
float _weight_decay;
size_t _step;

#if defined(__ENABLE_CUDA__)
float* _doubled_buffer[2];
cudaStream_t _streams[2];
bool _buf_index;
#endif
};

#if defined(__AVX512__) or defined(__AVX256__)
template <int span>
void Lion_Optimizer::Step_AVX(size_t* rounded_size,
float* _params,
float* grads,
float* _exp_avg,
size_t _param_size,
ds_half_precision_t* dev_params,
bool half_precision)
{
size_t new_rounded_size = 0;
int rshft = half_precision ? 1 : 0;

constexpr float neg1 = -1.0f;
AVX_Data neg1_4;
neg1_4.data = SIMD_SET(neg1);

AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);

float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);

float step_size = -_alpha;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);

float after_decay = 1.0f - _alpha * _weight_decay;
AVX_Data after_decay_4;
if (_weight_decay > 0) after_decay_4.data = SIMD_SET(after_decay);

new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span);
for (size_t t = 0; t < new_rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
AVX_Data grad_4[span];
simd_load<span>(grad_4, grads + (i >> rshft), half_precision);

AVX_Data momentum_4[span];
simd_load<span>(momentum_4, _exp_avg + i, false);

AVX_Data param_4[span];
simd_load<span>(param_4, _params + (i >> rshft), half_precision);

AVX_Data tmp_4[span];

simd_mul<span>(tmp_4, momentum_4, betta1_4);
simd_fma<span>(tmp_4, grad_4, betta1_minus1_4, tmp_4);
// We already used intrinsics, so consider the machine representation fixed.
simd_and<span>(tmp_4, tmp_4, neg1_4);
simd_xor<span>(tmp_4, tmp_4, step_size_4);
if (_weight_decay > 0) {
simd_fma<span>(param_4, param_4, after_decay_4, tmp_4);
} else {
simd_add<span>(param_4, param_4, tmp_4);
}

simd_mul<span>(momentum_4, momentum_4, betta2_4);
simd_fma<span>(momentum_4, grad_4, betta2_minus1_4, momentum_4);

simd_store<span>(_params + (i >> rshft), param_4, half_precision);
#if defined(__ENABLE_CUDA__)
if (dev_params) {
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
}
#endif
simd_store<span>(_exp_avg + i, momentum_4, false);
}
#if defined(__ENABLE_CUDA__)
if (dev_params) {
if (half_precision)
launch_param_update_half(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
else
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);

_buf_index = !_buf_index;
}
#endif
}
*rounded_size = new_rounded_size;
}
#endif

int create_lion_optimizer(int optimizer_id,
float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float weight_decay = 0,
bool should_log = false);

int ds_lion_step(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float weight_decay,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg);

int ds_lion_step_plus_copy(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float weight_decay,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& gpu_params);

int destroy_lion_optimizer(int optimizer_id);
60 changes: 59 additions & 1 deletion csrc/includes/simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_AND(x, y) _mm512_and_ps(x, y)
#define SIMD_ANDNOT(x, y) _mm512_andnot_ps(x, y)
#define SIMD_OR(x, y) _mm512_or_ps(x, y)
#define SIMD_XOR(x, y) _mm512_xor_ps(x, y)
#define SIMD_WIDTH 16

#define SIMD_LOAD2(x, h) \
Expand All @@ -42,10 +46,14 @@
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_AND(x, y) _mm256_and_ps(x, y)
#define SIMD_ANDNOT(x, y) _mm256_andnot_ps(x, y)
#define SIMD_OR(x, y) _mm256_or_ps(x, y)
#define SIMD_XOR(x, y) _mm256_xor_ps(x, y)
#define SIMD_WIDTH 8

#define SIMD_LOAD2(x, h) \
((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x))

#define SIMD_STORE2(x, d, h) \
((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
: _mm256_storeu_ps(x, d))
Expand Down Expand Up @@ -136,5 +144,55 @@ inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); }
}
template <int span>
inline void simd_and(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_AND(src_a_l[i].data, src_a_r.data); }
}
template <int span>
inline void simd_and(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_AND(src_a_l[i].data, src_a_r[i].data); }
}
template <int span>
inline void simd_andnot(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ANDNOT(src_a_l[i].data, src_a_r.data); }
}
template <int span>
inline void simd_andnot(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) {
dst[i].data = SIMD_ANDNOT(src_a_l[i].data, src_a_r[i].data);
}
}
template <int span>
inline void simd_or(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_OR(src_a_l[i].data, src_a_r.data); }
}
template <int span>
inline void simd_or(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_OR(src_a_l[i].data, src_a_r[i].data); }
}
template <int span>
inline void simd_xor(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_XOR(src_a_l[i].data, src_a_r.data); }
}
template <int span>
inline void simd_xor(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_XOR(src_a_l[i].data, src_a_r[i].data); }
}

#endif
16 changes: 16 additions & 0 deletions csrc/lion/cpu_lion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include "cpu_lion.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("lion_update", &ds_lion_step, "DeepSpeed CPU Lion update (C++)");
m.def("lion_update_copy",
&ds_lion_step_plus_copy,
"DeepSpeed CPU Lion update and param copy (C++)");
m.def("create_lion", &create_lion_optimizer, "DeepSpeed CPU Lion (C++)");
m.def("destroy_lion", &destroy_lion_optimizer, "DeepSpeed CPU Lion destroy (C++)");
}
Loading

0 comments on commit 8e64c3b

Please sign in to comment.