Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Jan 9, 2024
1 parent 2f41a3a commit 95256ea
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 19 deletions.
6 changes: 3 additions & 3 deletions test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def eval_symbolic(self, use_cinn):
out = net(self.x)
return out

def test_eval_symolic(self):
def test_eval_symbolic(self):
cinn_out = self.eval_symbolic(use_cinn=True)
dy_out = self.eval_symbolic(use_cinn=False)
np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)
Expand All @@ -128,7 +128,7 @@ def eval_symbolic(self, use_cinn):
out = net(self.x)
return out

def test_eval_symolic(self):
def test_eval_symbolic(self):
import os

is_debug = os.getenv('IS_DEBUG_DY_SHAPE')
Expand Down Expand Up @@ -161,7 +161,7 @@ def eval_symbolic(self, use_cinn):
out = net(self.x, self.y)
return out

def test_eval_symolic(self):
def test_eval_symbolic(self):
# cinn_out = self.eval_symbolic(use_cinn=True)
dy_out = self.eval_symbolic(use_cinn=False)
# np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)
Expand Down
6 changes: 3 additions & 3 deletions test/ir/pir/fused_pass/pass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def check_fused_ops(self, program):
)
op_names = [op.name() for op in program.global_block().ops]
for valid_op_name, valid_op_count in self.valid_op_map.items():
acctual_valid_op_count = op_names.count(valid_op_name)
actual_valid_op_count = op_names.count(valid_op_name)
self.assertTrue(
valid_op_count == acctual_valid_op_count,
valid_op_count == actual_valid_op_count,
"Checking of the number of fused operator < {} > failed. "
"Expected: {}, Received: {}".format(
valid_op_name, valid_op_count, acctual_valid_op_count
valid_op_name, valid_op_count, actual_valid_op_count
),
)

Expand Down
8 changes: 4 additions & 4 deletions test/ir/pir/fused_pass/test_conv2d_add_act_fuse_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class TestConv2dAddActFusePattern(PassTest):
def is_program_valid(self, program):
return True

def build_ir_progam(self):
def build_ir_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
Expand Down Expand Up @@ -86,7 +86,7 @@ def setUp(self):
self.skip_accuracy_verification = True

def sample_program(self):
yield self.build_ir_progam(), False
yield self.build_ir_program(), False

def test_check_output(self):
self.check_pass_correct()
Expand Down Expand Up @@ -114,7 +114,7 @@ class TestConv2dAdd2ActFusePattern(PassTest):
def is_program_valid(self, program):
return True

def build_ir_progam(self):
def build_ir_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
Expand Down Expand Up @@ -165,7 +165,7 @@ def setUp(self):
self.skip_accuracy_verification = True

def sample_program(self):
yield self.build_ir_progam(), False
yield self.build_ir_program(), False

def test_check_output(self):
self.check_pass_correct()
Expand Down
4 changes: 2 additions & 2 deletions test/ir/pir/fused_pass/test_conv2d_add_fuse_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class TestConv2dAddFusePass(PassTest):
def is_program_valid(self, program=None):
return True

def build_ir_progam(self):
def build_ir_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
Expand Down Expand Up @@ -70,7 +70,7 @@ def build_ir_progam(self):
return [main_prog, start_prog]

def sample_program(self):
yield self.build_ir_progam(), False
yield self.build_ir_program(), False

def setUp(self):
if core.is_compiled_with_cuda():
Expand Down
4 changes: 2 additions & 2 deletions test/ir/pir/fused_pass/test_conv2d_bn_fuse_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class TestConv2dBnPassPattern(PassTest):
def is_program_valid(self, program=None):
return True

def build_ir_progam(self):
def build_ir_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
Expand Down Expand Up @@ -70,7 +70,7 @@ def build_ir_progam(self):
return [main_prog, start_prog]

def sample_program(self):
pir_program = self.build_ir_progam()
pir_program = self.build_ir_program()
yield pir_program, False

def test_check_output(self):
Expand Down
4 changes: 2 additions & 2 deletions test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_cuda_version():
"weight_only_linear requires CUDA >= 11.2",
)
class TestFusedWeightOnlyLinearPass_Fp32(PassTest):
def is_conifg_valid(self, w_shape, bias_shape):
def is_config_valid(self, w_shape, bias_shape):
if w_shape[-1] != bias_shape[0]:
return False

Expand Down Expand Up @@ -99,7 +99,7 @@ def sample_program(self):
for dtype in ['float16', "float32"]:
for w_shape in [[64, 64], [64, 15]]:
for bias_shape in [[64], [15]]:
if self.is_conifg_valid(w_shape, bias_shape) is False:
if self.is_config_valid(w_shape, bias_shape) is False:
continue
with paddle.pir_utils.IrGuard():
start_prog = paddle.static.Program()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def sample_program(self):
for x_shape in [[3, 2]]:
for w_shape in [[2, 3]]:
for y_shape in [[1, 3], [3]]:
for bais_shape in [[3, 3]]:
for bias_shape in [[3, 3]]:
for with_relu in [True, False]:
with paddle.pir_utils.IrGuard():
start_prog = paddle.static.Program()
Expand Down Expand Up @@ -68,7 +68,7 @@ def sample_program(self):

bias1 = paddle.static.data(
name='bias1',
shape=bais_shape,
shape=bias_shape,
dtype='float32',
)

Expand All @@ -93,7 +93,7 @@ def sample_program(self):
"float32"
),
"bias1": np.random.random(
bais_shape
bias_shape
).astype("float32"),
}
self.fetch_list = [out]
Expand Down

0 comments on commit 95256ea

Please sign in to comment.