Skip to content

Latest commit

 

History

History
105 lines (101 loc) · 4.87 KB

proposal1.md

File metadata and controls

105 lines (101 loc) · 4.87 KB

This is an initial proposal for a mixed-precision tensor contraction interface:

typedef char mode_type;
typedef int64_t stride_type;
typedef int64_t extent_type;
enum error_t
{
   SUCCESS,
   INVALID_ARGUMENTS,
   INTERNAL_ERROR,
   NOT_SUPPORTED
};
enum data_type_t
{
   TYPE_FP16,
   TYPE_FP32,
   TYPE_FP64,
   TYPE_INT16,
   TYPE_INT32,
   TYPE_INT64,
   TYPE_FCOMPLEX,
   TYPE_DCOMPLEX
};
/**
 * \brief This routine computes the tensor contraction C = alpha * op(A) * op(B) + beta * op(C)
 *
 * \f[ \mathcal{C}_{\text{modes}_\mathcal{C}} \gets \alpha * op(\mathcal{A}_{\text{modes}_\mathcal{A}}) op(B_{\text{modes}_\mathcal{B}}) + \beta op(\mathcal{C}_{\text{modes}_\mathcal{C}}), \f]
 * where op(X) = X or op(X) = complex conjugate(X).
 *
 *
 * \param[in] alpha Scaling for A*B (data_type_t is determined by 'typeCompute')
 * \param[in] A Pointer to the data corresponding to A (data type is determined by 'typeA')
 * \param[in] typeA Datatype of A. This values could be TYPE_SINGLE, TYPE_DOUBLE, TYPE_COMPLEX, or TYPE_DOUBLE_COMPLEX
 * \param[in] conjA Indicates if the entries of A should be conjucated (only applies to complex types)
 * \param[in] nmodeA Number of modes of A
 * \param[in] extentA Array with 'nmodeA' values that represents the extent of A (e.g., extentA[] = {4,8,12} represents an order-3 tensor of size 4x8x12).
 * \param[in] strideA Array with 'nmodeA' values that represents the strides of
 *            A with respect to each mode. While the following inequality must be obeyed: 
 *               (strideA[i] == 0) or (strideA[i] >= s * extentA[i-1], if i > 0, where s
 *               represents the last strideA[j] that is larger than 0, with j < i) .
 *            strideA[i] == 0 indicates that this dimension will be broadcasted.
 *
 *            This argument is optional and may be NULL; in this case a compact
 *            tensor is assumed.
 * \param[in] modeA Array with 'nmodeA' values that represent the modes of A.
 * \param[in] B Pointer to the data corresponding to B (data type is determined by 'typeB')
 * \param[in] typeB Datatype of B (see typeA)
 * \param[in] conjB Indicates if the entries of B should be conjugated (only applies to complex types)
 * \param[in] nmodeB Number of modes of B
 * \param[in] extentB Array with 'nmodeB' values that represents the extent of B.
 * \param[in] strideB Array with 'nmodeB' values that represents the strides of B with respect to each mode (see strideA).
 * \param[in] beta Scaling for C (data_type_t is determined by 'typeCompute')
 * \param[in,out] C Pointer to the data corresponding to C (data type is determined by 'typeC')
 * \param[in] typeC Datatype of C (see typeA)
 * \param[in] conjC Indicates if the initial entries of C should be conjucated (only applies to complex types)
 * \param[in] nmodeC Number of modes of C
 * \param[in] extentC Array with 'nmodeC' values that represents the extent of C.
 * \param[in] strideC Array with 'nmodeC' values that represents the strides of C with respect to each mode (see strideA).
 * \param[in] typeCompute Datatype of for the intermediate computation of typeCompute T = A * B
 *
 *
 * Example:
 *
 * The tensor contraction C[a,b,c,d] = 1.3 * A[b,e,d,f] * B[f,e,a,c], 
 * where C, A, and B respectively are double-precision tensors of size E_a x E_b x E_c x E_d,
 * E_b x E_e x E_d x E_f, and E_f x E_e x E_a x E_c can be computed as follows:
 *
 * double alpha = 1.3;
 * double beta = 0.0;
 * extent_type extentC[] = {E_a, E_b, E_c, E_d};
 * extent_type extentA[] = {E_b, E_e, E_d, E_f};
 * extent_type extentB[] = {E_f, E_e, E_a, E_c};
 * stride_type strideC[] = {1, E_a, E_a*E_b, E_a*E_b*E_c}; //optional
 * stride_type strideA[] = {1, E_b, E_b*E_e, E_b*E_e*E_d}; //optional
 * stride_type strideB[] = {1, E_f, E_f*E_e, E_f*E_e*E_a}; //optional
 * mode_type modeC[] = {'a','b','c','d'};
 * mode_type modeA[] = {'b','e','d','f'};
 * mode_type modeB[] = {'f','e','a','c'};
 * int nmodeA = 4;
 * int nmodeB = 4;
 * int nmodeC = 4;
 * data_type_t typeA = TYPE_FP64;
 * data_type_t typeB = TYPE_FP64;
 * data_type_t typeC = TYPE_FP64;
 * data_type_t typeCompute = TYPE_FP64;
 *
 * error_t error = tensorMult(&alpha, A, typeA, false, nmodeA, extentA, NULL, modeA, 
 *                             B, typeB, false, nmodeB, extentB, NULL, modeB, 
 *                     &beta,  C, typeC, false, nmodeC, extentC, NULL, modeC, typeCompute);
 *
 */
error_t tensorMult(const void* alpha, const void *A, data_type_t typeA, bool conjA, int nmodeA, const extent_type *extentA, const stride_type *strideA, const mode_type* modeA,
                                      const void *B, data_type_t typeB, bool conjB, int nmodeB, const extent_type *extentB, const stride_type *strideB, const mode_type* modeB,
                   const void* beta,        void *C, data_type_t typeC, bool conjC, int nmodeC, const extent_type *extentC, const stride_type *strideC, const mode_type* modeC, data_type_t typeCompute);