Skip to content

Removed comma in some json files #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,7 @@ gentbl_cc_library(
"include/tfrt/gpu/kernels/gpu_ccl_ops.td",
"include/tfrt/gpu/kernels/gpu_dnn_ops.td",
"include/tfrt/gpu/kernels/gpu_driver_ops.td",
"include/tfrt/gpu/kernels/gpu_fft_ops.td",
"include/tfrt/gpu/kernels/gpu_solver_ops.td",
],
deps = [
Expand Down Expand Up @@ -880,6 +881,7 @@ tfrt_cc_library(
"lib/kernels/ccl_kernels.cc",
"lib/kernels/dnn_kernels.cc",
"lib/kernels/driver_kernels.cc",
"lib/kernels/fft_kernels.cc",
"lib/kernels/solver_kernels.cc",
],
alwayslink_static_registration_src = "lib/kernels/static_registration.cc",
Expand Down
22 changes: 22 additions & 0 deletions backends/gpu/include/tfrt/gpu/gpu_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "tfrt/gpu/wrapper/ccl_types.h"
#include "tfrt/gpu/wrapper/dnn_wrapper.h"
#include "tfrt/gpu/wrapper/driver_wrapper.h"
#include "tfrt/gpu/wrapper/fft_wrapper.h"
#include "tfrt/gpu/wrapper/solver_wrapper.h"
#include "tfrt/host_context/async_value_ref.h"
#include "tfrt/support/forward_decls.h"
Expand Down Expand Up @@ -418,6 +419,27 @@ class GpuSolverHandle {
wrapper::OwningSolverHandle handle_;
};

class GpuFftHandle {
public:
explicit GpuFftHandle(AsyncValueRef<GpuContext> context,
wrapper::OwningFftHandle handle, wrapper::FftType type);
~GpuFftHandle();

GpuFftHandle(GpuFftHandle&&) = default;
GpuFftHandle& operator=(GpuFftHandle&&) = default;

const wrapper::OwningFftHandle& operator->() const { return handle_; }
wrapper::FftHandle get() const { return handle_.get(); }

wrapper::FftType type() const { return type_; }
const AsyncValueRef<GpuContext>& context() const { return context_; }

private:
AsyncValueRef<GpuContext> context_;
wrapper::OwningFftHandle handle_;
wrapper::FftType type_;
};

template <typename T>
T* GetRawPointer(const GpuBuffer& buffer) {
return static_cast<T*>(buffer.pointer().raw());
Expand Down
91 changes: 91 additions & 0 deletions backends/gpu/include/tfrt/gpu/kernels/gpu_fft_ops.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright 2022 The TensorFlow Runtime Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//===- gpu_fft_ops.td ----------------------------------------------------===//
//
// CuFFT based CUDA operation definitions.
// We separate FFTs from the other cuda library calls in cuDNN because cuFFT
// is a separate library.
//
//===----------------------------------------------------------------------===//

#ifdef GPU_FFT_OPS
#else
#define GPU_FFT_OPS

include "tfrt/gpu/kernels/gpu_ops_base.td"

//===----------------------------------------------------------------------===//
// Fast Fourier Transform Type enum definitions.
//===----------------------------------------------------------------------===//
def GPU_FftHandleType : GPU_Type<"OwningFftHandle"> { let mnemonic = "fft.handle"; }

def GPU_FftTypeAttr : GPU_WrapperAttr<"FftType">;
def GPU_FftDirectionAttr : GPU_WrapperAttr<"FftDirection">;

def GPU_FftCreateOp : GPU_Op<"fft.create", []> {
let description = [{
tfrt_gpu.fft.create returns a tfrt_gpu.fft.handle for the FFT operation
that is described by the supplied parameters.

Example:
%fft = tfrt_gpu.fft.create %ctx, CUFFT_R2C, 1, [4, 4], [16, 4, 1], [16, 4, 1]

'dims' are the FFT dimensions, with rank one smaller than the two strides.
}];
// Note: these attributes will have to become values to support dynamic
// shapes. But it's unclear at the moment what type they would be and XLIR
// as the only user is operating on static shapes still.
let arguments = (ins GPU_ContextType:$context,
GPU_FftTypeAttr:$type, I64Attr:$batch, I64ArrayAttr:$dims,
I64ArrayAttr:$in_strides, I64ArrayAttr:$out_strides);
let results = (outs GPU_FftHandleType:$handle);
let assemblyFormat = [{
$context`,` custom<Enum>($type)`,` $batch`,` $dims`,`
$in_strides`,` $out_strides attr-dict
}];
let hasVerifier = 1;
}

def GPU_FftGetWorkspaceSizeOp : GPU_Op<"fft.get_workspace_size", []> {
let description = [{
tfrt_gpu.fft.create returns workspace size in bytes required to execute
the FFT operation.
}];
let arguments = (ins GPU_FftHandleType:$handle);
let results = (outs I64:$workspace_size);
let assemblyFormat = [{$handle attr-dict}];
}

def GPU_FftExecuteOp : GPU_Op<"fft.execute", []> {
let description = [{
tfrt_gpu.fft.execute runs the FFT according to the supplied parameters.
}];
let arguments = (ins
GPU_StreamType:$stream,
GPU_FftHandleType:$handle,
GPU_BufferType:$input,
GPU_BufferType:$output,
GPU_BufferType:$workspace,
GPU_FftDirectionAttr:$direction,
TFRT_ChainType:$chain
);
let results = (outs TFRT_ChainType);
let assemblyFormat = [{
$stream`,` $handle`,` $input`,` $output`,` $workspace`,`
custom<Enum>($direction)`,` $chain attr-dict
}];
}

#endif // GPU_FFT_OPS
3 changes: 3 additions & 0 deletions backends/gpu/include/tfrt/gpu/kernels/gpu_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "tfrt/gpu/wrapper/blas_wrapper.h"
#include "tfrt/gpu/wrapper/ccl_types.h"
#include "tfrt/gpu/wrapper/dnn_wrapper.h"
#include "tfrt/gpu/wrapper/fft_wrapper.h"
#include "tfrt/gpu/wrapper/wrapper.h"
#include "tfrt/tensor/opdefs/host_tensor.h"
#include "tfrt/tensor/opdefs/tensor.h"
Expand Down Expand Up @@ -131,6 +132,8 @@ using BlasFillModeAttr = EnumAttr<wrapper::BlasFillMode>;
using BlasSideModeAttr = EnumAttr<wrapper::BlasSideMode>;
using CclDataTypeAttr = EnumAttr<wrapper::CclDataType>;
using CclReductionOpAttr = EnumAttr<wrapper::CclReductionOp>;
using FftTypeAttr = EnumAttr<wrapper::FftType>;
using FftDirectionAttr = EnumAttr<wrapper::FftDirection>;

namespace conversion {

Expand Down
1 change: 1 addition & 0 deletions backends/gpu/include/tfrt/gpu/kernels/gpu_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ include "tfrt/gpu/kernels/gpu_blas_ops.td"
include "tfrt/gpu/kernels/gpu_ccl_ops.td"
include "tfrt/gpu/kernels/gpu_dnn_ops.td"
include "tfrt/gpu/kernels/gpu_driver_ops.td"
include "tfrt/gpu/kernels/gpu_fft_ops.td"
include "tfrt/gpu/kernels/gpu_solver_ops.td"

#endif // GPU_OPS
40 changes: 4 additions & 36 deletions backends/gpu/include/tfrt/gpu/wrapper/cusolver_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,47 +32,15 @@ llvm::Expected<OwningSolverHandle> CusolverDnCreate();
llvm::Error CusolverDnDestroy(cusolverDnHandle_t handle);
llvm::Error CusolverDnSetStream(cusolverDnHandle_t handle, cudaStream_t stream);
llvm::Expected<Stream> CusolverDnGetStream(cusolverDnHandle_t handle);

// TODO(hanbinyoon): Replace with a type-punned version of CusolverDnPotrf.
llvm::Error CusolverDnPotrf(CurrentContext current, cusolverDnHandle_t handle,
cublasFillMode_t fillMode, int n, Pointer<float> A,
int heightA, Pointer<float> workspace,
int workspaceSize, Pointer<int> devInfo);
llvm::Error CusolverDnPotrf(CurrentContext current, cusolverDnHandle_t handle,
cublasFillMode_t fillMode, int n, Pointer<double> A,
int heightA, Pointer<double> workspace,
cudaDataType dataType, cublasFillMode_t fillMode, int n,
Pointer<void> A, int heightA, Pointer<void> workspace,
int workspaceSize, Pointer<int> devInfo);
llvm::Error CusolverDnPotrf(CurrentContext current, cusolverDnHandle_t handle,
cublasFillMode_t fillMode, int n,
Pointer<cuComplex> A, int heightA,
Pointer<cuComplex> workspace, int workspaceSize,
Pointer<int> devInfo);
llvm::Error CusolverDnPotrf(CurrentContext current, cusolverDnHandle_t handle,
cublasFillMode_t fillMode, int n,
Pointer<cuDoubleComplex> A, int heightA,
Pointer<cuDoubleComplex> workspace,
int workspaceSize, Pointer<int> devInfo);
llvm::Error CusolverDnPotrfBatched(CurrentContext current,
cusolverDnHandle_t handle,
cublasFillMode_t fillMode, int n,
Pointer<float *> Aarray, int heightA,
Pointer<int> devInfoArray, int batchSize);
llvm::Error CusolverDnPotrfBatched(CurrentContext current,
cusolverDnHandle_t handle,
cusolverDnHandle_t handle, cudaDataType dataType,
cublasFillMode_t fillMode, int n,
Pointer<double *> Aarray, int heightA,
Pointer<void *> Aarray, int heightA,
Pointer<int> devInfoArray, int batchSize);
llvm::Error CusolverDnPotrfBatched(CurrentContext current,
cusolverDnHandle_t handle,
cublasFillMode_t fillMode, int n,
Pointer<cuComplex *> Aarray, int heightA,
Pointer<int> devInfoArray, int batchSize);
llvm::Error CusolverDnPotrfBatched(CurrentContext current,
cusolverDnHandle_t handle,
cublasFillMode_t fillMode, int n,
Pointer<cuDoubleComplex *> Aarray,
int heightA, Pointer<int> devInfoArray,
int batchSize);
llvm::Expected<int> CusolverDnPotrfBufferSize(CurrentContext current,
cusolverDnHandle_t handle,
cublasFillMode_t fillMode, int n,
Expand Down
20 changes: 7 additions & 13 deletions backends/gpu/include/tfrt/gpu/wrapper/rocsolver_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,13 @@ llvm::Error RocsolverDestroy(rocblas_handle handle);
llvm::Error RocsolverSetStream(rocblas_handle handle, hipStream_t stream);
llvm::Expected<Stream> RocsolverGetStream(rocblas_handle handle);
llvm::Error RocsolverPotrf(CurrentContext current, rocblas_handle handle,
rocblas_fill fillMode, int n, Pointer<float> A,
int heightA, Pointer<int> devInfo);
llvm::Error RocsolverPotrf(CurrentContext current, rocblas_handle handle,
rocblas_fill fillMode, int n, Pointer<double> A,
int heightA, Pointer<int> devInfo);
llvm::Error RocsolverPotrf(CurrentContext current, rocblas_handle handle,
rocblas_fill fillMode, int n,
Pointer<rocblas_float_complex> A, int heightA,
Pointer<int> devInfo);
llvm::Error RocsolverPotrf(CurrentContext current, rocblas_handle handle,
rocblas_fill fillMode, int n,
Pointer<rocblas_double_complex> A, int heightA,
Pointer<int> devInfo);
rocblas_datatype dataType, rocblas_fill fillMode, int n,
Pointer<void> A, int heightA, Pointer<int> devInfo);
llvm::Error RocsolverPotrfBatched(CurrentContext current, rocblas_handle handle,
rocblas_datatype dataType, rocblas_fill fillMode,
int n, Pointer<void *> Aarray,
int heightA, Pointer<int> devInfo,
int batchSize);

} // namespace wrapper
} // namespace gpu
Expand Down
10 changes: 10 additions & 0 deletions backends/gpu/include/tfrt/gpu/wrapper/solver_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <memory>

#include "tfrt/gpu/wrapper/wrapper.h"
#include "tfrt/gpu/wrapper/blas_wrapper.h"

namespace tfrt {
namespace gpu {
Expand All @@ -46,6 +47,15 @@ llvm::Expected<OwningSolverHandle> SolverCreate(Platform platform);
llvm::Error SolverDestroy(SolverHandle handle);
llvm::Error SolverSetStream(SolverHandle handle, Stream stream);
llvm::Expected<Stream> SolverGetStream(SolverHandle handle);
llvm::Error SolverPotrf(CurrentContext current, SolverHandle handle,
BlasDataType dataType, BlasFillMode fillMode, int n,
Pointer<void> buffer, int stride,
Pointer<void> workspace, int workspaceSize,
Pointer<int> devInfo);
llvm::Error SolverPotrfBatched(CurrentContext current, SolverHandle handle,
BlasDataType dataType, BlasFillMode fillMode, int n,
Pointer<void*> Aarray, int heightA,
Pointer<int> devInfoArray, int batchSize);

} // namespace wrapper
} // namespace gpu
Expand Down
7 changes: 7 additions & 0 deletions backends/gpu/lib/gpu_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,5 +381,12 @@ GpuSolverHandle::GpuSolverHandle(AsyncValueRef<GpuContext> context,

GpuSolverHandle::~GpuSolverHandle() = default;

GpuFftHandle::GpuFftHandle(AsyncValueRef<GpuContext> context,
wrapper::OwningFftHandle handle,
wrapper::FftType type)
: context_(std::move(context)), handle_(std::move(handle)), type_(type) {}

GpuFftHandle::~GpuFftHandle() = default;

} // namespace gpu
} // namespace tfrt
110 changes: 110 additions & 0 deletions backends/gpu/lib/kernels/fft_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright 2022 The TensorFlow Runtime Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// This file implements the tfrt_gpu.fft kernel.
#include <algorithm>
#include <cstdint>
#include <iterator>
#include <utility>

#include "llvm/ADT/SmallVector.h"
#include "tfrt/gpu/gpu_types.h"
#include "tfrt/gpu/kernels/kernels_detail.h"
#include "tfrt/gpu/wrapper/fft_wrapper.h"
#include "tfrt/gpu/wrapper/wrapper.h"
#include "tfrt/host_context/attribute_utils.h"
#include "tfrt/host_context/kernel_registry.h"
#include "tfrt/support/fp16.h"

namespace tfrt {
namespace gpu {

// tfrt_gpu.fft_create_handle creates an FFT handle
static Expected<GpuFftHandle> FftCreate(
Argument<GpuContext> context,
// Needs to be sorted alphabetically by attribute name!
Attribute<int64_t> batch, ArrayAttribute<int64_t> dims,
ArrayAttribute<int64_t> in_strides, ArrayAttribute<int64_t> out_strides,
Attribute<int> type) {
if (!llvm::is_sorted(in_strides.data(), std::greater<int64_t>()) ||
!llvm::is_sorted(out_strides.data(), std::greater<int64_t>())) {
return MakeStringError("Only row-major layout is supported");
}
auto get_dimensions = [](ArrayRef<int64_t> strides) {
llvm::SmallVector<int64_t, 3> dimensions(strides.size() - 1);
for (int i = 0; i < dimensions.size(); ++i) {
assert(strides[i + 1] != 0 && strides[i] % strides[i + 1] == 0);
dimensions[i] = strides[i] / strides[i + 1];
}
return dimensions;
};
llvm::SmallVector<int64_t, 4> in_dims = get_dimensions(in_strides.data());
llvm::SmallVector<int64_t, 4> out_dims = get_dimensions(out_strides.data());

if (dims.size() != in_dims.size() || dims.size() != out_dims.size())
return MakeStringError("Inconsistent dims/strides lengths");

auto current = wrapper::CtxSetCurrent(context->get());
if (!current) return current.takeError();

auto handle = wrapper::FftCreate(*current);
if (!handle) return handle.takeError();
if (auto error = wrapper::FftDisableAutoAllocation(handle->get()))
return std::move(error);

auto fft_type = wrapper::FftType::FromOpaqueValue(*type);
auto workspace_size = wrapper::FftMakePlanMany(
handle->get(), fft_type, *batch, dims.data(), in_dims,
in_strides[dims.size()], in_strides[0], out_dims,
out_strides[dims.size()], out_strides[0]);
if (!workspace_size) return workspace_size.takeError();

return GpuFftHandle(context.ValueRef(), std::move(*handle), fft_type);
}

static Expected<int64_t> FftGetWorkspaceSize(const GpuFftHandle& handle) {
return wrapper::FftGetWorkspaceSize(handle.get());
}

// tfrt_gpu.fft_exec executes the FFT plan associated with the given handle on a
// given stream.
static Error FftExecute(const GpuStream& stream, const GpuFftHandle& handle,
const GpuBuffer& input, const GpuBuffer& output,
const GpuBuffer& workspace, const Chain& chain,
Attribute<int> direction) {
auto current = wrapper::CtxSetCurrent(stream.context()->get());
if (!current) return current.takeError();

if (auto error = wrapper::FftSetStream(handle.get(), stream.get()))
return error;

if (auto error = wrapper::FftSetWorkspace(handle.get(), workspace.pointer(),
workspace.size())) {
return error;
}

return wrapper::FftExec(*current, handle.get(), input.pointer(),
output.pointer(), handle.type(),
wrapper::FftDirection::FromOpaqueValue(*direction));
}

void RegisterGpuFftKernels(KernelRegistry* kernel_reg) {
kernel_reg->AddKernel("tfrt_gpu.fft.create", TFRT_KERNEL(FftCreate));
kernel_reg->AddKernel("tfrt_gpu.fft.get_workspace_size",
TFRT_KERNEL(FftGetWorkspaceSize));
kernel_reg->AddKernel("tfrt_gpu.fft.execute", TFRT_KERNEL(FftExecute));
}

} // namespace gpu
} // namespace tfrt
Loading