Skip to content

Commit

Permalink
ver 2.1.0 ready
Browse files Browse the repository at this point in the history
  • Loading branch information
tae898 committed Sep 25, 2024
1 parent bca245d commit 77ec4ea
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 51 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ repo](https://github.com/humemai/humemai).
- ["A Machine with Short-Term, Episodic, and Semantic Memory
Systems"](https://arxiv.org/abs/2212.02098)
- ["Leveraging Knowledge Graph-Based Human-Like Memory Systems to Solve Partially Observable Markov Decision Processes"](https://arxiv.org/abs/2408.05861)


## pdoc documentation

Expand Down
16 changes: 8 additions & 8 deletions room_env/envs/room2.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def step(self, actions: tuple[list[str], str]) -> tuple[dict, int, bool, dict]:
north, east, south, west, or stay.
Returns:
(observation, question), reward, truncated, done, info
(observation, question), rewards, truncated, done, info
"""
actions_qa, action_explore = actions
Expand All @@ -849,20 +849,20 @@ def step(self, actions: tuple[list[str], str]) -> tuple[dict, int, bool, dict]:

if len(self.answers) == 0:
assert actions_qa == [], "You shouldn't answer any questions"
reward = 0
rewards = []

else:
assert len(actions_qa) == len(
self.answers
), "You should answer all the questions."
reward = 0
rewards = []
for actions_qa, answer in zip(actions_qa, self.answers):
if actions_qa == answer["current"]:
reward += self.CORRECT
rewards.append(self.CORRECT)
elif actions_qa == answer["previous"]:
reward += self.PARTIAL
rewards.append(self.PARTIAL)
else:
reward += self.WRONG
rewards.append(self.WRONG)

if not self.make_everything_static:
for obj in self.objects["independent"]:
Expand Down Expand Up @@ -891,15 +891,15 @@ def step(self, actions: tuple[list[str], str]) -> tuple[dict, int, bool, dict]:
if (self.current_time + 1) % self.question_interval == 0:
return (
self.get_observations_and_question(generate_questions=True),
reward,
rewards,
done,
truncated,
info,
)
else:
return (
self.get_observations_and_question(generate_questions=False),
reward,
rewards,
done,
truncated,
info,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = room_env
version = 2.0.7
version = 2.1.0
author = Taewoon Kim
author_email = info@humem.ai
description = The Room environment
Expand Down
10 changes: 5 additions & 5 deletions test/room_env2/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_all_static(self) -> None:
locations[obj.name] = [obj.location]

while True:
observations, reward, done, truncated, info = env.step((["foo"], "stay"))
observations, rewards, done, truncated, info = env.step((["foo"], "stay"))
for obj_type, objs in env.objects.items():
for obj in objs:
locations[obj.name].append(obj.location)
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_not_all_static(self) -> None:
locations[obj.name] = [obj.location]

while True:
observations, reward, done, truncated, info = env.step(
observations, rewards, done, truncated, info = env.step(
(["foo"], random.choice(["north", "south", "east", "west"]))
)
for obj_type, objs in env.objects.items():
Expand Down Expand Up @@ -122,16 +122,16 @@ def test_deterministic_objects(self) -> None:
}
env = gym.make("room_env:RoomEnv-v2", **env_config)
observations, info = env.reset()
rewards = 0
rewards = []

while True:
observations, reward, done, truncated, info = env.step(
observations, rewards_, done, truncated, info = env.step(
(
["random answer"] * len(observations["questions"]),
random.choice(["north", "east", "south", "west", "stay"]),
)
)
rewards += reward
rewards.append(rewards_)
if done or truncated:
break

Expand Down
16 changes: 9 additions & 7 deletions test/room_env2/test_num_total_questions_question_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def test_wrong_answers(self) -> None:

self.assertEqual(observations["questions"], [])
with self.assertRaises(AssertionError):
observations, reward, done, truncated, info = env.step(("foo", "stay"))
observations, rewards, done, truncated, info = env.step(("foo", "stay"))
with self.assertRaises(AssertionError):
observations, reward, done, truncated, info = env.step((["foo"], "stay"))
observations, rewards, done, truncated, info = env.step((["foo"], "stay"))

def test_correct_answers(self) -> None:
env_config = {
Expand All @@ -64,8 +64,8 @@ def test_correct_answers(self) -> None:
questions_all = []
self.assertEqual(len(observations["questions"]), 0)

observations, reward, done, truncated, info = env.step(([], "stay"))
self.assertEqual(reward, 0)
observations, rewards, done, truncated, info = env.step(([], "stay"))
self.assertEqual(rewards, [])

while True:
flag = False
Expand All @@ -77,11 +77,13 @@ def test_correct_answers(self) -> None:
for q in observations["questions"]:
questions_all.append(q)

observations, reward, done, truncated, info = env.step((actions_qa, "stay"))
observations, rewards, done, truncated, info = env.step(
(actions_qa, "stay")
)
if flag:
self.assertEqual(reward, -1 * env.num_questions_step)
self.assertEqual(rewards, [-1] * env.num_questions_step)
else:
self.assertEqual(reward, 0)
self.assertEqual(rewards, [])

if done:
break
Expand Down
8 changes: 4 additions & 4 deletions test/room_env2/test_partial_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,26 @@ def test_partial_points(self) -> None:
if break_:
break
if len(set(obj.history)) == 1:
observations, reward, done, truncated, info = env.step(
observations, rewards, done, truncated, info = env.step(
(
[obj.location],
random.choice(["north", "south", "east", "west", "stay"]),
)
)

self.assertEqual(reward, 1)
self.assertEqual(rewards, [1])
else:
for previous_location in obj.history[::-1]:
if previous_location != obj.location:
break

observations, reward, done, truncated, info = env.step(
observations, rewards, done, truncated, info = env.step(
(
[previous_location],
random.choice(["north", "south", "east", "west", "stay"]),
)
)
self.assertEqual(reward, 0)
self.assertEqual(rewards, [0])

if done:
break
50 changes: 25 additions & 25 deletions test/room_env2/test_room_env_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ def test_all(self) -> None:

question_previous = observations["questions"][0]

observations, reward, done, truncated, info = self.env.step(
observations, rewards_, done, truncated, info = self.env.step(
(actions_qa, "east")
)
rewards.append(reward)
self.assertEqual(reward, 1)
rewards.append(rewards_)
self.assertEqual(rewards_, [1])
self.assertFalse(done)
# if question_previous == ["?", "atlocation", "officeroom", 0]:
# self.assertEqual(info, {"answers": ["desk"], "timestamp": 0})
Expand Down Expand Up @@ -131,11 +131,11 @@ def test_all(self) -> None:
raise ValueError
question_previous = observations["questions"][0]

observations, reward, done, truncated, info = self.env.step(
observations, rewards_, done, truncated, info = self.env.step(
(actions_qa, "stay")
)
rewards.append(reward)
self.assertEqual(reward, 1)
rewards.append(rewards_)
self.assertEqual(rewards_, [1])
self.assertFalse(done)
# if question_previous == ["?", "atlocation", "officeroom", 1]:
# self.assertEqual(info, {"answers": ["desk"], "timestamp": 1})
Expand All @@ -151,20 +151,20 @@ def test_all(self) -> None:
raise ValueError

for _ in range(7):
observations, reward, done, truncated, info = self.env.step(
observations, rewards_, done, truncated, info = self.env.step(
(["foo"], "stay")
)
rewards.append(reward)
rewards.append(rewards_)

self.assertEqual(reward, -1)
self.assertEqual(rewards_, [-1])
self.assertFalse(done)

observations, reward, done, truncated, info = self.env.step(
observations, rewards_, done, truncated, info = self.env.step(
(["bar"], "stay")
)
rewards.append(reward)
rewards.append(rewards_)
self.assertEqual(len(rewards), 10)
self.assertEqual(reward, -1)
self.assertEqual(rewards_, [-1])
self.assertTrue(done)
# self.assertIsNone(self.env.observations_room)
# self.assertIsNone(self.env.question)
Expand Down Expand Up @@ -296,10 +296,10 @@ def test_all(self) -> None:
actions_qa = ["officeroom"]
question_previous = observations["questions"]

observations, reward, done, truncated, info = self.env.step(
observations, rewards_, done, truncated, info = self.env.step(
(actions_qa, "east")
)
rewards.append(reward)
rewards.append(rewards_)
# if question_previous[0] == "?":
# self.assertEqual(
# info, {"answers": ["desk", "tae", "laptop"], "timestamp": 0}
Expand Down Expand Up @@ -335,7 +335,7 @@ def test_all(self) -> None:
["?", "atlocation", "livingroom", 1],
],
)
self.assertEqual(reward, 1)
self.assertEqual(rewards_, [1])
self.assertFalse(done)

if observations["questions"][0] == ["desk", "atlocation", "?", 1]:
Expand All @@ -353,10 +353,10 @@ def test_all(self) -> None:

question_previous = observations["questions"][0]

observations, reward, done, truncated, info = self.env.step(
observations, rewards_, done, truncated, info = self.env.step(
(actions_qa, "west")
)
rewards.append(reward)
rewards.append(rewards_)
if question_previous == ["desk", "atlocation", "?", 1]:
self.assertEqual(
info,
Expand Down Expand Up @@ -416,24 +416,24 @@ def test_all(self) -> None:
["?", "atlocation", "officeroom", 2],
],
)
self.assertEqual(reward, 1)
self.assertEqual(rewards_, [1])
self.assertFalse(done)

for _ in range(97):
observations, reward, done, truncated, info = self.env.step(
observations, rewards_, done, truncated, info = self.env.step(
(["foo"], "stay")
)
rewards.append(reward)
rewards.append(rewards_)

self.assertEqual(reward, -1)
self.assertEqual(rewards_, [-1])
self.assertFalse(done)

observations, reward, done, truncated, info = self.env.step(
observations, rewards_, done, truncated, info = self.env.step(
(["bar"], "stay")
)
rewards.append(reward)
rewards.append(rewards_)
self.assertEqual(len(rewards), 100)
self.assertEqual(reward, -1)
self.assertEqual(rewards_, [-1])
self.assertTrue(done)
# self.assertIsNone(self.env.observations_room)
# self.assertIsNone(self.env.question)
Expand All @@ -458,7 +458,7 @@ def test_all(self) -> None:
while True:
actions_qa = [random.choice(observations["questions"][0])]
action_explore = random.choice(["north", "east", "south", "west", "stay"])
observations, reward, done, truncated, info = self.env.step(
observations, rewards_, done, truncated, info = self.env.step(
(actions_qa, action_explore)
)
if done:
Expand Down

0 comments on commit 77ec4ea

Please sign in to comment.