Skip to content

Commit

Permalink
Rework of Sklearn Metrics (Lightning-AI#1327)
Browse files Browse the repository at this point in the history
* Create utils.py

* Create __init__.py

* redo sklearn metrics

* add some more metrics

* add sklearn metrics

* Create __init__.py

* redo sklearn metrics

* New metric classes (Lightning-AI#1326)

* Create metrics package

* Create metric.py

* Create utils.py

* Create __init__.py

* add tests for metric utils

* add docstrings for metrics utils

* add function to recursively apply other function to collection

* add tests for this function

* update test

* Update pytorch_lightning/metrics/metric.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* update metric name

* remove example docs

* fix tests

* add metric tests

* fix to tensor conversion

* fix apply to collection

* Update CHANGELOG.md

* Update pytorch_lightning/metrics/metric.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* remove tests from init

* add missing type annotations

* rename utils to convertors

* Create metrics.rst

* Update index.rst

* Update index.rst

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* add doctest example

* rename file and fix imports

* added parametrized test

* replace lambda with inlined function

* rename apply_to_collection to apply_func

* Separated class description from init args

* Apply suggestions from code review

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* adjust random values

* suppress output when seeding

* remove gpu from doctest

* Add requested changes and add ellipsis for doctest

* forgot to push these files...

* add explicit check for dtype to convert to

* fix ddp tests

* remove explicit ddp destruction

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* add sklearn metrics

* start adding sklearn tests

* fix typo

* return x and y only for curves

* fix typo

* add missing tests for sklearn funcs

* imports

* __all__

* imports

* fix sklearn arguments

* fix imports

* update requirements

* Update CHANGELOG.md

* Update test_sklearn_metrics.py

* formatting

* formatting

* format

* fix all warnings and formatting problems

* Update environment.yml

* Update requirements-extra.txt

* Update environment.yml

* Update requirements-extra.txt

* fix all warnings and formatting problems

* Update CHANGELOG.md

* docs

* inherit

* docs inherit.

* docs

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* docs

* req

* min

* Apply suggestions from code review

Co-authored-by: Tullie Murrell <tulliemurrell@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Tullie Murrell <tulliemurrell@gmail.com>
  • Loading branch information
6 people authored Jun 10, 2020
1 parent 16a7326 commit bd49b07
Show file tree
Hide file tree
Showing 17 changed files with 983 additions and 28 deletions.
10 changes: 7 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,12 @@ references:
name: Make Documentation
command: |
# First run the same pipeline as Read-The-Docs
sudo apt-get update && sudo apt-get install -y cmake
sudo pip install -r docs/requirements.txt
# apt-get update && apt-get install -y cmake
# using: https://hub.docker.com/r/readthedocs/build
# we need to use py3.7 ot higher becase of an issue with metaclass inheritence
pyenv global 3.7.3
python --version
pip install -r docs/requirements.txt
cd docs; make clean; make html --debug --jobs 2 SPHINXOPTS="-W"
test_docs: &test_docs
Expand All @@ -81,7 +85,7 @@ jobs:

Build-Docs:
docker:
- image: circleci/python:3.7
- image: readthedocs/build:latest
steps:
- checkout
- *make_docs
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ jobs:
- name: Set min. dependencies
if: matrix.requires == 'minimal'
run: |
python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)"
python -c "req = open('requirements-extra.txt').read().replace('>', '=') ; open('requirements-extra.txt', 'w').write(req)"
python -c "req = open('tests/requirements-devel.txt').read().replace('>', '=') ; open('tests/requirements-devel.txt', 'w').write(req)"
python -c "req = open('requirements.txt').read().replace('>=', '==') ; open('requirements.txt', 'w').write(req)"
python -c "req = open('requirements-extra.txt').read().replace('>=', '==') ; open('requirements-extra.txt', 'w').write(req)"
python -c "req = open('tests/requirements-devel.txt').read().replace('>=', '==') ; open('tests/requirements-devel.txt', 'w').write(req)"
# Note: This uses an internal pip API and may not always work
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
Expand Down
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [unreleased] - YYYY-MM-DD

### Added
Expand All @@ -23,7 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723))
- Allow dataloaders without sampler field present ([#1907](https://github.com/PyTorchLightning/pytorch-lightning/pull/1907))
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
'sphinx.ext.linkcode',
'sphinx.ext.autosummary',
'sphinx.ext.napoleon',
'sphinx.ext.imgmath',
'recommonmark',
'sphinx.ext.autosectionlabel',
# 'm2r',
Expand Down
4 changes: 4 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ dependencies:
- autopep8
- check-manifest
- twine==1.13.0
- pillow<7.0.0
- scipy>=0.13.3
- scikit-learn>=0.20.0


- pip:
- test-tube>=0.7.5
Expand Down
12 changes: 7 additions & 5 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from tempfile import TemporaryDirectory
from typing import Optional, Generator, Union

from torch.nn import Module

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
Expand All @@ -47,7 +49,7 @@
# --- Utility functions ---


def _make_trainable(module: torch.nn.Module) -> None:
def _make_trainable(module: Module) -> None:
"""Unfreezes a given module.
Args:
Expand All @@ -58,7 +60,7 @@ def _make_trainable(module: torch.nn.Module) -> None:
module.train()


def _recursive_freeze(module: torch.nn.Module,
def _recursive_freeze(module: Module,
train_bn: bool = True) -> None:
"""Freezes the layers of a given module.
Expand All @@ -80,7 +82,7 @@ def _recursive_freeze(module: torch.nn.Module,
_recursive_freeze(module=child, train_bn=train_bn)


def freeze(module: torch.nn.Module,
def freeze(module: Module,
n: Optional[int] = None,
train_bn: bool = True) -> None:
"""Freezes the layers up to index n (if n is not None).
Expand All @@ -101,7 +103,7 @@ def freeze(module: torch.nn.Module,
_make_trainable(module=child)


def filter_params(module: torch.nn.Module,
def filter_params(module: Module,
train_bn: bool = True) -> Generator:
"""Yields the trainable parameters of a given module.
Expand All @@ -124,7 +126,7 @@ def filter_params(module: torch.nn.Module,
yield param


def _unfreeze_and_add_param_group(module: torch.nn.Module,
def _unfreeze_and_add_param_group(module: Module,
optimizer: Optimizer,
lr: Optional[float] = None,
train_bn: bool = True):
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from typing import Dict, Union

import torch
from torch.nn import Module


class GradInformation(torch.nn.Module):
class GradInformation(Module):

def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]:
"""Compute each parameter's gradient's norm and their overall norm.
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from pytorch_lightning.utilities import move_data_to_device

Expand All @@ -14,7 +15,7 @@
APEX_AVAILABLE = True


class ModelHooks(torch.nn.Module):
class ModelHooks(Module):

# TODO: remove in v0.9.0
def on_sanity_check_start(self):
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@
"""

from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning.metrics.sklearn import (
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
5 changes: 3 additions & 2 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Union
from typing import Any, Optional

import torch
import torch.distributed
from torch.nn import Module

from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -11,7 +12,7 @@
__all__ = ['Metric', 'TensorMetric', 'NumpyMetric']


class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC):
class Metric(ABC, DeviceDtypeModuleMixin, Module):
"""
Abstract base class for metric implementation.
Expand Down
Loading

0 comments on commit bd49b07

Please sign in to comment.