Skip to content

Commit

Permalink
[PIR] migrate DataFeeder into pir (#60434)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarioLulab authored Jan 4, 2024
1 parent 488bd17 commit d307890
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 70 deletions.
42 changes: 30 additions & 12 deletions python/paddle/base/data_feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import numpy as np

from paddle import pir

from ..pir import Value
from ..pir.core import ParameterMeta
from . import core
Expand Down Expand Up @@ -419,19 +421,35 @@ def __init__(self, feed_list, place, program=None):
self.feed_names = []
self.feed_shapes = []
self.feed_lod_level = []
if program is None:
program = default_main_program()
for each_var in feed_list:
if isinstance(each_var, str):
each_var = program.block(0).var(each_var)
if not isinstance(each_var, Variable):
raise TypeError("Feed list should contain a list of variable")
self.feed_dtypes.append(each_var.dtype)
self.feed_names.append(each_var.name)
self.feed_lod_level.append(each_var.lod_level)
self.feed_shapes.append(each_var.shape)

self.place = place
if in_pir_mode():
if program is None:
program = pir.core.default_main_program()
for each_var in feed_list:
if isinstance(each_var, str):
raise ValueError(
"In PIR Mode, Not supported string input yet"
)
if not isinstance(each_var, Value):
raise TypeError("Feed list should contain a list of Value")
self.feed_dtypes.append(each_var.dtype)
self.feed_names.append(each_var.name)
self.feed_lod_level.append(each_var.lod_level)
self.feed_shapes.append(each_var.shape)
else:
if program is None:
program = default_main_program()
for each_var in feed_list:
if isinstance(each_var, str):
each_var = program.block(0).var(each_var)
if not isinstance(each_var, Variable):
raise TypeError(
"Feed list should contain a list of variable"
)
self.feed_dtypes.append(each_var.dtype)
self.feed_names.append(each_var.name)
self.feed_lod_level.append(each_var.lod_level)
self.feed_shapes.append(each_var.shape)

def feed(self, iterable):
"""
Expand Down
1 change: 1 addition & 0 deletions python/paddle/static/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def _reset_data_op_insertion_point():
ir_dtype = paddle.pir.core.convert_np_dtype_to_dtype_(dtype)
_reset_data_op_insertion_point()
out = paddle._pir_ops.data(name, shape, ir_dtype, core.Place())
out.lod_level = lod_level
paddle.pir.reset_insertion_point_to_end()
return out

Expand Down
151 changes: 93 additions & 58 deletions test/legacy_test/test_data_feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,74 +16,109 @@

import paddle
from paddle import base
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()


class TestDataFeeder(unittest.TestCase):
@test_with_pir_api
def test_lod_level_0_converter(self):
img = paddle.static.data(name='image', shape=[-1, 1, 28, 28])
label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64')
feeder = base.DataFeeder([img, label], base.CPUPlace())
result = feeder.feed([([0] * 784, [9]), ([1] * 784, [1])])

self.assertEqual(result['image'].shape(), [2, 1, 28, 28])
self.assertEqual(result['label'].shape(), [2, 1])
self.assertEqual(result['image'].recursive_sequence_lengths(), [])
self.assertEqual(result['label'].recursive_sequence_lengths(), [])

try:
result = feeder.feed([([0] * 783, [9]), ([1] * 783, [1])])
self.assertTrue(False)
except ValueError:
self.assertTrue(True)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
img = paddle.static.data(name='image', shape=[-1, 1, 28, 28])
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
feeder = base.DataFeeder([img, label], base.CPUPlace())
result = feeder.feed([([0] * 784, [9]), ([1] * 784, [1])])

self.assertEqual(result['image'].shape(), [2, 1, 28, 28])
self.assertEqual(result['label'].shape(), [2, 1])
self.assertEqual(result['image'].recursive_sequence_lengths(), [])
self.assertEqual(result['label'].recursive_sequence_lengths(), [])

try:
result = feeder.feed([([0] * 783, [9]), ([1] * 783, [1])])
self.assertTrue(False)
except ValueError:
self.assertTrue(True)

@test_with_pir_api
def test_lod_level_1_converter(self):
# lod_level = 1
# each sentence has a different number of words
sentences = paddle.static.data(
name='sentences', shape=[-1, 1], dtype='int64', lod_level=1
)
label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64')
feeder = base.DataFeeder([sentences, label], base.CPUPlace())

# lod = [[0, 3, 5, 9]]
# data = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
# label = [1] * len(data)
result = feeder.feed(
[([1, 2, 3], [1]), ([4, 5], [1]), ([6, 7, 8, 9], [1])]
)

self.assertEqual(result['sentences'].shape(), [9, 1])
self.assertEqual(result['label'].shape(), [3, 1])
self.assertEqual(
result['sentences'].recursive_sequence_lengths(), [[3, 2, 4]]
)
self.assertEqual(result['label'].recursive_sequence_lengths(), [])
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
# lod_level = 1
# each sentence has a different number of words
sentences = paddle.static.data(
name='sentences', shape=[-1, 1], dtype='int64', lod_level=1
)
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
feeder = base.DataFeeder([sentences, label], base.CPUPlace())

# lod = [[0, 3, 5, 9]]
# data = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
# label = [1] * len(data)
result = feeder.feed(
[([1, 2, 3], [1]), ([4, 5], [1]), ([6, 7, 8, 9], [1])]
)

self.assertEqual(result['sentences'].shape(), [9, 1])
self.assertEqual(result['label'].shape(), [3, 1])
self.assertEqual(
result['sentences'].recursive_sequence_lengths(), [[3, 2, 4]]
)
self.assertEqual(result['label'].recursive_sequence_lengths(), [])

@test_with_pir_api
def test_lod_level_2_converter(self):
# lod_level = 2
# paragraphs -> sentences -> words
paragraphs = paddle.static.data(
name='paragraphs', shape=[-1, 1], dtype='int64', lod_level=2
)
label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64')
feeder = base.DataFeeder([paragraphs, label], base.CPUPlace())

# lod = [[0, 2, 3], [0, 3, 5, 9]]
# data = [[[1, 2, 3], [4, 5]], [[6, 7, 8, 9]]]
# label = [1] * len(data)
result = feeder.feed(
[([[1, 2, 3], [4, 5]], [1]), ([[6, 7, 8, 9]], [1])]
)

self.assertEqual(result['paragraphs'].shape(), [9, 1])
self.assertEqual(result['label'].shape(), [2, 1])
self.assertEqual(
result['paragraphs'].recursive_sequence_lengths(),
[[2, 1], [3, 2, 4]],
)
self.assertEqual(result['label'].recursive_sequence_lengths(), [])
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
# lod_level = 2
# paragraphs -> sentences -> words
paragraphs = paddle.static.data(
name='paragraphs', shape=[-1, 1], dtype='int64', lod_level=2
)
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
feeder = base.DataFeeder([paragraphs, label], base.CPUPlace())

# lod = [[0, 2, 3], [0, 3, 5, 9]]
# data = [[[1, 2, 3], [4, 5]], [[6, 7, 8, 9]]]
# label = [1] * len(data)
result = feeder.feed(
[([[1, 2, 3], [4, 5]], [1]), ([[6, 7, 8, 9]], [1])]
)

self.assertEqual(result['paragraphs'].shape(), [9, 1])
self.assertEqual(result['label'].shape(), [2, 1])
self.assertEqual(
result['paragraphs'].recursive_sequence_lengths(),
[[2, 1], [3, 2, 4]],
)
self.assertEqual(result['label'].recursive_sequence_lengths(), [])

def test_errors(self):
def pir_mode_not_supported_str_feed():
with paddle.pir_utils.IrGuard():
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
img = paddle.static.data(
name='image', shape=[-1, 1, 28, 28]
)
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
feeder = base.DataFeeder(['image', label], base.CPUPlace())

self.assertRaises(ValueError, pir_mode_not_supported_str_feed)


if __name__ == '__main__':
Expand Down

0 comments on commit d307890

Please sign in to comment.