Skip to content

Commit a7c64ae

Browse files
authored
DistModel supports feed of list (#62945)
1 parent 41dc104 commit a7c64ae

File tree

1 file changed

+15
-1
lines changed
  • python/paddle/distributed/auto_parallel

1 file changed

+15
-1
lines changed

python/paddle/distributed/auto_parallel/api.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1926,7 +1926,21 @@ def __call__(self, *args):
19261926
if self._mode == "eval":
19271927
if self._engine._loss is None:
19281928
raise ValueError("Please set loss function before evaluation.")
1929-
feeds = self._make_feeds(list(args))
1929+
1930+
feed_list = []
1931+
for feed_item in list(args):
1932+
if isinstance(feed_item, (list, tuple)):
1933+
feed_list += list(feed_item)
1934+
elif isinstance(feed_item, paddle.Tensor):
1935+
feed_list += [feed_item]
1936+
elif isinstance(feed_item, core.LoDTensor):
1937+
feed_list += [feed_item]
1938+
else:
1939+
raise TypeError(
1940+
f"The inputs of DistModel should be list or tensor, but got {type(feed_item)}"
1941+
)
1942+
1943+
feeds = self._make_feeds(feed_list)
19301944
outs = self._engine.run(feeds)
19311945

19321946
if self._mode == "predict":

0 commit comments

Comments
 (0)