Skip to content

Commit

Permalink
Migrate to pathlib (#537)
Browse files Browse the repository at this point in the history
* Initial mypy configuration

* Initial change to get the PR up

* Initial review at replacing os.path

* Bug fixes from tests

* Fix types: test_envs.py

* Fix types: conftest.py

* Fix types: tests/util

* Fix types: tests/scripts

* Fix types: tests/rewards

* Fix types: tests/policies

* Incorrect decorator in update_stats method form networks.py::BaseNorm

* Fix types: tests/algorithms (adersarial and bc)

* Fix types: tests/algorithms (dagger and pc)

* Fix types: tests/data

* Linting

* Linting

* Fix types: algorithms/preference_comparisons.py

* Fix types: algorithms/mce_irl.py

* Formatting, fixed minor bug

* Clarify why types are ignored

* Started fixing types on algorithms/density.py

* Linting

* Linting (add back type ignore after reformatting)

* Fixed types: imitation/data/types.py

* Fixed types (started): imitation/data/

* Fixed types: imitation/data/buffer.py

* Fixed bug in buffer.py

* Fixed types: imitation/data/rollout.py

* Fixed types: imitation/data/wrappers.py

* Improve makefile to support automatic cache cleaning

* Fixed types: imitation/testing/

* Linting, fixed wrong return type in rewards.predict_processed_all

* Fixed types: imitation/policies/

* Formatting

* Fixed types: imitation/rewards/

* Fixed types: imitation/rewards/

* Fixed types: imitation/scripts/

* Fixed types: imitation/util/ and formatting

* Linting and formatting

* Bug fixes for test errors

* Linting and typing

* Improve typing in algorithms

* Formatting

* Bug fix

* Formatting

* Fixes suggested by Adam.

* Fix mypy version.

* Fix bugs

* Remove unused imports

* Formatting

* Added parse_path func and refactored code to use it

* Fix typing, linting

* Update TabularPolicy.predict to match base class

* Fix not checking for dones

* Change for loop to dict comprehension

* Remove is_ensemble to clear up type checking errors

* Reduce code duplication and general cleanup

* Fix type annotation of step_dict

* Change List to Sequence

* Fix density.py::DensityAlgorithm._set_demo_from_batch

* Fixed n_steps (OnPolicyAlgorithm)

* Fix errors in tests

* Include some suggestions into rollout.py and preference_comparisons.py

* Formatting

* Fix setter error as per python/mypy#5936

* add reason for assertion.

* Fix style guide violation: https://google.github.io/styleguide/pyguide.html#22-imports

* Update src/imitation/scripts/parallel.py

Co-authored-by: Adam Gleave <adam@gleave.me>

* Move kwargs to the end.

* Swap order of expert_policy_type and expert_policy_path validation check

* Update src/imitation/util/util.py

Co-authored-by: Adam Gleave <adam@gleave.me>

* Update tests/rewards/test_reward_fn.py

Co-authored-by: Adam Gleave <adam@gleave.me>

* Explicit random state setting and fix corresponding tests (except notebooks, sacred config, scripts)

* Fix notebooks; add script to clean notebooks

* Fix all tests.

* Formattting.

* Additional fixes

* Linting

* Remove automatically generated `_api` docs files too on `make clean`

* Fix docstrings.

* Fix issue with next(iter(iterable))

* Formatting

* Remove whitespace

* Add TODO message to remove type ignore later

* Remove unnecessary assertion.

* Fixed types in density.py set_demonstrations

* Added type ignore to pytype bug

* Fix_get_first_iter_element and add tests

* Bugfix in BC and tests -- masked as previously iterator ran out too early!

* Remove makefile for now

* Added link to SB3 issue for future reference.

* Fix types of train_imitation
Only return "expert_stats" if all trajectories have reward.

* Modify assert in test_bc to reflect correct type

* Add ci/clean_notebooks.py to CI checks

* Improve clean_notebooks.py by allowing checking only mode.

* Add ipynb notebook checks to CI

* Add support for explicit files for notebook cleaning

* Clean notebooks

* Small improvements in util.py

* Replace TransitionKind with TransitionsMinimal

* Delete unused statement in test

* Update src/imitation/util/util.py

Co-authored-by: Adam Gleave <adam@gleave.me>

* Update src/imitation/util/util.py

Co-authored-by: Adam Gleave <adam@gleave.me>

* Make type ignore specific to pytype

* Linting

* Migrate from RandomState (deprecated) to Generator

* Add backticks to error message

* Create "AnyNorm" alias

* Small fix

* Add additional checks to shapes in _set_demo_from_batch

* Fix RolloutStatsComputer type

* Improved logging/messages in clean_notebooks.py

* Fix issues resulting from merge

* Bug fix

* Bug fix (wasn't really fixed before)

* Fixed docs example of BC

* Fix bugs resulting from merge

* Fix docs (dagger.rst) caught by sphinx CI

* Add mypy to CI

* Continue fixing miscellaneous type errors

* Linting

* Fix issue with normalize_input_layer type

* Add support for checking presence of generic type ignores

* Allow subdirectories in notebook clean

* Add full typing support for TransitionsMinimal as a sequence

* Fix types for density.py

* Misc fixes

* Add support for prefix context manager in logger (from #529)

* Added back accidentally removed code

* Replaced preference comparisons prefix with ctx manager

* Fixed errors

* Bug fixes

* Docstring fixes

* Fix bug in serialize.py

* Fixed codecheck by pointing notebook checks to docs

* Add rng to mce_irl.rst (doctest)

* Add rng to density.rst (doctest)

* Fix remaining rst files

* Increase sample size to reduce flakiness

* Ignore files not passing mypy for now

* Comment in wrong line

* Comment in wrong line

* Move excluded files to argument

* Add quotes to mypy arg call

* Fix CI mypy call

* Fix CI yaml

* Break ignored files up into one line each

* Address PR comments

* Point SB3 to master to include bug fix

* Small bug fixes

* Small bug fixes

* Sort import

* Linting

* Do not follow imports for ignored files

* Fix tests for context managers

* Format / fix tests for context manager

* Switch to sb3 1.6.1

* Formatting

* Upgrade Python version in Windows CI

* Remove unused import

* Remove unused fixture

* Add coveragerc file

* Add utils test

* Add tests and asserts

* Add test to synthetic gatherer

* Add trajectory unwrap tests

* Formatting

* Remove bracket typo

* Fix .coveragerc instruction

* Improve density algo coverage and bug fixes

* Fix bug in test

* Add pragma no cover updates

* Minor coverage tweaks

* Fix iterator test

* Add test for parse_path

* Updates on sacred util

* Mark type ignore rule

* Mark type ignore rule

* Miscellaneous bug fixes and improvements

* Reformat hanging line

* Ignore parse path checks for windows

* Add trailing comma

* Minor changes

* No newline end of file

* Update src/imitation/data/types.py

Co-authored-by: Adam Gleave <adam@gleave.me>

* Update src/imitation/data/types.py

Co-authored-by: Adam Gleave <adam@gleave.me>

* Include suggestions from Adam

Co-authored-by: Adam Gleave <adam@gleave.me>
  • Loading branch information
Rocamonde and AdamGleave authored Oct 11, 2022
1 parent 70c8cee commit 531fa06
Show file tree
Hide file tree
Showing 33 changed files with 316 additions and 210 deletions.
1 change: 0 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ executors:
(?x)(
src/imitation/algorithms/preference_comparisons.py$
| src/imitation/rewards/reward_nets.py$
| src/imitation/util/sacred.py$
| src/imitation/algorithms/base.py$
| src/imitation/scripts/train_preference_comparisons.py$
| src/imitation/rewards/serialize.py$
Expand Down
1 change: 0 additions & 1 deletion ci/code_checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ SRC_FILES=(src/ tests/ experiments/ examples/ docs/conf.py setup.py ci/)
EXCLUDE_MYPY="(?x)(
src/imitation/algorithms/preference_comparisons.py$
| src/imitation/rewards/reward_nets.py$
| src/imitation/util/sacred.py$
| src/imitation/algorithms/base.py$
| src/imitation/scripts/train_preference_comparisons.py$
| src/imitation/rewards/serialize.py$
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/1_train_bc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
2 changes: 1 addition & 1 deletion docs/tutorials/3_train_gail.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
2 changes: 1 addition & 1 deletion docs/tutorials/4_train_airl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
2 changes: 1 addition & 1 deletion docs/tutorials/5_train_preference_comparisons.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
2 changes: 1 addition & 1 deletion docs/tutorials/7_train_density.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
13 changes: 6 additions & 7 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import collections
import dataclasses
import logging
import os
from typing import (
Callable,
Iterable,
Expand Down Expand Up @@ -127,7 +126,7 @@ def __init__(
gen_algo: base_class.BaseAlgorithm,
reward_net: reward_nets.RewardNet,
n_disc_updates_per_round: int = 2,
log_dir: str = "output/",
log_dir: types.AnyPath = "output/",
disc_opt_cls: Type[th.optim.Optimizer] = th.optim.Adam,
disc_opt_kwargs: Optional[Mapping] = None,
gen_train_timesteps: Optional[int] = None,
Expand Down Expand Up @@ -202,7 +201,7 @@ def __init__(
self.venv = venv
self.gen_algo = gen_algo
self._reward_net = reward_net.to(gen_algo.device)
self._log_dir = log_dir
self._log_dir = types.parse_path(log_dir)

# Create graph for optimising/recording stats on discriminator
self._disc_opt_cls = disc_opt_cls
Expand All @@ -215,10 +214,10 @@ def __init__(
)

if self._init_tensorboard:
logging.info("building summary directory at " + self._log_dir)
summary_dir = os.path.join(self._log_dir, "summary")
os.makedirs(summary_dir, exist_ok=True)
self._summary_writer = thboard.SummaryWriter(summary_dir)
logging.info(f"building summary directory at {self._log_dir}")
summary_dir = self._log_dir / "summary"
summary_dir.mkdir(parents=True, exist_ok=True)
self._summary_writer = thboard.SummaryWriter(str(summary_dir))

self.venv_buffering = wrappers.BufferingWrapper(self.venv)

Expand Down
2 changes: 1 addition & 1 deletion src/imitation/algorithms/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,4 +472,4 @@ def save_policy(self, policy_path: types.AnyPath) -> None:
Args:
policy_path: path to save policy to.
"""
th.save(self.policy, types.path_to_str(policy_path))
th.save(self.policy, types.parse_path(policy_path))
22 changes: 8 additions & 14 deletions src/imitation/algorithms/dagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,8 @@ def reconstruct_trainer(
A deserialized `DAggerTrainer`.
"""
custom_logger = custom_logger or imit_logger.configure()
checkpoint_path = pathlib.Path(
types.path_to_str(scratch_dir),
"checkpoint-latest.pt",
)
scratch_dir = types.parse_path(scratch_dir)
checkpoint_path = scratch_dir / "checkpoint-latest.pt"
trainer = th.load(checkpoint_path, map_location=utils.get_device(device))
trainer.venv = venv
trainer._logger = custom_logger
Expand All @@ -109,14 +107,14 @@ def _save_dagger_demo(
# however that NPZ save here is likely more space efficient than
# pickle from types.save(), and types.save only accepts
# TrajectoryWithRew right now (subclass of Trajectory).
save_dir_obj = pathlib.Path(types.path_to_str(save_dir))
save_dir = types.parse_path(save_dir)
assert isinstance(trajectory, types.Trajectory)
actual_prefix = f"{prefix}-" if prefix else ""
timestamp = util.make_unique_timestamp()
filename = f"{actual_prefix}dagger-demo-{timestamp}.npz"

save_dir_obj.mkdir(parents=True, exist_ok=True)
npz_path = save_dir_obj / filename
save_dir.mkdir(parents=True, exist_ok=True)
npz_path = save_dir / filename
np.savez_compressed(npz_path, **dataclasses.asdict(trajectory))
logging.info(f"Saved demo at '{npz_path}'")

Expand Down Expand Up @@ -344,7 +342,7 @@ def __init__(
if beta_schedule is None:
beta_schedule = LinearBetaSchedule(15)
self.beta_schedule = beta_schedule
self.scratch_dir = pathlib.Path(types.path_to_str(scratch_dir))
self.scratch_dir = types.parse_path(scratch_dir)
self.venv = venv
self.round_num = 0
self._last_loaded_round = -1
Expand Down Expand Up @@ -397,11 +395,7 @@ def _load_all_demos(self):
return demo_transitions, num_demos_by_round

def _get_demo_paths(self, round_dir):
return [
os.path.join(round_dir, p)
for p in os.listdir(round_dir)
if p.endswith(".npz")
]
return [round_dir / p for p in os.listdir(round_dir) if p.endswith(".npz")]

def _demo_dir_path_for_round(self, round_num: Optional[int] = None) -> pathlib.Path:
if round_num is None:
Expand All @@ -411,7 +405,7 @@ def _demo_dir_path_for_round(self, round_num: Optional[int] = None) -> pathlib.P
def _try_load_demos(self) -> None:
"""Load the dataset for this round into self.bc_trainer as a DataLoader."""
demo_dir = self._demo_dir_path_for_round()
demo_paths = self._get_demo_paths(demo_dir) if os.path.isdir(demo_dir) else []
demo_paths = self._get_demo_paths(demo_dir) if demo_dir.is_dir() else []
if len(demo_paths) == 0:
raise NeedsDemosException(
f"No demos found for round {self.round_num} in dir '{demo_dir}'. "
Expand Down
4 changes: 1 addition & 3 deletions src/imitation/algorithms/mce_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,7 @@ def set_pi(self, pi: np.ndarray) -> None:
self.pi = pi

def _predict(self, observation: th.Tensor, deterministic: bool = False):
raise NotImplementedError(
"Should never be called as predict overridden.",
)
raise NotImplementedError("Should never be called as predict overridden.")

def forward( # type: ignore[override]
self,
Expand Down
88 changes: 82 additions & 6 deletions src/imitation/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,87 @@ def dataclass_quick_asdict(obj) -> Dict[str, Any]:
return d


def path_to_str(path: AnyPath) -> str:
if isinstance(path, bytes):
return path.decode()
def parse_path(
path: AnyPath,
allow_relative: bool = True,
base_directory: Optional[pathlib.Path] = None,
) -> pathlib.Path:
"""Parse a path to a `pathlib.Path` object.
All resulting paths are resolved, absolute paths. If `allow_relative` is True,
then relative paths are allowed as input, and are resolved relative to the
current working directory, or relative to `base_directory` if it is
specified.
Args:
path: The path to parse. Can be a string, bytes, or `os.PathLike`.
allow_relative: If True, then relative paths are allowed as input, and
are resolved relative to the current working directory. If False,
an error is raised if the path is not absolute.
base_directory: If specified, then relative paths are resolved relative
to this directory, instead of the current working directory.
Returns:
A `pathlib.Path` object.
Raises:
ValueError: If `allow_relative` is False and the path is not absolute.
ValueError: If `base_directory` is specified and `allow_relative` is
False.
"""
if base_directory is not None and not allow_relative:
raise ValueError(
"If `base_directory` is specified, then `allow_relative` must be True.",
)

parsed_path: pathlib.Path
if isinstance(path, pathlib.Path):
parsed_path = path
elif isinstance(path, str):
parsed_path = pathlib.Path(path)
elif isinstance(path, bytes):
parsed_path = pathlib.Path(path.decode())
else:
parsed_path = pathlib.Path(str(path))

if parsed_path.is_absolute():
return parsed_path
else:
if allow_relative:
base_directory = base_directory or pathlib.Path.cwd()
# relative to current working directory
return base_directory / parsed_path
else:
raise ValueError(f"Path {str(parsed_path)} is not absolute")


def parse_optional_path(
path: Optional[AnyPath],
allow_relative: bool = True,
base_directory: Optional[pathlib.Path] = None,
) -> Optional[pathlib.Path]:
"""Parse an optional path to a `pathlib.Path` object.
All resulting paths are resolved, absolute paths. If `allow_relative` is True,
then relative paths are allowed as input, and are resolved relative to the
current working directory, or relative to `base_directory` if it is
specified.
Args:
path: The path to parse. Can be a string, bytes, or `os.PathLike`.
allow_relative: If True, then relative paths are allowed as input, and
are resolved relative to the current working directory. If False,
an error is raised if the path is not absolute.
base_directory: If specified, then relative paths are resolved relative
to this directory, instead of the current working directory.
Returns:
A `pathlib.Path` object, or None if `path` is None.
"""
if path is None:
return None
else:
return str(path)
return parse_path(path, allow_relative, base_directory)


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -417,10 +493,10 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]):
trajectories: The trajectories to save.
Raises:
ValueError: If the trajectories are not all of the same type, i.e. some are
ValueError: If not all trajectories have the same type, i.e. some are
`Trajectory` and others are `TrajectoryWithRew`.
"""
p = pathlib.Path(path_to_str(path))
p = parse_path(path)
p.parent.mkdir(parents=True, exist_ok=True)
tmp_path = f"{p}.tmp"

Expand Down
2 changes: 1 addition & 1 deletion src/imitation/policies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _choose_action(self, obs: np.ndarray) -> np.ndarray:
def forward(self, *args):
# technically BasePolicy is a Torch module, so this needs a forward()
# method
raise NotImplementedError()
raise NotImplementedError() # pragma: no cover


class RandomPolicy(HardCodedPolicy):
Expand Down
16 changes: 8 additions & 8 deletions src/imitation/policies/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# torch.load() and torch.save() calls

import logging
import os
import pathlib
from typing import Callable, Type, TypeVar

import huggingface_sb3 as hfsb3
from stable_baselines3.common import base_class, callbacks, policies, vec_env

from imitation.data import types
from imitation.policies import base
from imitation.util import registry

Expand Down Expand Up @@ -52,7 +52,7 @@ def load_stable_baselines_model(
The deserialized RL algorithm.
"""
logging.info(f"Loading Stable Baselines policy for '{cls}' from '{path}'")
path_obj = pathlib.Path(path)
path_obj = types.parse_path(path)

if path_obj.is_dir():
path_obj = path_obj / "model.zip"
Expand Down Expand Up @@ -181,7 +181,7 @@ def load_policy(


def save_stable_model(
output_dir: str,
output_dir: pathlib.Path,
model: base_class.BaseAlgorithm,
filename: str = "model.zip",
) -> None:
Expand All @@ -197,9 +197,9 @@ def save_stable_model(
# Save each model in new directory in case we want to add metadata or other
# information in future. (E.g. we used to save `VecNormalize` statistics here,
# although that is no longer necessary.)
os.makedirs(output_dir, exist_ok=True)
model.save(os.path.join(output_dir, filename))
logging.info("Saved policy to %s", output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
model.save(output_dir / filename)
logging.info(f"Saved policy to {output_dir}")


class SavePolicyCallback(callbacks.EventCallback):
Expand All @@ -211,7 +211,7 @@ class SavePolicyCallback(callbacks.EventCallback):

def __init__(
self,
policy_dir: str,
policy_dir: pathlib.Path,
*args,
**kwargs,
):
Expand All @@ -227,6 +227,6 @@ def __init__(

def _on_step(self) -> bool:
assert self.model is not None
output_dir = os.path.join(self.policy_dir, f"{self.num_timesteps:012d}")
output_dir = self.policy_dir / f"{self.num_timesteps:012d}"
save_stable_model(output_dir, self.model)
return True
Loading

0 comments on commit 531fa06

Please sign in to comment.