Skip to content

Commit

Permalink
MH related docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
jingnongqu committed Feb 5, 2025
1 parent 44a9261 commit 3926eed
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 6 deletions.
16 changes: 14 additions & 2 deletions src/ultk/language/grammar/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from importlib import import_module
from itertools import product
from typing import Any, Callable, Generator, TypedDict, TypeVar, Iterable
from typing import Any, Callable, Generator, TypedDict, TypeVar
from yaml import load
from functools import cache

Expand All @@ -20,7 +20,6 @@
from ultk.util.frozendict import FrozenDict

T = TypeVar("T")
Dataset = Iterable[tuple[Referent, T]]


@dataclass(frozen=True)
Expand Down Expand Up @@ -185,6 +184,11 @@ def replace_children(self, children) -> None:

@cache
def node_count(self) -> int:
"""Count the node of a GrammaticalExpression
Returns:
int: node count
"""
counter = 1
stack = [self]
while stack:
Expand Down Expand Up @@ -280,6 +284,14 @@ def probability(self, rule: Rule) -> float:
return float(rule.weight) / sum([r.weight for r in self._rules[rule.lhs]])

def prior(self, expr: GrammaticalExpression) -> float:
"""Prior of a GrammaticalExpression
Args:
expr (GrammaticalExpression): the GrammaticalExpression for compuation
Returns:
float: prior
"""
probability = self.probability(self._rules_by_name[expr.rule_name])
children = expr.children if expr.children else ()
for child in children:
Expand Down
34 changes: 32 additions & 2 deletions src/ultk/language/grammar/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ def mh_sample(
data: Dataset,
likelihood_func: Callable[[Dataset, GrammaticalExpression], float] = all_or_nothing,
) -> GrammaticalExpression:
"""Sample a new GrammaticalExpression from an exsiting one and data using Metropolis Hastings
Args:
expr (GrammaticalExpression): the exsiting GrammaticalExpression
grammar (Grammar): the grammar for generation
data (Dataset): data used for calculation of acceptance probability
likelihood_func (Callable[[Dataset, GrammaticalExpression], float], optional): _description_. Defaults to all_or_nothing.
Returns:
GrammaticalExpression: newly sampled GrammaticalExpression
"""
old_tree_prior = grammar.prior(expr)
old_node_count = expr.node_count()
while True:
Expand Down Expand Up @@ -41,7 +52,15 @@ def mh_sample(

def mh_select(
old_tree: GrammaticalExpression,
) -> (GrammaticalExpression, GrammaticalExpression):
) -> tuple[GrammaticalExpression, GrammaticalExpression]:
"""Select a node for futher change from a GrammaticalExpression
Args:
old_tree (GrammaticalExpression): input GrammaticalExpression
Returns:
tuple[GrammaticalExpression, GrammaticalExpression]: the node selected for change and its parent node
"""
linearized_self = []
parents = []
stack = [(old_tree, -1)]
Expand All @@ -64,7 +83,18 @@ def mh_generate(
current_node: GrammaticalExpression,
parent_node: GrammaticalExpression,
grammar: Grammar,
) -> (GrammaticalExpression, GrammaticalExpression):
) -> tuple[GrammaticalExpression, GrammaticalExpression]:
"""Generate a new GrammaticalExpression
Args:
old_tree (GrammaticalExpression): the original full GrammaticalExpression
current_node (GrammaticalExpression): the node selected for change
parent_node (GrammaticalExpression): the parent node for the chaging node
grammar (Grammar): grammar used for generation
Returns:
tuple[GrammaticalExpression, GrammaticalExpression]: the new full GrammaticalExpression and the changed node
"""
if current_node != old_tree:
new_children = []
children = parent_node.children if parent_node.children else ()
Expand Down
12 changes: 10 additions & 2 deletions src/ultk/language/grammar/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@
from ultk.language.semantics import Referent

T = TypeVar("T")

Dataset = Iterable[tuple[Referent, T]]


def all_or_nothing(data: Dataset, tree: GrammaticalExpression):
def all_or_nothing(data: Dataset, tree: GrammaticalExpression) -> float:
"""Basic all or nothing likelihood, return 1 if all data are correctly predicted, 0 otherwise
Args:
data (Dataset): data for likelihood calculation
tree (GrammaticalExpression): GrammaticalExpression for likelihood calculation
Returns:
float: likelihood
"""
return float(all(tree(datum[0]) == datum[1] for datum in data))

0 comments on commit 3926eed

Please sign in to comment.