Skip to content

[PyTorch] Store Tensor explicitly in IValue #48824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update on "[PyTorch] Store Tensor explicitly in IValue"
Enables following diff, which will make toTensor() return
`const Tensor&` and allow callers to avoid refcounting overhead.

Differential Revision: [D25324617](https://our.internmc.facebook.com/intern/diff/D25324617/)

[ghstack-poisoned]
  • Loading branch information
swolchok committed Dec 12, 2020
commit 805c9896d9830c0ecd813296b0ee0d95a3b20d5a
2 changes: 1 addition & 1 deletion aten/src/ATen/core/ivalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ size_t IValue::hash(const IValue& v) {
case Tag::Tensor:
// Tensor __hash__ is equivalent to `id()`, so take the pointer value of
// the tensor to emulate it
return c10::get_hash(v.payload.as_int);
return c10::get_hash(v.payload.as_tensor.unsafeGetTensorImpl());
case Tag::Int:
return c10::get_hash(v.payload.as_int);
case Tag::String:
Expand Down
52 changes: 33 additions & 19 deletions aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,11 @@ struct Capsule {
struct CAFFE2_API IValue final {
IValue(const IValue& rhs)
: IValue(rhs.payload, rhs.tag, rhs.is_intrusive_ptr) {
if (is_intrusive_ptr) {
if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
c10::raw::intrusive_ptr::incref(payload.as_intrusive_ptr);
}
}

IValue(IValue&& rhs) noexcept : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) {
moveFrom(std::move(rhs));
rhs.tag = Tag::None;
Expand All @@ -174,9 +175,7 @@ struct CAFFE2_API IValue final {
destroy();
}

// Always-inline for performance -- this gets called frequently
// inside the core of the static runtime.
C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept {
IValue& operator=(IValue&& rhs) & noexcept {
if (&rhs == this) {
return *this;
}
Expand Down Expand Up @@ -294,6 +293,9 @@ struct CAFFE2_API IValue final {
return 1;
}

if (payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) {
return 0;
}
return c10::raw::intrusive_ptr::use_count(payload.as_intrusive_ptr);
}

Expand Down Expand Up @@ -352,7 +354,7 @@ struct CAFFE2_API IValue final {
: tag(Tag::Blob), is_intrusive_ptr(true) {
// TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract
// and store it as a Tensor instead.
payload.as_intrusive_ptr = blob.release();
payload.as_intrusive_ptr = blob.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
}

/// @private [doxygen private]
Expand Down Expand Up @@ -691,7 +693,7 @@ struct CAFFE2_API IValue final {
// This is not an optional optimization: our incref call
// *will not* do the right thing when called on an
// undefined generator.
payload.as_intrusive_ptr = g.unsafeReleaseGeneratorImpl();
payload.as_intrusive_ptr = g.unsafeReleaseGeneratorImpl() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
}
bool isGenerator() const {
return Tag::Generator == tag;
Expand Down Expand Up @@ -775,7 +777,7 @@ struct CAFFE2_API IValue final {
const void* internalToPointer() const {
TORCH_INTERNAL_ASSERT(
isPtrType(), "Can only call internalToPointer() for pointer types");
return payload.as_intrusive_ptr;
return payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton() ? payload.as_intrusive_ptr : nullptr;
}

TypePtr type() const;
Expand Down Expand Up @@ -842,10 +844,11 @@ struct CAFFE2_API IValue final {
c10::intrusive_ptr<T, NullType> toIntrusivePtr() const;

void destroy() {
if (isTensor()) {
payload.as_tensor.~Tensor();
if (isTensor() || is_intrusive_ptr) {
c10::intrusive_ptr<intrusive_ptr_target, c10::UndefinedTensorImpl>::reclaim(payload.as_tensor.unsafeReleaseTensorImpl());
// payload.as_tensor.~Tensor();
} else if (is_intrusive_ptr) {
c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr);
c10::intrusive_ptr<intrusive_ptr_target, c10::UndefinedTensorImpl>::reclaim(payload.as_intrusive_ptr);
}
}

Expand Down Expand Up @@ -879,6 +882,9 @@ struct CAFFE2_API IValue final {
int64_t as_int;
double as_double;
bool as_bool;
// Invariant: never nullptr; null state is represented as
// c10::UndefinedTensorImpl::singleton() for consistency of
// representation with Tensor.
c10::intrusive_ptr_target* as_intrusive_ptr;
at::Tensor as_tensor;
struct {
Expand Down Expand Up @@ -911,7 +917,7 @@ struct CAFFE2_API WeakIValue final {
: payload(rhs.payload),
tag(rhs.tag),
is_intrusive_ptr(rhs.is_intrusive_ptr) {
if (is_intrusive_ptr) {
if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
}
}
Expand All @@ -925,14 +931,16 @@ struct CAFFE2_API WeakIValue final {
memcpy(&payload, &rhs.payload, sizeof(payload));
}
if (is_intrusive_ptr) {
c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
}
}
}
WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() {
swap(rhs);
}
~WeakIValue() {
if (is_intrusive_ptr) {
if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr);
}
}
Expand Down Expand Up @@ -961,16 +969,20 @@ struct CAFFE2_API WeakIValue final {
memcpy(&newPayload, &payload, sizeof(newPayload));
return IValue(newPayload, tag, false);
}
auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target>::reclaim(
payload.as_intrusive_ptr);
if (IValue::Tag::Tensor == tag) {
auto ip = temp.lock().release();
auto temp = c10::weak_intrusive_ptr<at::TensorImpl, c10::UndefinedTensorImpl>::reclaim(
static_cast<at::TensorImpl*>(payload.as_intrusive_ptr));
c10::intrusive_ptr<at::TensorImpl, c10::UndefinedTensorImpl> ip(temp.lock());
if (!ip) {
return IValue();
} else {
return IValue(ip);
return IValue(at::Tensor(std::move(ip)));
}
} else {
auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target>::reclaim(
payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
? nullptr
: payload.as_intrusive_ptr);
IValue::Payload pl;
pl.as_intrusive_ptr = temp.lock().release();
temp.release();
Expand All @@ -986,7 +998,7 @@ struct CAFFE2_API WeakIValue final {
if (!is_intrusive_ptr) {
return 1;
}
auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target>::reclaim(
auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target, c10::UndefinedTensorImpl>::reclaim(
payload.as_intrusive_ptr);
size_t result = temp.use_count();
temp.release();
Expand All @@ -997,7 +1009,7 @@ struct CAFFE2_API WeakIValue final {
if (!is_intrusive_ptr) {
return 1;
}
auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target>::reclaim(
auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target, c10::UndefinedTensorImpl>::reclaim(
payload.as_intrusive_ptr);
size_t result = temp.weak_use_count();
temp.release();
Expand All @@ -1012,6 +1024,8 @@ struct CAFFE2_API WeakIValue final {
int64_t as_int;
double as_double;
bool as_bool;
// Invariant: never nullptr; null state is represented as
// UndefinedTensorImpl::singleton() for consistency with Tensor.
c10::intrusive_ptr_target* as_intrusive_ptr;
struct {
DeviceType type;
Expand Down
57 changes: 40 additions & 17 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,18 @@ struct tagged_capsule {
template <class T, class NullType>
c10::intrusive_ptr<T, NullType> IValue::moveToIntrusivePtr() {
auto t = c10::intrusive_ptr<T, NullType>::reclaim(
static_cast<T*>(payload.as_intrusive_ptr));
payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
? NullType::singleton()
: static_cast<T*>(payload.as_intrusive_ptr));
clearToNone();
return t;
}
template <typename T, class NullType>
c10::intrusive_ptr<T, NullType> IValue::toIntrusivePtr() const {
auto r = c10::intrusive_ptr<T, NullType>::reclaim(
static_cast<T*>(payload.as_intrusive_ptr));
payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
? NullType::singleton()
: static_cast<T*>(payload.as_intrusive_ptr));
auto p = r;
r.release();
return p;
Expand Down Expand Up @@ -736,6 +740,7 @@ using _guarded_unsigned_long = std::conditional_t<

inline const ivalue::Object& IValue::toObjectRef() const {
AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference");
return *static_cast<const c10::ivalue::Object*>(payload.as_intrusive_ptr);
}

Expand Down Expand Up @@ -982,6 +987,9 @@ inline c10::List<int64_t> IValue::toIntList() const& {
}
inline std::vector<int64_t> IValue::toIntVector() const {
AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
"called toIntVector on null intrusive_ptr IValue");
return createVectorFromList<int64_t>(
static_cast<const c10::detail::ListImpl*>(payload.as_intrusive_ptr));
}
Expand All @@ -995,6 +1003,9 @@ inline c10::List<double> IValue::toDoubleList() const& {
}
inline std::vector<double> IValue::toDoubleVector() const {
AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
"called toDoubleVector on null intrusive_ptr IValue");
return createVectorFromList<double>(
static_cast<const c10::detail::ListImpl*>(payload.as_intrusive_ptr));
}
Expand All @@ -1016,6 +1027,9 @@ inline c10::List<at::Tensor> IValue::toTensorList() const& {
}
inline std::vector<at::Tensor> IValue::toTensorVector() const {
AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
"called toTensorVector on null intrusive_ptr IValue");
return createVectorFromList<at::Tensor>(
static_cast<const c10::detail::ListImpl*>(payload.as_intrusive_ptr));
}
Expand All @@ -1029,6 +1043,9 @@ inline c10::List<IValue> IValue::toList() const& {
}
inline c10::ArrayRef<IValue> IValue::toListRef() const {
AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
"called toListRef on null intrusive_ptr IValue");
return static_cast<const c10::detail::ListImpl*>(payload.as_intrusive_ptr)
->list;
}
Expand All @@ -1051,7 +1068,7 @@ inline c10::intrusive_ptr<ivalue::Tuple> IValue::toTuple() const& {

inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
: tag(Tag::Tuple), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
payload.as_intrusive_ptr = v.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
}
template <
typename... Args,
Expand All @@ -1067,14 +1084,14 @@ inline IValue::IValue(const std::tuple<Args...>& t)

inline IValue::IValue(c10::intrusive_ptr<ivalue::ConstantString> v)
: tag(Tag::String), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
payload.as_intrusive_ptr = v.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
}
inline IValue::IValue(std::string v)
: IValue(ivalue::ConstantString::create(std::move(v))) {}

inline IValue::IValue(c10::impl::GenericList v)
: tag(Tag::GenericList), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.impl_.release();
payload.as_intrusive_ptr = v.impl_.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
}

template <class T, IValue::enable_if_ivalue_constructible<T>>
Expand Down Expand Up @@ -1106,7 +1123,7 @@ inline IValue::IValue(std::array<T, N> v) : IValue(c10::List<T>()) {

inline IValue::IValue(c10::impl::GenericDict v)
: tag(Tag::GenericDict), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.impl_.release();
payload.as_intrusive_ptr = v.impl_.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
}
template <class Key, class Value>
inline IValue::IValue(c10::Dict<Key, Value> v)
Expand All @@ -1133,25 +1150,25 @@ inline IValue::IValue(c10::nullopt_t) : IValue() {}

inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
: tag(Tag::Object), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
payload.as_intrusive_ptr = v.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
}

inline IValue::IValue(c10::intrusive_ptr<ivalue::PyObjectHolder> v)
: tag(Tag::PyObject), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
payload.as_intrusive_ptr = v.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
}

inline IValue::IValue(c10::intrusive_ptr<ivalue::EnumHolder> v)
: tag(Tag::Enum), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
payload.as_intrusive_ptr = v.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
}

inline IValue IValue::make_capsule(
intrusive_ptr<torch::CustomClassHolder> blob) {
IValue iv;
iv.tag = Tag::Capsule;
iv.is_intrusive_ptr = true;
iv.payload.as_intrusive_ptr = blob.release();
iv.payload.as_intrusive_ptr = blob.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
return iv;
}

Expand All @@ -1169,28 +1186,31 @@ IValue::IValue(c10::intrusive_ptr<T> custom_class) {
auto ivalue_obj = c10::ivalue::Object::create(
c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1);
ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class)));
payload.as_intrusive_ptr = ivalue_obj.release();
payload.as_intrusive_ptr = ivalue_obj.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
tag = Tag::Object;
is_intrusive_ptr = true;
}

inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
: tag(Tag::Future), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
payload.as_intrusive_ptr = v.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
}

inline IValue::IValue(c10::intrusive_ptr<c10::RRefInterface> v)
: tag(Tag::RRef), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
payload.as_intrusive_ptr = v.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
}

inline IValue::IValue(c10::intrusive_ptr<at::Quantizer> v)
: tag(Tag::Quantizer), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
payload.as_intrusive_ptr = v.release() ?: static_cast<intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
}

inline const std::string& IValue::toStringRef() const {
AT_ASSERT(isString(), "Expected String but got ", tagKind());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
"called toStringRef on null intrusive_ptr IValue");
return static_cast<const c10::ivalue::ConstantString*>(
payload.as_intrusive_ptr)
->string();
Expand All @@ -1201,6 +1221,9 @@ inline c10::optional<std::reference_wrapper<const std::string>> IValue::
return c10::nullopt;
}
AT_ASSERT(isString(), "Expected optional<string> but got ", tagKind());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
"called toOptionalStringRef on null intrusive_ptr IValue");
return std::reference_wrapper<const std::string>(
static_cast<const c10::ivalue::ConstantString*>(payload.as_intrusive_ptr)
->string());
Expand Down Expand Up @@ -1256,13 +1279,13 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const {
} else if (this->isTensor() && rhs.isTensor()) {
// for tensor type, just check the as_intrusive_ptr since is_intrusive_ptr
// is false for undefined tensor
return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
return this->payload.as_tensor.is_same(rhs.payload.as_tensor);
} else if (this->isTensor() && rhs.isNone()) {
// special case: undefined tensor and None are the same identity
return !this->is_intrusive_ptr;
return !this->payload.as_tensor.defined();
} else if (this->isNone() && rhs.isTensor()) {
// special case: undefined tensor and None are the same identity
return !rhs.is_intrusive_ptr;
return !rhs.payload.as_tensor.defined();
} else if (this->isInt() && rhs.isInt()) {
return this->toInt() == rhs.toInt();
} else if (this->isDouble() && rhs.isDouble()) {
Expand Down
4 changes: 2 additions & 2 deletions c10/util/intrusive_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class intrusive_ptr final {
"NullType must have a constexpr singleton() method");
#endif
static_assert(
std::is_same<TTarget*, decltype(NullType::singleton())>::value,
std::is_base_of<TTarget, typename std::remove_pointer<decltype(NullType::singleton())>::type>::value,
"NullType::singleton() must return a element_type* pointer");

TTarget* target_;
Expand Down Expand Up @@ -509,7 +509,7 @@ class weak_intrusive_ptr final {
"NullType must have a constexpr singleton() method");
#endif
static_assert(
std::is_same<TTarget*, decltype(NullType::singleton())>::value,
std::is_base_of<TTarget, typename std::remove_pointer<decltype(NullType::singleton())>::type>::value,
"NullType::singleton() must return a element_type* pointer");

TTarget* target_;
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.