Skip to content

Commit 2c86e39

Browse files
committed
Call new MPI init fn in mpi_learn.py; call print_unique() about signals
1 parent 64bfa8a commit 2c86e39

File tree

6 files changed

+36
-24
lines changed

6 files changed

+36
-24
lines changed

data/signals.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import print_function
2+
import plasma.global_vars as g
23
import numpy as np
34
import sys
45

@@ -57,27 +58,27 @@ def get_units(str):
5758
found = True
5859

5960
except Exception as e:
60-
print(e)
61+
g.print_unique(e)
6162
sys.stdout.flush()
6263
pass
6364

6465
# Retrieve data from PTDATA if node not found
6566
if not found:
66-
# print("not in full path {}".format(signal))
67+
# g.print_unique("not in full path {}".format(signal))
6768
data = c.get('_s = ptdata2("'+signal+'",'+str(shot)+')').data()
6869
if len(data) != 1:
6970
rank = np.ndim(data)
7071
found = True
7172
# Retrieve data from Pseudo-pointname if not in ptdata
7273
if not found:
73-
# print("not in PTDATA {}".format(signal))
74+
# g.print_unique("not in PTDATA {}".format(signal))
7475
data = c.get('_s = pseudo("'+signal+'",'+str(shot)+')').data()
7576
if len(data) != 1:
7677
rank = np.ndim(data)
7778
found = True
7879
# this means the signal wasn't found
7980
if not found:
80-
print("No such signal: {}".format(signal))
81+
g.print_unique("No such signal: {}".format(signal))
8182
pass
8283

8384
# get time base
@@ -125,7 +126,7 @@ def fetch_jet_data(signal_path, shot_num, c):
125126
signal_path, shot_num)).data()
126127
found = True
127128
except Exception as e:
128-
print(e)
129+
g.print_unique(e)
129130
sys.stdout.flush()
130131
# pass
131132
return time, data, ydata, found
@@ -361,8 +362,9 @@ def fetch_nstx_data(signal_path, shot_num, c):
361362

362363
all_signals_restricted = all_signals
363364

364-
print('All signals (determines which signals are downloaded & preprocessed):')
365-
print(all_signals.values())
365+
g.print_unique('All signals (determines which signals are downloaded'
366+
' & preprocessed):')
367+
g.print_unique(all_signals.values())
366368

367369
fully_defined_signals = {
368370
sig_name: sig for (sig_name, sig) in all_signals_restricted.items() if (

examples/mpi_learn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import plasma.global_vars as g
2+
g.init_MPI()
23
from plasma.models.mpi_runner import (
34
mpi_train, mpi_make_predictions_and_evaluate
45
)

plasma/conf_parser.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import print_function
2+
import plasma.global_vars as g
13
from plasma.primitives.shots import ShotListFiles
24
import data.signals as sig
35
from plasma.utils.hashing import myhash_signals
@@ -74,7 +76,7 @@ def parameters(input_file):
7476
elif params['target'] == 'ttdlinear':
7577
params['data']['target'] = TTDLinearTarget
7678
else:
77-
print('Unkown type of target. Exiting')
79+
g.print_unique('Unkown type of target. Exiting')
7880
exit(1)
7981

8082
# params['model']['output_activation'] =
@@ -344,15 +346,17 @@ def parameters(input_file):
344346
params['paths']['use_signals_dict'] = sig.fully_defined_signals_1D
345347

346348
else:
347-
print("Unkown data set {}".format(params['paths']['data']))
349+
g.print_unique("Unknown dataset {}".format(
350+
params['paths']['data']))
348351
exit(1)
349352

350353
if len(params['paths']['specific_signals']):
351354
for s in params['paths']['specific_signals']:
352355
if s not in params['paths']['use_signals_dict'].keys():
353-
print("Signal {} is not fully defined for {} machine. ",
354-
"Skipping...".format(
355-
s, params['paths']['data'].split("_")[0]))
356+
g.print_unique(
357+
"Signal {} is not fully defined for {} machine. ",
358+
"Skipping...".format(
359+
s, params['paths']['data'].split("_")[0]))
356360
params['paths']['specific_signals'] = list(
357361
filter(
358362
lambda x: x in params['paths']['use_signals_dict'].keys(),
@@ -370,9 +374,9 @@ def parameters(input_file):
370374
params['paths']['all_signals'] = sort_by_channels(
371375
list(params['paths']['all_signals_dict'].values()))
372376

373-
print("Selected signals (determines which signals are used for",
374-
"training):\n{}".format(params['paths']['use_signals']))
375-
377+
g.print_unique("Selected signals (determines which signals are used",
378+
" for training):\n{}".format(
379+
params['paths']['use_signals']))
376380
params['paths']['shot_files_all'] = (
377381
params['paths']['shot_files'] + params['paths']['shot_files_test'])
378382
params['paths']['all_machines'] = list(

plasma/global_vars.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import print_function
12
import sys
23

34
# global variable defaults for non-MPI runs
@@ -9,15 +10,16 @@
910
backend = ''
1011

1112

12-
def init_MPI(conf):
13+
def init_MPI():
1314
from mpi4py import MPI
1415
global comm, task_index, num_workers
15-
global NUM_GPUS, MY_GPU, backend
16-
1716
comm = MPI.COMM_WORLD
1817
task_index = comm.Get_rank()
1918
num_workers = comm.Get_size()
2019

20+
21+
def init_GPU_backend(conf):
22+
global NUM_GPUS, MY_GPU, backend
2123
NUM_GPUS = conf['num_gpus']
2224
MY_GPU = task_index % NUM_GPUS
2325
backend = conf['model']['backend']

plasma/models/mpi_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@
3939
import socket
4040
sys.setrecursionlimit(10000)
4141

42-
4342
# import keras sequentially because it otherwise reads from ~/.keras/keras.json
4443
# with too many threads:
4544
# from mpi_launch_tensorflow import get_mpi_task_index
4645

47-
# set global variables for entire module regarding MPI environment
48-
# TODO(KGF): consider moving this fn/init call to mpi_learn.py and/or
49-
# setting "mpi_initialized" global bool flag, since that is client-facing
50-
g.init_MPI(conf)
46+
# set global variables for entire module regarding MPI & GPU environment
47+
g.init_GPU_backend(conf)
48+
# moved this fn/init call to client-facing mpi_learn.py
49+
# g.init_MPI()
50+
# TODO(KGF): set "mpi_initialized" global bool flag?
5151

5252
# initialization code for mpi_runner.py module:
5353
if g.backend == 'tf' or g.backend == 'tensorflow':

setup.cfg

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,7 @@ ignore =
2626
E731,
2727
# W5: Line break warning
2828
# W503: line break before binary operator (use mutually exclusive W504)
29-
W503
29+
W503
30+
# suppres linter warning about MPI init fn call before module-level imports
31+
per-file-ignores =
32+
examples/mpi_learn.py:E402

0 commit comments

Comments
 (0)