Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/framework/op_kernel_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class OpKernelType {

size_t hash_key() const { return Hash()(*this); }

bool operator<(const OpKernelType& o) const {
return hash_key() < o.hash_key();
}

bool operator==(const OpKernelType& o) const;

bool operator!=(const OpKernelType& o) const { return !(*this == o); }
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
namespace paddle {
namespace imperative {

const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) {
return var->SharedVar();
}

const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<VariableWrapper>& var) {
return var;
}

const framework::Tensor* GetTensorFromVar(const framework::Variable& var) {
if (var.IsType<framework::LoDTensor>()) {
return &(var.Get<framework::LoDTensor>());
Expand Down
52 changes: 42 additions & 10 deletions paddle/fluid/imperative/prepared_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) {
}
}

extern const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var);
extern const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<VariableWrapper>& var);

template <typename VarType>
std::shared_ptr<NameVarMap<VarType>> PrepareData(
const framework::OperatorWithKernel& op, const NameVarMap<VarType>& ins,
Expand All @@ -82,23 +87,50 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
} else {
VLOG(3) << "Transform Variable " << var_base->Name() << " from "
<< kernel_type_for_var << " to " << expected_kernel_key;
framework::Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor,
&out);
if (NeedTransformDataType(kernel_type_for_var, expected_kernel_key)) {
// To avoid NameVarMap copy construction overhead in general
// scenarios, if inplace transformed, return original input directly

if (GetVariableWrapper(var_base)->hasCacheKey(expected_kernel_key)) {
VLOG(3) << "Hit variable_wrapper cache: key="
<< expected_kernel_key;
std::shared_ptr<VariableWrapper> cache_var =
GetVariableWrapper(var_base)->getCacheValue(
expected_kernel_key);
if (tmp_ins_ptr == nullptr) {
tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins);
}

const auto* tensor = GetTensorFromVar(cache_var->Var());
auto tmp_var = std::make_shared<VarType>(var_base->Name());
tmp_var->SetType(var_base->Type());
SetTensorToVariable(var_base->Var(), out, tmp_var->MutableVar());
SetTensorToVariable(cache_var->Var(), *tensor,
tmp_var->MutableVar());
(*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
} else {
// if dtype is same, transform inplace will not change the original
// value, transform inplace to avoid multiple copy
SetTensorToVariable(var_base->Var(), out, var_base->MutableVar());
framework::Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor,
&out);
if (NeedTransformDataType(kernel_type_for_var,
expected_kernel_key)) {
// To avoid NameVarMap copy construction overhead in general
// scenarios, if inplace transformed, return original input
// directly
if (tmp_ins_ptr == nullptr) {
tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins);
}
auto tmp_var = std::make_shared<VarType>(var_base->Name());
tmp_var->SetType(var_base->Type());
SetTensorToVariable(var_base->Var(), out, tmp_var->MutableVar());
(*tmp_ins_ptr)[name_pair.first][i] = tmp_var;

GetVariableWrapper(var_base)->setCacheValue(
expected_kernel_key, GetVariableWrapper(tmp_var));
VLOG(3) << "Set cache to variable_wrapper: key="
<< expected_kernel_key;
} else {
// if dtype is same, transform inplace will not change the
// original
// value, transform inplace to avoid multiple copy
SetTensorToVariable(var_base->Var(), out, var_base->MutableVar());
}
}
}
}
Expand Down
21 changes: 21 additions & 0 deletions paddle/fluid/imperative/variable_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

#pragma once

#include <map>
#include <memory>
#include <string>
#include <utility>

#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/hooks.h"
#include "paddle/fluid/imperative/op_base.h"
Expand Down Expand Up @@ -238,6 +240,21 @@ class VariableWrapper {
inplace_version_snapshot_ = new_version;
}

bool hasCacheKey(const paddle::framework::OpKernelType& key) {
return var_cache.find(key) != var_cache.end();
}

std::shared_ptr<VariableWrapper> getCacheValue(
const paddle::framework::OpKernelType& key) {
return var_cache[key];
}

void setCacheValue(const paddle::framework::OpKernelType& key,
std::shared_ptr<VariableWrapper> val) {
var_cache[key] = val;
return;
}

private:
void SetGradVar(const std::shared_ptr<VariableWrapper>& var) {
auto shared_var = grad_var_.lock();
Expand Down Expand Up @@ -311,6 +328,10 @@ class VariableWrapper {
framework::Variable var_;
std::string name_;

// Used for cache the dtype promotioned variableWrapper in real and complex
// compute of Paddle Quantum
std::map<paddle::framework::OpKernelType, std::shared_ptr<VariableWrapper>>
var_cache;
// add this property for users may set stop_gradient themselves and this
// should override the frameworks setting (-1) unset, (1) true, (0) false
int overrided_stop_gradient_{-1};
Expand Down