Skip to content

Commit

Permalink
[jit][edge] Migrate TupleType to DynamicType on mobile. (pytorch#70205)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#70205

Use DynamicType instead of TupleType all over the place in Lite Interpreter. Namely we need to modify the following places:
1. Type parser which produces the Type constants.
2. IValue::type() which returns reflected Type from IValues.
3. Helper functions to construct the container value.
4. Typechecks which test whether a type instance is a particular container type.
ghstack-source-id: 146818620

Test Plan: CI

Reviewed By: iseeyuan

Differential Revision: D33176925

fbshipit-source-id: 00f7a5db37ba772c912643c733db6c52dfdc695d
  • Loading branch information
zhxchen17 authored and facebook-github-bot committed Jan 11, 2022
1 parent 5cae40c commit 40b80aa
Show file tree
Hide file tree
Showing 13 changed files with 189 additions and 45 deletions.
54 changes: 50 additions & 4 deletions aten/src/ATen/core/dynamic_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/type_factory.h>
#include <c10/util/Exception.h>

namespace c10 {
Expand All @@ -19,13 +20,13 @@ bool contains(DynamicType::Tag lhs, DynamicType::Tag rhs) {
} // namespace

std::string DynamicType::str() const {
if (name_) {
return *name_;
}
std::string ret = "Dynamic<";
ret += std::to_string(static_cast<DynamicTypeBits>(tag_));
ret += ">";
if (tag_ == Tag::Class) {
auto name = class_->name();
ret += "[" + (name ? name->qualifiedName() : "Unknown Class") + "]";
} else if (arguments_.elems.size() > 0) {
if (tag_ != Tag::Class && arguments_.elems.size() > 0) {
ret += "[";
for (const auto& arg : arguments_.elems) {
if (arg.label) {
Expand Down Expand Up @@ -82,9 +83,21 @@ DynamicTypePtr DynamicType::create(Type& other) {
DynamicType::DynamicType(Tag tag, Arguments arguments)
: SharedType(Kind), tag_(tag), arguments_(std::move(arguments)) {}

DynamicType::DynamicType(Tag tag, c10::string_view name, Arguments arguments)
: SharedType(Kind),
tag_(tag),
name_(std::string{name}),
arguments_(std::move(arguments)) {}

DynamicType::DynamicType(const Type& other) : SharedType(DynamicType::Kind) {
auto kind = other.kind();
TORCH_INTERNAL_ASSERT(kind != Kind);
if (auto n = other.castRaw<NamedType>()) {
if (const auto& qn = n->name()) {
name_ = qn->qualifiedName();
}
}

if (auto cls = other.cast<ClassType>()) {
new (&class_) ClassTypePtr(std::move(cls));
tag_ = Tag::Class;
Expand Down Expand Up @@ -225,4 +238,37 @@ DynamicType::Ptr IValue::TagType<c10::DynamicType>::get(const c10::IValue& v) {
}
}

DynamicTypePtr ivalue::TupleTypeFactory<c10::DynamicType>::create(
std::vector<TypePtr> elemTypes) {
return DynamicTypeFactory::create<TupleType>(std::move(elemTypes));
}

DynamicTypePtr ivalue::TupleTypeFactory<c10::DynamicType>::fallback(
const Type&) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
return nullptr;
}

TORCH_API TupleTypePtr
ivalue::TupleTypeFactory<TupleType>::fallback(const Type& type) {
#ifdef C10_MOBILE
return nullptr;
#else
const auto& dyn = type.expectRef<DynamicType>();
std::vector<c10::string_view> fields;
std::vector<TypePtr> types;

for (const auto& elem : dyn.arguments().elems) {
types.emplace_back(elem.ty);
if (const auto& name = elem.label) {
fields.emplace_back(*elem.label);
}
}
if (const auto& name = dyn.name()) {
return TupleType::createNamed(*name, fields, types);
}
return TupleType::create(std::move(types));
#endif
}

} // namespace c10
19 changes: 19 additions & 0 deletions aten/src/ATen/core/dynamic_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class DynamicType : public SharedType {
public:
// TODO Change Ptr to DynamicTypePtr when all migrations are done.
using Ptr = TypePtr;
using ElementType = DynamicType;
~DynamicType() override;

struct Arguments {
Expand All @@ -129,11 +130,18 @@ class DynamicType : public SharedType {
static TORCH_API DynamicTypePtr create(Type& ty);

explicit DynamicType(Tag, Arguments);
explicit DynamicType(Tag, c10::string_view, Arguments);

TypePtr containedType(size_t) const override;
Tag tag() const {
return tag_;
}
const c10::optional<std::string>& name() const {
return name_;
}
const Arguments& arguments() const {
return arguments_;
}

private:
bool symmetric() const override {
Expand All @@ -158,6 +166,7 @@ class DynamicType : public SharedType {
}

Tag tag_;
c10::optional<std::string> name_;
union {
Arguments arguments_;
ClassTypePtr class_;
Expand All @@ -181,4 +190,14 @@ struct IValue::TagType<c10::DynamicType> {
static DynamicType::Ptr get(const c10::IValue& v);
};

namespace ivalue {

template <>
struct TORCH_API TupleTypeFactory<c10::DynamicType> {
static DynamicTypePtr create(std::vector<TypePtr> elemTypes);
static DynamicTypePtr fallback(const Type&);
};

} // namespace ivalue

} // namespace c10
31 changes: 22 additions & 9 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,18 +575,28 @@ struct TORCH_API TupleElements {
}
};

template <typename T>
struct TupleTypeFactory {};

template <>
struct TORCH_API TupleTypeFactory<TupleType> {
static TupleTypePtr create(std::vector<TypePtr> types) {
return TupleType::create(std::move(types));
}
static TupleTypePtr fallback(const Type& type);
};

struct TORCH_API Tuple : c10::intrusive_ptr_target {
private:
TupleElements elements_;
mutable std::shared_ptr<TupleType>
type_; // lazily computed for unnamed tuples
mutable c10::TypePtr type_; // lazily computed for unnamed tuples

public:
// named tuples have additional type information, so we
// directly create them tagged
static c10::intrusive_ptr<Tuple> createNamed(
std::vector<IValue> elements_,
std::shared_ptr<TupleType> type_) {
c10::TypePtr type_) {
return c10::make_intrusive<Tuple>(std::move(elements_), std::move(type_));
}

Expand Down Expand Up @@ -685,14 +695,17 @@ struct TORCH_API Tuple : c10::intrusive_ptr_target {
return elements_.size();
}

template <typename T = c10::Type>
std::shared_ptr<TupleType> type() const {
template <typename T = c10::TupleType>
std::shared_ptr<T> type() const {
if (!type_) {
type_ = TupleType::create(fmap(elements(), [&](const IValue& v) {
return v.type<T>();
type_ = TupleTypeFactory<T>::create(fmap(elements(), [&](const IValue& v) {
return v.type<typename T::ElementType>();
}));
}
return type_;
if (auto t = type_->cast<T>()) {
return t;
}
return TupleTypeFactory<T>::fallback(*type_);
}

static size_t hash(const Tuple& t) {
Expand All @@ -712,7 +725,7 @@ struct TORCH_API Tuple : c10::intrusive_ptr_target {
explicit Tuple(std::vector<IValue> elements)
: elements_(std::move(elements)){}

explicit Tuple(std::vector<IValue> elements, std::shared_ptr<TupleType> type)
explicit Tuple(std::vector<IValue> elements, c10::TypePtr type)
: elements_(std::move(elements)), type_(std::move(type)) {}

explicit Tuple(TupleElements&& elements)
Expand Down
20 changes: 20 additions & 0 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,10 @@ struct TORCH_API TupleType : public NamedType {
const std::vector<std::string>& field_names,
const std::vector<TypePtr>& field_types);

static TupleTypePtr createNamed(const c10::optional<c10::QualifiedName>& name,
const std::vector<c10::string_view>& field_names,
const std::vector<TypePtr>& field_types);

static TupleTypePtr create(
std::vector<TypePtr> types) {
return TupleTypePtr(new TupleType(
Expand Down Expand Up @@ -1059,6 +1063,13 @@ struct TORCH_API TupleType : public NamedType {
static const TypeKind Kind = TypeKind::TupleType;

private:
template <typename S>
static TupleTypePtr createWithSpec(
const c10::optional<c10::QualifiedName>& name,
const std::vector<S>& field_names,
const std::vector<TypePtr>& field_types,
std::vector<IValue>& field_defaults);

TupleType(
std::vector<TypePtr> elements_,
c10::optional<c10::QualifiedName> name,
Expand Down Expand Up @@ -2012,6 +2023,15 @@ inline typename detail::CastConstReturnType<NamedType>::type Type::cast<NamedTyp
return nullptr;
}

template<>
inline const NamedType* Type::castRaw<NamedType>() const {
if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
return static_cast<const NamedType*>(this);
}
return nullptr;
}

// Used as a return type when inferring the IValue type of a Python object.
struct InferredType {
/* implicit */ InferredType(TypePtr type) : type_(std::move(type)) {}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/jit_type_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ struct TORCH_API Type {

using TypePtr = SingletonOrSharedTypePtr<Type>;
using Ptr = TypePtr;
using ElementType = Type;

// subtyping relation. By default, we return true for the case
// when the type is exactly equal or if this <: T where rhs = Optional[T]
Expand Down
28 changes: 22 additions & 6 deletions aten/src/ATen/core/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,15 +596,31 @@ TupleTypePtr TupleType::createNamed(
const c10::optional<c10::QualifiedName>& qualName,
const std::vector<std::string>& field_names,
const std::vector<TypePtr>& field_types) {
std::vector<IValue> empty_defaults;
return TupleType::createNamed(qualName, field_names, field_types, empty_defaults);
}
std::vector<IValue> empty_defaults;
return TupleType::createNamed(qualName, field_names, field_types, empty_defaults);
}

TupleTypePtr TupleType::createNamed(
const c10::optional<c10::QualifiedName>& qualName,
const std::vector<c10::string_view>& field_names,
const std::vector<TypePtr>& field_types) {
std::vector<IValue> empty_defaults;
return createWithSpec(qualName, field_names, field_types, empty_defaults);
}

TupleTypePtr TupleType::createNamed(const c10::optional<c10::QualifiedName>& qualName,
TupleTypePtr TupleType::createNamed(
const c10::optional<c10::QualifiedName>& qualName,
const std::vector<std::string>& field_names,
const std::vector<TypePtr>& field_types,
std::vector<IValue>& field_defaults) {
return createWithSpec(qualName, field_names, field_types, field_defaults);
}

template <typename S>
TupleTypePtr TupleType::createWithSpec(const c10::optional<c10::QualifiedName>& qualName,
const std::vector<S>& field_names,
const std::vector<TypePtr>& field_types,
std::vector<IValue>& field_defaults) {
TORCH_INTERNAL_ASSERT(field_names.size() == field_types.size());

std::vector<Argument> arguments;
Expand All @@ -613,7 +629,7 @@ TupleTypePtr TupleType::createNamed(const c10::optional<c10::QualifiedName>& qua
for (size_t i = 0; i < field_names.size(); ++i) {
if (i < min_default_idx) {
Argument arg{
/*name=*/field_names[i],
/*name=*/std::string{field_names[i]},
/*type=*/field_types[i],
/*N=*/i};
arguments.emplace_back(std::move(arg));
Expand All @@ -625,7 +641,7 @@ TupleTypePtr TupleType::createNamed(const c10::optional<c10::QualifiedName>& qua
"mutability could lead to potential memory aliasing "
"problems");
Argument arg{
/*name=*/field_names[i],
/*name=*/std::string{field_names[i]},
/*type=*/field_types[i],
/*N=*/i,
/*default_value=*/field_defaults[j]};
Expand Down
35 changes: 30 additions & 5 deletions aten/src/ATen/core/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,43 @@ namespace c10 {

struct TORCH_API DynamicTypeFactory {
template <typename T, typename... Args>
static c10::TypePtr create(Args&&... args) {
static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) {
return std::make_shared<c10::DynamicType>(
c10::DynamicTypeTrait<T>::tagValue(),
c10::DynamicType::Arguments(
c10::ArrayRef<c10::TypePtr>({std::forward<Args>(args)...})));
c10::DynamicType::Arguments(c10::ArrayRef<c10::TypePtr>(
{std::move(ty), std::forward<Args>(args)...})));
}
template <typename T>
static c10::DynamicTypePtr create(std::vector<c10::TypePtr> types) {
return std::make_shared<c10::DynamicType>(
c10::DynamicTypeTrait<T>::tagValue(),
c10::DynamicType::Arguments(types));
}
static c10::DynamicTypePtr createNamedTuple(
const std::string& name,
const std::vector<c10::string_view>& fields,
const std::vector<c10::TypePtr>& types) {
return std::make_shared<c10::DynamicType>(
c10::DynamicType::Tag::Tuple,
name,
c10::DynamicType::Arguments(fields, types));
}
};

struct TORCH_API DefaultTypeFactory {
template <typename T, typename... Args>
static c10::TypePtr create(Args&&... args) {
return T::create(std::forward<Args>(args)...);
static c10::TypePtr create(TypePtr ty, Args&&... args) {
return T::create(std::move(ty), std::forward<Args>(args)...);
}
template <typename T>
static c10::TypePtr create(std::vector<c10::TypePtr> types) {
return T::create(std::move(types));
}
static c10::TypePtr createNamedTuple(
const std::string& name,
const std::vector<c10::string_view>& fields,
const std::vector<c10::TypePtr>& types) {
return c10::TupleType::createNamed(name, fields, types);
}
};

Expand Down
10 changes: 6 additions & 4 deletions aten/src/ATen/test/type_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,25 +184,27 @@ TEST(TypeEquality, TupleEquality) {

TEST(TypeEquality, NamedTupleEquality) {
// Named tuples should compare equal if they share a name and field names
std::vector<std::string> fields = {"a", "b", "c", "d"};
std::vector<std::string> otherFields = {"wow", "so", "very", "different"};
auto type = TupleType::createNamed(
"MyNamedTuple",
{"a", "b", "c", "d"},
fields,
{IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
auto type2 = TupleType::createNamed(
"MyNamedTuple",
{"a", "b", "c", "d"},
fields,
{IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
EXPECT_EQ(*type, *type2);

auto differentName = TupleType::createNamed(
"WowSoDifferent",
{"a", "b", "c", "d"},
fields,
{IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
EXPECT_NE(*type, *differentName);

auto differentField = TupleType::createNamed(
"MyNamedTuple",
{"wow", "so", "very", "different"},
otherFields,
{IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
EXPECT_NE(*type, *differentField);
}
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/mobile/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ bool InterpreterState::run(Stack& stack) {
frame.step();
} break;
case NAMED_TUPLE_CONSTRUCT: {
namedTupleConstruct(
stack, code.types_[inst.X]->expect<at::TupleType>(), inst.N);
namedTupleConstruct(stack, code.types_[inst.X], inst.N);
frame.step();
} break;
case CREATE_OBJECT: {
Expand Down
Loading

0 comments on commit 40b80aa

Please sign in to comment.