Skip to content

PyTorch feature parity #1431

Open
Open
@CarloLucibello

Description

@CarloLucibello

A list of PyTorch 1.7 features.
Items are checked if we have something more or less equivalent in Flux or in the julia ecosystem and supported by Flux.
This list is not complete, it comes from a rough scan of pytorch's documentation. Please feel free to add anything I missed in the comments, and whoever has write access to modify the list.
Related issue https://github.com/FluxML/ML-Coordination-Tracker/issues/16, and more generally anything in https://github.com/FluxML/ML-Coordination-Tracker/issues

Pytorch Features

Conv Layers

  • Conv1d, Conv2d, Conv3d.
  • ConvTranspose1d, ConvTranspose2d, ConvTranspose3d.
  • groups in convolution layers
  • Fold, Unfold. In progress: Add fold and unfold NNlib.jl#444

Pooling Layers

  • MaxPool1d, MaxPool2d, MaxPool3d
  • MaxUnPool1d, MaxUnPool2d, MaxUnPool3d
  • AvgPool1d, AvgPool2d, AvgPool3d
  • FractionalMaxPool2d
  • LPPool1d, LPPool2d
  • AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
  • AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d

Padding Layers

  • ReflectionPad (1d,2d)
  • ReplicationPad (1d,2d,3d) ( NNlib.pad_repeat)
  • ZeroPad (2d)
  • ConstantPad (1d,2d,3d)
  • Add corresponding layers for all of the aboves wrapping the NNlin functions keep as functions. Need to add them Flux's docs.

Activations

  • ... . NNlib has an extensive collection of activation, plus we have any julia function.

Normalization Layers

Recurrent Layers

  • RNN
  • GRU
  • LSTM

Attention Layers

Linear Layers

  • Identity
  • Linear
  • Bilinear

Dropout Layers

Sparse Layers

Distance Functions

  • CosineSimilarity. We have this in Distances.jl. Also easy to handcode. TODO check if AD and gpu friendly.
  • PairwiseDistance. We have this in Distances.jl TODO check if AD and gpu friendly (could use Tullio.jl to achieve both)

Loss Functions

Vision Layers

Initialization

Parallelism and Distributed

  • DataParallel
  • DistributedDataParallel(solved by https://github.com/DhairyaLGandhi/DaggerFlux.jl
  • set_num_threads, set_num_interop_threads. Not sure which operations are parallelized in pytorch. Here we have parallelization only in blas operations.

Distributions

  • diff rules for logpdf offered by DistributionsAD.jl
  • rsample. params's differentiability through sampling supported by many distr: gradient(mu -> rand(Normal(mu, 1)), 0) == (1,).

ONNX

FFT

  • ... . Zygote has the adjoints for AbstractFFTs.

Quantization

  • ...

Pruning

  • WIP pruning package here

Optim

LinAlg

  • det
  • norm

Tensorboard

XLA

Misc

Pytorch Extras

Torchvision

Torchaudio
...

Torchtext
...

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions