Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbopd committed Jan 10, 2024
1 parent 5ad04bd commit f79fdb2
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 10 deletions.
1 change: 1 addition & 0 deletions cmake/flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ checkcompilercxx14flag()
if(NOT WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
else()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -std=c++17")
set(CMAKE_CXX_STANDARD 17)
endif()

Expand Down
13 changes: 8 additions & 5 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
#include "paddle/pir/core/type.h"
#include "paddle/pir/core/value.h"
#include "paddle/pir/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/pir/dialect/shape/ir/shape_attribute.h"
#include "paddle/pir/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_manager.h"
Expand Down Expand Up @@ -451,6 +452,7 @@ void BindOperation(py::module *m) {
[](Operation &self) -> py::dict {
py::dict attrs_dict;
for (auto &pair : self.attributes()) {
if (pair.second.isa<pir::shape::SymbolAttribute>()) continue;
attrs_dict[pair.first.c_str()] =
paddle::dialect::GetAttributeData(pair.second);
}
Expand Down Expand Up @@ -1564,11 +1566,12 @@ static bool HasDynamicShape(const Program &program) {
continue;
}
for (uint32_t i = 0; i < op.num_results(); ++i) {
if (op.result(i) && op.result(i)
.type()
.dyn_cast<pir::ShapedTypeInterface>()
.IsDynamicShape()) {
return true;
if (op.result(i) && op.result(i).type()) {
auto shaped_type =
op.result(i).type().dyn_cast<pir::ShapedTypeInterface>();
if (shaped_type &&
pir::ShapedTypeInterface::IsDynamicShape(shaped_type.GetShape()))
return true;
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,7 @@ def _create_program(self, is_infer_mode=False):
# TODO(lanxianghit) mv this into pass_fn
def shape_pass_fn(forward_program, backward_program):
pm = paddle.base.libpaddle.pir.PassManager()
paddle.base.libpaddle.pir.infer_symbolic_shape_pass(
pm, forward_program
)
paddle.base.libpaddle.pir.infer_symbolic_shape_pass(pm)
pm.run(forward_program)
return forward_program, backward_program

Expand Down
3 changes: 1 addition & 2 deletions test/cpp/pir/core/type_interface_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ TEST(shapedtype_test, shapedtype_test) {
dense_tensor_type_interface.GetElementType().isa<pir::Float32Type>(),
true);
EXPECT_EQ(dense_tensor_type_interface.GetShape(), dims);
EXPECT_EQ(dense_tensor_type_interface.kDynamic,
std::numeric_limits<int64_t>::min());
EXPECT_EQ(dense_tensor_type_interface.kDynamic, -1);
EXPECT_EQ(dense_tensor_type_interface.GetRank(), 2);
EXPECT_EQ(dense_tensor_type_interface.IsDynamic(2), false);
EXPECT_EQ(dense_tensor_type_interface.IsDynamicShape(dims), false);
Expand Down

0 comments on commit f79fdb2

Please sign in to comment.