Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 43 additions & 28 deletions paddle/fluid/imperative/amp_auto_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/fluid/imperative/amp_auto_cast.h"

#include <algorithm>
#include <memory>
#include <string>
#include <utility>
Expand All @@ -35,14 +36,29 @@ AmpOperators& AmpOperators::Instance() {
return instance;
}

std::shared_ptr<std::unordered_set<std::string>> AmpOperators::GetAllowOps() {
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableAllowOps() {
return allow_ops_;
}

std::shared_ptr<std::unordered_set<std::string>> AmpOperators::GetBlockOps() {
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableBlockOps() {
return block_ops_;
}

std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
os << "allow ops: ";
auto allow_ops = ops.GetMutableAllowOps();
std::copy((*allow_ops).begin(), (*allow_ops).end(),
std::ostream_iterator<std::string>(os, " "));
os << "; ";
os << "block ops: ";
auto block_ops = ops.GetMutableBlockOps();
std::copy((*block_ops).begin(), (*block_ops).end(),
std::ostream_iterator<std::string>(os, " "));
return os;
}

inline std::string GetDtypeStr(
const std::shared_ptr<imperative::VarBase>& var) {
return framework::DataTypeToString(var->DataType());
Expand Down Expand Up @@ -115,51 +131,50 @@ static inline framework::proto::VarType::Type GetPromoteType(

NameVarBaseMap AutoCastInputs(const std::string& op_type,
const NameVarBaseMap& ins) {
NameVarBaseMap new_ins = {};
if (AmpOperators::Instance().GetAllowOps()->count(op_type)) {
for (const auto& pair : ins) {
NameVarBaseMap new_ins(ins);
if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) {
for (auto& pair : new_ins) {
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
if ((op_type == "batch_norm" || op_type == "layer_norm") &&
pair.first != "X") {
continue;
}

VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float16";
for (const auto& var : pair.second) {
auto new_var = CastToFP16(var);
new_ins[pair.first].emplace_back(new_var);
for (auto& var : pair.second) {
var = CastToFP16(var);
}
}
return new_ins;
} else if (AmpOperators::Instance().GetBlockOps()->count(op_type)) {
for (const auto& pair : ins) {
} else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float";
for (const auto& var : pair.second) {
auto new_var = CastToFP32(var);
new_ins[pair.first].emplace_back(new_var);
for (auto& var : pair.second) {
var = CastToFP32(var);
}
}
return new_ins;
} else {
auto dst_type = GetPromoteType(ins);

for (const auto& pair : ins) {
for (auto& pair : new_ins) {
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
if ((op_type == "batch_norm" || op_type == "layer_norm") &&
pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
continue;
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type);
for (const auto& var : pair.second) {
// NOTE(zhiqiu): Conv + BN always occur together, we needn't
// cast X of batch_norm to FP32, which is produced by conv as FP16 type.
if (op_type == "batch_norm" && pair.first == "X" &&
dst_type == framework::proto::VarType::FP32) {
new_ins[pair.first].emplace_back(var);
continue;
}
auto new_var = dst_type == framework::proto::VarType::FP32
? CastToFP32(var)
: CastToFP16(var);
new_ins[pair.first].emplace_back(new_var);
for (auto& var : pair.second) {
var = (dst_type == framework::proto::VarType::FP32 ? CastToFP32(var)
: CastToFP16(var));
}
}
return new_ins;
}
return ins;
return new_ins;
}

} // namespace imperative
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/imperative/amp_auto_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ class AmpOperators {

static AmpOperators& Instance();

std::shared_ptr<std::unordered_set<std::string>> GetAllowOps();
std::shared_ptr<std::unordered_set<std::string>> GetMutableAllowOps();

std::shared_ptr<std::unordered_set<std::string>> GetBlockOps();
std::shared_ptr<std::unordered_set<std::string>> GetMutableBlockOps();

private:
AmpOperators(); // forbid calling default constructor
Expand All @@ -52,6 +52,8 @@ class AmpOperators {
std::shared_ptr<std::unordered_set<std::string>> block_ops_;
};

std::ostream& operator<<(std::ostream& os, AmpOperators& ops);

// NOTE(zhiqiu): AutoCastGuard is used for RAII.
class AutoCastGuard {
public:
Expand Down
39 changes: 21 additions & 18 deletions paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1257,27 +1257,30 @@ void BindImperative(py::module *m_ptr) {
py::return_value_policy::reference)
.def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName,
py::arg("key") = "dygraph_tmp")
.def(
"_set_amp_op_list",
[](imperative::Tracer &self,
std::unordered_set<std::string> &allow_ops,
std::unordered_set<std::string> &block_ops) {
// NOTE(zhiqiu): The automatic conversion in pybind11 between
// c++
// STL and python set/list/dict involve a copy operation that
// prevents pass-by-reference semantics, so it is ok to swap.
// The reaseon why not directly pass
// std::shared_ptr<std::unordered_set<std::string>>
// is that pybind11 forbid shared_ptr<T> where T is not custom
// type.
imperative::AmpOperators::Instance().GetAllowOps()->swap(allow_ops);
imperative::AmpOperators::Instance().GetBlockOps()->swap(block_ops);
})
.def("_set_amp_op_list",
[](imperative::Tracer &self,
std::unordered_set<std::string> &allow_ops,
std::unordered_set<std::string> &block_ops) {
// NOTE(zhiqiu): The automatic conversion in pybind11 between
// c++
// STL and python set/list/dict involve a copy operation that
// prevents pass-by-reference semantics, so it is ok to swap.
// The reaseon why not directly pass
// std::shared_ptr<std::unordered_set<std::string>>
// is that pybind11 forbid shared_ptr<T> where T is not custom
// type.
imperative::AmpOperators::Instance().GetMutableAllowOps()->swap(
allow_ops);
imperative::AmpOperators::Instance().GetMutableBlockOps()->swap(
block_ops);
VLOG(4) << "AMP operators changed, "
<< imperative::AmpOperators::Instance();
})
.def("_get_amp_op_list",
[](imperative::Tracer &self) {
return std::make_tuple(
*(imperative::AmpOperators::Instance().GetAllowOps()),
*(imperative::AmpOperators::Instance().GetBlockOps()));
*(imperative::AmpOperators::Instance().GetMutableAllowOps()),
*(imperative::AmpOperators::Instance().GetMutableBlockOps()));
})
.def("trace",
[](imperative::Tracer &self, const std::string &type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,5 +389,21 @@ def test_resnet(self):
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2))


class TestLayerNormFp16(unittest.TestCase):
r''' layer_norm and batch_norm support mixed inputs, i.e., only input x is fp16
and other params are fp32.
'''

def test_layer_norm_fp16(self):
if fluid.is_compiled_with_cuda():
with fluid.dygraph.guard(fluid.CUDAPlace(0)):
x = paddle.rand([2, 2, 2, 3])
layer_norm = paddle.nn.LayerNorm(x.shape[1:])
with paddle.amp.auto_cast(custom_white_list=['layer_norm']):
out = layer_norm(x)

self.assertTrue(out.dtype == fluid.core.VarDesc.VarType.FP16)


if __name__ == '__main__':
unittest.main()