Skip to content

Commit a2878e3

Browse files
authored
[Inference] Add Reduce Utils (#5537)
* add reduce utils * add using to delele namespace prefix
1 parent 04aca9e commit a2878e3

File tree

9 files changed

+180
-363
lines changed

9 files changed

+180
-363
lines changed

extensions/csrc/common/micros.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,6 @@
99

1010
#include <ATen/ATen.h>
1111

12-
#ifndef TORCH_CHECK
13-
#define TORCH_CHECK AT_CHECK
14-
#endif
15-
16-
#ifdef VERSION_GE_1_3
17-
#define DATA_PTR data_ptr
18-
#else
19-
#define DATA_PTR data
20-
#endif
21-
2212
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
2313
switch (TYPE) { \
2414
case at::ScalarType::Half: { \
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include <cuda.h>
4+
#include <cuda_fp16.h>
5+
#include <cuda_runtime.h>
6+
7+
#include <functional>
8+
9+
namespace colossalAI {
10+
namespace cuda {
11+
namespace funcs {
12+
13+
enum class BinaryOpType { kAdd = 0, kMinus, kMul, KDiv, kMax, KMin };
14+
15+
template <typename T, BinaryOpType Op>
16+
struct BinaryOpFunctor;
17+
18+
template <typename T>
19+
struct BinaryOpFunctor<T, BinaryOpType::kAdd>
20+
: public std::binary_function<T, T, T> {
21+
__host__ __device__ T operator()(T lhs, T rhs) { return lhs + rhs; }
22+
};
23+
24+
template <typename T>
25+
struct BinaryOpFunctor<T, BinaryOpType::kMax>
26+
: public std::binary_function<T, T, T> {
27+
__host__ __device__ T operator()(T lhs, T rhs) { return max(lhs, rhs); }
28+
};
29+
30+
} // namespace funcs
31+
} // namespace cuda
32+
} // namespace colossalAI

0 commit comments

Comments
 (0)