Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rllib/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2582,7 +2582,7 @@ py_test(
py_test(
# TODO(#50340): test is flaky.
name = "test_offline_prelearner",
size = "medium",
size = "large",
srcs = ["offline/tests/test_offline_prelearner.py"],
# Include the offline data files.
data = [
Expand Down
36 changes: 34 additions & 2 deletions rllib/offline/tests/test_offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,12 @@ def test_offline_prelearner_sample_from_old_sample_batch_data(self):

def test_offline_prelearner_sample_from_episode_data(self):

print("[DEBUG] Starting test_offline_prelearner_sample_from_episode_data")

# Store data only temporary.
data_path = "/tmp/cartpole-v1_episodes/"
# Configure PPO for recording.
print("[DEBUG] Configuring PPO for recording...")
config = (
PPOConfig()
.environment(
Expand All @@ -281,50 +284,79 @@ def test_offline_prelearner_sample_from_episode_data(self):
)

# Record some episodes.
print("[DEBUG] Building PPO algorithm...")
algo = config.build()
for _ in range(3):
print("[DEBUG] PPO algorithm built successfully")
for i in range(3):
print(f"[DEBUG] Starting PPO training iteration {i + 1}/3...")
algo.train()
print(f"[DEBUG] Completed PPO training iteration {i + 1}/3")

print("[DEBUG] Stopping PPO algorithm...")
algo.stop()
print("[DEBUG] PPO algorithm stopped")

# Reset the input data and the episode read flag.
print("[DEBUG] Configuring BC for offline data...")
self.config.offline_data(
input_=[data_path],
input_read_episodes=True,
input_read_batch_size=50,
)

# Build the `BC` algorithm.
print("[DEBUG] Building BC algorithm...")
algo = self.config.build()
print("[DEBUG] BC algorithm built successfully")

# Read in the generated set of episode data.
print("[DEBUG] Reading parquet data...")
episode_ds = ray.data.read_parquet(data_path)
print("[DEBUG] Parquet data read successfully")

# Sample a batch of episodes from the episode dataset.
print("[DEBUG] Taking batch of 256 episodes...")
episode_batch = episode_ds.take_batch(256)
print("[DEBUG] Batch taken successfully")

# Get the module state from the `Learner`.
print("[DEBUG] Getting module state from Learner...")
module_state = algo.offline_data.learner_handles[0].get_state(
component=COMPONENT_RL_MODULE,
)[COMPONENT_RL_MODULE]
print("[DEBUG] Module state retrieved successfully")

# Set up an `OfflinePreLearner` instance.
print("[DEBUG] Creating OfflinePreLearner instance...")
oplr = OfflinePreLearner(
config=self.config,
module_spec=algo.offline_data.module_spec,
module_state=module_state,
spaces=algo.offline_data.spaces[INPUT_ENV_SPACES],
)
print("[DEBUG] OfflinePreLearner created successfully")

# Sample a batch.
print("[DEBUG] Sampling batch through OfflinePreLearner...")
batch = unflatten_dict(oplr(episode_batch))
print("[DEBUG] Batch sampled successfully")

# Assert that we have indeed a batch of `train_batch_size_per_learner`.
self.assertEqual(
batch[DEFAULT_POLICY_ID][Columns.REWARDS].shape[0],
self.config.train_batch_size_per_learner,
)
print("[DEBUG] Assertion passed")

# Remove all generated Parquet data from disk.
print("[DEBUG] Cleaning up temporary data...")
shutil.rmtree(data_path)
print("[DEBUG] Test completed successfully")


if __name__ == "__main__":
import sys

import pytest

sys.exit(pytest.main(["-v", __file__]))
sys.exit(pytest.main(["-v", __file__, "-s"]))