Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
chr0nikler committed Dec 24, 2024
1 parent 43b41e3 commit 8801e61
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions docs/introduction/train_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,30 +166,34 @@ def get_moving_avgs(arr, window, convolution_mode):
# Smooth over a 500 episode window
rolling_length = 500
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))

axs[0].set_title("Episode rewards")
# compute and assign a rolling average of the data to provide a smoother graph
reward_moving_average = (
get_moving_avgs(env.return_queue, rolling_length, "valid")
/ rolling_length
reward_moving_average = get_moving_avgs(
env.return_queue,
rolling_length,
"valid"
)
axs[0].plot(range(len(reward_moving_average)), reward_moving_average)

axs[1].set_title("Episode lengths")
length_moving_average = (
get_moving_avgs(env.length_queue, rolling_length, "valid")
/ rolling_length
length_moving_average = get_moving_avgs(
env.length_queue,
rolling_length,
"valid"
)
axs[1].plot(range(len(length_moving_average)), length_moving_average)

axs[2].set_title("Training Error")
training_error_moving_average = (
get_moving_avgs(agent.training_error, rolling_length, "same")
/ rolling_length
training_error_moving_average = get_moving_avgs(
agent.training_error,
rolling_length,
"same"
)
axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average)
plt.tight_layout()
plt.show()



```

![](../_static/img/tutorials/blackjack_training_plots.png "Training Plot")
Expand Down

0 comments on commit 8801e61

Please sign in to comment.