Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pareto dominance definition #174

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion qdax/utils/pareto_front.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@ def compute_pareto_dominance(
Return booleans when the vector is dominated by the batch.
"""
diff = jnp.subtract(batch_of_criteria, criteria_point)
return jnp.any(jnp.all(diff > 0, axis=-1))
neutral_values = -jnp.ones_like(diff)
diff = jax.vmap(lambda x1, x2: jnp.where(mask, x1, x2), in_axes=(1, 1), out_axes=1)(
neutral_values, diff
)
diff_greater_than_zero = jnp.any(diff > 0, axis=-1)
diff_geq_than_zero = jnp.all(diff >= 0, axis=-1)

return jnp.any(jnp.logical_and(diff_greater_than_zero, diff_geq_than_zero))



def compute_pareto_front(batch_of_criteria: jnp.ndarray) -> jnp.ndarray:
Expand Down
Loading