Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions python/paddle/v2/framework/default_scope_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Default scope function.

`Paddle` manages Scope as programming language's scope. It just a
thread-local stack of Scope. Top of that stack is current scope, the bottom
of that stack is all scopes' parent.

Invoking `create_var/get_var` can `create/get` variable in current scope.
Invoking `enter_local_scope/leave_local_scope` can create or destroy local
scope.

A `scoped_function` will take a `function` as input. That function will be
invoked in a new local scope.
"""

import paddle.v2.framework.core
import threading

__tl_scope__ = threading.local()

__all__ = [
'get_cur_scope', 'enter_local_scope', 'leave_local_scope', 'create_var',
'get_var', 'scoped_function'
]


def get_cur_scope():
"""
Get current scope.
:rtype: paddle.v2.framework.core.Scope
"""
cur_scope_stack = getattr(__tl_scope__, 'cur_scope', None)
if cur_scope_stack is None:
__tl_scope__.cur_scope = list()
if len(__tl_scope__.cur_scope) == 0:
__tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope(None))
return __tl_scope__.cur_scope[-1]


def enter_local_scope():
Copy link
Contributor

@Superjomn Superjomn Jul 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scope 必须是连续么?

scopeA:
   xxx

scope B:
    xxxx

scopeA:
    xxxx

不行么?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是特别清楚什么是『必须连续』?
但是,举的例子可以做到。

enter_local_scope()  # scope A
xxx

enter_local_scope() # scope B
xxx
leave_local_scope()
xxx
leave_local_scope()  # leave scope A

"""
Enter a new local scope
"""
cur_scope = get_cur_scope()
new_scope = paddle.v2.framework.core.Scope(cur_scope)
__tl_scope__.cur_scope.append(new_scope)


def leave_local_scope():
"""
Leave local scope
"""
__tl_scope__.cur_scope.pop()


def create_var(name):
"""
create variable in current scope.
"""
return get_cur_scope().create_var(name)


def get_var(name):
"""
get variable in current scope.
"""
return get_cur_scope().get_var(name)


def scoped_function(func):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the situation to use this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe used to define an RNN step net?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the scope of stepnet is create inside RnnOP as far as know

"""
invoke `func` in new scope.

:param func: a callable function that will be run in new scope.
:type func: callable
"""
enter_local_scope()
try:
func()
except:
raise
finally:
leave_local_scope()
3 changes: 2 additions & 1 deletion python/paddle/v2/framework/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_python_test(test_framework test_protobuf.py test_scope.py)
add_python_test(test_framework test_protobuf.py test_scope.py
test_default_scope_funcs.py)
33 changes: 33 additions & 0 deletions python/paddle/v2/framework/tests/test_default_scope_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from paddle.v2.framework.default_scope_funcs import *
import unittest


class TestDefaultScopeFuncs(unittest.TestCase):
def test_cur_scope(self):
self.assertIsNotNone(get_cur_scope())

def test_none_variable(self):
self.assertIsNone(get_var("test"))

def test_create_var_get_var(self):
var_a = create_var("var_a")
self.assertIsNotNone(var_a)
self.assertIsNotNone(get_cur_scope().get_var('var_a'))
enter_local_scope()
self.assertIsNotNone(get_cur_scope().get_var('var_a'))
leave_local_scope()

def test_var_get_int(self):
def __new_scope__():
i = create_var("var_i")
self.assertFalse(i.is_int())
i.set_int(10)
self.assertTrue(i.is_int())
self.assertEqual(10, i.get_int())

for _ in xrange(10):
scoped_function(__new_scope__)


if __name__ == '__main__':
unittest.main()