-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Allow usage with tensorflow 2.0.0 (via tf.compat.v1) #2665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
c6b8ecf
72c9954
8d9487e
b4c88b6
2f8b430
358112b
e646367
f37cdf5
27ac64f
bf4ba5c
7d5b606
ce23d62
c28f29a
7951602
f622d70
844384e
fb64da6
86f1291
d0faa12
3d7c4b8
7054b24
dccd9ce
7d5f00e
30ff573
419bba4
a7571eb
7576c19
144be20
2e12f91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from mlagents.trainers.tf import tf as tf # noqa | ||
from mlagents.trainers.tf import tf_flatten, tf_rnn, tf_variance_scaling # noqa |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import tensorflow as tf | ||
from mlagents.trainers import tf | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ervteng @awjuliani @andrewcoh (research pod folks) How do you feel about importing this way instead (instead of Note that either way, we should prevent directly importing "raw" tensorflow except for one central spot; I added some checks to flake8 to prevent this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either should be OK. I kind of like the current way rather than the other one since it seems less like we have a "special" version of TensorFlow built into mlagents. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By "current way", do you mean the one in the PR now, or on develop? I feel like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The one in the PR. There could definitely be an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved the import wrapper to |
||
|
||
from mlagents.trainers.models import LearningModel | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# This should be the only place that we import tensorflow directly. | ||
# Everywhere else is caught by the banned-modules setting for flake8 | ||
import tensorflow as tf # noqa I201 | ||
from distutils.version import LooseVersion | ||
|
||
|
||
# LooseVersion handles things "1.2.3a" or "4.5.6-rc7" fairly sensibly. | ||
_is_tensorflow2 = LooseVersion(tf.__version__) >= LooseVersion("2.0.0") | ||
|
||
# A few things that we use live in different places between tensorflow 1.x and 2.x | ||
# If anything new is added, please add it here | ||
|
||
if _is_tensorflow2: | ||
import tensorflow.compat.v1 as tf | ||
|
||
tf_variance_scaling = tf.initializers.variance_scaling | ||
tf_flatten = tf.layers.flatten | ||
tf_rnn = tf.nn.rnn_cell | ||
|
||
tf.disable_v2_behavior() | ||
else: | ||
import tensorflow.contrib.layers as c_layers | ||
|
||
tf_variance_scaling = c_layers.variance_scaling_initializer | ||
tf_flatten = c_layers.flatten | ||
tf_rnn = tf.contrib.rnn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned in the PR, I don't love the name. Would like to find a better way to handle this.
Also, the parameter name changed (factor vs scale) to I just left it as positional.