Skip to content

Commit f14a1ff

Browse files
Merge pull request #497 from lxvm/integrands
add integrand interface
2 parents 434e752 + a6fd63a commit f14a1ff

File tree

5 files changed

+295
-39
lines changed

5 files changed

+295
-39
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com> and contributors"]
4-
version = "2.0.0"
4+
version = "1.99.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/SciMLBase.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,14 @@ abstract type AbstractDiffEqFunction{iip} <:
589589
"""
590590
$(TYPEDEF)
591591
592+
Base for types defining integrand functions.
593+
"""
594+
abstract type AbstractIntegralFunction{iip} <:
595+
AbstractSciMLFunction{iip} end
596+
597+
"""
598+
$(TYPEDEF)
599+
592600
Base for types defining optimization functions.
593601
"""
594602
abstract type AbstractOptimizationFunction{iip} <: AbstractSciMLFunction{iip} end
@@ -659,7 +667,9 @@ function specialization(::Union{ODEFunction{iip, specialize},
659667
RODEFunction{iip, specialize},
660668
NonlinearFunction{iip, specialize},
661669
OptimizationFunction{iip, specialize},
662-
BVPFunction{iip, specialize}}) where {iip,
670+
BVPFunction{iip, specialize},
671+
IntegralFunction{iip, specialize},
672+
BatchIntegralFunction{iip, specialize}}) where {iip,
663673
specialize}
664674
specialize
665675
end
@@ -787,7 +797,8 @@ export remake
787797

788798
export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, DAEFunction,
789799
DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction,
790-
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction
800+
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction,
801+
IntegralFunction, BatchIntegralFunction
791802

792803
export OptimizationFunction
793804

src/problems/basic_problems.jl

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -335,26 +335,16 @@ which are `Number`s or `AbstractVector`s with the same geometry as `u`.
335335
### Constructors
336336
337337
```
338-
IntegralProblem{iip}(f,lb,ub,p=NullParameters();
339-
nout=1, batch = 0, kwargs...)
338+
IntegralProblem(f,domain,p=NullParameters(); kwargs...)
339+
IntegralProblem(f,lb,ub,p=NullParameters(); kwargs...)
340340
```
341341
342-
- f: the integrand, callable function `y = f(u,p)` for out-of-place or `f(y,u,p)` for in-place.
342+
- f: the integrand, callable function `y = f(u,p)` for out-of-place (default) or an
343+
`IntegralFunction` or `BatchIntegralFunction` for inplace and batching optimizations.
344+
- domain: an object representing an integration domain, i.e. the tuple `(lb, ub)`.
343345
- lb: Either a number or vector of lower bounds.
344346
- ub: Either a number or vector of upper bounds.
345347
- p: The parameters associated with the problem.
346-
- nout: The output size of the function f. Defaults to 1, i.e., a scalar valued function.
347-
If `nout > 1` f is a vector valued function .
348-
- batch: The preferred number of points to batch. This allows user-side parallelization
349-
of the integrand. If `batch == 0` no batching is performed.
350-
If `batch > 0` both `u` and `y` get an additional dimension added to it.
351-
This means that:
352-
if `f` is a multi variable function each `u[:,i]` is a different point to evaluate `f` at,
353-
if `f` is a single variable function each `u[i]` is a different point to evaluate `f` at,
354-
if `f` is a vector valued function each `y[:,i]` is the evaluation of `f` at a different point,
355-
if `f` is a scalar valued function `y[i]` is the evaluation of `f` at a different point.
356-
Note that batch is a suggestion for the number of points,
357-
and it is not necessarily true that batch is the same as batchsize in all algorithms.
358348
- kwargs: Keyword arguments copied to the solvers.
359349
360350
Additionally, we can supply iip like IntegralProblem{iip}(...) as true or false to declare at
@@ -364,30 +354,58 @@ compile time whether the integrator function is in-place.
364354
365355
The fields match the names of the constructor arguments.
366356
"""
367-
struct IntegralProblem{isinplace, P, F, B, K} <: AbstractIntegralProblem{isinplace}
357+
struct IntegralProblem{isinplace, P, F, T, K} <: AbstractIntegralProblem{isinplace}
368358
f::F
369-
lb::B
370-
ub::B
371-
nout::Int
359+
domain::T
372360
p::P
373-
batch::Int
374361
kwargs::K
375-
@add_kwonly function IntegralProblem{iip}(f, lb, ub, p = NullParameters();
376-
nout = 1,
377-
batch = 0, kwargs...) where {iip}
378-
@assert typeof(lb)==typeof(ub) "Type of lower and upper bound must match"
362+
@add_kwonly function IntegralProblem{iip}(f::AbstractIntegralFunction{iip}, domain,
363+
p = NullParameters();
364+
kwargs...) where {iip}
379365
warn_paramtype(p)
380-
new{iip, typeof(p), typeof(f), typeof(lb), typeof(kwargs)}(f,
381-
lb, ub, nout, p,
382-
batch, kwargs)
366+
new{iip, typeof(p), typeof(f), typeof(domain), typeof(kwargs)}(f,
367+
domain, p, kwargs)
383368
end
384369
end
385370

386371
TruncatedStacktraces.@truncate_stacktrace IntegralProblem 1 4
387372

388-
function IntegralProblem(f, lb, ub, args...;
373+
function IntegralProblem(f::AbstractIntegralFunction,
374+
domain,
375+
p = NullParameters();
389376
kwargs...)
390-
IntegralProblem{isinplace(f, 3)}(f, lb, ub, args...; kwargs...)
377+
IntegralProblem{isinplace(f)}(f, domain, p; kwargs...)
378+
end
379+
380+
function IntegralProblem(f::AbstractIntegralFunction,
381+
lb::B,
382+
ub::B,
383+
p = NullParameters();
384+
kwargs...) where {B}
385+
IntegralProblem(f, (lb, ub), p; kwargs...)
386+
end
387+
388+
function IntegralProblem(f, args...; nout = nothing, batch = nothing, kwargs...)
389+
if nout !== nothing || batch !== nothing
390+
@warn "`nout` and `batch` keywords are deprecated in favor of inplace `IntegralFunction`s or `BatchIntegralFunction`s. See the updated Integrals.jl documentation for details."
391+
end
392+
393+
max_batch = batch === nothing ? 0 : batch
394+
g = if isinplace(f, 3)
395+
output_prototype = Vector{Float64}(undef, nout === nothing ? 1 : nout)
396+
if max_batch == 0
397+
IntegralFunction(f, output_prototype)
398+
else
399+
BatchIntegralFunction(f, output_prototype, max_batch=max_batch)
400+
end
401+
else
402+
if max_batch == 0
403+
IntegralFunction(f)
404+
else
405+
BatchIntegralFunction(f, max_batch=max_batch)
406+
end
407+
end
408+
IntegralProblem(g, args...; kwargs...)
391409
end
392410

393411
struct QuadratureProblem end
@@ -405,8 +423,8 @@ Sampled integral problems are defined as:
405423
```math
406424
\sum_i w_i y_i
407425
```
408-
where `y_i` are sampled values of the integrand, and `w_i` are weights
409-
assigned by a quadrature rule, which depend on sampling points `x`.
426+
where `y_i` are sampled values of the integrand, and `w_i` are weights
427+
assigned by a quadrature rule, which depend on sampling points `x`.
410428
411429
## Problem Type
412430
@@ -415,10 +433,10 @@ assigned by a quadrature rule, which depend on sampling points `x`.
415433
```
416434
SampledIntegralProblem(y::AbstractArray, x::AbstractVector; dim=ndims(y), kwargs...)
417435
```
418-
- y: The sampled integrand, must be a subtype of `AbstractArray`.
419-
It is assumed that the values of `y` along dimension `dim`
436+
- y: The sampled integrand, must be a subtype of `AbstractArray`.
437+
It is assumed that the values of `y` along dimension `dim`
420438
correspond to the integrand evaluated at sampling points `x`
421-
- x: Sampling points, must be a subtype of `AbstractVector`.
439+
- x: Sampling points, must be a subtype of `AbstractVector`.
422440
- dim: Dimension along which to integrate. Defaults to the last dimension of `y`.
423441
- kwargs: Keyword arguments copied to the solvers.
424442
@@ -434,7 +452,7 @@ struct SampledIntegralProblem{Y, X, K} <: AbstractIntegralProblem{false}
434452
@add_kwonly function SampledIntegralProblem(y::AbstractArray, x::AbstractVector;
435453
dim = ndims(y),
436454
kwargs...)
437-
@assert dim <= ndims(y) "The integration dimension `dim` is larger than the number of dimensions of the integrand `y`"
455+
@assert dim<=ndims(y) "The integration dimension `dim` is larger than the number of dimensions of the integrand `y`"
438456
@assert length(x)==size(y, dim) "The integrand `y` must have the same length as the sampling points `x` along the integrated dimension."
439457
@assert axes(x, 1)==axes(y, dim) "The integrand `y` must obey the same indexing as the sampling points `x` along the integrated dimension."
440458
new{typeof(y), typeof(x), typeof(kwargs)}(y, x, dim, kwargs)

0 commit comments

Comments
 (0)