Skip to content

Commit

Permalink
[Op][O2c] Creation operators (apache#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 authored Jan 7, 2023
1 parent 5efc8f7 commit 80808fb
Show file tree
Hide file tree
Showing 11 changed files with 1,461 additions and 17 deletions.
54 changes: 54 additions & 0 deletions include/tvm/relax/attrs/create.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/create.h
* \brief Attributes for tensor creation operators.
*/
#ifndef TVM_RELAX_ATTRS_CREATE_H_
#define TVM_RELAX_ATTRS_CREATE_H_

#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

/*! \brief Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operators */
struct InitAttrs : public tvm::AttrsNode<InitAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(InitAttrs, "relax.attrs.InitAttrs") {
TVM_ATTR_FIELD(dtype).describe("The data type of the created tensor.");
}
}; // struct InitAttrs

/*! \brief Attributes used in tril and triu operator */
struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> {
int k;

TVM_DECLARE_ATTRS(TriluAttrs, "relax.attrs.TriluAttrs") {
TVM_ATTR_FIELD(k).describe(
"The number of diagonals above or below the main diagonal to exclude or include.");
}
}; // struct TriluAttrs

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_CREATE_H_
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# Operators
from .base import *
from .binary import *
from .create import *
from .tensor import *
from .op_attrs import *
from .statistical import *
Expand Down
31 changes: 16 additions & 15 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from tvm.runtime.object import Object

from . import _ffi_api
from ..expr import Expr, ShapeExpr, Tuple, Call, ExternFunc
from ..expr import Expr, ShapeExpr, Call, ExternFunc
from ..expr import Tuple as RxTuple
from ..ty import DynTensorType, TupleType
from ...ir import Array, Type, PrimExpr

Expand All @@ -43,8 +44,8 @@ def null_value() -> Call:

def call_tir(
func: Union[str, Expr],
args: Union[Expr, Tuple, List[Expr]],
shape: Union[Tuple, ShapeExpr, List[int]],
args: Union[Expr, List[Expr]],
shape: Union[RxTuple, ShapeExpr, List[int]],
dtype: Union[str, List[str]],
tir_vars: Optional[ShapeExpr] = None,
) -> Call:
Expand All @@ -56,10 +57,10 @@ def call_tir(
func : Union[str, Expr]
The destination-passing-style function, can be ExternFunc or PrimFunc.
args : Union[Expr, Tuple, List[Expr]]
args : Union[Expr, List[Expr]]
The input arguments.
shape: Union[Tuple, ShapeExpr, List[int]]
shape: Union[RxTuple, ShapeExpr, List[int]]
The output shape. Tuple(ShapeExpr) if multiple outputs, ShapeExpr if single output.
dtype: Union[str, List[str]]
Expand Down Expand Up @@ -95,7 +96,7 @@ def _create_shape(shape: List[Union[int, PrimExpr]]) -> ShapeExpr:
if all([not isinstance(x, (list, tuple, Array, ShapeExpr)) for x in shape]):
shape = _create_shape(shape) # type: ignore
elif all([isinstance(x, (list, tuple, Array, ShapeExpr)) for x in shape]):
shape = Tuple(
shape = RxTuple(
[
_create_shape(x) if not isinstance(x, ShapeExpr) else x # type: ignore
for x in shape
Expand All @@ -107,10 +108,10 @@ def _create_shape(shape: List[Union[int, PrimExpr]]) -> ShapeExpr:
)

if isinstance(args, Expr): # type: ignore
args = Tuple((args,))
args = RxTuple((args,))

if isinstance(args, (list, tuple)):
args = Tuple(args)
args = RxTuple(args)

if isinstance(dtype, str):
output_type = DynTensorType(len(shape), dtype)
Expand All @@ -126,7 +127,7 @@ def _create_shape(shape: List[Union[int, PrimExpr]]) -> ShapeExpr:

def call_builtin(
func: Union[str, Expr],
args: Union[Tuple, List[Expr]],
args: Union[RxTuple, List[Expr]],
*,
type_args: Optional[List[Type]] = None,
int_args: Optional[List[int]] = None,
Expand All @@ -141,7 +142,7 @@ def call_builtin(
func : Expr
The builtin function to be called.
args : Union[Tuple, List[Expr]]
args : Union[RxTuple, List[Expr]]
The input arguments.
type_args: Optional[List[Type]]
Expand All @@ -168,7 +169,7 @@ def call_builtin(
func = ExternFunc(func)

if isinstance(args, (list, tuple)):
args = Tuple(args)
args = RxTuple(args)

return _ffi_api.call_builtin( # type: ignore
func, args, type_args, int_args, dtype_arg, str_args, require_ctx # type: ignore
Expand All @@ -177,7 +178,7 @@ def call_builtin(

def make_closure(
func: Expr,
args: Union[Tuple, List[Expr]],
args: Union[RxTuple, List[Expr]],
) -> Object:
"""
Create a closure with free variables and return the closure.
Expand All @@ -198,14 +199,14 @@ def make_closure(
"""

if isinstance(args, (list, tuple)):
args = Tuple(args)
args = RxTuple(args)

return _ffi_api.make_closure(func, args) # type: ignore


def invoke_closure(
closure: Expr,
args: Union[Tuple, List[Expr]],
args: Union[RxTuple, List[Expr]],
type_args: Union[List[Type], Type],
) -> Object:
"""
Expand All @@ -229,7 +230,7 @@ def invoke_closure(
"""

if isinstance(args, (list, tuple)):
args = Tuple(args)
args = RxTuple(args)
if not isinstance(type_args, (list, tuple)):
type_args = (type_args,)

Expand Down
209 changes: 209 additions & 0 deletions python/tvm/relax/op/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Creation operators."""
from typing import Optional, Tuple, Union

from tvm import DataType
from tvm.ir.expr import PrimExpr

from . import _ffi_api
from ..expr import Expr, ShapeExpr

PrimExprLike = Union[int, PrimExpr]


def full(
shape: Union[Tuple[PrimExprLike], Expr],
fill_value: Expr,
dtype: Optional[Union[str, DataType]] = None,
) -> Expr:
"""Fill array with scalar value.
Parameters
----------
shape : Union[Tuple[PrimExprLike], Expr]
The shape of the created tensor.
fill_value : relax.Expr
The value to fill. Must be a scalar tensor.
dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of fill_value.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.full(shape, fill_value, dtype) # type: ignore


def full_like(x: Expr, fill_value: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr:
"""Construct a tensor such that
- its shape is the same as the input data tensor's shape,
- its value is filled with the input scalar fill value.
Parameters
----------
x : relax.Expr
The input tensor, which provides the shape, and dtype
when the `dtype` field is not specified.
fill_value : relax.Expr
The value to fill. Must be a scalar tensor.
dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of the input tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.full_like(x, fill_value, dtype) # type: ignore


def ones(shape: Union[Tuple[PrimExprLike], Expr], dtype: Union[str, DataType]) -> Expr:
"""Construct a tensor of all ones, with the input shape and dtype.
Parameters
----------
shape : Union[Tuple[PrimExprLike], Expr]
The shape of the created tensor.
dtype : Union[str, DataType]
The data type of the created tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
if isinstance(shape, (tuple, list)):
shape = ShapeExpr(shape)
return _ffi_api.ones(shape, dtype) # type: ignore


def ones_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr:
"""Construct a tensor with all ones, with shape of the input tensor shape.
Parameters
----------
x : relax.Expr
The input tensor, which provides the shape, and dtype
when the `dtype` field is not specified.
dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of the input tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.ones_like(x, dtype) # type: ignore


def zeros(shape: Union[Tuple[PrimExprLike], Expr], dtype: Union[str, DataType]) -> Expr:
"""Construct a tensor of all zeros, with the input shape and dtype.
Parameters
----------
shape : Union[Tuple[PrimExprLike], Expr]
The shape of the created tensor.
dtype : Union[str, DataType]
The data type of the created tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
if isinstance(shape, (tuple, list)):
shape = ShapeExpr(shape)
return _ffi_api.zeros(shape, dtype) # type: ignore


def zeros_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr:
"""Construct a tensor with all zeros, with shape of the input tensor shape.
Parameters
----------
x : relax.Expr
The input tensor, which provides the shape, and dtype
when the `dtype` field is not specified.
dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of the input tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.zeros_like(x, dtype) # type: ignore


def tril(x: Expr, k: int = 0) -> Expr:
"""Return the lower triangular part of a matrix or a batch of matrices.
Parameters
----------
x : relax.Expr
The tensor that tril will be applied to.
It is required to have at least two dimensions.
k : int
The index indicating the diagonal above which to zero elements.
If k = 0, the diagonal is the main diagonal.
If k < 0, the diagonal is below the main diagonal.
If k > 0, the diagonal is above the main diagonal.
Returns
-------
ret : relax.Expr
The result tensor.
"""
return _ffi_api.tril(x, k) # type: ignore


def triu(x: Expr, k: int = 0) -> Expr:
"""Return the upper triangular part of a matrix or a batch of matrices.
Parameters
----------
x : relax.Expr
The tensor that triu will be applied to.
It is required to have at least two dimensions.
k : int
The index indicating the diagonal below which to zero elements.
If k = 0, the diagonal is the main diagonal.
If k < 0, the diagonal is below the main diagonal.
If k > 0, the diagonal is above the main diagonal.
Returns
-------
ret : relax.Expr
The result tensor.
"""
return _ffi_api.triu(x, k) # type: ignore
Loading

0 comments on commit 80808fb

Please sign in to comment.