Skip to content

Commit

Permalink
Tensor construction codemod - 1/2 (pytorch#15598)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#15598

Codemod generated with clangr shard mode, 25 files per diff,
motivation: pytorch#12407

Reviewed By: dzhulgakov

Differential Revision: D13542429

fbshipit-source-id: db1059c78e85724d9b4fdab70466cf329db68359
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Jan 4, 2019
1 parent ad0ef7a commit 9e88547
Show file tree
Hide file tree
Showing 18 changed files with 68 additions and 77 deletions.
4 changes: 2 additions & 2 deletions caffe2/operators/accuracy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ template <>
bool AccuracyOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(PREDICTION);
auto& label = Input(LABEL);
auto* Y = Output(0);

CAFFE_ENFORCE_EQ(X.ndim(), 2);
int N = X.dim32(0);
int D = X.dim32(1);
CAFFE_ENFORCE_EQ(label.ndim(), 1);
CAFFE_ENFORCE_EQ(label.dim32(0), N);
Y->Resize(vector<int64_t>());
auto* Y = Output(0, vector<int64_t>(), at::dtype<float>());
float* Ydata = Y->template mutable_data<float>();
math::Set<float, CUDAContext>(1, 0, Ydata, &context_);
AccuracyKernel<<<
Expand Down
4 changes: 2 additions & 2 deletions caffe2/operators/boolean_mask_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class BooleanMaskOp<CUDAContext> final : public Operator<CUDAContext> {
auto* destData = (uint8_t*)dest->raw_mutable_data(src.meta());
const auto* srcData = (uint8_t*)src.raw_data();
if (OutputSize() == 2) {
auto* indicesOut = Output(1);
indicesOut->Resize(numOfOutput);

auto* indicesOut = Output(1, {numOfOutput}, at::dtype<int64_t>());
indicesOut->template mutable_data<int64_t>();
}

Expand Down
8 changes: 4 additions & 4 deletions caffe2/operators/channel_backprop_stats_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,16 @@ bool ChannelBackpropStatsOp<CUDAContext>::RunOnDevice() {
const int W = X.ndim() > 3 ? X.dim32(3) : 1;
const int D = X.ndim() > 4 ? X.dim32(4) : 1;

auto dScale = Output(SCALE_GRAD);
auto dBias = Output(BIAS_GRAD);



const auto Xarr = X.data<float>();
const auto dYarr = dY.data<float>();
const auto meanArr = mean.data<float>();
const auto invStddevArr = invStddev.data<float>();

dBias->Resize(C);
dScale->Resize(C);
auto dBias = Output(BIAS_GRAD, {C}, at::dtype<float>());
auto dScale = Output(SCALE_GRAD, {C}, at::dtype<float>());

const auto valsPerChannel = H * W * D;

Expand Down
8 changes: 4 additions & 4 deletions caffe2/operators/channel_stats_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ bool ChannelStatsOp<CUDAContext>::RunOnDevice() {
const int W = X.ndim() > 3 ? X.dim32(3) : 1;
const int D = X.ndim() > 4 ? X.dim32(4) : 1;

auto sum = Output(SUM);
auto sumsq = Output(SUMSQ);



const auto X_arr = X.data<float>();
const auto valsPerChannel = H * W * D;
Expand All @@ -166,8 +166,8 @@ bool ChannelStatsOp<CUDAContext>::RunOnDevice() {
sumScratch_.Resize(numBlocksTotal);
sumsqScratch_.Resize(numBlocksTotal);

sum->Resize(C);
sumsq->Resize(C);
auto sum = Output(SUM, {C}, at::dtype<float>());
auto sumsq = Output(SUMSQ, {C}, at::dtype<float>());

ChannelStatsBlockKernel<CAFFE_CUDA_NUM_THREADS>
<<<numBlocksTotal, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
Expand Down
33 changes: 15 additions & 18 deletions caffe2/operators/cross_entropy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ template <>
bool LabelCrossEntropyOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto& label = Input(1);
auto* Y = Output(0);

int N, D;
if (X.ndim() > 1) {
N = X.dim32(0);
Expand All @@ -42,7 +42,7 @@ bool LabelCrossEntropyOp<float, CUDAContext>::RunOnDevice() {
CAFFE_ENFORCE(
(label.ndim() == 1) || (label.ndim() == 2 && label.dim32(1) == 1));
CAFFE_ENFORCE_EQ(label.dim32(0), N);
Y->Resize(vector<int64_t>(size_t(1), N));
auto* Y = Output(0, vector<int64_t>(size_t(1), N), at::dtype<float>());
LabelCrossEntropyKernel<<<
CAFFE_GET_BLOCKS(N),
CAFFE_CUDA_NUM_THREADS,
Expand Down Expand Up @@ -113,11 +113,11 @@ __global__ void MakeTwoClassGradientKernel(
template <>
bool MakeTwoClassOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
auto shape = X.dims().vec();
shape.push_back(2);
CAFFE_ENFORCE_LT(X.size(), std::numeric_limits<int>::max() / 2);
Y->Resize(shape);
auto* Y = Output(0, shape, at::dtype<float>());
int N = X.size();
MakeTwoClassKernel<<<
CAFFE_GET_BLOCKS(N),
Expand All @@ -131,13 +131,13 @@ bool MakeTwoClassOp<float, CUDAContext>::RunOnDevice() {
template <>
bool MakeTwoClassGradientOp<float, CUDAContext>::RunOnDevice() {
auto& dY = Input(0);
auto* dX = Output(0);
auto shape = dY.dims().vec();
CAFFE_ENFORCE_GE(shape.size(), 1);
CAFFE_ENFORCE_EQ(shape.back(), 2);
shape.pop_back();
CAFFE_ENFORCE_LT(dY.size(), std::numeric_limits<int>::max());
dX->Resize(shape);
auto* dX = Output(0, shape, at::dtype<float>());
int N = dX->size();
MakeTwoClassGradientKernel<<<
CAFFE_GET_BLOCKS(N),
Expand Down Expand Up @@ -248,13 +248,11 @@ bool SigmoidCrossEntropyWithLogitsOp<float, CUDAContext>::RunOnDevice() {
const auto inner_size = logits.ndim() > 0 ? logits.dims().back() : 1;
const auto outer_size = logits.size() / inner_size;
auto* out = Output(0);
if (logits.ndim() == 0) {
out->Resize(std::vector<int64_t>{});
} else {
std::vector<int64_t> dims(logits.dims().begin(), logits.dims().end() - 1);
out->Resize(dims);
std::vector<int64_t> dims;
if (logits.dim() != 0) {
dims = std::vector<int64_t>(logits.dims().begin(), logits.dims().end() - 1);
}
auto* out = Output(0, dims, at::dtype<float>());
auto* out_ptr = out->template mutable_data<float>();
auto* logits_ptr = logits.data<float>();
Expand Down Expand Up @@ -370,13 +368,12 @@ bool WeightedSigmoidCrossEntropyWithLogitsOp<float, CUDAContext>::
const auto inner_size = logits.ndim() > 0 ? logits.dims().back() : 1;
const auto outer_size = logits.size() / inner_size;
auto* out = Output(0);
if (logits.ndim() == 0) {
out->Resize(std::vector<int64_t>{});
} else {
std::vector<int64_t> dims(logits.dims().begin(), logits.dims().end() - 1);
out->Resize(dims);
std::vector<int64_t> dims;
if (logits.dim() != 0) {
dims =
std::vector<int64_t>(logits.sizes().begin(), logits.sizes().end() - 1);
}
auto* out = Output(0, dims, at::dtype<float>());
auto* out_ptr = out->template mutable_data<float>();
auto* logits_ptr = logits.data<float>();
Expand Down
4 changes: 2 additions & 2 deletions caffe2/operators/deform_conv_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,8 @@ bool DeformConvGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {

T* dbias_data = nullptr;
if (!no_bias_) {
auto* dbias = Output(BIAS_OR_INPUT_GRAD);
dbias->Resize(M);

auto* dbias = Output(BIAS_OR_INPUT_GRAD, {M}, at::dtype<T>());
if (bias_multiplier_.size() != output_image_size) {
// If the helper bias multiplier is not M, reshape and fill it with one.
bias_multiplier_.Resize(vector<int64_t>(1, output_image_size));
Expand Down
4 changes: 2 additions & 2 deletions caffe2/operators/depthwise_3x3_conv_op_cudnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,8 @@ class Depthwise3x3ConvGradientOp final : public ConvPoolOpBase<CUDAContext> {
M,
dY.dim32(2),
dY.dim32(3)));
auto* dbias = Output(BIAS_OR_INPUT_GRAD);
dbias->Resize(M);
auto* dbias = Output(BIAS_OR_INPUT_GRAD, {M}, at::dtype<float>());
CUDNN_ENFORCE(cudnnConvolutionBackwardBias(
cudnn_wrapper_.inline_cudnn_handle(),
cudnnTypeWrapper<float>::kOne(),
Expand Down
12 changes: 6 additions & 6 deletions caffe2/operators/dropout_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ __global__ void DropoutKernel(
template <>
bool DropoutOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
Y->Resize(X.dims());

auto* Y = Output(0, X.dims(), at::dtype<float>());
if (is_test_) {
if (Y != &X) {
context_.CopySameDevice<float>(
Expand All @@ -34,8 +34,8 @@ bool DropoutOp<float, CUDAContext>::RunOnDevice() {
// boolean numbers, we will generate into dY and write the result to
// mask.
float* Ydata = Y->template mutable_data<float>();
auto* mask = Output(1);
mask->Resize(X.dims());

auto* mask = Output(1, X.dims(), at::dtype<bool>());
CAFFE_ENFORCE(X.data<float>() != Ydata, "In-place GPU dropout is broken");
CURAND_ENFORCE(
curandGenerateUniform(context_.curand_generator(), Ydata, X.size()));
Expand Down Expand Up @@ -69,8 +69,8 @@ __global__ void DropoutGradientKernel(
template <>
bool DropoutGradientOp<float, CUDAContext>::RunOnDevice() {
auto& dY = Input(0);
auto* dX = Output(0);
dX->Resize(dY.dims());
auto* dX = Output(0, dY.dims(), at::dtype<float>());
if (is_test_) {
if (dX != &dY) {
context_.CopySameDevice<float>(
Expand Down
4 changes: 2 additions & 2 deletions caffe2/operators/integral_image_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ __global__ void ColPassGradientKernel(
template <>
bool IntegralImageOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);

CAFFE_ENFORCE(X.ndim() == 4, "Only supports 4D tensors for the momement");

// Input is (N, C, H, W)
// Output is (N, C, H + 1, W + 1)
vector<int64_t> out_shape(X.dims().vec());
out_shape[2] += 1; // H + 1 output size
out_shape[3] += 1; // W + 1 output size
Y->Resize(out_shape);
auto* Y = Output(0, out_shape, at::dtype<float>());

const int chans = X.dim32(1);
const int rows_out = Y->dim32(2);
Expand Down
4 changes: 2 additions & 2 deletions caffe2/operators/lengths_tile_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template <>
bool LengthsTileOp<CUDAContext>::RunOnDevice() {
auto& data = Input(DATA);
auto& lengths = Input(LENGTHS);
auto* output = Output(0);


CAFFE_ENFORCE_EQ(lengths.ndim(), 1, "LENGTHS must be 1-D");
CAFFE_ENFORCE_GE(data.ndim(), 1, "DATA should be at least 1-D");
Expand All @@ -39,7 +39,7 @@ bool LengthsTileOp<CUDAContext>::RunOnDevice() {

auto shape = data.dims().vec();
shape[0] = total_length;
output->Resize(shape);
auto* output = Output(0, shape, at::dtype<float>());

auto numElementsPerRow = data.size_from_dim(1);
auto numElements = total_length * numElementsPerRow;
Expand Down
4 changes: 2 additions & 2 deletions caffe2/operators/mem_query_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class GetGPUMemoryUsageOp final : public Operator<CUDAContext> {
std::vector<long> max_by_gpu = CUDAContext::MaxMemoryByGpu();
CHECK_EQ(total_by_gpu.size(), max_by_gpu.size());

auto* stats = Output(0);
stats->Resize(2, total_by_gpu.size());

auto* stats = Output(0, {2, static_cast<int64_t>(total_by_gpu.size())}, at::dtype<long>());
context_.CopyFromCPU<long>(
total_by_gpu.size(),
total_by_gpu.data(),
Expand Down
8 changes: 4 additions & 4 deletions caffe2/operators/multi_class_accuracy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ template <>
bool MultiClassAccuracyOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(PREDICTION);
auto& label = Input(LABEL);
auto* Y0 = Output(0);
auto* Y1 = Output(1);


DCHECK_EQ(X.ndim(), 2);
// amount, number of instances
int N = X.dim32(0);
// dimension, number of classes
int D = X.dim32(1);
DCHECK_EQ(label.ndim(), 1);
DCHECK_EQ(label.dim32(0), N);
Y0->Resize(D);
Y1->Resize(D);
auto* Y0 = Output(0, {D}, at::dtype<float>());
auto* Y1 = Output(1, {D}, at::dtype<int>());

const float* Xdata = X.data<float>();
const int* labeldata = label.data<int>();
Expand Down
14 changes: 6 additions & 8 deletions caffe2/operators/pad_op_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -403,12 +403,11 @@ bool PadImageOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
template<>
bool PadImageGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
auto& dY = Input(0);
auto* dX = Output(0);
dX->Resize(
dY.dim32(0),
auto* dX = Output(0, { dY.dim32(0),
dY.dim32(1),
dY.dim32(2) - pad_t() - pad_b(),
dY.dim32(3) - pad_l() - pad_r());
dY.dim32(3) - pad_l() - pad_r()}, at::dtype<float>());
const int input_size = dY.size();
const int padded_height = dY.dim32(2);
const int padded_width = dY.dim32(3);
Expand Down Expand Up @@ -484,12 +483,11 @@ bool PadImageGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
template<>
bool PadImageGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
auto& dY = Input(0);
auto* dX = Output(0);
dX->Resize(
dY.dim32(0),
auto* dX = Output(0, { dY.dim32(0),
dY.dim32(1) - pad_t() - pad_b(),
dY.dim32(2) - pad_l() - pad_r(),
dY.dim32(3));
dY.dim32(3)}, at::dtype<float>());
const int input_size = dY.size();
const int padded_height = dY.dim32(1);
const int padded_width = dY.dim32(2);
Expand Down
8 changes: 4 additions & 4 deletions caffe2/operators/resize_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ __global__ void NearestNeighborGradientKernel(
template <>
bool ResizeNearestOp<float, CUDAContext>::RunOnDevice() {
const auto& X = Input(0);
auto* Y = Output(0);


const auto inputDims = X.dims();
CAFFE_ENFORCE_EQ(4, inputDims.size());
Expand All @@ -90,7 +90,7 @@ bool ResizeNearestOp<float, CUDAContext>::RunOnDevice() {
}
int output_width = input_width * width_scale_;
int output_height = input_height * height_scale_;
Y->Resize(batch_size, num_channels, output_height, output_width);
auto* Y = Output(0, {batch_size, num_channels, output_height, output_width}, at::dtype<float>());

const auto size = Y->size();
NearestNeighborKernel<<<
Expand All @@ -116,7 +116,7 @@ template <>
bool ResizeNearestGradientOp<float, CUDAContext>::RunOnDevice() {
const auto& dY = Input(0);
const auto& X = Input(1);
auto* dX = Output(0);
const auto inputDims = dY.dims();
CAFFE_ENFORCE_EQ(4, inputDims.size());
Expand All @@ -133,7 +133,7 @@ bool ResizeNearestGradientOp<float, CUDAContext>::RunOnDevice() {
height_scale_ = scales_data[0];
width_scale_ = scales_data[1];
}
dX->Resize(batch_size, num_channels, output_height, output_width);
auto* dX = Output(0, {batch_size, num_channels, output_height, output_width}, at::dtype<float>());
math::Set<float, CUDAContext>(
dX->size(), 0.0f, dX->template mutable_data<float>(), &context_);
Expand Down
4 changes: 2 additions & 2 deletions caffe2/operators/reverse_packed_segs_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ void ReversePackedSegsOp<CUDAContext>::DoRunWithLengthType() {
"segments, embeddings>");
CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");

auto* output = Output(0);

const auto shape = data.dims();
output->Resize(shape);
auto* output = Output(0, shape, at::dtype<T>());

const auto max_length = data.dims()[0];
const auto batch_size = data.dims()[1];
Expand Down
4 changes: 2 additions & 2 deletions caffe2/operators/rmac_regions_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ __global__ void RMACRegionsKernel(
template <>
bool RMACRegionsOp<CUDAContext>::RunOnDevice() {
const auto& X = Input(0); // Input tensor
auto* output = Output(0); // RoIs
// RoIs

if (X.size() == 0) {
return true;
Expand Down Expand Up @@ -194,7 +194,7 @@ bool RMACRegionsOp<CUDAContext>::RunOnDevice() {
int num_rois = 0;
context_.CopyBytesToCPU(sizeof(int), num_rois_.data<int>(), &num_rois);
int N = batch_size * num_rois;
output->Resize(N, 5); // [batch_id x1 y1 x2 y2]
auto* output = Output(0, {N, 5}, at::dtype<float>()); // [batch_id x1 y1 x2 y2]

// Compute region coordinates
RMACRegionsKernel<<<
Expand Down
Loading

0 comments on commit 9e88547

Please sign in to comment.