Skip to content

Commit ea1518b

Browse files
committed
Update base for Update on "Decomposition for adaptive_avg_pool2d"
This was already implemented as a lowering in pytorch/torchdynamo#962. I'm putting the idea up here (I haven't even run this code, so it surely has *many* issues, but I reckon the general idea should hopefully be alright). [ghstack-poisoned]
2 parents a417bae + 2f04ba2 commit ea1518b

File tree

160 files changed

+8196
-7749
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

160 files changed

+8196
-7749
lines changed

.github/ci_commit_pins/xla.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
fix_upstream_break_83586
1+
9b2f7929c2dae841888a836449c25b04c8cf4045

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 1 addition & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,7 @@ void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
11621162
reinterpret_cast<cuDoubleComplex*>(result)));
11631163
}
11641164

1165-
// This guards blocks use of getrsBatched, geqrfBatched, getrfBatched, getriBatched on platforms other than cuda
1165+
// This guards blocks use of getrsBatched, geqrfBatched, getrfBatched on platforms other than cuda
11661166
#ifdef CUDART_VERSION
11671167

11681168
template <>
@@ -1323,67 +1323,6 @@ void getrfBatched<c10::complex<float>>(
13231323
batchsize));
13241324
}
13251325

1326-
template <>
1327-
void getriBatched<double>(
1328-
int n, double** dA_array, int ldda, int* ipiv_array, double** dC_array, int lddc, int* info_array, int batchsize) {
1329-
auto handle = at::cuda::getCurrentCUDABlasHandle();
1330-
TORCH_CUDABLAS_CHECK(cublasDgetriBatched(
1331-
handle, n, dA_array, ldda, ipiv_array, dC_array, lddc, info_array, batchsize));
1332-
}
1333-
1334-
template <>
1335-
void getriBatched<float>(
1336-
int n, float** dA_array, int ldda, int* ipiv_array, float** dC_array, int lddc, int* info_array, int batchsize) {
1337-
auto handle = at::cuda::getCurrentCUDABlasHandle();
1338-
TORCH_CUDABLAS_CHECK(cublasSgetriBatched(
1339-
handle, n, dA_array, ldda, ipiv_array, dC_array, lddc, info_array, batchsize));
1340-
}
1341-
1342-
template <>
1343-
void getriBatched<c10::complex<double>>(
1344-
int n,
1345-
c10::complex<double>** dA_array,
1346-
int ldda,
1347-
int* ipiv_array,
1348-
c10::complex<double>** dC_array,
1349-
int lddc,
1350-
int* info_array,
1351-
int batchsize) {
1352-
auto handle = at::cuda::getCurrentCUDABlasHandle();
1353-
TORCH_CUDABLAS_CHECK(cublasZgetriBatched(
1354-
handle,
1355-
n,
1356-
reinterpret_cast<cuDoubleComplex**>(dA_array),
1357-
ldda,
1358-
ipiv_array,
1359-
reinterpret_cast<cuDoubleComplex**>(dC_array),
1360-
lddc,
1361-
info_array,
1362-
batchsize));
1363-
}
1364-
1365-
template <>
1366-
void getriBatched<c10::complex<float>>(
1367-
int n,
1368-
c10::complex<float>** dA_array,
1369-
int ldda,
1370-
int* ipiv_array,
1371-
c10::complex<float>** dC_array,
1372-
int lddc,
1373-
int* info_array,
1374-
int batchsize) {
1375-
auto handle = at::cuda::getCurrentCUDABlasHandle();
1376-
TORCH_CUDABLAS_CHECK(cublasCgetriBatched(
1377-
handle,
1378-
n,
1379-
reinterpret_cast<cuComplex**>(dA_array),
1380-
ldda,
1381-
ipiv_array,
1382-
reinterpret_cast<cuComplex**>(dC_array),
1383-
lddc,
1384-
info_array,
1385-
batchsize));
1386-
}
13871326

13881327
template <>
13891328
void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double)) {

aten/src/ATen/cuda/CUDABlas.h

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
227227
template <>
228228
void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
229229

230-
// This guards blocks use of getrsBatched, geqrfBatched, getrfBatched, getriBatched on platforms other than cuda
230+
// This guards blocks use of getrsBatched, geqrfBatched, getrfBatched on platforms other than cuda
231231
#ifdef CUDART_VERSION
232232

233233
#define CUDABLAS_GETRS_ARGTYPES(Dtype) \
@@ -287,22 +287,6 @@ TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_ARGTYPE
287287
template<>
288288
TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float>));
289289

290-
#define CUDABLAS_GETRI_ARGTYPES(Dtype) \
291-
int n, Dtype** dA_array, int ldda, int* ipiv_array, Dtype** dC_array, int lddc, int* info_array, int batchsize
292-
293-
template<class Dtype>
294-
void getriBatched(CUDABLAS_GETRI_ARGTYPES(Dtype)) {
295-
TORCH_CHECK(false, "at::cuda::blas::getriBatched: not implemented for ", typeid(Dtype).name());
296-
}
297-
template<>
298-
TORCH_CUDA_CU_API void getriBatched<float>(CUDABLAS_GETRI_ARGTYPES(float));
299-
template<>
300-
TORCH_CUDA_CU_API void getriBatched<double>(CUDABLAS_GETRI_ARGTYPES(double));
301-
template<>
302-
TORCH_CUDA_CU_API void getriBatched<c10::complex<double>>(CUDABLAS_GETRI_ARGTYPES(c10::complex<double>));
303-
template<>
304-
TORCH_CUDA_CU_API void getriBatched<c10::complex<float>>(CUDABLAS_GETRI_ARGTYPES(c10::complex<float>));
305-
306290
#define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \
307291
cublasHandle_t handle, cublasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize
308292

0 commit comments

Comments
 (0)