Skip to content

Commit 8822518

Browse files
Parameter space noise for DQN and DDPG (openai#75)
* Export param noise * Update documentation * Final finishing touches
1 parent df82a15 commit 8822518

21 files changed

+1369
-47
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ pip install baselines
1515
- [DQN](baselines/deepq)
1616
- [PPO](baselines/pposgd)
1717
- [TRPO](baselines/trpo_mpi)
18+
- [DDPG](baselines/ddpg)

baselines/common/azure_utils.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import zipfile
44

55
from azure.common import AzureMissingResourceHttpError
6-
from azure.storage.blob import BlobService
6+
try:
7+
from azure.storage.blob import BlobService
8+
except ImportError:
9+
from azure.storage.blob import BlockBlobService as BlobService
710
from shutil import unpack_archive
811
from threading import Event
912

@@ -114,18 +117,23 @@ def progress_callback(current, total):
114117
arcpath = os.path.join(td, "archive.zip")
115118
for backup_blob_name in [blob_name, blob_name + '.backup']:
116119
try:
117-
blob_size = self._service.get_blob_properties(
120+
properties = self._service.get_blob_properties(
118121
blob_name=backup_blob_name,
119122
container_name=self._container_name
120-
)['content-length']
123+
)
124+
if hasattr(properties, 'properties'):
125+
# Annoyingly, Azure has changed the API and this now returns a blob
126+
# instead of it's properties with up-to-date azure package.
127+
blob_size = properties.properties.content_length
128+
else:
129+
blob_size = properties['content-length']
121130
if int(blob_size) > 0:
122131
self._service.get_blob_to_path(
123132
container_name=self._container_name,
124133
blob_name=backup_blob_name,
125134
file_path=arcpath,
126135
max_connections=4,
127-
progress_callback=progress_callback,
128-
max_retries=10)
136+
progress_callback=progress_callback)
129137
unpack_archive(arcpath, dest_path)
130138
download_done.wait()
131139
return True

baselines/common/misc_util.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,9 @@ def boolean_flag(parser, name, default=False, help=None):
237237
help: str
238238
help string for the flag
239239
"""
240-
parser.add_argument("--" + name, action="store_true", default=default, help=help)
241-
parser.add_argument("--no-" + name, action="store_false", dest=name)
240+
dest = name.replace('-', '_')
241+
parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help)
242+
parser.add_argument("--no-" + name, action="store_false", dest=dest)
242243

243244

244245
def get_wrapper_by_name(env, classname):

baselines/common/mpi_fork.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os, subprocess, sys
22

3-
def mpi_fork(n):
3+
def mpi_fork(n, bind_to_core=False):
44
"""Re-launches the current script with workers
55
Returns "parent" for original parent, "child" for MPI children
66
"""
@@ -13,7 +13,11 @@ def mpi_fork(n):
1313
OMP_NUM_THREADS="1",
1414
IN_MPI="1"
1515
)
16-
subprocess.check_call(["mpirun", "-np", str(n), sys.executable] + sys.argv, env=env)
16+
args = ["mpirun", "-np", str(n)]
17+
if bind_to_core:
18+
args += ["-bind-to", "core"]
19+
args += [sys.executable] + sys.argv
20+
subprocess.check_call(args, env=env)
1721
return "parent"
1822
else:
1923
return "child"

baselines/ddpg/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)