Skip to content

Lord225/explain-rl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

87 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Explainable Vision Reinforcement Learning with PPO and Vision Transformers

This repository contains the implementation and results of a thesis project focused on enhancing the explainability of reinforcement learning (RL) agents using vision-based observations. The project investigates the use of Proximal Policy Optimization (PPO) combined with Vision Transformers (ViTs) and proposes novel methods for improving model interpretability.

Our analysis reveals a significant improvement, with a 41% lower in mean squared error (MSE) loss between segmentation and embeddings correlation.
Furthermore, agent behavior interpretability is analyzed using tools such as decision trees. Experimental results demonstrate that the proposed methods significantly enhance both the explainability of the models and the stability of the training process.

segmentation gen

Core ideas

  • Explainability in RL: Methods for interpreting the decisions made by RL agents, particularly those using vision-based observations.
  • PPO with Segmentation Regularization: A modified PPO algorithm that incorporates segmentation-based regularization to improve model interpretability.
  • Vision Transformers (ViTs): The use of ViTs as feature extractors for RL agents, leveraging their attention mechanisms for explainability.
  • Comparison with CNNs: A comparative study between Convolutional Neural Networks (CNNs) and ViTs in terms of performance and interpretability.
  • Explaining behavior with decision trees: Using generated and explicit segmentation mask I built small and interpretable decision trees.

Results

Using enviroments from procgen benchmark we trained few architecures using modifed PPO algorithm. The results are in par with native PPO and sometimes can even surpass the original. The experiments demonstrate the effectiveness of the proposed methods:

  • Improved Explainability: The use of segmentation-based regularization and attention mechanisms significantly enhances the interpretability of the agent's decisions.
  • Stability: The modified PPO algorithm have simillar stability compared to original algorithm.

rewards

Segmentation loss

After training, embeddings were used to extract human-readable features. Finetuned model has asymptoticly lower MSE error between decoded embeddings and segmented image. It proves that model has learned to idenfity objects on screen and carry this information to the output of the network, allowing for more robust explainability.

explainer

Attention Rollout

attention

Reproduction

All logs, scripts and some models required to run experiments are in this repository.

Install dependecies

You should create conda enviroment from enviroment.yml file. To compile and run custom procgen env follow the instructions in original repo of the benchmark.

Run experiments

You can run any script from src folder. Scripts starting with train_* are training scripts. Example usage

python train_network_baseline.py
python train_network_baseline.py --resume ./path/to/resumed/model

Explainability scripts

Scripts that are analizing the inner workings of the model are in /explain/ folder together with resulting gifs and plots.

Track training

tensorboard --logdir ./logs

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published