Skip to content

Commit 182a40e

Browse files
author
Chris Elion
authored
Allow usage with tensorflow 2.0.0 (via tf.compat.v1) (#2665)
1 parent 540519b commit 182a40e

30 files changed

+85
-46
lines changed

.circleci/config.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,13 @@ workflows:
228228
executor: python373
229229
pyversion: 3.7.3
230230
# Test python 3.7 with the newest supported versions
231-
pip_constraints: test_constraints_max_version.txt
231+
pip_constraints: test_constraints_max_tf1_version.txt
232+
- build_python:
233+
name: python_3.7.3+tf2
234+
executor: python373
235+
pyversion: 3.7.3
236+
# Test python 3.7 with the newest supported versions
237+
pip_constraints: test_constraints_max_tf2_version.txt
232238
- markdown_link_check
233239
- protobuf_generation_check
234240
- deploy:

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ repos:
4444
.*_pb2.py|
4545
.*_pb2_grpc.py
4646
)$
47-
additional_dependencies: [flake8-comprehensions]
47+
# flake8-tidy-imports is used for banned-modules, not actually tidying
48+
additional_dependencies: [flake8-comprehensions, flake8-tidy-imports]
4849
- id: trailing-whitespace
4950
name: trailing-whitespace-markdown
5051
types: [markdown]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from mlagents.tf_utils.tf import tf as tf # noqa

ml-agents/mlagents/tf_utils/tf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# This should be the only place that we import tensorflow directly.
2+
# Everywhere else is caught by the banned-modules setting for flake8
3+
import tensorflow as tf # noqa I201
4+
from distutils.version import LooseVersion
5+
6+
7+
# LooseVersion handles things "1.2.3a" or "4.5.6-rc7" fairly sensibly.
8+
_is_tensorflow2 = LooseVersion(tf.__version__) >= LooseVersion("2.0.0")
9+
10+
if _is_tensorflow2:
11+
import tensorflow.compat.v1 as tf
12+
13+
tf.disable_v2_behavior()
14+
else:
15+
pass

ml-agents/mlagents/trainers/bc/models.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import tensorflow as tf
2-
import tensorflow.contrib.layers as c_layers
1+
from mlagents.tf_utils import tf
2+
33
from mlagents.trainers.models import LearningModel
44

55

@@ -44,9 +44,7 @@ def __init__(
4444
size,
4545
activation=None,
4646
use_bias=False,
47-
kernel_initializer=c_layers.variance_scaling_initializer(
48-
factor=0.01
49-
),
47+
kernel_initializer=tf.initializers.variance_scaling(0.01),
5048
)
5149
)
5250
self.action_probs = tf.concat(
@@ -93,7 +91,7 @@ def __init__(
9391
activation=None,
9492
use_bias=False,
9593
name="pre_action",
96-
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01),
94+
kernel_initializer=tf.initializers.variance_scaling(0.01),
9795
)
9896
self.clipped_sample_action = tf.clip_by_value(self.policy, -1, 1)
9997
self.sample_action = tf.identity(self.clipped_sample_action, name="action")

ml-agents/mlagents/trainers/components/bc/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import tensorflow as tf
1+
from mlagents.tf_utils import tf
2+
23
from mlagents.trainers.models import LearningModel
34

45

ml-agents/mlagents/trainers/components/reward_signals/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import abc
66

7-
import tensorflow as tf
7+
from mlagents.tf_utils import tf
88

99
from mlagents.envs.brain import BrainInfo
1010
from mlagents.trainers.trainer import UnityTrainerException

ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List, Tuple
2-
import tensorflow as tf
2+
from mlagents.tf_utils import tf
3+
34
from mlagents.trainers.models import LearningModel
45

56

ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Dict, List
22
import numpy as np
3-
import tensorflow as tf
3+
from mlagents.tf_utils import tf
4+
45
from mlagents.envs.brain import BrainInfo
56

67
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult

ml-agents/mlagents/trainers/components/reward_signals/gail/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Optional, Tuple
22

3-
import tensorflow as tf
3+
from mlagents.tf_utils import tf
4+
45
from mlagents.trainers.models import LearningModel
56

67
EPSILON = 1e-7

0 commit comments

Comments
 (0)