Skip to content

Commit

Permalink
AgeModel class for Bayesian age modeling
Browse files Browse the repository at this point in the history
  • Loading branch information
sarttiso committed Oct 6, 2024
1 parent f149955 commit 3f84d25
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/stratage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@
'fit_floating_model',
'age_depth',
'model_ls',
'model',
'AgeModel',
'model2ages']
226 changes: 167 additions & 59 deletions src/stratage/stratage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
import warnings

import numpy as np

Expand All @@ -9,8 +10,12 @@
import pytensor.tensor as pt
from pytensor.graph import Apply, Op

import arviz as az

from scipy.optimize import minimize_scalar, lsq_linear

from tqdm import tqdm

from .geochron import Geochron

from stratage import __version__
Expand Down Expand Up @@ -443,68 +448,171 @@ def model_ls(units, geochron,
return sed_rates, hiatuses


def model(units, geochron, sed_rates_prior, hiatuses_prior,
draws=1000, **kwargs):
"""MCMC modeling of sed rates.
class AgeModel:
"""Age model object for Bayesian inference of sedimentation rates and hiatuses.
User must provide priors for sedimentation rates and hiatuses, which are functions valid as dist arguments to pymc.CustomDist(dist=dist). In this case, the only valid signature is dist(size=size), since this function does not permit additional arguments to the distribution.
Args:
units (ndarray): nx2 array of unit bottom and top heights for n units.
Attributes:
units (numpy.ndarray): nx2 array of unit bottom and top heights for n units.
geochron (geochron.Geochron): Geochron object containing geochron constraints.
sed_rates_prior (function): Prior distribution for sedimentation rates. Must be valid as dist argument to pymc.CustomDist(dist=dist). Signature is sed_rate_prior(size=size).
hiatuses_prior (function): Prior distribution for hiatuses. Must be valid as dist argument to pymc.CustomDist(dist=dist). Signature is hiatus_prior(size=size).
draws (int, optional): Number of MCMC draws. Defaults to 1000.
**kwargs: Additional keyword arguments to pass to the MCMC sampler.
Returns:
arviz.InferenceData: ArviZ InferenceData object containing the MCMC trace.
n_units (int): Number of units.
n_contacts (int): Number of contacts.
units_trim (numpy.ndarray): Trimmed unit heights after adjusting for the top and bottom units.
sed_rates_ls (numpy.ndarray): Sedimentation rates from least squares model.
hiatuses_ls (numpy.ndarray): Hiatuses from least squares model.
model (pymc.Model): PyMC model object for Bayesian inference.
vars_list (list): List of variables in the pymc.model.
"""

# number of units
n_units = units.shape[0]
# number of contacts
n_contacts = n_units - 1
# confirm that geochron heights are within the section
geochron_height_check(units, geochron.h)
# trim the section to the top and bottom of the geochron constraints
units_trim = trim_units(units, geochron.h)
# create least squares model as initial guess
sed_rates_ls, hiatuses_ls = model_ls(units, geochron)
# create time increment log-like function
loglike_op = loglike_gen(geochron, units_trim)
# create model
coords = {'units': np.arange(n_units),
'contacts': np.arange(n_contacts),
'pairs': np.arange(geochron.n_pairs)}
model = pm.Model(coords=coords)
with model:
# sed rates
sed_rates = pm.CustomDist('sed_rates',
dist=sed_rates_prior,
shape=(n_units,),
dims='units')
model.set_initval(sed_rates, sed_rates_ls)
# hiatuses
hiatuses = pm.CustomDist('hiatuses',
dist=hiatuses_prior,
shape=(n_contacts,),
dims='contacts')
model.set_initval(hiatuses, hiatuses_ls)
# likelihood
likelihood = pm.CustomDist('likelihood',
pm.math.concatenate([sed_rates, hiatuses]),
observed=np.zeros(geochron.n_pairs),
logp=loglike_op)
# sample
vars_list = list(model.values_to_rvs.keys())[:-1]
with model:
trace = pm.sample(draws=draws, **kwargs)
return trace


def model2ages(trace, n_posterior=None):
# if n_posterior is None, use effective sample size to dictate number of posterior samples
if n_posterior is None:
# n_posterior = trace.n_eff
return
def __init__(self, units, geochron, sed_rates_prior, hiatuses_prior):
"""Initializes age model object for Bayesian inference of sedimentation rates and hiatuses.
Args:
units (numpy.ndarray): nx2 array of unit bottom and top heights for n units.
geochron (geochron.Geochron): Geochron object containing geochron constraints.
sed_rates_prior (function): Prior distribution for sedimentation rates. Must be valid as dist argument to pymc.CustomDist(dist=dist). Signature is sed_rate_prior(size=size).
hiatuses_prior (function): Prior distribution for hiatuses. Must be valid as dist argument to pymc.CustomDist(dist=dist). Signature is hiatus_prior(size=size).
"""
# assign attributes
self.units = units
self.geochron = geochron
self.sed_rates_prior = sed_rates_prior
self.hiatuses_prior = hiatuses_prior

# number of units
self.n_units = units.shape[0]
# number of contacts
self.n_contacts = self.n_units - 1
# confirm that geochron heights are within the section
geochron_height_check(self.units, self.geochron.h)
# trim the section to the top and bottom of the geochron constraints
self.units_trim = trim_units(self.units, self.geochron.h)
# create least squares model as initial guess
self.sed_rates_ls, self.hiatuses_ls = model_ls(self.units, self.geochron)
# create time increment log-like function
loglike_op = loglike_gen(self.geochron, self.units_trim)
# create time increment random-like function
rand_op = randlike_gen(self.geochron, self.units_trim)
# create model
coords = {'units': np.arange(self.n_units),
'contacts': np.arange(self.n_contacts),
'pairs': np.arange(geochron.n_pairs)}
self.model = pm.Model(coords=coords)
with self.model:
# sed rates
sed_rates = pm.CustomDist('sed_rates',
dist=sed_rates_prior,
shape=(self.n_units,),
dims='units')
self.model.set_initval(sed_rates, self.sed_rates_ls)
# hiatuses
hiatuses = pm.CustomDist('hiatuses',
dist=hiatuses_prior,
shape=(self.n_contacts,),
dims='contacts')
self.model.set_initval(hiatuses, self.hiatuses_ls)
# likelihood
likelihood = pm.CustomDist('likelihood',
pm.math.concatenate([sed_rates, hiatuses]),
observed=np.zeros(self.geochron.n_pairs),
logp=loglike_op,
random=rand_op)
# variable list
self.vars_list = list(self.model.values_to_rvs.keys())[:-1]

def sample_prior(self, draws=100):
"""Sample prior predictive distribution of age models.
Args:
draws (int, optional): Number of prior predictive draws. Defaults to 100.
Returns:
list: List of prior predictive samples of times; each element is a nx2 array of unit bottom and top times for n units.
"""
# sample prior
prior_params = pm.sample_prior_predictive(draws=draws, model=self.model).prior
# numpy arrays
sed_rates_prior = prior_params.sed_rates.to_numpy().squeeze()
hiatuses_prior = prior_params.hiatuses.to_numpy().squeeze()
# get times
times_prior = []
# mean age of lowest geochron constraint
mean_age = np.mean(self.geochron.rv[0].rvs(size=100))
# iterate over draws to generate times
for ii in range(draws):
# attempt to fit floating model
try:
with warnings.catch_warnings():
warnings.simplefilter("error", RuntimeWarning)
cur_time = fit_floating_model(sed_rates_prior[ii],
hiatuses_prior[ii],
self.units_trim,
self.geochron)
# if fit fails, use floating model pinned to mean age of lowest geochron constraint
except RuntimeWarning:
cur_time = get_times(sed_rates_prior[ii],
hiatuses_prior[ii],
self.units_trim) + mean_age
times_prior.append(cur_time)
return times_prior

def sample(self, draws=1000, **kwargs):
"""Sample posterior distribution of sedimentation rates and hiatus durations.
The output of this function must be transformed to age models.
Args:
draws (int, optional): Number of posterior draws. Defaults to 1000.
Returns:
arviz.InferenceData: ArviZ InferenceData object containing the MCMC trace.
"""
# sample
with self.model:
trace = pm.sample(draws=draws, **kwargs)
return trace

def trace2ages(self, trace, h=None, n_posterior=None):
"""Transform MCMC trace to age models.
Args:
trace (arviz.InferenceData): ArviZ InferenceData object containing the MCMC trace.
h (arraylike, optional): Heights at which to evaluate the age model. Defaults to None. If None, only times arrays are returned.
n_posterior (int, optional): Number of posterior samples. Defaults to None.
Returns:
list: List of age models; each element is a nx2 array of unit bottom and top times for n units.
numpy.ndarray: Array of age models at heights h; each row is an age model. Shape is (n_posterior, len(h)). Only returned if h is not None.
"""
n_chain = trace.posterior.chain.size
n_draws = trace.posterior.draw.size
# if n_posterior is None, take min of 10,000 and chain*draws
if n_posterior is None:
n_posterior = np.min([10000, n_chain*n_draws])
posterior_params = az.extract(trace, num_samples=n_posterior)
# get posterior samples
sed_rates_post = posterior_params.sed_rates.to_numpy().squeeze().T
hiatuses_post = posterior_params.hiatuses.to_numpy().squeeze().T
# get times
times_post = []
# iterate over posterior samples to generate times
for ii in tqdm(range(n_posterior),
desc='Anchoring floating age models'):
# fit floating model
cur_time = fit_floating_model(sed_rates_post[ii],
hiatuses_post[ii],
self.units_trim,
self.geochron)
times_post.append(cur_time)
# if no heights provided, return times arrays only
if h is None:
return times_post
# create age-depth models for heights
else:
t_posterior = np.zeros((n_posterior, len(h)))
for ii in tqdm(range(n_posterior),
desc='Interpolating heights to ages'):
t_posterior[ii, :] = age(times_post[ii],
self.units_trim, h)
return times_post, t_posterior

0 comments on commit 3f84d25

Please sign in to comment.