Skip to content

Make random number generator optional in creation of AgentSet, take R… #2789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,30 +162,39 @@ class AgentSet(MutableSet, Sequence):
which means that agents not referenced elsewhere in the program may be automatically removed from the AgentSet.

Notes:
A `UserWarning` is issued if `random=None`. You can resolve this warning by explicitly
passing a random number generator. In most cases, this will be the seeded random number
generator in the model. So, you would do `random=self.random` in a `Model` or `Agent` instance.
If random is None then the random number generator in the model of the first agent is used.
If the agents list is empty and random is also None a user warning is issued and the AgentSet
is an empty list and a default random number generator. This can make models non-reproducible.
If your code may create an AgentSet with no agents please pass a random number generator explicitly.

"""

def __init__(self, agents: Iterable[Agent], random: Random | None = None):
def __init__(
self,
agents: Iterable[Agent],
random: Random | np.random.Generator | None = None,
):
"""Initializes the AgentSet with a collection of agents and a reference to the model.

Args:
agents (Iterable[Agent]): An iterable of Agent objects to be included in the set.
random (Random): the random number generator
random (Random | np.random.Generator | None): the random number generator
"""
if random is None:
self._agents = weakref.WeakKeyDictionary(dict.fromkeys(agents))
if (len(self._agents) == 0) and random is None:
warnings.warn(
"Random number generator not specified, this can make models non-reproducible. Please pass a random number generator explicitly",
"No Agents specified in creation of AgentSet and no random number generator specified. "
"This can make models non-reproducible. Please pass a random number generator explicitly",
UserWarning,
stacklevel=2,
)
random = (
Random()
) # FIXME see issue 1981, how to get the central rng from model
self.random = random
self._agents = weakref.WeakKeyDictionary(dict.fromkeys(agents))
random = Random()

if random is not None:
self.random = random
else:
# all agents in an AgentSet should share the same model, just take it from first
self.random = self._agents.keys().__next__().model.random

def __len__(self) -> int:
"""Return the number of agents in the AgentSet."""
Expand Down
2 changes: 1 addition & 1 deletion mesa/experimental/continuous_space/continuous_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _add_agent(self, agent: Agent) -> int:
if self._agent_positions.shape[0] <= index:
# we are out of space
fraction = 0.2 # we add 20% Fixme
n = int(round(fraction * self._n_agents))
n = round(fraction * self._n_agents)
self._agent_positions = np.vstack(
[
self._agent_positions,
Expand Down
60 changes: 38 additions & 22 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_agentset():
model = Model()
agents = [AgentTest(model) for _ in range(10)]

agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

assert agents[0] in agentset
assert len(agentset) == len(agents)
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_function(agent):

# because AgentSet uses weakrefs, we need hard refs as well....
other_agents, another_set = pickle.loads( # noqa: S301
pickle.dumps([agents, AgentSet(agents, random=model.random)])
pickle.dumps([agents, AgentSet(agents)])
)
assert all(
a1.unique_id == a2.unique_id for a1, a2 in zip(another_set, other_agents)
Expand All @@ -131,17 +131,33 @@ def test_agentset_initialization():
model = Model()
empty_agentset = AgentSet([], random=model.random)
assert len(empty_agentset) == 0
with pytest.warns(UserWarning):
empty_agentset2 = AgentSet([])
assert len(empty_agentset2) == 0

agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
assert len(agentset) == 10


def test_agentset_initialization_w_random():
"""Test agentset initialization."""
model = Model()
empty_agentset = AgentSet([], random=model.random)
assert len(empty_agentset) == 0
assert empty_agentset.random == model.random

agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents)
assert len(agentset) == 10
assert agentset.random == model.random


def test_agentset_serialization():
"""Test pickleability of agentset."""
model = Model()
agents = [AgentTest(model) for _ in range(5)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

serialized = pickle.dumps(agentset)
deserialized = pickle.loads(serialized) # noqa: S301
Expand All @@ -156,7 +172,7 @@ def test_agent_membership():
"""Test agent membership in AgentSet."""
model = Model()
agents = [AgentTest(model) for _ in range(5)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

assert agents[0] in agentset
assert AgentTest(model) not in agentset
Expand Down Expand Up @@ -218,7 +234,7 @@ def test_agentset_get_item():
"""Test integer based access to AgentSet."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

assert agentset[0] == agents[0]
assert agentset[-1] == agents[-1]
Expand All @@ -232,7 +248,7 @@ def test_agentset_do_str():
"""Test AgentSet.do with str."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

with pytest.raises(AttributeError):
agentset.do("non_existing_method")
Expand All @@ -245,7 +261,7 @@ def test_agentset_do_str():
n = 10
model = Model()
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -255,7 +271,7 @@ def test_agentset_do_str():
# setup
model = Model()
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -267,7 +283,7 @@ def test_agentset_do_callable():
"""Test AgentSet.do with callable."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

# Test callable with non-existent function
with pytest.raises(AttributeError):
Expand All @@ -281,7 +297,7 @@ def test_agentset_do_callable():
n = 10
model = Model()
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -292,7 +308,7 @@ def test_agentset_do_callable():
# setup again for lambda function tests
model = Model()
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -310,7 +326,7 @@ def remove_function(agent):
# setup again for actual function tests
model = Model()
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -321,7 +337,7 @@ def remove_function(agent):
# setup again for actual function tests
model = Model()
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
for agent in agents:
agent.agent_set = agentset

Expand Down Expand Up @@ -386,7 +402,7 @@ def test_agentset_agg():
agent.energy = i + 1
agent.wealth = 10 * (i + 1)

agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

# Test min aggregation
min_energy = agentset.agg("energy", min)
Expand Down Expand Up @@ -435,7 +451,7 @@ def __init__(self, model, age=None):

model = Model()
agents = [TestAgentWithAttribute(model, age=i) for i in range(5)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

# Set a new attribute "health" and an existing attribute "age" for all agents
agentset.set("health", 100).set("age", 50).set("status", "active")
Expand All @@ -454,7 +470,7 @@ def test_agentset_map_str():
"""Test AgentSet.map with strings."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

with pytest.raises(AttributeError):
agentset.do("non_existing_method")
Expand All @@ -467,7 +483,7 @@ def test_agentset_map_callable():
"""Test AgentSet.map with callable."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

# Test callable with non-existent function
with pytest.raises(AttributeError):
Expand All @@ -494,7 +510,7 @@ def test_method(self):
self.called = True

agents = [TestAgentShuffleDo(model) for _ in range(100)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

# Test shuffle_do with a string method name
agentset.shuffle_do("test_method")
Expand Down Expand Up @@ -544,7 +560,7 @@ def test_agentset_get_attribute():
"""Test AgentSet.get for attributes."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

unique_ids = agentset.get("unique_id")
assert unique_ids == [agent.unique_id for agent in agents]
Expand All @@ -558,7 +574,7 @@ def test_agentset_get_attribute():
agent = AgentTest(model)
agent.i = i**2
agents.append(agent)
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

values = agentset.get(["unique_id", "i"])

Expand Down Expand Up @@ -634,7 +650,7 @@ def get_unique_identifier(self):

model = Model()
agents = [TestAgent(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

groups = agentset.groupby("even")
assert len(groups.groups[True]) == 5
Expand Down
Loading