@@ -23,8 +23,8 @@ limitations under the License. */
2323#include " paddle/fluid/framework/data_type_transform.h"
2424#include " paddle/fluid/framework/details/nan_inf_utils.h"
2525#include " paddle/fluid/framework/op_call_stack.h"
26+ #include " paddle/fluid/framework/pten_utils.h"
2627#include " paddle/fluid/framework/shape_inference.h"
27- #include " paddle/fluid/framework/tcmpt_utils.h"
2828#include " paddle/fluid/framework/transfer_scope_cache.h"
2929#include " paddle/fluid/framework/unused_var_check.h"
3030#include " paddle/fluid/framework/var_type.h"
@@ -1140,7 +1140,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
11401140 // and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second
11411141 // phase
11421142 if (FLAGS_run_pt_kernel &&
1143- pt ::KernelFactory::Instance ().ContainsKernel (type_.c_str ())) {
1143+ pten ::KernelFactory::Instance ().ContainsKernel (type_.c_str ())) {
11441144 if (pt_kernel_signature_.get () == nullptr || pt_kernel_.get () == nullptr ) {
11451145 ChoosePtKernel (exe_ctx);
11461146 }
@@ -1286,10 +1286,11 @@ void OperatorWithKernel::ChoosePtKernel(const ExecutionContext& ctx) const {
12861286
12871287 kernel_type_.reset (new OpKernelType (InnerGetExpectedKernelType (ctx)));
12881288
1289- auto pt_kernel_name = pt ::KernelName (pt_kernel_signature_->first );
1289+ auto pt_kernel_name = pten ::KernelName (pt_kernel_signature_->first );
12901290 auto pt_kernel_key = TransOpKernelTypeToPtKernelKey (*kernel_type_.get ());
1291- pt_kernel_.reset (new pt::Kernel (pt::KernelFactory::Instance ().SelectKernel (
1292- pt_kernel_name, pt_kernel_key)));
1291+ pt_kernel_.reset (
1292+ new pten::Kernel (pten::KernelFactory::Instance ().SelectKernel (
1293+ pt_kernel_name, pt_kernel_key)));
12931294
12941295 if (pt_kernel_->IsValid ()) {
12951296 VLOG (1 ) << " Static mode ChoosePtKernel - kernel name: " << pt_kernel_name
@@ -1781,7 +1782,7 @@ KernelSignature OperatorWithKernel::GetExpectedPtKernelArgs(
17811782 }
17821783}
17831784
1784- pt ::KernelContext OperatorWithKernel::BuildPtKernelContext (
1785+ pten ::KernelContext OperatorWithKernel::BuildPtKernelContext (
17851786 const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const {
17861787 VLOG (1 ) << RuntimeContextDebugString (ctx);
17871788
@@ -1792,7 +1793,7 @@ pt::KernelContext OperatorWithKernel::BuildPtKernelContext(
17921793 // 3. needless attributes remove
17931794 // 4. use pt Tensor directly
17941795 // 5. kernel input is not DenseTensor
1795- pt ::KernelContext op_kernel_ctx (dev_ctx);
1796+ pten ::KernelContext op_kernel_ctx (dev_ctx);
17961797
17971798 auto & input_names = std::get<0 >(pt_kernel_signature_->second );
17981799 auto & attr_names = std::get<1 >(pt_kernel_signature_->second );
@@ -1826,7 +1827,7 @@ pt::KernelContext OperatorWithKernel::BuildPtKernelContext(
18261827 << in_def.layout ;
18271828
18281829 auto ins_vector = ctx.inputs .at (input_names[i]);
1829- std::vector<std::shared_ptr<tcmpt ::TensorBase>> tmp_inputs;
1830+ std::vector<std::shared_ptr<pten ::TensorBase>> tmp_inputs;
18301831
18311832 for (auto var : ins_vector) {
18321833 auto pt_in = framework::InputVariableToPtTensor (*var, in_def);
@@ -1839,7 +1840,7 @@ pt::KernelContext OperatorWithKernel::BuildPtKernelContext(
18391840 auto out_def = output_defs.at (i);
18401841 auto outs_vector = ctx.outputs .at (output_names[i]);
18411842
1842- std::vector<std::shared_ptr<tcmpt ::TensorBase>> tmp_outputs;
1843+ std::vector<std::shared_ptr<pten ::TensorBase>> tmp_outputs;
18431844 for (auto var : outs_vector) {
18441845 auto pt_out = framework::OutputVariableToPtTensor (var, out_def);
18451846 tmp_outputs.emplace_back (pt_out);
@@ -1849,12 +1850,13 @@ pt::KernelContext OperatorWithKernel::BuildPtKernelContext(
18491850
18501851 for (size_t i = 0 ; i < attr_names.size (); ++i) {
18511852 auto & attr = Attrs ().at (attr_names[i]);
1852- if (attr_defs[i].type_index == std::type_index (typeid (pt ::Scalar))) {
1853+ if (attr_defs[i].type_index == std::type_index (typeid (pten ::Scalar))) {
18531854 // TODO(chenweihang): support other attrs later
18541855 // TODO(zhangyunfei): Scalar should hold scaler type, and we should check
18551856 // attribtue type by attr_defs
18561857 if (std::type_index (attr.type ()) == std::type_index (typeid (float ))) {
1857- op_kernel_ctx.EmplaceBackAttr (pt::Scalar (BOOST_GET_CONST (float , attr)));
1858+ op_kernel_ctx.EmplaceBackAttr (
1859+ pten::Scalar (BOOST_GET_CONST (float , attr)));
18581860 } else {
18591861 PADDLE_THROW (platform::errors::Unimplemented (
18601862 " unsupported cast op attribute `%s` to Scalar when construct "
0 commit comments