Skip to content

Commit

Permalink
Merge master
Browse files Browse the repository at this point in the history
  • Loading branch information
nikita-klsh committed Jan 21, 2020
2 parents b5e8191 + dbaddd6 commit aac2a5e
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 94 deletions.
48 changes: 27 additions & 21 deletions docker_containers/picking_docker/picking_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import sys
import argparse

import torch
import numpy as np

sys.path.append('../..')

from seismicpro.batchflow import Dataset, B
from seismicpro.batchflow.models.torch import UNet
from seismicpro.src import FieldIndex, TraceIndex, SeismicDataset
from seismicpro.batchflow import B, Pipeline
from seismicpro.batchflow.models.torch import UNet # pylint: disable=import-error
from seismicpro.src import TraceIndex, SeismicDataset

def make_prediction():
""" Read the model and data paths and run inference pipeline.
Expand All @@ -29,9 +28,10 @@ def make_prediction():
parser.add_argument('-bs', '--batch_size', type=int, help="The number of traces in \
the batch for inference stage.", default=1000)
parser.add_argument('-ts', '--trace_len', type=int, help="The number of first samples \
of the trace to load.", default=751)
parser.add_argument('-dvc', '--device', type=str or torch.device, help="The device for \
inference. Can be 'cpu' or 'gpu'.", default=torch.device('cpu'))
of the trace to load.", default=1000)
parser.add_argument('-dvc', '--device', type=str, help="The device for \
inference. Can be 'cpu' or 'gpu'.", default='cpu')
parser.add_argument('-s', '--shift', type=float, help="Picking time phase shift", default=0)
args = parser.parse_args()
path_raw = args.path_raw
model = args.path_model
Expand All @@ -40,9 +40,11 @@ def make_prediction():
batch_size = args.batch_size
trace_len = args.trace_len
device = args.device
predict(path_raw, model, num_zero, save_to, batch_size, trace_len, device)
shift = args.shift
predict(path_raw, model, num_zero, save_to, batch_size, trace_len, device, shift)

def predict(path_raw, path_model, num_zero, save_to, batch_size, trace_len, device):
def predict(path_raw, path_model, num_zero=100, save_to='dump.csv',
batch_size=1000, trace_len=1000, device='cpu', shift=0):
"""Make predictions and dump results using loaded model and path to data.
Parameters
Expand All @@ -61,6 +63,8 @@ def predict(path_raw, path_model, num_zero, save_to, batch_size, trace_len, devi
The number of first samples in the trace to load to the pipeline.
device: str or torch.device, default: 'cpu'
The device used for inference. Can be 'gpu' in case of avaliavle GPU.
shift: float, default: 0
Shift the picking times for each trace on the given phase shift, measured in radians.
"""
data = SeismicDataset(TraceIndex(name='raw', path=path_raw))
Expand All @@ -76,19 +80,21 @@ def predict(path_raw, path_model, num_zero, save_to, batch_size, trace_len, devi
except OSError:
pass

test_pipeline = (data.p
.init_model('dynamic', UNet, 'my_model', config=config_predict)
.load(components='raw', fmt='segy', tslice=np.arange(trace_len))
.drop_zero_traces(num_zero=num_zero, src='raw')
.standardize(src='raw', dst='raw')
.add_components(components='predictions')
.apply_transform_all(src='raw', dst='raw', func=lambda x: np.stack(x))
.predict_model('my_model', B('raw'), fetches='predictions',
save_to=B('predictions', mode='a'))
.mask_to_pick(src='predictions', dst='predictions', labels=False)
.dump(src='predictions', fmt='picks', path=save_to,
traces='raw', to_samples=True))
test_tmpl = (data.p
.init_model('dynamic', UNet, 'my_model', config=config_predict)
.load(components='raw', fmt='segy', tslice=slice(0, trace_len))
.drop_zero_traces(num_zero=num_zero, src='raw')
.standardize(src='raw', dst='raw')
.add_components(components='predictions')
.apply_transform_all(src='raw', dst='raw', func=lambda x: np.stack(x))
.predict_model('my_model', B('raw'), fetches='predictions',
save_to=B('predictions', mode='a'))
.mask_to_pick(src='predictions', dst='predictions', labels=False)
)
if shift:
test_tmpl += Pipeline().shift_pick_phase(src='predictions', dst='predictions', src_traces='raw', shift=shift)

test_pipeline = test_tmpl + Pipeline().dump(src='predictions', fmt='picks', path=save_to, src_traces='raw')
test_pipeline.run(batch_size, n_epochs=1, drop_last=False, shuffle=False, bar=True)

if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions models/First_break_picking/1d_CNN/model_description.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@
" .load(components='raw', fmt='segy')\n",
" .load(components='markup', fmt='picks')\n",
" .standardize(src='raw', dst='raw')\n",
" .picking_to_mask(src='markup', dst='mask')\n",
" .picking_to_mask(src='markup', dst='mask', src_traces='raw')\n",
" .init_model('dynamic', UNet, 'my_model', config)\n",
" .init_variable('loss', init_on_each_run=list)\n",
" .apply_transform_all(src='raw', dst='raw', func=lambda x: np.stack(x))\n",
Expand Down Expand Up @@ -597,7 +597,7 @@
" .load(components='raw', fmt='segy')\n",
" .load(components='markup', fmt='picks')\n",
" .standardize(src='raw', dst='raw')\n",
" .picking_to_mask(src='markup', dst='mask')\n",
" .picking_to_mask(src='markup', dst='mask', src_traces='raw')\n",
" .apply_transform_all(src='raw', dst='raw', func=lambda x: np.stack(x))\n",
" .update_variable('traces', B('raw'), mode='a')\n",
" .apply_transform_all(src='mask', dst='mask', func=lambda x: np.stack(x))\n",
Expand All @@ -607,7 +607,7 @@
" save_to=B('predictions', mode='a'))\n",
" .mask_to_pick(src='predictions', dst='predictions', labels=False)\n",
" .update_variable('predictions', B('predictions'), mode='a')\n",
" .dump(src='predictions', fmt='picks',path='model_predictions.csv', traces='raw', to_samples=True)\n",
" .dump(src='predictions', fmt='picks',path='model_predictions.csv', src_traces='raw')\n",
" .run_later(1000, n_epochs=1, drop_last=False, shuffle=False, bar=True))"
]
},
Expand Down Expand Up @@ -1306,5 +1306,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
6 changes: 3 additions & 3 deletions models/First_break_picking/1d_CNN/research.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
" .load(components='raw', fmt='segy')\n",
" .load(components='markup', fmt='picks')\n",
" .standardize(src='raw', dst='raw')\n",
" .picking_to_mask(src='markup', dst='mask')\n",
" .picking_to_mask(src='markup', dst='mask', src_traces='raw')\n",
" .init_model('dynamic', UNet, 'my_model', config)\n",
" .init_variable('loss', init_on_each_run=list)\n",
" .apply_transform_all(src='raw', dst='raw', func=lambda x: np.stack(x))\n",
Expand All @@ -153,7 +153,7 @@
" .load(components='raw', fmt='segy')\n",
" .load(components='markup', fmt='picks')\n",
" .standardize(src='raw', dst='raw')\n",
" .picking_to_mask(src='markup', dst='mask')\n",
" .picking_to_mask(src='markup', dst='mask', src_traces='raw')\n",
" .update_variable('true', B('mask'), mode='a')\n",
" .add_components(components='predictions')\n",
" .apply_transform_all(src='raw', dst='raw', func=lambda x: np.stack(x))\n",
Expand Down Expand Up @@ -412,5 +412,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
Loading

0 comments on commit aac2a5e

Please sign in to comment.