Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trail Access #187

Merged
merged 23 commits into from
Jan 13, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add unit tests for python and fix the bugs found with them
  • Loading branch information
rkaminsk committed Jan 12, 2020
commit d09a3bf70d45df25e08f8768e073e71b7b328316
79 changes: 79 additions & 0 deletions app/clingo/tests/python/assignment.lp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#script (python)

import clingo
from itertools import chain

def check_lit(assignment, l):
trail = assignment.trail
t = assignment.value(l)
if t is not None:
s = 1
if not t:
s = -1
n = 0
for k in trail:
if k == s * l:
n = n + 1
assert n == 1

def check_trail(assignment):
trail = assignment.trail

n = 0
for i in range(0, assignment.decision_level + 1):
check = False
for j in range(trail.begin(i), trail.end(i)):
check = True
assert assignment.is_true(trail[j])
n = n + 1
assert check
assert n == len(trail)

n = 0
for l in assignment:
n = n + 1
check_lit(assignment, l)
assert n == len(assignment)

n = 0
for l in chain(assignment[0::2], assignment[1::2]):
n = n + 1
assert n == len(assignment)

n = 0
for l in chain(trail[0::2], trail[1::2]):
n = n + 1
assert n == len(trail)

n = 1
for l in trail[0:-1]:
n = n + 1
assert n == len(trail)

n = 0
for l in trail[-1:-1-len(trail):-1]:
n = n + 1
assert n == len(trail)

n = 0
for l in chain(trail[-1:-1-len(trail):-2], trail[-2:-1-len(trail):-2]):
n = n + 1
assert n == len(trail)

class Propagator:
def init(self, init):
check_trail(init.assignment)

init.check_mode = clingo.PropagatorCheckMode.Fixpoint

def check(self, control):
check_trail(control.assignment)

def main(prg):
prg.register_propagator(Propagator())
prg.ground([("base", [])])
prg.solve()

#end.

{ a; b; c }.
10 changes: 10 additions & 0 deletions app/clingo/tests/python/assignment.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Step: 1

a
a b
a b c
a c
b
b c
c
SAT
29 changes: 22 additions & 7 deletions libpyclingo/pyclingo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,16 @@
#else
#define OBBASE(x) x
#define Py_hash_t long
#endif

#if PY_VERSION_HEX < 0x03060100

int PySlice_Unpack(PyObject *slice, Py_ssize_t *start, Py_ssize_t *stop, Py_ssize_t *step) {
int SliceUnpack(PyObject *slice, Py_ssize_t *start, Py_ssize_t *stop, Py_ssize_t *step) {
Py_ssize_t length;
return PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice), PY_SSIZE_T_MAX, start, stop, step, &length);
}

Py_ssize_t PySlice_AdjustIndices(Py_ssize_t length, Py_ssize_t *start, Py_ssize_t *stop, Py_ssize_t step) {
Py_ssize_t SliceAdjustIndices(Py_ssize_t length, Py_ssize_t *start, Py_ssize_t *stop, Py_ssize_t step) {
assert(step != 0);
assert(step >= -PY_SSIZE_T_MAX);
if (*start < 0) {
Expand Down Expand Up @@ -93,6 +96,10 @@ Py_ssize_t PySlice_AdjustIndices(Py_ssize_t length, Py_ssize_t *start, Py_ssize_
}
return 0;
}
#else

#define SliceUnpack PySlice_Unpack
#define SliceAdjustIndices PySlice_AdjustIndices

#endif

Expand Down Expand Up @@ -3258,7 +3265,7 @@ Helper object for slicing support.
static Object construct(Reference seq, Reference slice) {
auto self = new_();
new (&self->seq) Object{seq};
if (PySlice_Unpack(slice.toPy(), &self->start, &self->stop, &self->step) < 0) {
if (SliceUnpack(slice.toPy(), &self->start, &self->stop, &self->step) < 0) {
throw PyException();
}
return self;
Expand All @@ -3269,15 +3276,18 @@ Helper object for slicing support.
}

Py_ssize_t sq_length() {
return PySlice_AdjustIndices(seq.size(), &start, &stop, step);
auto b = start, e = stop;
return SliceAdjustIndices(seq.size(), &b, &e, step);
}

Object sq_item(Py_ssize_t index) {
if (index < 0 || index >= sq_length()) {
auto b = start, e = stop;
auto l = SliceAdjustIndices(seq.size(), &b, &e, step);
if (index < 0 || index >= l) {
PyErr_Format(PyExc_IndexError, "invalid index");
return nullptr;
}
return seq.getItem(start + index * step);
return seq.getItem(b + index * step);
}

Object mp_subscript(Reference slice) {
Expand Down Expand Up @@ -3488,7 +3498,12 @@ literals in the assignment.
}

Object mp_subscript(Reference slice) {
return Slice::construct(*this, slice);
if (PySlice_Check(slice.toPy())) {
return Slice::construct(*this, slice);
}
else {
return sq_item(pyToCpp<Py_ssize_t>(slice));
}
}

Object to_c() {
Expand Down