Skip to content

Commit 95c053a

Browse files
committed
candidate_wrapper
1 parent 97a6461 commit 95c053a

File tree

4 files changed

+34
-2
lines changed

4 files changed

+34
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
## Candidate wrapper
2+
3+
The `candidate_wrapper` argument can be passed to the `physo.SR` or `physo.ClassSR` to apply a wrapper function $g$ to the candidate symbolic function's output $f(X)$.
4+
5+
The wrapper function $g$ should be a callable taking the candidate symbolic function callable $f$ and the input data $X$ as arguments, and returning the wrapped output $g(f(X))$.
6+
By default `candidate_wrapper = None`, no wrapper is applied (identity).
7+
8+
Note that the wrapper function should be differentiable and written in `pytorch` if free constants are to be optimized since the free constants are optimized using gradient-based optimization.
9+
10+
In addition, it is recommended to use protected functions when writing the wrapper function to avoid evaluating the symbolic function on invalid points (eg. using log abs instead of log).
11+
See the [protected functions](https://physo.readthedocs.io/en/latest/r_features.html#protected-version-optional) documentation for more details.

physo/task/args_handler.py

+6
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def check_args_and_build_run_config(multi_X, multi_y, multi_y_weights,
4646
use_protected_ops,
4747
# Stopping
4848
epochs,
49+
# Candidate wrapper
50+
candidate_wrapper,
4951
# Default run config to use
5052
run_config,
5153
# Default run monitoring
@@ -301,6 +303,10 @@ def check_args_and_build_run_config(multi_X, multi_y, multi_y_weights,
301303
raise ValueError("entropy_weight should be castable to a float.")
302304
assert isinstance(entropy_weight, float), "entropy_weight should be a float."
303305

306+
# ------------------------------- CANDIDATE_WRAPPER -------------------------------
307+
# candidate_wrapper should be callable or None
308+
assert candidate_wrapper is None or callable(candidate_wrapper), "candidate_wrapper should be callable or None."
309+
304310
# ------------------------------- RETURN -------------------------------
305311
# Returning
306312
handled_args = {

physo/task/class_sr.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def ClassSR(multi_X, multi_y, multi_y_weights=1.,
3333
max_n_evaluations = None,
3434
stop_after_n_epochs = args_handler.default_stop_after_n_epochs,
3535
epochs = None,
36+
# Candidate wrapper
37+
candidate_wrapper = None,
3638
# Default run config to use
3739
run_config = None,
3840
# Default run monitoring
@@ -119,6 +121,9 @@ def ClassSR(multi_X, multi_y, multi_y_weights=1.,
119121
epochs : int or None (optional)
120122
Number of epochs to perform. By default, uses the number in the default config file.
121123
124+
candidate_wrapper : callable or None (optional)
125+
Wrapper to apply to candidate program's output, candidate_wrapper taking func, X as arguments where func is
126+
a candidate program callable (taking X as arg). By default = None, no wrapper is applied (identity).
122127
run_config : dict or None (optional)
123128
Run configuration (by default uses physo.task.class_sr.default_config)
124129
See physo/config/ for examples of run configurations.
@@ -182,6 +187,8 @@ def ClassSR(multi_X, multi_y, multi_y_weights=1.,
182187
use_protected_ops = use_protected_ops,
183188
# Stopping
184189
epochs = epochs,
190+
# Candidate wrapper
191+
candidate_wrapper = candidate_wrapper,
185192
# Default run config to use
186193
run_config = run_config,
187194
# Default run monitoring
@@ -209,7 +216,8 @@ def ClassSR(multi_X, multi_y, multi_y_weights=1.,
209216
rewards, candidates = fit (multi_X = multi_X,
210217
multi_y = multi_y,
211218
multi_y_weights = multi_y_weights,
212-
run_config = run_config,
219+
candidate_wrapper = candidate_wrapper,
220+
run_config = run_config,
213221
stop_reward = stop_reward,
214222
stop_after_n_epochs = stop_after_n_epochs,
215223
max_n_evaluations = max_n_evaluations,

physo/task/sr.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def SR(X, y, y_weights=1.,
3232
max_n_evaluations = None,
3333
stop_after_n_epochs = args_handler.default_stop_after_n_epochs,
3434
epochs = None,
35+
# Candidate wrapper
36+
candidate_wrapper = None,
3537
# Default run config to use
3638
run_config = None,
3739
# Default run monitoring
@@ -102,6 +104,9 @@ def SR(X, y, y_weights=1.,
102104
epochs : int or None (optional)
103105
Number of epochs to perform. By default, uses the number in the default config file.
104106
107+
candidate_wrapper : callable or None (optional)
108+
Wrapper to apply to candidate program's output, candidate_wrapper taking func, X as arguments where func is
109+
a candidate program callable (taking X as arg). By default = None, no wrapper is applied (identity).
105110
run_config : dict or None (optional)
106111
Run configuration (by default uses physo.task.sr.default_config)
107112
See physo/config/ for examples of run configurations.
@@ -163,7 +168,9 @@ def SR(X, y, y_weights=1.,
163168
stop_reward = stop_reward,
164169
max_n_evaluations = max_n_evaluations,
165170
stop_after_n_epochs = stop_after_n_epochs,
166-
epochs = epochs,
171+
epochs = epochs,
172+
# Candidate wrapper
173+
candidate_wrapper = candidate_wrapper,
167174
# Default run config to use
168175
run_config = run_config,
169176
# Default run monitoring

0 commit comments

Comments
 (0)