Skip to content

Commit

Permalink
polish/fix/cleanup includes
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Sep 5, 2023
1 parent 2503029 commit cafbcb1
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 66 deletions.
128 changes: 81 additions & 47 deletions ortools/algorithms/binary_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <functional>
#include <utility>

#include "absl/functional/function_ref.h"
#include "absl/log/check.h"
#include "absl/numeric/int128.h"
#include "ortools/base/dump_vars.h"
#include "ortools/base/logging.h"
Expand Down Expand Up @@ -99,7 +103,7 @@ Point BinarySearchMidpoint(Point x, Point y);
// - We technically do not need the points to be sorted and can use
// linear-time median computation to speed this up.
//
// TODO(user): replace std::function by absl::AnyInvocable here and in
// TODO(user): replace std::function by absl::FunctionRef here and in
// BinarySearch().
template <class Point, class Value>
std::pair<Point, Value> ConvexMinimum(absl::Span<const Point> sorted_points,
Expand All @@ -115,6 +119,15 @@ std::pair<Point, Value> ConvexMinimum(bool is_to_the_right,
absl::Span<const Point> sorted_points,
std::function<Value(Point)> f);

// Searches in the range [begin, end), where Point supports basic arithmetic.
template <class Point, class Value>
std::pair<Point, Value> RangeConvexMinimum(Point begin, Point end,
absl::FunctionRef<Value(Point)> f);
template <class Point, class Value>
std::pair<Point, Value> RangeConvexMinimum(std::pair<Point, Value> current_min,
Point begin, Point end,
absl::FunctionRef<Value(Point)> f);

// _____________________________________________________________________________
// Implementation.

Expand Down Expand Up @@ -222,83 +235,104 @@ Point BinarySearch(Point x_true, Point x_false, std::function<bool(Point)> f) {
}

template <class Point, class Value>
std::pair<Point, Value> ConvexMinimum(absl::Span<const Point> sorted_points,
std::function<Value(Point)> f) {
DCHECK(!sorted_points.empty());
if (sorted_points.size() == 1) {
return {sorted_points[0], f(sorted_points[0])};
std::pair<Point, Value> RangeConvexMinimum(Point begin, Point end,
absl::FunctionRef<Value(Point)> f) {
DCHECK_LT(begin, end);
const Value size = end - begin;
if (size == 1) {
return {begin, f(begin)};
}

// Starts by splitting interval in two with two queries and getting some info.
// Note the current min will be outside the interval.
bool is_to_the_right;
std::pair<Point, Value> current_min;
{
DCHECK_GE(sorted_points.size(), 2);
const int i = sorted_points.size() / 2;
const Value v = f(sorted_points[i]);
const int before_i = i - 1;
const Value before_v = f(sorted_points[before_i]);
if (before_v == v) return {sorted_points[before_i], before_v};
DCHECK_GE(size, 2);
const Point mid = begin + (end - begin) / 2;
DCHECK_GT(mid, begin);
const Value v = f(mid);
const Point before_mid = mid - 1;
const Value before_v = f(before_mid);
if (before_v == v) return {before_mid, before_v};
if (before_v < v) {
// Note that we exclude before_i from the span.
current_min = {sorted_points[before_i], before_v};
is_to_the_right = true;
sorted_points = sorted_points.subspan(0, std::max(0, before_i));
// Note that we exclude before_mid from the range.
current_min = {before_mid, before_v};
end = before_mid;
} else {
is_to_the_right = false;
current_min = {sorted_points[i], v};
sorted_points = sorted_points.subspan(i + 1);
current_min = {mid, v};
begin = mid + 1;
}
}
if (sorted_points.empty()) return current_min;
return ConvexMinimum<Point, Value>(is_to_the_right, current_min,
sorted_points, std::move(f));
if (begin >= end) return current_min;
return RangeConvexMinimum<Point, Value>(current_min, begin, end, f);
}

template <class Point, class Value>
std::pair<Point, Value> ConvexMinimum(bool is_to_the_right,
std::pair<Point, Value> current_min,
absl::Span<const Point> sorted_points,
std::function<Value(Point)> f) {
DCHECK(!sorted_points.empty());
while (sorted_points.size() > 1) {
const int i = sorted_points.size() / 2;
const Value v = f(sorted_points[i]);
std::pair<Point, Value> RangeConvexMinimum(std::pair<Point, Value> current_min,
Point begin, Point end,
absl::FunctionRef<Value(Point)> f) {
DCHECK_LT(begin, end);
while ((end - begin) > 1) {
DCHECK(current_min.first < begin || current_min.first >= end);
bool current_is_after_end = current_min.first >= end;
const Point mid = begin + (end - begin) / 2;
const Value v = f(mid);
if (v >= current_min.second) {
// If the midpoint is no better than our current minimum, then the
// global min must lie between our midpoint and our current min.
if (is_to_the_right) {
sorted_points = sorted_points.subspan(i + 1);
if (current_is_after_end) {
begin = mid + 1;
} else {
sorted_points = sorted_points.subspan(0, i);
end = mid;
}
} else {
// v < current_min.second, we cannot decide, so we use a second value
// close to v like in the initial step.
DCHECK_GT(i, 0);
const int before_i = i - 1;
const Value before_v = f(sorted_points[before_i]);
if (before_v == v) return {sorted_points[before_i], before_v};
DCHECK_GT(mid, begin);
const Point before_mid = mid - 1;
const Value before_v = f(before_mid);
if (before_v == v) return {before_mid, before_v};
if (before_v < v) {
current_min = {sorted_points[before_i], before_v};
is_to_the_right = true;
sorted_points = sorted_points.subspan(0, std::max(0, before_i));
current_min = {before_mid, before_v};
current_is_after_end = true;
end = before_mid;
} else {
is_to_the_right = false;
current_min = {sorted_points[i], v};
sorted_points = sorted_points.subspan(i + 1);
current_is_after_end = false;
current_min = {mid, v};
begin = mid + 1;
}
}
}

if (!sorted_points.empty()) {
const Value v = f(sorted_points[0]);
if (v <= current_min.second) return {sorted_points[0], v};
if (end - begin == 1) {
const Value v = f(begin);
if (v <= current_min.second) return {begin, v};
}
return current_min;
}

template <class Point, class Value>
std::pair<Point, Value> ConvexMinimum(absl::Span<const Point> sorted_points,
std::function<Value(Point)> f) {
auto index_f = [&](int index) -> Value { return f(sorted_points[index]); };
const auto& [index, v] =
RangeConvexMinimum<int64_t, Value>(0, sorted_points.size(), index_f);
return {sorted_points[index], v};
}

template <class Point, class Value>
std::pair<Point, Value> ConvexMinimum(bool is_to_the_right,
std::pair<Point, Value> current_min,
absl::Span<const Point> sorted_points,
std::function<Value(Point)> f) {
auto index_f = [&](int index) -> Value { return f(sorted_points[index]); };
std::pair<int, Value> index_current_min = std::make_pair(
is_to_the_right ? sorted_points.size() : -1, current_min.second);
const auto& [index, v] = RangeConvexMinimum<int64_t, Value>(
index_current_min, 0, sorted_points.size(), index_f);
if (index == index_current_min.first) return current_min;
return {sorted_points[index], v};
}
} // namespace operations_research

#endif // OR_TOOLS_ALGORITHMS_BINARY_SEARCH_H_
42 changes: 39 additions & 3 deletions ortools/algorithms/binary_search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,13 @@ TEST(ConvexMinimumTest, ExhaustiveTest) {
});
total_num_queries += num_queries;
max_num_queries = std::max(max_num_queries, num_queries);
ASSERT_EQ(value, 0);
ASSERT_GE(point, b1);
ASSERT_LE(point, b2);
EXPECT_EQ(value, 0);
EXPECT_GE(point, b1);
EXPECT_LE(point, b2);
// Fail after one example.
ASSERT_TRUE(value == 0 && b1 <= point && point <= b2)
<< "queries: " << num_queries << " opt range: [" << b1 << ", " << b2
<< "]";
}
}

Expand Down Expand Up @@ -378,4 +382,36 @@ TEST(ConvexMinimumTest, TwoQueriesIfSizeTwoReversed) {
EXPECT_EQ(num_queries, 2);
}

TEST(RangeConvexMinimumTest, HugeRangeTest) {
int total_num_queries = 0;
int max_num_queries = 0;
for (int b1 = -100; b1 < 100; ++b1) {
for (int b2 = b1; b2 < b1 + 100; ++b2) {
int num_queries = 0;
const auto [point, value] = RangeConvexMinimum<int64_t, double>(
std::numeric_limits<int64_t>::min() / 2,
std::numeric_limits<int64_t>::max() / 2, [&](int64_t v) -> double {
++num_queries;
if (v < b1) {
return b1 - v;
} else if (v > b2) {
return v - b2;
}
return 0;
});
total_num_queries += num_queries;
max_num_queries = std::max(max_num_queries, num_queries);
EXPECT_EQ(value, 0);
EXPECT_GE(point, b1);
EXPECT_LE(point, b2);
// Don't continue past the first failing example to limit the number of
// errors.
ASSERT_TRUE(value == 0 && b1 <= point && point <= b2)
<< "queries: " << num_queries << " opt range: [" << b1 << ", " << b2
<< "]";
}
}
// 80 is the worst case we would expect from ternary search: 2*log_3(2^63).
EXPECT_LE(max_num_queries, 80);
}
} // namespace operations_research
3 changes: 2 additions & 1 deletion ortools/base/threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "ortools/base/threadpool.h"

#include "absl/log/check.h"
#include "absl/strings/string_view.h"

namespace operations_research {
void RunWorker(void* data) {
Expand All @@ -25,7 +26,7 @@ void RunWorker(void* data) {
}
}

ThreadPool::ThreadPool(const std::string& prefix, int num_workers)
ThreadPool::ThreadPool(absl::string_view prefix, int num_workers)
: num_workers_(num_workers) {}

ThreadPool::~ThreadPool() {
Expand Down
4 changes: 3 additions & 1 deletion ortools/base/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
#include <thread> // NOLINT
#include <vector>

#include "absl/strings/string_view.h"

namespace operations_research {
class ThreadPool {
public:
ThreadPool(const std::string& prefix, int num_threads);
ThreadPool(absl::string_view prefix, int num_threads);
~ThreadPool();

void StartWorkers();
Expand Down
14 changes: 6 additions & 8 deletions ortools/lp_data/mps_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,15 @@ MPSReaderFormat TemplateFormat(MPSReader::Form form) {
} // namespace

// Parses instance from a file.
absl::Status MPSReader::ParseFile(const std::string& file_name,
absl::Status MPSReader::ParseFile(absl::string_view file_name,
LinearProgram* data, Form form) {
DataWrapper<LinearProgram> data_wrapper(data);
return MPSReaderTemplate<DataWrapper<LinearProgram>>()
.ParseFile(file_name, &data_wrapper, TemplateFormat(form))
.status();
}

absl::Status MPSReader::ParseFile(const std::string& file_name,
absl::Status MPSReader::ParseFile(absl::string_view file_name,
MPModelProto* data, Form form) {
DataWrapper<MPModelProto> data_wrapper(data);
return MPSReaderTemplate<DataWrapper<MPModelProto>>()
Expand All @@ -339,7 +339,7 @@ absl::Status MPSReader::ParseFile(const std::string& file_name,

// Loads instance from string. Useful with MapReduce. Automatically detects
// the file's format (free or fixed).
absl::Status MPSReader::ParseProblemFromString(const std::string& source,
absl::Status MPSReader::ParseProblemFromString(absl::string_view source,
LinearProgram* data,
MPSReader::Form form) {
DataWrapper<LinearProgram> data_wrapper(data);
Expand All @@ -348,7 +348,7 @@ absl::Status MPSReader::ParseProblemFromString(const std::string& source,
.status();
}

absl::Status MPSReader::ParseProblemFromString(const std::string& source,
absl::Status MPSReader::ParseProblemFromString(absl::string_view source,
MPModelProto* data,
MPSReader::Form form) {
DataWrapper<MPModelProto> data_wrapper(data);
Expand All @@ -357,8 +357,7 @@ absl::Status MPSReader::ParseProblemFromString(const std::string& source,
.status();
}

absl::StatusOr<MPModelProto> MpsDataToMPModelProto(
const std::string& mps_data) {
absl::StatusOr<MPModelProto> MpsDataToMPModelProto(absl::string_view mps_data) {
MPModelProto model;
DataWrapper<MPModelProto> data_wrapper(&model);
RETURN_IF_ERROR(
Expand All @@ -368,8 +367,7 @@ absl::StatusOr<MPModelProto> MpsDataToMPModelProto(
return model;
}

absl::StatusOr<MPModelProto> MpsFileToMPModelProto(
const std::string& mps_file) {
absl::StatusOr<MPModelProto> MpsFileToMPModelProto(absl::string_view mps_file) {
MPModelProto model;
DataWrapper<MPModelProto> data_wrapper(&model);
RETURN_IF_ERROR(
Expand Down
13 changes: 7 additions & 6 deletions ortools/lp_data/mps_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@
#include "absl/base/attributes.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "ortools/linear_solver/linear_solver.pb.h"
#include "ortools/lp_data/lp_data.h"

namespace operations_research {
namespace glop {

// Parses an MPS model from a string.
absl::StatusOr<MPModelProto> MpsDataToMPModelProto(const std::string& mps_data);
absl::StatusOr<MPModelProto> MpsDataToMPModelProto(absl::string_view mps_data);

// Parses an MPS model from a file.
absl::StatusOr<MPModelProto> MpsFileToMPModelProto(const std::string& mps_file);
absl::StatusOr<MPModelProto> MpsFileToMPModelProto(absl::string_view mps_file);

// Implementation class. Please use the 2 functions above.
//
Expand All @@ -54,17 +55,17 @@ class ABSL_DEPRECATED("Use the direct methods instead") MPSReader {
enum Form { AUTO_DETECT, FREE, FIXED };

// Parses instance from a file.
absl::Status ParseFile(const std::string& file_name, LinearProgram* data,
absl::Status ParseFile(absl::string_view file_name, LinearProgram* data,
Form form = AUTO_DETECT);

absl::Status ParseFile(const std::string& file_name, MPModelProto* data,
absl::Status ParseFile(absl::string_view file_name, MPModelProto* data,
Form form = AUTO_DETECT);
// Loads instance from string. Useful with MapReduce. Automatically detects
// the file's format (free or fixed).
absl::Status ParseProblemFromString(const std::string& source,
absl::Status ParseProblemFromString(absl::string_view source,
LinearProgram* data,
MPSReader::Form form = AUTO_DETECT);
absl::Status ParseProblemFromString(const std::string& source,
absl::Status ParseProblemFromString(absl::string_view source,
MPModelProto* data,
MPSReader::Form form = AUTO_DETECT);
};
Expand Down

0 comments on commit cafbcb1

Please sign in to comment.