forked from ZikangXiong/diff-spec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
350 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,65 @@ | ||
# Differentiable Logic Specification | ||
|
||
Connect differentiable components with logicial operators. | ||
Connect differentiable components with logical operators. | ||
|
||
## Install | ||
|
||
```bash | ||
pip install git+https://github.com/ZikangXiong/diff-spec.git | ||
``` | ||
|
||
## First Order Logic (On-going) | ||
[First order logic](https://en.wikipedia.org/wiki/First-order_logic) is a logic formalism to describe the behavior of a system. It contrains basic logic operators such as `and`, `or`, `not`, and quantifiers like `forall`, `exists`. | ||
## First Order Logic | ||
[First-order logic](https://en.wikipedia.org/wiki/First-order_logic) is a logic formalism to describe the behavior of a system. It contains basic logic operators such as `and`, `or`, `not`, and quantifiers like `forall`, and `exists`. | ||
|
||
## Signal Temproal Logic | ||
[Signal temproal logic](https://people.eecs.berkeley.edu/~sseshia/fmee/lectures/EECS294-98_Spring2014_STL_Lecture.pdf) is a formal language to describe the behavior of a dynamical system. It is widely used in formal verification of cyber-physical systems. | ||
We can connect any differentiable components with logical operators, with the requirement that a component outputs a great-than-0 value representing | ||
`true` and a less-than-0 value representing `false`. | ||
|
||
```python | ||
p1 = InFirstQuadrant() | ||
p2 = XGreatThan(2.0) | ||
p3 = XLessThan(3.0) | ||
p4 = YLessThan(3.0) | ||
|
||
# Define a formula | ||
f = FOL(p1).forall() & FOL(p2).exists() & FOL(p3).forall() & FOL(p4).forall() | ||
print(f) | ||
|
||
# ((((∀ InFirstQuadrant) & (∃ XGreatThan(2.0))) & (∀ XLessThan(3.0))) & (∀ YLessThan(3.0))) | ||
``` | ||
|
||
p1-4 can be any differentiable components, including neural networks. In the above example, p1 is a predicate that checks if the input is in the first quadrant. p2-4 are predicates that check if the input is greater than 2, less than 3, and less than 3 in the x and y-axis, respectively. | ||
|
||
One can optimize a world of inputs to satisfy the formula with [gradient](examples/fol/differentiability.py). | ||
|
||
## Signal Temporal Logic | ||
[Signal temporal logic](https://people.eecs.berkeley.edu/~sseshia/fmee/lectures/EECS294-98_Spring2014_STL_Lecture.pdf) is a formal language to describe the behavior of a dynamical system. It is widely used in the formal verification of cyber-physical systems. | ||
|
||
We can use STL to describe the behavior of a system. For example, we can use STL to describe repeatedly visit `goal_1` and `goal_2` in timestep 0 to 13. | ||
|
||
```python | ||
# goal_1 is a rectangle area centered in [0, 0] with width and height 1 | ||
goal_1 = STL(RectReachPredicte(np.array([0, 0]), np.array([1, 1]), "goal_1")) | ||
goal_1 = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal_1")) | ||
# goal_2 is a rectangle area centered in [2, 2] with width and height 1 | ||
goal_2 = STL(RectReachPredicte(np.array([2, 2]), np.array([1, 1]), "goal_2")) | ||
goal_2 = STL(RectReachPredicate(np.array([2, 2]), np.array([1, 1]), "goal_2")) | ||
|
||
# form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5 | ||
# and that holds always in 0 to 8 | ||
# and that always holds in 0 to 8 | ||
# In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13 | ||
form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8) | ||
``` | ||
|
||
We can synthesize a trace with [gradient](examples/stl/diffrentiablity.py) or [mixed-integer programming](examples/stl/solver.py). | ||
We can synthesize a trace with [gradient](examples/stl/differentiability.py) or [mixed-integer programming](examples/stl/solver.py). | ||
|
||
## Probability Temproal Logic (On-going) | ||
Probability temproal logic is a on-going work intergrating probability and random variables into temproal logic. It is useful in robot planning and control, reinforcement learning, and formal verification. | ||
<!-- ## Probability Temporal Logic (Ongoing) | ||
Probability temporal logic is an ongoing work integrating probability and random variables into temporal logic. It is useful in robot planning and control, reinforcement learning, and formal verification. --> | ||
|
||
<!-- ## Citation | ||
If you find this repository useful in your research, please cite: | ||
``` | ||
## Citation | ||
If you find this repository useful in your research, considering to cite: | ||
```bibtex | ||
@misc{xiong2023diffspec, | ||
title={DiffSpec: A Differentiable Logic Specification Framework}, | ||
title={DiffSpec: A Differentiable Logic Specification Framework}, | ||
url={https://github.com/ZikangXiong/diff-spec/}, | ||
author={Zikang Xiong}, | ||
year={2023}, | ||
} | ||
``` --> | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# %% | ||
import matplotlib.pyplot as plt | ||
import torch as th | ||
|
||
from ds.fol import FOL, PredicateBase | ||
|
||
|
||
# =============================================================================== | ||
# Predicate Examples | ||
# =============================================================================== | ||
class InFirstQuadrant(PredicateBase): | ||
def __init__(self): | ||
super().__init__(name="InFirstQuadrant") | ||
|
||
def eval(self, xs: th.Tensor) -> th.Tensor: | ||
""" | ||
Return positive values if the point is in the first quadrant. | ||
Return negative values if the point is not in the first quadrant. | ||
""" | ||
return th.min(xs, dim=1).values | ||
|
||
|
||
class XGreatThan(PredicateBase): | ||
def __init__(self, x): | ||
super().__init__(name=f"XGreatThan({x})") | ||
self.x = x | ||
|
||
def eval(self, xs: th.Tensor) -> th.Tensor: | ||
""" | ||
Return positive values if the first coordinate is greater than x. | ||
Return negative values if the first coordinate is not greater than x. | ||
""" | ||
return xs[..., 0] - self.x | ||
|
||
|
||
class XLessThan(PredicateBase): | ||
def __init__(self, x): | ||
super().__init__(name=f"XLessThan({x})") | ||
self.x = x | ||
|
||
def eval(self, xs: th.Tensor) -> th.Tensor: | ||
""" | ||
Return positive values if the first coordinate is less than x. | ||
Return negative values if the first coordinate is not less than x. | ||
""" | ||
return -xs[..., 0] + self.x | ||
|
||
|
||
class YGreatThan(PredicateBase): | ||
def __init__(self, y): | ||
super().__init__(name=f"YGreatThan({y})") | ||
self.y = y | ||
|
||
def eval(self, xs: th.Tensor) -> th.Tensor: | ||
""" | ||
Return positive values if the second coordinate is greater than y. | ||
Return negative values if the second coordinate is not greater than y. | ||
""" | ||
return xs[..., 1] - self.y | ||
|
||
|
||
class YLessThan(PredicateBase): | ||
def __init__(self, y): | ||
super().__init__(name=f"YLessThan({y})") | ||
self.y = y | ||
|
||
def eval(self, xs: th.Tensor) -> th.Tensor: | ||
""" | ||
Return positive values if the second coordinate is less than y. | ||
Return negative values if the second coordinate is not less than y. | ||
""" | ||
return self.y - xs[..., 1] | ||
|
||
|
||
# =============================================================================== | ||
# Differentiability | ||
# =============================================================================== | ||
def backward(): | ||
# Define a predicate | ||
p1 = InFirstQuadrant() | ||
p2 = XGreatThan(2.0) | ||
p3 = XLessThan(3.0) | ||
p4 = YLessThan(3.0) | ||
|
||
# Define a formula | ||
f = FOL(p1).forall() & FOL(p2).exists() & FOL(p3).forall() & FOL(p4).forall() | ||
print(f) | ||
|
||
# create plot | ||
fig, ax = plt.subplots(1, 2, figsize=(10, 5)) | ||
|
||
# set x and y limits | ||
ax[0].set_xlim(-3, 3) | ||
ax[0].set_ylim(-3, 3) | ||
ax[1].set_xlim(-3, 3) | ||
ax[1].set_ylim(-3, 3) | ||
|
||
# Define two worlds, we support batch evaluation, | ||
# but the points in each world should be the same / padding to same. | ||
x = th.tensor( | ||
[ | ||
[[1.0, 1.0], [-1.0, 1.0], [1.0, -1.0], [-1.0, -1.0]], | ||
[[0.5, 0.5], [-0.5, 0.5], [0.5, -0.5], [-0.5, -0.5]], | ||
], | ||
requires_grad=True, | ||
) | ||
|
||
ax[0].scatter( | ||
x.numpy(force=True)[0, :, 0], x.numpy(force=True)[0, :, 1], label="world 1" | ||
) | ||
ax[0].scatter( | ||
x.numpy(force=True)[1, :, 0], x.numpy(force=True)[1, :, 1], label="world 2" | ||
) | ||
|
||
# define optimizer | ||
optimizer = th.optim.Adam([x], lr=0.1) | ||
|
||
# optimize | ||
loss = None | ||
for i in range(10000): | ||
optimizer.zero_grad() | ||
loss = -f(x).mean() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
print(loss) | ||
print(x) | ||
|
||
# plot | ||
ax[1].scatter( | ||
x.numpy(force=True)[0, :, 0], x.numpy(force=True)[0, :, 1], label="world 1" | ||
) | ||
ax[1].scatter( | ||
x.numpy(force=True)[1, :, 0], x.numpy(force=True)[1, :, 1], label="world 2" | ||
) | ||
|
||
ax[0].legend() | ||
ax[1].legend() | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
backward() | ||
|
||
# %% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.