Skip to content

Commit

Permalink
[Dy2stat] Add Test and Example Code for Different Access to ProgramTr…
Browse files Browse the repository at this point in the history
…anslator and Fix Related Bug (PaddlePaddle#23958)

To prepare for publishing APIs, I added tests for that we can access dy2stat through:

@fluid.dygraph.declarative
@fluid.dygraph.jit.declarative
fluid.dygraph.ProgramTranslator()
fluid.dygraph.dygraph_to_static.ProgramTranslator()
fluid.dygraph.dygraph_to_static.program_translator.ProgramTranslator()

It surprised me that we had bugs on those different usages. I have fixed them.

I also added example codes for these new APIs

This PR also pulls my current PR PaddlePaddle#23880, so the PR history is long. For reviewer information, you could review this PR after PaddlePaddle#23880 is merged
  • Loading branch information
zhhsplendid authored Apr 19, 2020
1 parent 2291634 commit 45e48c3
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 28 deletions.
4 changes: 4 additions & 0 deletions python/paddle/fluid/dygraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
from . import static_runner
from .static_runner import StaticModelRunner

from . import dygraph_to_static
from .dygraph_to_static import ProgramTranslator

__all__ = []
__all__ += layers.__all__
__all__ += base.__all__
Expand All @@ -57,3 +60,4 @@
__all__ += learning_rate_scheduler.__all__
__all__ += backward_strategy.__all__
__all__ += jit.__all__
__all__ += ['ProgramTranslator']
24 changes: 20 additions & 4 deletions python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api

__all__ = ['DygraphToStaticAst', 'convert_to_static']

Expand Down Expand Up @@ -96,9 +97,24 @@ def visit_FunctionDef(self, node):
self.generic_visit(node)
# Remove the decorated name of dygraph_to_static
if hasattr(node, 'decorator_list'):
decorator_list = [
d for d in node.decorator_list if d.id not in DECORATOR_NAMES
]
decorator_list = []
for d in node.decorator_list:
if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES:
raise NotImplementedError(
"ProgramTranslator hasn't implemented multiple decorators. Please remove "
+ d.id + " in " + self.decorate_func_name)
if isinstance(d, gast.Attribute):
full_attribute_name = get_attribute_full_name(d)
has_translate_decorator = False
for deco in DECORATOR_NAMES:
if deco in full_attribute_name:
has_translate_decorator = True
break
if not has_translate_decorator:
raise NotImplementedError(
"ProgramTranslator hasn't implemented multiple decorators. Please remove "
+ full_attribute_name + " in " +
self.decorate_func_name)
node.decorator_list = decorator_list
return node

Expand Down
Loading

0 comments on commit 45e48c3

Please sign in to comment.