Skip to content

Commit

Permalink
[MPS] Fix the crash in huberloss with Float16 (pytorch#94567)
Browse files Browse the repository at this point in the history
- Also fix FP16 correctness issues in several other ops by lowering their FP16 precision in the new list `FP16_LOW_PRECISION_LIST`.
- Add atol/rtol to the `AssertEqual()` of Gradient tests.
Pull Request resolved: pytorch#94567
Approved by: https://github.com/kulinseth
  • Loading branch information
razarmehr authored and pytorchmergebot committed Feb 10, 2023
1 parent d8f4026 commit 7c4acda
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
8 changes: 5 additions & 3 deletions aten/src/ATen/native/mps/operations/LossOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1010,12 +1010,14 @@ void smooth_l1_loss_backward_impl(

MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);

MPSDataType input_type = getMPSScalarType(input.scalar_type());
MPSGraphTensor* deltaTensor = [mpsGraph constantWithScalar:delta
shape:@[@1]
dataType:MPSDataTypeFloat32];
dataType:input_type];
MPSGraphTensor* halfTensor = [mpsGraph constantWithScalar:.5f
shape:@[@1]
dataType:MPSDataTypeFloat32];
dataType:input_type];

MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor: inputTensor
secondaryTensor: targetTensor
Expand Down Expand Up @@ -1144,7 +1146,7 @@ Tensor huber_loss_mps(const Tensor& input, const Tensor& target, int64_t reducti
name:nil];
MPSGraphTensor* deltaTensor = [mpsGraph constantWithScalar:delta
shape:getMPSShape(target)
dataType:MPSDataTypeFloat32];
dataType:getMPSDataType(target.scalar_type())];
MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
secondaryTensor:targetTensor
name:nil];
Expand Down
21 changes: 15 additions & 6 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8798,7 +8798,7 @@ class TestConsistency(TestCase):
'nn.functional.group_norm': ['f32'],
'nn.functional.hardtanh': ['f32', 'i16', 'i32', 'i64'],
'nn.functional.hinge_embedding_loss': ['f32'],
'nn.functional.huber_loss': ['f32'],
'nn.functional.huber_loss': ['f16', 'f32'],
'nn.functional.instance_norm': ['f32'],
'nn.functional.kl_div': ['f32', 'i16', 'i32', 'i64'],
'nn.functional.l1_loss': ['f16', 'f32'],
Expand Down Expand Up @@ -9030,7 +9030,7 @@ class TestConsistency(TestCase):
'nn.functional.glu': ['f32'],
'nn.functional.hardtanh': ['f32'],
'nn.functional.hinge_embedding_loss': ['f32'],
'nn.functional.huber_loss': ['f32'],
'nn.functional.huber_loss': ['f16', 'f32'],
'nn.functional.instance_norm': ['f32'],
'nn.functional.kl_div': ['f32'],
'nn.functional.l1_loss': ['f16', 'f32'],
Expand Down Expand Up @@ -9139,7 +9139,6 @@ class TestConsistency(TestCase):
'nn.functional.conv_transpose1d': [torch.int64],
'nn.functional.conv_transpose2d': [torch.int64],
'nn.functional.conv_transpose3d': [torch.int64, torch.float32],
'nn.functional.huber_loss': [torch.float16],
'nn.functional.local_response_norm': [torch.int64],
'nn.functional.padcircular': [torch.uint8],
'pow': [torch.int64],
Expand Down Expand Up @@ -9238,6 +9237,17 @@ class TestConsistency(TestCase):
'dot': [torch.int64],
}

FP16_LOW_PRECISION_LIST = {
'add', 'sub', 'div',
'__rdiv__', '__rmul__',
'nn.functional.huber_loss',
'true_divide', 'kron',
'gradient', 'var', 'std',
'linalg.vector_norm',
'masked.sum', 'masked.std',
'masked.var',
}

# Used for accept mode only
NEW_ALLOW_LIST = defaultdict(list)
NEW_ALLOW_LIST_GRAD = defaultdict(list)
Expand Down Expand Up @@ -9308,8 +9318,7 @@ def get_samples():
if op.name == "nn.functional.conv2d" and dtype == torch.float32:
atol = 1e-4
rtol = 3e-5
elif (op.name == "add" or op.name == "sub" or
op.name == "masked.sum" or op.name == "masked.std" or op.name == "masked.var") and dtype == torch.float16:
elif (op.name in self.FP16_LOW_PRECISION_LIST) and dtype == torch.float16:
atol = 1e-2
rtol = 1e-2
elif (op.name == "masked.mean"):
Expand Down Expand Up @@ -9379,7 +9388,7 @@ def req_grad(t):
cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)

self.assertEqual(cpu_grad_inputs, mps_grad_inputs)
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
except Exception as e:
if not generate_new_truth:
raise e
Expand Down

0 comments on commit 7c4acda

Please sign in to comment.