-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Default scope function #2808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Default scope function #2808
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
""" | ||
Default scope function. | ||
|
||
`Paddle` manages Scope as programming language's scope. It just a | ||
thread-local stack of Scope. Top of that stack is current scope, the bottom | ||
of that stack is all scopes' parent. | ||
|
||
Invoking `create_var/get_var` can `create/get` variable in current scope. | ||
Invoking `enter_local_scope/leave_local_scope` can create or destroy local | ||
scope. | ||
|
||
A `scoped_function` will take a `function` as input. That function will be | ||
invoked in a new local scope. | ||
""" | ||
|
||
import paddle.v2.framework.core | ||
import threading | ||
|
||
__tl_scope__ = threading.local() | ||
|
||
__all__ = [ | ||
'get_cur_scope', 'enter_local_scope', 'leave_local_scope', 'create_var', | ||
'get_var', 'scoped_function' | ||
] | ||
|
||
|
||
def get_cur_scope(): | ||
""" | ||
Get current scope. | ||
:rtype: paddle.v2.framework.core.Scope | ||
""" | ||
cur_scope_stack = getattr(__tl_scope__, 'cur_scope', None) | ||
if cur_scope_stack is None: | ||
__tl_scope__.cur_scope = list() | ||
if len(__tl_scope__.cur_scope) == 0: | ||
__tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope(None)) | ||
return __tl_scope__.cur_scope[-1] | ||
|
||
|
||
def enter_local_scope(): | ||
""" | ||
Enter a new local scope | ||
""" | ||
cur_scope = get_cur_scope() | ||
new_scope = paddle.v2.framework.core.Scope(cur_scope) | ||
__tl_scope__.cur_scope.append(new_scope) | ||
|
||
|
||
def leave_local_scope(): | ||
""" | ||
Leave local scope | ||
""" | ||
__tl_scope__.cur_scope.pop() | ||
|
||
|
||
def create_var(name): | ||
""" | ||
create variable in current scope. | ||
""" | ||
return get_cur_scope().create_var(name) | ||
|
||
|
||
def get_var(name): | ||
""" | ||
get variable in current scope. | ||
""" | ||
return get_cur_scope().get_var(name) | ||
|
||
|
||
def scoped_function(func): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the situation to use this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe used to define an RNN step net? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But the scope of stepnet is create inside RnnOP as far as know |
||
""" | ||
invoke `func` in new scope. | ||
|
||
:param func: a callable function that will be run in new scope. | ||
:type func: callable | ||
""" | ||
enter_local_scope() | ||
try: | ||
func() | ||
except: | ||
raise | ||
finally: | ||
leave_local_scope() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
add_python_test(test_framework test_protobuf.py test_scope.py) | ||
add_python_test(test_framework test_protobuf.py test_scope.py | ||
test_default_scope_funcs.py) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from paddle.v2.framework.default_scope_funcs import * | ||
import unittest | ||
|
||
|
||
class TestDefaultScopeFuncs(unittest.TestCase): | ||
def test_cur_scope(self): | ||
self.assertIsNotNone(get_cur_scope()) | ||
|
||
def test_none_variable(self): | ||
self.assertIsNone(get_var("test")) | ||
|
||
def test_create_var_get_var(self): | ||
var_a = create_var("var_a") | ||
self.assertIsNotNone(var_a) | ||
self.assertIsNotNone(get_cur_scope().get_var('var_a')) | ||
enter_local_scope() | ||
self.assertIsNotNone(get_cur_scope().get_var('var_a')) | ||
leave_local_scope() | ||
|
||
def test_var_get_int(self): | ||
def __new_scope__(): | ||
i = create_var("var_i") | ||
self.assertFalse(i.is_int()) | ||
i.set_int(10) | ||
self.assertTrue(i.is_int()) | ||
self.assertEqual(10, i.get_int()) | ||
|
||
for _ in xrange(10): | ||
scoped_function(__new_scope__) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scope 必须是连续么?
不行么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不是特别清楚什么是『必须连续』?
但是,举的例子可以做到。