Skip to content

Commit d206b56

Browse files
committed
Adding support for subquests in the Inform7 code.
1 parent c77a105 commit d206b56

File tree

22 files changed

+290
-130
lines changed

22 files changed

+290
-130
lines changed

scripts/tw-make

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def parse_args():
4646
help="Nb. of objects in the world.")
4747
custom_parser.add_argument("--quest-length", type=int, default=5, metavar="LENGTH",
4848
help="Minimum nb. of actions the quest requires to be completed.")
49+
custom_parser.add_argument("--quest-breadth", type=int, default=3, metavar="BREADTH",
50+
help="Control how non-linear a quest can be.")
4951

5052
challenge_parser = subparsers.add_parser("challenge", parents=[general_parser],
5153
help='Generate a game for one of the challenges.')
@@ -72,7 +74,7 @@ if __name__ == "__main__":
7274
}
7375

7476
if args.subcommand == "custom":
75-
game_file, game = textworld.make(args.world_size, args.nb_objects, args.quest_length, grammar_flags,
77+
game_file, game = textworld.make(args.world_size, args.nb_objects, args.quest_length, args.quest_breadth, grammar_flags,
7678
seed=args.seed, games_dir=args.output)
7779

7880
elif args.subcommand == "challenge":
@@ -87,7 +89,7 @@ if __name__ == "__main__":
8789

8890
print("Game generated: {}".format(game_file))
8991
if args.verbose:
90-
print(game.quests[0].desc)
92+
print(game.objective)
9193

9294
if args.view:
9395
textworld.render.visualize(game, interactive=True)

scripts/tw-stats

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ if __name__ == "__main__":
3838
continue
3939

4040
if len(game.quests) > 0:
41-
objectives[game_filename] = game.quests[0].desc
41+
objectives[game_filename] = game.objective
4242

4343
names |= set(info.name for info in game.infos.values() if info.name is not None)
4444
game_logger.collect(game)

tests/test_make_game.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ def test_making_game_with_names_to_exclude():
1111
g_rng.set_seed(42)
1212

1313
with make_temp_directory(prefix="test_render_wrapper") as tmpdir:
14-
game_file1, game1 = textworld.make(2, 20, 3, {"names_to_exclude": []},
14+
game_file1, game1 = textworld.make(2, 20, 3, 3, {"names_to_exclude": []},
1515
seed=123, games_dir=tmpdir)
1616

1717
game1_objects_names = [info.name for info in game1.infos.values() if info.name is not None]
18-
game_file2, game2 = textworld.make(2, 20, 3, {"names_to_exclude": game1_objects_names},
18+
game_file2, game2 = textworld.make(2, 20, 3, 3, {"names_to_exclude": game1_objects_names},
1919
seed=123, games_dir=tmpdir)
2020
game2_objects_names = [info.name for info in game2.infos.values() if info.name is not None]
2121
assert len(set(game1_objects_names) & set(game2_objects_names)) == 0
@@ -24,8 +24,8 @@ def test_making_game_with_names_to_exclude():
2424
def test_making_game_is_reproducible_with_seed():
2525
grammar_flags = {}
2626
with make_temp_directory(prefix="test_render_wrapper") as tmpdir:
27-
game_file1, game1 = textworld.make(2, 20, 3, grammar_flags, seed=123, games_dir=tmpdir)
28-
game_file2, game2 = textworld.make(2, 20, 3, grammar_flags, seed=123, games_dir=tmpdir)
27+
game_file1, game1 = textworld.make(2, 20, 3, 3, grammar_flags, seed=123, games_dir=tmpdir)
28+
game_file2, game2 = textworld.make(2, 20, 3, 3, grammar_flags, seed=123, games_dir=tmpdir)
2929
assert game_file1 == game_file2
3030
assert game1 == game2
3131
# Make sure they are not the same Python objects.

tests/test_play_generated_games.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ def test_play_generated_games():
1616
# Sample game specs.
1717
world_size = rng.randint(1, 10)
1818
nb_objects = rng.randint(0, 20)
19-
quest_length = rng.randint(1, 10)
19+
quest_length = rng.randint(2, 5)
20+
quest_breadth = rng.randint(3, 7)
2021
game_seed = rng.randint(0, 65365)
2122
grammar_flags = {} # Default grammar.
2223

2324
with make_temp_directory(prefix="test_play_generated_games") as tmpdir:
24-
game_file, game = textworld.make(world_size, nb_objects, quest_length, grammar_flags,
25+
game_file, game = textworld.make(world_size, nb_objects, quest_length, quest_breadth, grammar_flags,
2526
seed=game_seed, games_dir=tmpdir)
2627

2728
# Solve the game using WalkthroughAgent.

tests/test_textworld.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_game_walkthrough_agent(self):
5858
agent = textworld.agents.WalkthroughAgent()
5959
env = textworld.start(self.game_file)
6060
env.activate_state_tracking()
61-
commands = self.game.quests[0].commands
61+
commands = self.game.main_quest.commands
6262
agent.reset(env)
6363
game_state = env.reset()
6464

tests/test_tw_play.py renamed to tests/test_tw-play.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from textworld.utils import make_temp_directory
88

99

10-
def test_making_a_custom_game():
11-
with make_temp_directory(prefix="test_tw-play") as tmpdir:
12-
game_file, _ = textworld.make(5, 10, 5, {}, seed=1234, games_dir=tmpdir)
10+
def test_playing_a_game():
11+
with make_temp_directory(prefix="test_tw-play") as tmpdir:
12+
game_file, _ = textworld.make(5, 10, 5, 4, {}, seed=1234, games_dir=tmpdir)
1313

1414
command = ["tw-play", "--max-steps", "100", "--mode", "random", game_file]
1515
assert check_call(command) == 0
@@ -18,4 +18,4 @@ def test_making_a_custom_game():
1818
assert check_call(command) == 0
1919

2020
command = ["tw-play", "--max-steps", "100", "--mode", "walkthrough", game_file]
21-
assert check_call(command) == 0
21+
assert check_call(command) == 0

textworld/agents/walkthrough.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def reset(self, env):
2626
raise NameError(msg)
2727

2828
# Load command from the generated game.
29-
self._commands = iter(env.game.quests[0].commands)
29+
self._commands = iter(env.game.main_quest.commands)
3030

3131
def act(self, game_state, reward, done):
3232
try:

textworld/envs/glulx/git_glulx_ml.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,7 @@ def init(self, output: str, game=None,
149149
self._game_progression = GameProgression(game, track_quests=compute_intermediate_reward)
150150
self._state_tracking = state_tracking
151151
self._compute_intermediate_reward = compute_intermediate_reward and len(game.quests) > 0
152-
153-
self._objective = ""
154-
if len(game.quests) > 0:
155-
self._objective = game.quests[0].desc
152+
self._objective = game.objective
156153

157154
def view(self) -> "GlulxGameState":
158155
"""
@@ -317,6 +314,7 @@ def intermediate_reward(self):
317314

318315
@property
319316
def score(self):
317+
# XXX: Should the score reflect the sum of all subquests' reward?
320318
if self.has_won:
321319
return 1
322320
elif self.has_lost:
@@ -326,6 +324,7 @@ def score(self):
326324

327325
@property
328326
def max_score(self):
327+
# XXX: Should the score reflect the sum of all subquests' reward?
329328
return 1
330329

331330
@property

textworld/envs/wrappers/tests/test_viewer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_html_viewer():
1717
num_items = 10
1818
g_rng.set_seed(1234)
1919
grammar_flags = {"theme": "house", "include_adj": True}
20-
game = textworld.generator.make_game(world_size=num_nodes, nb_objects=num_items, quest_length=3, grammar_flags=grammar_flags)
20+
game = textworld.generator.make_game(world_size=num_nodes, nb_objects=num_items, quest_length=3, quest_breadth=1, grammar_flags=grammar_flags)
2121

2222
game_name = "test_html_viewer_wrapper"
2323
with make_temp_directory(prefix=game_name) as tmpdir:

textworld/generator/__init__.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def make_game_with(world, quests=None, grammar=None):
147147
return game
148148

149149

150-
def make_game(world_size: int, nb_objects: int, quest_length: int,
150+
def make_game(world_size: int, nb_objects: int, quest_length: int, quest_breadth: int,
151151
grammar_flags: Mapping = {},
152152
rngs: Optional[Dict[str, RandomState]] = None
153153
) -> Game:
@@ -158,6 +158,7 @@ def make_game(world_size: int, nb_objects: int, quest_length: int,
158158
world_size: Number of rooms in the world.
159159
nb_objects: Number of objects in the world.
160160
quest_length: Minimum nb. of actions the quest requires to be completed.
161+
quest_breadth: How many branches the quest can have.
161162
grammar_flags: Options for the grammar.
162163
163164
Returns:
@@ -175,14 +176,34 @@ def make_game(world_size: int, nb_objects: int, quest_length: int,
175176
world = make_world(world_size, nb_objects=0, rngs=rngs)
176177

177178
# Sample a quest according to quest_length.
178-
options = ChainingOptions()
179+
class Options(ChainingOptions):
180+
181+
def get_rules(self, depth):
182+
if depth == 0:
183+
# Last action should not be "go <dir>".
184+
return data.get_rules().get_matching("^(?!go.*).*")
185+
else:
186+
return super().get_rules(depth)
187+
188+
options = Options()
179189
options.backward = True
190+
options.min_depth = 1
180191
options.max_depth = quest_length
192+
options.min_breadth = 1
193+
options.max_breadth = quest_breadth
181194
options.create_variables = True
182195
options.rng = rngs['rng_quest']
183196
options.restricted_types = {"r", "d"}
184197
chain = sample_quest(world.state, options)
198+
199+
subquests = []
200+
for i in range(1, len(chain.nodes)):
201+
if chain.nodes[i].breadth != chain.nodes[i - 1].breadth:
202+
quest = Quest(chain.actions[:i])
203+
subquests.append(quest)
204+
185205
quest = Quest(chain.actions)
206+
subquests.append(quest)
186207

187208
# Set the initial state required for the quest.
188209
world.state = chain.initial_state
@@ -191,7 +212,9 @@ def make_game(world_size: int, nb_objects: int, quest_length: int,
191212
world.populate(nb_objects, rng=rngs['rng_objects'])
192213

193214
grammar = make_grammar(grammar_flags, rng=rngs['rng_grammar'])
194-
game = make_game_with(world, [quest], grammar)
215+
game = make_game_with(world, subquests, grammar)
216+
game.change_grammar(grammar)
217+
195218
return game
196219

197220

0 commit comments

Comments
 (0)