Skip to content

AIMA4e Search #431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions core/src/main/java/aima/core/search/basic/SearchUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package aima.core.search.basic;

import aima.core.search.api.Node;
import aima.core.search.api.NodeFactory;
import aima.core.search.api.Problem;
import aima.core.search.basic.support.BasicNodeFactory;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
* Some utility functions for the search module
*/
public class SearchUtils {

/**
* Calculates the successors of a given node for a given problem.
*
* @param problem
* @param parent
* @param <A>
* @param <S>
* @return
*/
public static <A, S> List<Node<A, S>> successors(Problem<A, S> problem, Node<A, S> parent) {
S s = parent.state();
List<Node<A, S>> nodes = new ArrayList<>();

NodeFactory<A, S> nodeFactory = new BasicNodeFactory<>();
for (A action :
problem.actions(s)) {
S sPrime = problem.result(s, action);
double cost = parent.pathCost() + problem.stepCost(s, action, sPrime);
Node<A, S> node = nodeFactory.newChildNode(problem, parent, action);
nodes.add(node);
}
return nodes;
}

/**
* Calculates the depth of a node in a particular tree.
*
* @param node
* @return
*/
public static int depth(Node node) {
Node temp = node;
int count = 0;
while (temp != null) {
count++;
temp = temp.parent();
}
return count;
}

/**
* Extracts the list of actions from a solution state.
*
* @param solution
* @param <A>
* @param <S>
* @return
*/
public static <A, S> List<A> generateActions(Node<A, S> solution) {
Node<A, S> parent = solution;
List<A> actions = new ArrayList<>();
while (parent.parent() != null) {
actions.add(parent.action());
parent = parent.parent();
}
Collections.reverse(actions);
return actions;
}
}
211 changes: 96 additions & 115 deletions core/src/main/java/aima/core/search/basic/informed/AStarSearch.java
Original file line number Diff line number Diff line change
@@ -1,137 +1,118 @@
package aima.core.search.basic.informed;

import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import java.util.function.ToDoubleFunction;

import aima.core.search.api.Node;
import aima.core.search.api.NodeFactory;
import aima.core.search.api.Problem;
import aima.core.search.api.SearchController;
import aima.core.search.api.SearchForActionsFunction;
import aima.core.search.basic.SearchUtils;
import aima.core.search.basic.support.BasicNodeFactory;
import aima.core.search.basic.support.BasicSearchController;
import aima.core.search.basic.uninformedsearch.GenericSearchInterface;

import java.util.*;
import java.util.function.ToDoubleFunction;

/**
* <pre>
* function A*-SEARCH(problem) returns a solution, or failure
* node &larr; a node with STATE = problem.INITIAL-STATE, PATH-COST=0
* frontier &larr; a priority queue ordered by PATH-COST + h(NODE), with node as the only element
* explored &larr; an empty set
* loop do
* if EMPTY?(frontier) then return failure
* node &lt;- POP(frontier) // chooses the lowest-cost node in frontier
* if problem.GOAL-TEST(node.STATE) then return SOLUTION(node)
* add node.STATE to explored
* for each action in problem.ACTIONS(node.STATE) do
* child &larr; CHILD-NODE(problem, node, action)
* if child.STATE is not in explored or frontier then
* frontier &larr; INSERT(child, frontier)
* else if child.STATE is in frontier with higher COST then
* replace that frontier node with child
*  if problem's initial state is a goal then return empty path to initial state
* frontier a priority queue ordered by f(n) = h(n) + g(n), with a node for the initial state
*  reached ← a table of {state: the best path that reached state}; initially empty
*  solution ← failure
*  while frontier is not empty and top(frontier) is cheaper than solution do
*    parent ← pop(frontier)
*    for child in successors(parent) do
*      s ← child.state
*      if s is not in reached or child is a cheaper path than reached[s] then
*        reached[s] ← child
*        add child to the frontier
*        if child is a goal and is cheaper than solution then
*          solution = child
*  return solution
* </pre>
*
*
* @author Ciaran O'Reilly
* @author samagra
*/
public class AStarSearch<A, S> implements SearchForActionsFunction<A, S> {
// function A*-SEARCH((problem) returns a solution, or failure
@Override
public List<A> apply(Problem<A, S> problem) {
// node <- a node with STATE = problem.INITIAL-STATE, PATH-COST=0
Node<A, S> node = newRootNode(problem.initialState(), 0);
// frontier <- a priority queue ordered by PATH-COST + h(NODE), with
// node as the
// only element
Queue<Node<A, S>> frontier = newPriorityQueueOrderedByPathCostPlusH(node);
// explored <- an empty set
Set<S> explored = newExploredSet();
// loop do
while (true) {
// if EMPTY?(frontier) then return failure
if (frontier.isEmpty()) {
return failure();
}
// node <- POP(frontier) // chooses the lowest-cost node in frontier
node = frontier.remove();
// if problem.GOAL-TEST(node.STATE) then return SOLUTION(node)
if (isGoalState(node, problem)) {
return solution(node);
}
// add node.STATE to explored
explored.add(node.state());
// for each action in problem.ACTIONS(node.STATE) do
for (A action : problem.actions(node.state())) {
// child <- CHILD-NODE(problem, node, action)
Node<A, S> child = newChildNode(problem, node, action);
// if child.STATE is not in explored or frontier then
if (!(explored.contains(child.state()) || containsState(frontier, child.state()))) {
// frontier <- INSERT(child, frontier)
frontier.add(child);
} // else if child.STATE is in frontier with higher COST then
else if (removedNodeFromFrontierWithSameStateAndHigherCost(child, frontier)) {
// replace that frontier node with child
frontier.add(child);
}
}
}
}

//
// Supporting Code
protected ToDoubleFunction<Node<A, S>> h;
protected NodeFactory<A, S> nodeFactory = new BasicNodeFactory<>();
protected SearchController<A, S> searchController = new BasicSearchController<A, S>();

public AStarSearch(ToDoubleFunction<Node<A, S>> h) {
this.h = h;
}

public ToDoubleFunction<Node<A, S>> getHeuristicFunctionH() {
return h;
}

public Node<A, S> newRootNode(S initialState, double pathCost) {
return nodeFactory.newRootNode(initialState, pathCost);
}
public class AStarSearch<A, S> implements GenericSearchInterface<A, S>, SearchForActionsFunction<A, S> {

public Node<A, S> newChildNode(Problem<A, S> problem, Node<A, S> node, A action) {
return nodeFactory.newChildNode(problem, node, action);
}
// The heuristic function
protected ToDoubleFunction<Node<A, S>> h;
// A helper class to generate new nodes.
protected NodeFactory<A, S> nodeFactory = new BasicNodeFactory<>();

public Queue<Node<A, S>> newPriorityQueueOrderedByPathCostPlusH(Node<A, S> initialNode) {
Queue<Node<A, S>> frontier = new PriorityQueue<>(
Comparator.comparingDouble(n -> n.pathCost() + h.applyAsDouble(n)));
frontier.add(initialNode);
return frontier;
}
// frontier ← a priority queue ordered by f(n) = h(n)+g(n), with a node for the initial state
PriorityQueue<Node<A, S>> frontier = new PriorityQueue<>(new Comparator<Node<A, S>>() {
@Override
public int compare(Node<A, S> o1, Node<A, S> o2) {
return (int) (getCostValue(o1) - getCostValue(o2));
}
});

public Set<S> newExploredSet() {
return new HashSet<>();
}

public List<A> failure() {
return searchController.failure();
}
/**
* The constructor that takes in the heuristics function.
*
* @param h
*/
public AStarSearch(ToDoubleFunction<Node<A, S>> h) {
this.h = h;
}

public List<A> solution(Node<A, S> node) {
return searchController.solution(node);
}
@Override
public Node<A, S> search(Problem<A, S> problem) {
if (problem.isGoalState(problem.initialState())) {
return nodeFactory.newRootNode(problem.initialState());
}
frontier.clear();
frontier.add(nodeFactory.newRootNode(problem.initialState()));
// reached ← a table of {state: the best path that reached state}; initially empty
HashMap<S, Node<A, S>> reached = new HashMap<>();
Node<A, S> solution = null;
// while frontier is not empty and top(frontier) is cheaper than solution do
while (!frontier.isEmpty() &&
(solution == null || getCostValue(frontier.peek()) < getCostValue(solution))) {
Node<A, S> parent = frontier.poll();
for (Node<A, S> child :
SearchUtils.successors(problem, parent)) {
S s = child.state();
// if s is not in reached or child is a cheaper path than reached[s] then
if (!reached.containsKey(s) ||
getCostValue(child) < getCostValue(reached.get(s))) {
reached.put(s, child);
frontier.add(child);
// if child is a goal and is cheaper than solution
if (problem.isGoalState(s) &&
(solution == null || getCostValue(child) < getCostValue(solution))) {
solution = child;
}
}
}
}
return solution;
}

public boolean isGoalState(Node<A, S> node, Problem<A, S> problem) {
return searchController.isGoalState(node, problem);
}

public boolean containsState(Queue<Node<A, S>> frontier, S state) {
// NOTE: Not very efficient (i.e. linear in the size of the frontier)
return frontier.stream().anyMatch(frontierNode -> frontierNode.state().equals(state));
}
/**
* Returns the list of actions that need to be taken in order to achieve the goal.
*
* @param problem The search problem
* @return the list of actions
*/
@Override
public List<A> apply(Problem<A, S> problem) {
Node<A, S> solution = this.search(problem);
if (solution == null)
return new ArrayList<>();
else
return SearchUtils.generateActions(solution);
}

public boolean removedNodeFromFrontierWithSameStateAndHigherCost(Node<A, S> child, Queue<Node<A, S>> frontier) {
// NOTE: Not very efficient (i.e. linear in the size of the frontier)
return frontier.removeIf(n -> n.state().equals(child.state()) && n.pathCost() > child.pathCost());
}
/**
* Finds the value of f(n) = g(n)+h(n) for a node n.
*
* @param node The node n
* @return f(n)
*/
private double getCostValue(Node<A, S> node) {
return node.pathCost() + h.applyAsDouble(node);
}
}
Loading