Skip to content

Commit

Permalink
Added support of MSELoss and some other ops like lt/ge/dot
Browse files Browse the repository at this point in the history
  • Loading branch information
artyom-beilis committed Feb 11, 2022
1 parent 2cecc0f commit 7ec2e47
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 10 deletions.
53 changes: 53 additions & 0 deletions src/loss_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,57 @@ using c10::DeviceType;
sync_if_needed(grad_output.device());
return out;
}

// {"schema": "aten::mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "True", "default"
Tensor mse_loss(const Tensor & self, const Tensor & target, int64_t reduction)
{
GUARD;
Tensor self_c = self.contiguous();
dlprim::Tensor x=todp(self_c);
Tensor target_c = target.contiguous();
dlprim::Tensor lbl=todp(target_c);
bool reduce = false;
float scale = 1;
switch(reduction) {
case 0: reduce=false; break; // None
case 1: reduce=true; scale = 1.0f/x.shape().total_size(); break; // Mean
case 2: reduce=true; break; // sum
}
Tensor output = new_tensor_as(reduce ? dlprim::Shape() : x.shape(),self_c);
dlprim::Tensor y=todp(output);
auto q = getExecutionContext(self);
dlprim::Context ctx(q);
auto op = dlprim::core::PointwiseOperationBroadcastReduce::create(ctx,
{x.specs(),lbl.specs()},{y.specs()},0,x.dtype(),
"y0 = (x0-x1)*(x0-x1);",
"reduce_y0 = 0;",
"reduce_y0 += y0;");
WSGuard wsg(op->workspace(),self.device());
op->enqueue({x,lbl},{y},wsg.ws,{},{scale},{0},q);
sync_if_needed(self.device());

return output;
}
// {"schema": "aten::mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction)
Tensor mse_loss_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction)
{
GUARD;
Tensor grad_output_c = grad_output.contiguous();
Tensor self_c = self.contiguous();
Tensor target_c = target.contiguous();
dlprim::Tensor x = todp(self_c);
dlprim::Tensor dy = todp(grad_output_c);
dlprim::Tensor lbl = todp(target_c);
Tensor result = new_tensor_as(x.shape(),self_c);
dlprim::Tensor dx = todp(result);
double scale = reduction == 1 ? (1.0f/x.shape().total_size()) : 1.0;
dlprim::core::pointwise_operation_broadcast({dy,x,lbl},{dx},{scale},
"y0 = 2*(x1 -x2) * x0 * w0;",getExecutionContext(self.device()));
sync_if_needed(self.device());
return result;
}


} // namespace dlprim

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
Expand All @@ -173,4 +224,6 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("aten::binary_cross_entropy_backward.grad_input",&ptdlprim::binary_cross_entropy_backward_out);
m.impl("aten::_log_softmax.out",&ptdlprim::_log_softmax_out);
m.impl("aten::_log_softmax_backward_data.out",&ptdlprim::_log_softmax_backward_data_out);
m.impl("aten::mse_loss",&ptdlprim::mse_loss);
m.impl("aten::mse_loss_backward",&ptdlprim::mse_loss_backward);
}
106 changes: 104 additions & 2 deletions src/pointwise_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,67 @@ using c10::DeviceType;
sync_if_needed(self.device());
return out;
}

Tensor & comp_out(const Tensor & self, const Scalar & other, Tensor & out,std::string const &op)
{
GUARD;
Tensor self_c = self.contiguous();
dlprim::Tensor x0=todp(self_c);
dlprim::Tensor y0=todp(out);
float w0 = other.toDouble();
dlprim::core::pointwise_operation_broadcast({x0},{y0},{w0},
"y0 = x0 " + op + " w0 ? 1 : 0;",
getExecutionContext(self));

sync_if_needed(self.device());
return out;
}
// {"schema": "aten::le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
Tensor & le_out(const Tensor & self, const Scalar & other, Tensor & out)
{
return comp_out(self,other,out,"<=");
}
// {"schema": "aten::ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
Tensor & ge_out(const Tensor & self, const Scalar & other, Tensor & out)
{
return comp_out(self,other,out,">=");
}

// {"schema": "aten::lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
Tensor & lt_out(const Tensor & self, const Scalar & other, Tensor & out)
{
return comp_out(self,other,out,"<");
}
// {"schema": "aten::gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
Tensor & gt_out(const Tensor & self, const Scalar & other, Tensor & out)
{
return comp_out(self,other,out,">");
}



// {"schema": "aten::neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
Tensor & neg_out(const Tensor & self, Tensor & out)
{
GUARD;
Tensor self_c = self.contiguous();
dlprim::Tensor x0=todp(self_c);
dlprim::Tensor y0=todp(out);
dlprim::core::pointwise_operation_broadcast({x0},{y0},{},"y0=-x0;",getExecutionContext(self));
sync_if_needed(self.device());
return out;
}
// {"schema": "aten::reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
Tensor & reciprocal_out(const Tensor & self, Tensor & out)
{
GUARD;
Tensor self_c = self.contiguous();
dlprim::Tensor x0=todp(self_c);
dlprim::Tensor y0=todp(out);
dlprim::core::pointwise_operation_broadcast({x0},{y0},{},"y0=1.0/x0;",getExecutionContext(self));
sync_if_needed(self.device());
return out;
}

// {"schema": "aten::sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
Tensor & sqrt_out(const Tensor & self, Tensor & out)
Expand Down Expand Up @@ -613,8 +674,11 @@ using c10::DeviceType;
Tensor & sigmoid_(Tensor & self)
{
GUARD;
dlprim::Tensor X=todp(self);
Tensor self_c = self.contiguous();
dlprim::Tensor X=todp(self_c);
dlprim::core::activation_forward(X,X,dlprim::StandardActivations::sigmoid,getExecutionContext(self));
if(!self.is_contiguous())
self.copy_(self_c);
sync_if_needed(self.device());
return self;
}
Expand Down Expand Up @@ -705,8 +769,11 @@ using c10::DeviceType;
Tensor & tanh_(Tensor & self)
{
GUARD;
dlprim::Tensor X=todp(self);
Tensor self_c = self.contiguous();
dlprim::Tensor X=todp(self_c);
dlprim::core::activation_forward(X,X,dlprim::StandardActivations::tanh,getExecutionContext(self));
if(!self.is_contiguous())
self.copy_(self_c);
sync_if_needed(self.device());
return self;
}
Expand Down Expand Up @@ -822,6 +889,33 @@ using c10::DeviceType;
return min_or_max(self,false);
}

// {"schema": "aten::dot(Tensor self, Tensor tensor) -> Tensor", "dispatch": "True", "default": "False"}
Tensor dot(const Tensor & self, const Tensor & tensor)
{
GUARD;
Tensor self_c = self.contiguous();
Tensor tensor_c = tensor.contiguous();
dlprim::Tensor x0=todp(self_c);
dlprim::Tensor x1=todp(tensor_c);
Tensor result = new_tensor_as(dlprim::Shape(),self_c);
dlprim::Tensor y=todp(result);
auto q = getExecutionContext(self);
dlprim::Context ctx(q);
auto op = dlprim::core::PointwiseOperationBroadcastReduce::create(
ctx,
{x0.specs(),x1.specs()},{y.specs()},
0,dlprim::float_data,
"y0=x0*x1;",
"reduce_y0 = 0;",
"reduce_y0 += y0;");

WSGuard wsg(op->workspace(),self.device());
op->enqueue({x0,x1},{y},wsg.ws,{},{1},{0},q);
sync_if_needed(self.device());
return result;
}


// {"schema": "aten::ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
Tensor & ne_out(const Tensor & self, const Scalar & other, Tensor & out)
{
Expand Down Expand Up @@ -962,10 +1056,18 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("aten::argmax.out",&ptdlprim::argmax_out);
m.impl("aten::ne.Scalar_out",&ptdlprim::ne_out);
m.impl("aten::eq.Tensor_out",&ptdlprim::eq_out);
m.impl("aten::le.Scalar_out",&ptdlprim::le_out);
m.impl("aten::ge.Scalar_out",&ptdlprim::ge_out);
m.impl("aten::lt.Scalar_out",&ptdlprim::lt_out);
m.impl("aten::gt.Scalar_out",&ptdlprim::gt_out);
m.impl("aten::bitwise_and.Tensor_out",&ptdlprim::bitwise_and_out);
m.impl("aten::min",&ptdlprim::min);
m.impl("aten::max",&ptdlprim::max);
m.impl("aten::clamp.out",&ptdlprim::clamp_out);
m.impl("aten::neg.out",&ptdlprim::neg_out);
m.impl("aten::reciprocal.out",&ptdlprim::reciprocal_out);
m.impl("aten::dot",&ptdlprim::dot);

}

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
Expand Down
16 changes: 8 additions & 8 deletions tests/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_fwd_bwd_op(inputs,call,device,randgen=torch.randn):
raise Exception("Diff too big")


def test_fwd_bwd(inputs,call,device,randgen=torch.randn,with_params = False):
def test_fwd_bwd(inputs,call,device,randgen=torch.randn):
xs_cpu = []
xs_dev = []
with torch.no_grad():
Expand All @@ -115,14 +115,8 @@ def test_fwd_bwd(inputs,call,device,randgen=torch.randn,with_params = False):
xs_cpu.append(x_cpu)
xs_dev.append(x_dev)

if with_params:
for p in call.state_dict():
print(p)
call_dev = call.to(device)
else:
call_dev = call
y_cpu = call(*xs_cpu)
y_dev = call_dev(*xs_dev)
y_dev = call(*xs_dev)

print(y_cpu.shape)
print(y_dev.shape)
Expand Down Expand Up @@ -219,10 +213,16 @@ def test_all(device):
test_fwd_bwd([([4,3,5],-1),([4,3,5],-1)],torch.nn.BCELoss(),device,torch.rand)
print("BCE Loss no reduction")
test_fwd_bwd([([4,3,5],-1),([4,3,5],-1)],torch.nn.BCELoss(reduction='none'),device,torch.rand)
print("MSE Loss")
test_fwd_bwd([([4,3,5],-1),([4,3,5],-1)],torch.nn.MSELoss(),device,torch.rand)
print("MSE Loss no reduction")
test_fwd_bwd([([4,3,5],-1),([4,3,5],-1)],torch.nn.MSELoss(reduction='none'),device,torch.rand)
print("Min")
test_fwd([([4,3,5],-1)],torch.min,device)
print("Max")
test_fwd([([4,3,5],-1)],torch.max,device)
print("Dot")
test_fwd([([16],-1),([16],-1)],torch.dot,device)
print("Clamp 1")
test_fwd([([4,3,5],-1)],lambda x:torch.clamp(x,min=-0.2,max=0.3),device)
print("Clamp 2")
Expand Down

0 comments on commit 7ec2e47

Please sign in to comment.