Skip to content

Commit

Permalink
Clear memoized values after updating messages
Browse files Browse the repository at this point in the history
  • Loading branch information
jkawamoto committed Jan 17, 2023
1 parent abf9d1c commit 9bc5cd5
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 35 deletions.
20 changes: 12 additions & 8 deletions fraud_eagle/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
# along with rgmining-fraud-eagle. If not, see <http://www.gnu.org/licenses/>.
"""Provide a bipartite graph class implementing Fraud Eagle algorithm.
"""
from functools import lru_cache
from logging import getLogger
from typing import Any, Final, Optional, cast

import networkx as nx
import numpy as np
from common import memoized

from fraud_eagle.labels import ProductLabel, ReviewLabel, UserLabel
from fraud_eagle.likelihood import psi
Expand Down Expand Up @@ -328,7 +328,7 @@ def add_review(self, reviewer: Reviewer, product: Product, rating: float, *_args
self.graph.add_edge(reviewer, product, review=review)
return review

@memoized
@lru_cache
def retrieve_reviewers(self, product: Product) -> list[Reviewer]:
"""Retrieve reviewers review a given product.
Expand All @@ -340,7 +340,7 @@ def retrieve_reviewers(self, product: Product) -> list[Reviewer]:
"""
return list(self.graph.predecessors(product))

@memoized
@lru_cache
def retrieve_products(self, reviewer: Reviewer) -> list[Product]:
"""Retrieve products a given reviewer reviews.
Expand All @@ -352,7 +352,7 @@ def retrieve_products(self, reviewer: Reviewer) -> list[Product]:
"""
return list(self.graph.successors(reviewer))

@memoized
@lru_cache
def retrieve_review(self, reviewer: Reviewer, product: Product) -> Review:
"""Retrieve a review a given reviewer posts to a given product.
Expand Down Expand Up @@ -450,6 +450,10 @@ def update(self) -> float:
+ "\n".join(" {0}-{1}: {2}".format(edges[i], edges[i + 1], v) for i, v in enumerate(histo))
)

# Clear memoized values since messages are updated.
self.prod_message_from_all_users.cache_clear()
self.prod_message_from_all_products.cache_clear()

return max(diffs)

def _update_user_to_product(self, reviewer: Reviewer, product: Product, p_label: ProductLabel) -> float:
Expand Down Expand Up @@ -530,7 +534,7 @@ def _update_product_to_user(self, reviewer: Reviewer, product: Product, u_label:
)
return _logaddexp(*res.values())

@memoized
@lru_cache
def prod_message_from_all_users(self, product: Product, p_label: ProductLabel) -> float:
"""Compute a product of messages to a product.
Expand Down Expand Up @@ -581,13 +585,13 @@ def prod_message_from_users(self, reviewer: Optional[Reviewer], product: Product
Returns:
a logarithm of the product defined above.
"""
sum_all: float = self.prod_message_from_all_users(product, p_label)
sum_all = self.prod_message_from_all_users(product, p_label)
sum_reviewer = 0.0
if reviewer is not None:
sum_reviewer = self.retrieve_review(reviewer, product).user_to_product(p_label)
return sum_all - sum_reviewer

@memoized
@lru_cache
def prod_message_from_all_products(self, reviewer: Reviewer, u_label: UserLabel) -> float:
"""Compute a product of messages sending to a reviewer.
Expand Down Expand Up @@ -638,7 +642,7 @@ def prod_message_from_products(self, reviewer: Reviewer, product: Optional[Produ
Returns:
a logarithm of the product defined above.
"""
sum_all: float = self.prod_message_from_all_products(reviewer, u_label)
sum_all = self.prod_message_from_all_products(reviewer, u_label)
sum_product = 0.0
if product is not None:
sum_product = self.retrieve_review(reviewer, product).product_to_user(u_label)
Expand Down
25 changes: 7 additions & 18 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ include = ["COPYING"]
python = "^3.9"
numpy = "^1.24.1"
networkx = "^3.0"
rgmining-common = "^0.9.1"

[tool.poetry.group.dev.dependencies]
bump2version = "^1.0.1"
Expand All @@ -57,10 +56,6 @@ disallow_untyped_defs = true
module = "networkx"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "common"
ignore_missing_imports = true

[tool.black]
target-version = ['py39']
line-length = 120
Expand Down
12 changes: 8 additions & 4 deletions tests/graph/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,16 @@ def test_prod_message_from_users_update(review_graph: ReviewGraph) -> None:

assert_almost_equal(
review_graph.prod_message_from_users(reviewers[0], products[0], ProductLabel.GOOD),
sum(r.user_to_product(ProductLabel.GOOD) for r in reviews[1:]))
sum(r.user_to_product(ProductLabel.GOOD) for r in reviews[1:]),
)

for _ in range(10):
review_graph.update()

assert_almost_equal(
review_graph.prod_message_from_users(reviewers[0], products[0], ProductLabel.GOOD),
sum(r.user_to_product(ProductLabel.GOOD) for r in reviews[1:]))
sum(r.user_to_product(ProductLabel.GOOD) for r in reviews[1:]),
)


def test_prod_message_from_products(review_graph: ReviewGraph) -> None:
Expand Down Expand Up @@ -183,14 +185,16 @@ def test_prod_message_from_products_update(review_graph: ReviewGraph) -> None:

assert_almost_equal(
review_graph.prod_message_from_products(reviewers[0], products[0], UserLabel.HONEST),
sum(r.product_to_user(UserLabel.HONEST) for r in reviews[1:]))
sum(r.product_to_user(UserLabel.HONEST) for r in reviews[1:]),
)

for _ in range(10):
review_graph.update()

assert_almost_equal(
review_graph.prod_message_from_products(reviewers[0], products[0], UserLabel.HONEST),
sum(r.product_to_user(UserLabel.HONEST) for r in reviews[1:]))
sum(r.product_to_user(UserLabel.HONEST) for r in reviews[1:]),
)


def test_update_user_to_product(review_graph: ReviewGraph) -> None:
Expand Down

0 comments on commit 9bc5cd5

Please sign in to comment.