Skip to content

Commit 21e1e36

Browse files
authored
Merge pull request #5435 from Courtesy-Xs/add_gpu_launch_config
Add query and other components
2 parents f7aecc0 + 5eb5ff1 commit 21e1e36

22 files changed

+401
-118
lines changed

extensions/csrc/common/dev_info_mgr.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#pragma once
2+
3+
#include <memory>
4+
5+
#include "common/nvgpu_dev_info.h"
6+
#include "target.h"
7+
8+
namespace colossalAI {
9+
namespace common {
10+
11+
template <typename Ret>
12+
class DevInfoMgr final {
13+
public:
14+
static std::unique_ptr<Ret> GetDevInfo(int device_num) const {
15+
return std::make_unique<Ret>(device_num);
16+
}
17+
};
18+
19+
} // namespace common
20+
} // namespace colossalAI

extensions/csrc/cuda/type_shim.h renamed to extensions/csrc/common/micros.h

Lines changed: 9 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@
99

1010
#include <ATen/ATen.h>
1111

12-
#include "compat.h"
12+
#ifndef TORCH_CHECK
13+
#define TORCH_CHECK AT_CHECK
14+
#endif
15+
16+
#ifdef VERSION_GE_1_3
17+
#define DATA_PTR data_ptr
18+
#else
19+
#define DATA_PTR data
20+
#endif
1321

1422
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
1523
switch (TYPE) { \
@@ -214,90 +222,3 @@
214222
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
215223
"'"); \
216224
}
217-
218-
template <typename T>
219-
__device__ __forceinline__ T reduce_block_into_lanes(
220-
T *x, T val, int lanes = 1,
221-
bool share_result = false) // lanes is intended to be <= 32.
222-
{
223-
int tid = threadIdx.x + threadIdx.y * blockDim.x;
224-
int blockSize =
225-
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
226-
227-
if (blockSize >= 64) {
228-
x[tid] = val;
229-
__syncthreads();
230-
}
231-
232-
#pragma unroll
233-
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
234-
if (tid < i) x[tid] = x[tid] + x[tid + i];
235-
__syncthreads();
236-
}
237-
238-
T final;
239-
240-
if (tid < 32) {
241-
if (blockSize >= 64)
242-
final = x[tid] + x[tid + 32];
243-
else
244-
final = val;
245-
// __SYNCWARP();
246-
247-
#pragma unroll
248-
for (int i = 16; i >= lanes; i >>= 1)
249-
final = final + __shfl_down_sync(0xffffffff, final, i);
250-
}
251-
252-
if (share_result) {
253-
if (tid < lanes) x[tid] = final; // EpilogueOp
254-
// Make sure the smem result is visible to all warps.
255-
__syncthreads();
256-
}
257-
258-
return final;
259-
}
260-
261-
template <typename T>
262-
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
263-
T *x, T val, int lanes = 1,
264-
bool share_result = false) // lanes is intended to be <= 32.
265-
{
266-
int tid = threadIdx.x + threadIdx.y * blockDim.x;
267-
int blockSize =
268-
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
269-
270-
if (blockSize >= 64) {
271-
x[tid] = val;
272-
__syncthreads();
273-
}
274-
275-
#pragma unroll
276-
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
277-
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
278-
__syncthreads();
279-
}
280-
281-
T final;
282-
283-
if (tid < 32) {
284-
if (blockSize >= 64)
285-
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
286-
else
287-
final = val;
288-
// __SYNCWARP();
289-
290-
#pragma unroll
291-
for (int i = 16; i >= lanes; i >>= 1)
292-
final =
293-
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
294-
}
295-
296-
if (share_result) {
297-
if (tid < lanes) x[tid] = final; // EpilogueOp
298-
// Make sure the smem result is visible to all warps.
299-
__syncthreads();
300-
}
301-
302-
return final;
303-
}

extensions/csrc/cuda/include/mp_type_traits.h renamed to extensions/csrc/common/mp_type_traits.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
#include <ATen/ATen.h>
44

5-
#include "../type_shim.h"
5+
#include "micros.h"
66

7-
namespace infer {
8-
namespace dtype {
7+
namespace colossalAI {
8+
namespace common {
99

1010
template <typename T>
1111
class MPTypeTrait {
@@ -31,5 +31,5 @@ class MPTypeTrait<at::BFloat16> {
3131
using Type = float;
3232
};
3333

34-
} // namespace dtype
35-
} // namespace infer
34+
} // namespace common
35+
} // namespace colossalAI

extensions/csrc/common/target.h

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#pragma once
2+
3+
#include <exception>
4+
#include <iostream>
5+
#include <string>
6+
7+
namespace colossalAI {
8+
namespace common {
9+
10+
class Target {
11+
public:
12+
enum class OS : int {
13+
Unk = -1,
14+
Linux,
15+
Windows,
16+
};
17+
enum class Arch : int {
18+
Unk = -1,
19+
X86,
20+
Arm,
21+
NVGPU,
22+
AMDGPU,
23+
Ascend,
24+
};
25+
enum class BitLen : int {
26+
Unk = -1,
27+
k32,
28+
k64,
29+
};
30+
31+
explicit Target(OS os, Arch arch, BitLen bitlen)
32+
: os_(os), arch_(arch), bitlen_(bitlen) {}
33+
34+
bool defined() const {
35+
return (os_ != OS::Unk) && (arch_ != Arch::Unk) && (bitlen_ != BitLen::Unk);
36+
}
37+
38+
std::string str() const {
39+
std::string s{"OS: "};
40+
switch (os_) {
41+
case OS::Unk:
42+
s += "Unk";
43+
break;
44+
case OS::Linux:
45+
s += "Linux";
46+
break;
47+
case OS::Windows:
48+
s += "Windows";
49+
break;
50+
default:
51+
throw std::invalid_argument("Invalid OS type!");
52+
}
53+
s += "\t";
54+
s += "Arch: ";
55+
56+
switch (arch_) {
57+
case Arch::Unk:
58+
s += "Unk";
59+
break;
60+
case Arch::X86:
61+
s += "X86";
62+
break;
63+
case Arch::Arm:
64+
s += "Arm";
65+
break;
66+
case Arch::NVGPU:
67+
s += "NVGPU";
68+
break;
69+
case Arch::AMDGPU:
70+
s += "AMDGPU";
71+
break;
72+
case Arch::Ascend:
73+
s += "Ascend";
74+
break;
75+
default:
76+
throw std::invalid_argument("Invalid Arch type!");
77+
}
78+
s += "\t";
79+
s += "BitLen: ";
80+
81+
switch (bitlen_) {
82+
case BitLen::Unk:
83+
s += "Unk";
84+
break;
85+
case BitLen::k32:
86+
s += "k32";
87+
break;
88+
case BitLen::k64:
89+
s += "k64";
90+
break;
91+
default:
92+
throw std::invalid_argument("Invalid target bit length!");
93+
}
94+
95+
return s;
96+
}
97+
98+
OS os() const { return os_; }
99+
Arch arch() const { return arch_; }
100+
BitLen bitlen() const { return bitlen_; }
101+
102+
static Target DefaultX86Target();
103+
static Target DefaultArmTarget();
104+
static Target DefaultRocmTarget();
105+
static Target DefaultAscendTarget();
106+
107+
static Target DefaultCUDATarget() {
108+
return Target(OS::Linux, Arch::CUDA, BitLen::k64);
109+
}
110+
111+
friend std::ostream& operator<<(std::ostream& os, const Target& target);
112+
friend bool operator==(const Target& lhs, const Target& rhs);
113+
friend bool operator!=(const Target& lhs, const Target& rhs);
114+
115+
private:
116+
OS os_{OS::Unk};
117+
Arch arch_{Arch::Unk};
118+
BitLen bitlen_{BitLen::Unk};
119+
};
120+
121+
std::ostream& operator<<(std::ostream& os, const Target& target) {
122+
std::cout << target.str() << std::endl;
123+
}
124+
bool operator==(const Target& lhs, const Target& rhs) {
125+
return (lhs.os_ == rhs.os_) && (lhs.arch_ == rhs.arch_) &&
126+
(lhs.bitlen_ == rhs.bitlen_);
127+
}
128+
bool operator!=(const Target& lhs, const Target& rhs) {
129+
return (lhs.os_ != rhs.os_) && (lhs.arch_ != rhs.arch_) &&
130+
(lhs.bitlen_ != rhs.bitlen_);
131+
}
132+
133+
} // namespace common
134+
} // namespace colossalAI

extensions/csrc/cuda/activation_kernel.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
#include <torch/extension.h>
33
#include <stdio.h>
44

5-
#include "type_shim.h"
6-
#include "include/mp_type_traits.h"
5+
#include "../common/micros.h"
6+
#include "../common/mp_type_traits.h"
77

88
template<typename T>
99
__device__ __forceinline__ T silu_kernel(const T& x) {
1010
// x * sigmoid(x)
11-
using MT = typename infer::dtype::MPTypeTrait<T>::Type;
11+
using MT = typename colossalAI::common::MPTypeTrait<T>::Type;
1212
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x))));
1313
}
1414

@@ -17,7 +17,7 @@ __global__ void act_and_mul_kernel(
1717
const scalar_t* __restrict__ ins_data,
1818
scalar_t* __restrict__ outs_data,
1919
const int64_t numel) {
20-
using MT = typename infer::dtype::MPTypeTrait<scalar_t>::Type;
20+
using MT = typename colossalAI::common::MPTypeTrait<scalar_t>::Type;
2121

2222
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
2323
const int64_t grid_size = blockDim.x * gridDim.x;

extensions/csrc/cuda/compat.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +0,0 @@
1-
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
2-
#ifndef TORCH_CHECK
3-
#define TORCH_CHECK AT_CHECK
4-
#endif
5-
6-
#ifdef VERSION_GE_1_3
7-
#define DATA_PTR data_ptr
8-
#else
9-
#define DATA_PTR data
10-
#endif

extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include <torch/extension.h>
33
#include <stdio.h>
44

5-
#include "type_shim.h"
5+
#include "../common/micros.h"
66

77
template<typename scalar_t>
88
__global__ void decode_kv_cache_memcpy_kernel(

0 commit comments

Comments
 (0)