Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge latest bug fixes from release/1.3.x into main #1314

Merged
merged 38 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a6f7e75
add back sparse module
ClaudiaComito Mar 28, 2023
0ee2b38
bring back test_signal to pre-merge state
ClaudiaComito Mar 29, 2023
7677483
undo merge damage, part 2 of n
ClaudiaComito Mar 29, 2023
010eb61
undo merge damage 2 of 2(?)
ClaudiaComito Mar 29, 2023
c157b62
Merge branch 'main' into workflows/merge-release-into-main
ClaudiaComito Mar 29, 2023
b82da70
reinstate quick_start.md
ClaudiaComito Mar 29, 2023
952eac1
copy from fix/1168-update-docker-image-and-documentation-on-release13…
JuanPedroGHM Oct 31, 2023
5b37e05
corrected bug
Nov 7, 2023
2289eb3
docker scripts documentation
JuanPedroGHM Nov 14, 2023
198380c
Fix tzdata handling and merging multiple actions
bhagemeier Nov 21, 2023
e23d9d7
update pre-commit-config
ClaudiaComito Nov 22, 2023
ee39c63
Fix Pytorch release tracking workflows (#1264)
mtar Nov 22, 2023
3d3e8e8
Merge branch 'release/1.3.x' into bugs/1258-_Bug_Lasso_does_not_work_…
ClaudiaComito Nov 22, 2023
4a7b155
Merge branch 'release/1.3.x' into docker-release-update
ClaudiaComito Nov 22, 2023
ce2c3ef
Merge pull request #1257 from helmholtz-analytics/docker-release-update
mrfh92 Nov 22, 2023
e57f11e
Merge pull request #1267 from helmholtz-analytics/workflows/ci-matrix…
mrfh92 Nov 22, 2023
f1e0894
Merge pull request #1266 from helmholtz-analytics/workflows/update-pr…
mrfh92 Nov 22, 2023
a8cebca
Merge pull request #1259 from helmholtz-analytics/bugs/1258-_Bug_Lass…
mrfh92 Nov 22, 2023
94cd067
Fix `ht.diff` for 1-element-axis edge case (#1201)
mtar Nov 22, 2023
a1b0053
update version to 1.3.1 before release
ClaudiaComito Nov 23, 2023
e3af04b
revert
ClaudiaComito Nov 23, 2023
05325e2
Update version before release (#1274)
ClaudiaComito Nov 23, 2023
2de6410
Merge branch 'release/1.3.x' of github.com:helmholtz-analytics/heat i…
ClaudiaComito Nov 23, 2023
3db7af7
Update pytorch release PR workflow (#1286)
mtar Dec 6, 2023
d19f024
Pin `setup-mpi` version to 1.2.0 in CI matrix (#1313)
ClaudiaComito Dec 20, 2023
0d11791
Merge branch 'release/1.3.x' of github.com:helmholtz-analytics/heat i…
ClaudiaComito Dec 22, 2023
f077c20
Merge branch 'release/1.3.x' into workflows/merge-release-into-main
ClaudiaComito Dec 22, 2023
152239e
update version
ClaudiaComito Dec 22, 2023
656cef4
Merge branch 'main' into workflows/merge-release-into-main
ClaudiaComito Dec 22, 2023
b24a956
skip ihfftn tests for older torch versions
ClaudiaComito Dec 22, 2023
b008f53
add reason for skipping tests
ClaudiaComito Dec 22, 2023
6c3cbef
fix test skipping heuristics
ClaudiaComito Dec 22, 2023
fcd0218
raise NotImplementedError for ihfftn with torch<1.11
ClaudiaComito Dec 22, 2023
cce6057
fix check for ihfftn
ClaudiaComito Dec 22, 2023
968736e
raise error re: ihfftn support on older torch versions
ClaudiaComito Dec 22, 2023
ebadf7e
expand tests
ClaudiaComito Dec 22, 2023
1dd8d62
Apply suggestions from code review
mtar Jan 3, 2024
5bd3cb6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ jobs:
egress-policy: audit

- name: Checkout
uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0
uses: actions/checkout@v4
- name: Setup MPI
uses: mpi4py/setup-mpi@40c19a60792debf8ca403a3e6ee5f84c4e76555d # v1.2.1
uses: mpi4py/setup-mpi@v1.2.0
with:
mpi: ${{ matrix.mpi }}
- name: Use Python ${{ matrix.py-version }}
Expand Down
31 changes: 12 additions & 19 deletions .github/workflows/latest-pytorch-support.yml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
name: Support latest PyTorch

on:
push:
branches:
- 'support/new-pytorch-release-branch'
- 'support/new-pytorch-main-branch'
paths:
- '.github/pytorch-release-versions/*'
env:
working_branch: ${{ github.ref }}
workflow_call:
inputs:
working_branch:
required: true
type: string
base_branch:
required: true
type: string
permissions:
contents: write
issues: write
Expand Down Expand Up @@ -37,7 +37,7 @@ jobs:
uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0
with:
token: ${{ secrets.GITHUB_TOKEN }}
ref: '${{ env.working_branch }}'
ref: '${{ inputs.working_branch }}'
- name: Set env variables
run: |
echo "previous_pytorch=$(grep 'torch>=' setup.py | awk -F '<' '{print $2}' | tr -d '",')" >> $GITHUB_ENV
Expand All @@ -49,22 +49,15 @@ jobs:
- name: Update setup.py
run: |
sed -i '/torch>=/ s/'"${{ env.previous_pytorch }}"'/'"${{ env.setup_pytorch }}"'/g' setup.py
- name: Define base branch
run: |
if [[ ${{ github.ref }} =~ .*main.* ]]; then
echo "base_branch=main" >> $GITHUB_ENV
elif [[ ${{ github.ref }} =~ .*release.* ]]; then
echo "base_branch=release/1.2.x" >> $GITHUB_ENV
fi
- name: Create PR from branch
uses: peter-evans/create-pull-request@38e0b6e68b4c852a5500a94740f0e535e0d7ba54 # v4.2.4
with:
base: ${{ env.base_branch }}
branch: ${{ env.working_branch }}
base: ${{ inputs.base_branch }}
branch: ${{ inputs.working_branch }}
delete-branch: true
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Support latest PyTorch release
title: Support PyTorch ${{ env.new_pytorch }} on branch ${{ env.base_branch }}
title: Support PyTorch ${{ env.new_pytorch }} on branch ${{ inputs.base_branch }}
body: |
Run tests on latest PyTorch release
Issue/s resolved: #${{ steps.create-issue.outputs.number }}
Expand Down
39 changes: 0 additions & 39 deletions .github/workflows/pytorch-latest-main.yml

This file was deleted.

12 changes: 10 additions & 2 deletions .github/workflows/pytorch-latest-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ name: Get latest PyTorch version release branch
on:
workflow_dispatch:
env:
working_branch: support/new-pytorch-release-branch
base_branch: release/1.2.x
working_branch: support/new-pytorch-${{ github.ref_name }}
base_branch: ${{ github.ref_name }}
permissions:
contents: write
issues: write
pull-requests: write
jobs:
get-version:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -37,3 +39,9 @@ jobs:
git config --global user.email 'c.comito@fz-juelich.de@users.noreply.github.com'
git commit -am "New PyTorch release ${{ env.new }}"
git push --set-upstream origin ${{ env.working_branch }}
call-workflow:
needs: get-version
uses: ./.github/workflows/latest-pytorch-support.yml
with:
working_branch: support/new-pytorch-${{ github.ref_name }}
base_branch: ${{ github.ref_name }}
25 changes: 19 additions & 6 deletions heat/fft/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,25 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray:
output_shape = list(x.shape)
shift_op = fftn_op in [torch.fft.fftshift, torch.fft.ifftshift]
inverse_real_op = fftn_op in [torch.fft.irfftn, torch.fft.irfft2]
real_to_generic_fftn_ops = {
torch.fft.rfftn: torch.fft.fftn,
torch.fft.rfft2: torch.fft.fft2,
torch.fft.ihfftn: torch.fft.ifftn,
torch.fft.ihfft2: torch.fft.ifft2,
}

torch_has_ihfftn = hasattr(torch.fft, "ihfftn")

if torch_has_ihfftn:
real_to_generic_fftn_ops = {
torch.fft.rfftn: torch.fft.fftn,
torch.fft.rfft2: torch.fft.fft2,
torch.fft.ihfftn: torch.fft.ifftn,
torch.fft.ihfft2: torch.fft.ifft2,
}
else:
mtar marked this conversation as resolved.
Show resolved Hide resolved
real_to_generic_fftn_ops = {
torch.fft.rfftn: torch.fft.fftn,
torch.fft.rfft2: torch.fft.fft2,
}
if "ihfft2" in str(fftn_op) or "ihfftn" in str(fftn_op):
raise NotImplementedError(
"n-dim inverse Hermitian FFTs not implemented for torch < 1.11.0. Please upgrade torch."
)
real_op = fftn_op in real_to_generic_fftn_ops

# sanitize kwargs
Expand Down
27 changes: 19 additions & 8 deletions heat/fft/tests/test_fft.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import numpy as np
import torch
import unittest

import heat as ht
from heat.core.tests.test_suites.basic_test import TestCase

torch_ihfftn = hasattr(torch.fft, "ihfftn")


class TestFFT(TestCase):
def test_fft_ifft(self):
Expand Down Expand Up @@ -227,17 +230,25 @@ def test_hfft_ihfft(self):

def test_hfft2_ihfft2(self):
x = ht.random.randn(10, 6, 6, dtype=ht.float64)
inv_fft = ht.fft.ihfft2(x)
reconstructed_x = ht.fft.hfft2(inv_fft, s=x.shape[-2:])
self.assertTrue(ht.allclose(reconstructed_x, x))
if torch_ihfftn:
inv_fft = ht.fft.ihfft2(x)
reconstructed_x = ht.fft.hfft2(inv_fft, s=x.shape[-2:])
self.assertTrue(ht.allclose(reconstructed_x, x))
else:
with self.assertRaises(NotImplementedError):
ht.fft.ihfft2(x)

def test_hfftn_ihfftn(self):
x = ht.random.randn(10, 6, 6, dtype=ht.float64)
inv_fft = ht.fft.ifftn(x)
reconstructed_x = ht.fft.hfftn(inv_fft, s=x.shape)
self.assertTrue(ht.allclose(reconstructed_x, x))
reconstructed_x_no_s = ht.fft.hfftn(inv_fft)
self.assertEqual(reconstructed_x_no_s.shape[-1], 2 * (inv_fft.shape[-1] - 1))
if torch_ihfftn:
inv_fft = ht.fft.ihfftn(x)
reconstructed_x = ht.fft.hfftn(inv_fft, s=x.shape)
self.assertTrue(ht.allclose(reconstructed_x, x))
reconstructed_x_no_s = ht.fft.hfftn(inv_fft)
self.assertEqual(reconstructed_x_no_s.shape[-1], 2 * (inv_fft.shape[-1] - 1))
else:
with self.assertRaises(NotImplementedError):
ht.fft.ihfftn(x)

def test_rfft_irfft(self):
# n-D distributed
Expand Down
Loading