Skip to content

Commit bd88fd0

Browse files
izdebyfacebook-github-bot
authored andcommitted
Added .bfloat16() (pytorch#22852)
Summary: Add conversion method for bfloat16 Pull Request resolved: pytorch#22852 Differential Revision: D16256760 Pulled By: izdeby fbshipit-source-id: 01d75495f9df513a0cdf78791c3eb013ab92bd95
1 parent 8399197 commit bd88fd0

File tree

4 files changed

+15
-2
lines changed

4 files changed

+15
-2
lines changed

docs/source/tensors.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ view of a storage and defines numeric operations on it.
149149
.. autoattribute:: device
150150
.. autoattribute:: grad
151151
.. autoattribute:: ndim
152-
.. autoattribute:: T
152+
.. autoattribute:: T
153153

154154
.. automethod:: abs
155155
.. automethod:: abs_
@@ -185,6 +185,7 @@ view of a storage and defines numeric operations on it.
185185
.. automethod:: baddbmm_
186186
.. automethod:: bernoulli
187187
.. automethod:: bernoulli_
188+
.. automethod:: bfloat16
188189
.. automethod:: bincount
189190
.. automethod:: bitwise_not
190191
.. automethod:: bitwise_not_

test/test_torch.py

+1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def test_type_conversion_via_dtype_name(self):
158158
self.assertEqual(x.float().dtype, torch.float32)
159159
self.assertEqual(x.half().dtype, torch.float16)
160160
self.assertEqual(x.int().dtype, torch.int32)
161+
self.assertEqual(x.bfloat16().dtype, torch.bfloat16)
161162

162163
def test_doc(self):
163164
checked_types = (types.MethodType, types.FunctionType,

tools/autograd/templates/python_variable_methods.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,10 @@ static PyObject * THPVariable_bool(PyObject* self, PyObject* args) {
415415
return THPVariable_to_type(self, ScalarType::Bool);
416416
}
417417

418+
static PyObject * THPVariable_bfloat16(PyObject* self, PyObject* args) {
419+
return THPVariable_to_type(self, ScalarType::BFloat16);
420+
}
421+
418422
static PyObject * THPVariable_element_size(PyObject* self, PyObject* args)
419423
{
420424
HANDLE_TH_ERRORS
@@ -731,6 +735,7 @@ PyMethodDef variable_methods[] = {
731735
{"__matmul__", (PyCFunction)THPVariable_matmul, METH_VARARGS | METH_KEYWORDS, NULL},
732736
{"_is_view", (PyCFunction)THPVariable__is_view, METH_NOARGS, NULL},
733737
{"apply_", (PyCFunction)THPVariable_apply_, METH_O, NULL},
738+
{"bfloat16", (PyCFunction)THPVariable_bfloat16, METH_NOARGS, NULL},
734739
{"byte", (PyCFunction)THPVariable_byte, METH_NOARGS, NULL},
735740
{"char", (PyCFunction)THPVariable_char, METH_NOARGS, NULL},
736741
{"contiguous", (PyCFunction)THPVariable_contiguous, METH_VARARGS | METH_KEYWORDS, NULL},

torch/_tensor_docs.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def add_docstr_all(method, docstr):
749749
tensor([[5., 0., 0.],
750750
[0., 5., 0.],
751751
[0., 0., 5.]])
752-
>>> b = torch.zeros(7, 3)
752+
>>> b = torch.zeros(7, 3)
753753
>>> b.fill_diagonal_(5)
754754
tensor([[5., 0., 0.],
755755
[0., 5., 0.],
@@ -2562,6 +2562,12 @@ def callable(a, b) -> number
25622562
``self.char()`` is equivalent to ``self.to(torch.int8)``. See :func:`to`.
25632563
""")
25642564

2565+
add_docstr_all('bfloat16',
2566+
r"""
2567+
bfloat16() -> Tensor
2568+
``self.bfloat16()`` is equivalent to ``self.to(torch.bfloat16)``. See :func:`to`.
2569+
""")
2570+
25652571
add_docstr_all('double',
25662572
r"""
25672573
double() -> Tensor

0 commit comments

Comments
 (0)