Skip to content

Commit

Permalink
Fix constraint support.
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin Chrétien committed May 9, 2014
1 parent 758a44d commit 8d6143d
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/roboptim/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def argumentScales(self):
def argumentScales(self, value):
setArgumentScales (self._problem, value)

def addConstraint (self):
pass #FIXME:
def addConstraint (self, constraint, bounds):
addConstraint (self._problem, constraint._function, bounds)


class PySolver(object):
Expand Down
47 changes: 42 additions & 5 deletions src/wrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ namespace roboptim
}
}

size_type inputSize () const
{
return ::roboptim::DifferentiableFunction::inputSize ();
}

size_type outputSize () const
{
return ::roboptim::DifferentiableFunction::outputSize ();
}

virtual void impl_compute (result_t& result, const argument_t& argument)
const
{
Expand Down Expand Up @@ -417,17 +427,27 @@ namespace detail
functionConverter (PyObject* obj, Function** address)
{
assert (address);

if (!PyCapsule_CheckExact (obj))
{
PyErr_SetString (PyExc_TypeError, "Invalid Python Function given.");
return 0;
}

Function* ptr = static_cast<Function*>
(PyCapsule_GetPointer
(obj, ROBOPTIM_CORE_FUNCTION_CAPSULE_NAME));

if (!ptr)
{
PyErr_SetString
(PyExc_TypeError,
"Function object expected but another type was passed");
return 0;
}

*address = ptr;

return 1;
}

Expand Down Expand Up @@ -592,6 +612,14 @@ namespace detail

return 0;
}

// See: http://www.boost.org/doc/libs/1_55_0/libs/smart_ptr/sp_techniques.html#static
struct null_deleter
{
void operator () (void const *) const
{
}
};
} // end of namespace detail.

template <typename T>
Expand Down Expand Up @@ -650,15 +678,18 @@ static PyObject*
getName (PyObject*, PyObject* args)
{
Function* function = 0;

if (!PyArg_ParseTuple(args, "O&", &detail::functionConverter, &function))
return 0;

if (!function)
{
PyErr_SetString
(PyExc_TypeError,
"argument 1 should be a function object");
return 0;
}

return Py_BuildValue("s", function->getName ().c_str ());
}

Expand All @@ -671,6 +702,7 @@ createProblem (PyObject*, PyObject* args)

DifferentiableFunction* dfunction =
dynamic_cast<DifferentiableFunction*> (costFunction);

if (!dfunction)
{
PyErr_SetString
Expand Down Expand Up @@ -1143,11 +1175,13 @@ addConstraint (PyObject*, PyObject* args)
&detail::functionConverter, &function,
&min, &max))
return 0;

if (!problem)
{
PyErr_SetString (PyExc_TypeError, "1st argument must be a problem");
return 0;
}

if (!function)
{
PyErr_SetString (PyExc_TypeError, "2nd argument must be a function");
Expand All @@ -1164,10 +1198,11 @@ addConstraint (PyObject*, PyObject* args)
return 0;
}

//FIXME: this will make everything segv.
//contraint will be freed when problem disappear...
// boost::shared_ptr<DifferentiableFunction> constraint (dfunction);
// problem->addConstraint (constraint, Function::makeInterval (min, max));
// If we just used a boost::shared_ptr, the constraint would be freed when the
// problem disappears, so we use a null deleter to prevent that.
boost::shared_ptr<DifferentiableFunction> constraint (dfunction,
detail::null_deleter ());
problem->addConstraint (constraint, Function::makeInterval (min, max));

Py_INCREF (Py_None);
return Py_None;
Expand Down Expand Up @@ -1471,7 +1506,7 @@ toDict<resultWithWarnings_t> (PyObject* obj, PyObject* args)

template <>
PyObject*
toDict<solverError_t> (PyObject* obj, PyObject* args)
toDict<solverError_t> (PyObject*, PyObject* args)
{
solverError_t* error = 0;

Expand Down Expand Up @@ -1683,6 +1718,8 @@ static PyMethodDef RobOptimCoreMethods[] =
"Convert a Result object to a Python dictionary."},
{"resultWithWarningsToDict", toDict<resultWithWarnings_t>, METH_VARARGS,
"Convert a ResultWithWarnings object to a Python dictionary."},
{"solverErrorToDict", toDict<solverError_t>, METH_VARARGS,
"Convert a SolverError object to a Python dictionary."},

// Print functions
{"strFunction", print<Function>, METH_VARARGS,
Expand Down
3 changes: 3 additions & 0 deletions tests/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def test_solver(self):
problem.argumentScales = numpy.array([2.,])
numpy.testing.assert_almost_equal (problem.argumentScales, [2.,])

g1 = Square ()
problem.addConstraint (g1, [-1., 10.,])

# Let the test fail if the solver does not exist.
try:
solver = roboptim.core.PySolver ("ipopt", problem)
Expand Down
53 changes: 53 additions & 0 deletions tests/schittkowski.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,32 @@ def impl_gradient (self, result, x, functionId):
result[1] = 200. * (x[1] - x[0]**2)


class Problem6_Cost (roboptim.core.PyDifferentiableFunction):
def __init__ (self):
roboptim.core.PyDifferentiableFunction.__init__ \
(self, 2, 1, "(1 - x₀)²")

def impl_compute (self, result, x):
result[0] = (1 - x[0])**2

def impl_gradient (self, result, x, functionId):
result[0] = -2. * (1 - x[0])
result[1] = 0.


class Problem6_G1 (roboptim.core.PyDifferentiableFunction):
def __init__ (self):
roboptim.core.PyDifferentiableFunction.__init__ \
(self, 2, 1, "10 (x₁ - x₀²)")

def impl_compute (self, result, x):
result[0] = 10. * (x[1] - x[0]**2)

def impl_gradient (self, result, x, functionId):
result[0] = -20. * x[0]
result[1] = 10.


class TestFunctionPy(unittest.TestCase):

def test_problem_1(self):
Expand Down Expand Up @@ -77,6 +103,33 @@ def test_problem_2(self):
except Exception, e:
print ("Error: " + str(e))

def test_problem_6(self):
"""
Schittkowski problem #6
"""
cost = Problem6_Cost ()
problem = roboptim.core.PyProblem (cost)
problem.startingPoint = numpy.array([-1.2, 1., ])
problem.argumentBounds = numpy.array([[float("-inf"), float("inf")],
[float("-inf"), float("inf")], ])

g1 = Problem6_G1 ()
problem.addConstraint (g1, [0., 0.,])

# Check starting value
numpy.testing.assert_almost_equal (cost (problem.startingPoint)[0], 4.84)

# Let the test fail if the solver does not exist.
try:
solver = roboptim.core.PySolver ("ipopt", problem)
print (solver)
solver.solve ()
r = solver.minimum ()
print (r)
numpy.testing.assert_almost_equal (r.value, [0.])
numpy.testing.assert_almost_equal (r.x, [1., 1.])
except Exception, e:
print ("Error: " + str(e))

if __name__ == '__main__':
unittest.main ()

0 comments on commit 8d6143d

Please sign in to comment.