Skip to content

Commit 4a820f8

Browse files
committed
[PyTorch] Store Tensor explicitly in IValue
Pull Request resolved: #48824 Enables following diff, which will make toTensor() return `const Tensor&` and allow callers to avoid refcounting overhead. ghstack-source-id: 117906329 Differential Revision: [D25324617](https://our.internmc.facebook.com/intern/diff/D25324617/)
1 parent bc2352e commit 4a820f8

File tree

4 files changed

+206
-32
lines changed

4 files changed

+206
-32
lines changed

aten/src/ATen/core/ivalue.h

Lines changed: 119 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -163,23 +163,34 @@ struct CAFFE2_API IValue final {
163163
c10::raw::intrusive_ptr::incref(payload.as_intrusive_ptr);
164164
}
165165
}
166-
IValue(IValue&& rhs) noexcept : IValue() {
167-
swap(rhs);
166+
IValue(IValue&& rhs) noexcept : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) {
167+
moveFrom(std::move(rhs));
168+
rhs.tag = Tag::None;
169+
rhs.is_intrusive_ptr = false;
168170
}
171+
169172
/// @private [doxygen private]
170173
~IValue() {
171-
if (is_intrusive_ptr) {
172-
c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr);
173-
}
174+
destroy();
174175
}
175-
IValue& operator=(IValue&& rhs) & noexcept {
176-
IValue(std::move(rhs)).swap(*this); // this also sets rhs to None
176+
177+
// Always-inline for performance -- this gets called frequently
178+
// inside the core of the static runtime.
179+
C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept {
180+
if (&rhs == this) {
181+
return *this;
182+
}
183+
184+
destroy();
185+
moveFrom(std::move(rhs));
177186
return *this;
178187
}
188+
179189
IValue& operator=(IValue const& rhs) & {
180190
IValue(rhs).swap(*this);
181191
return *this;
182192
}
193+
183194
void dump() const;
184195

185196
/**
@@ -288,7 +299,27 @@ struct CAFFE2_API IValue final {
288299

289300
/// @private [doxygen private]
290301
void swap(IValue& rhs) noexcept {
291-
std::swap(payload, rhs.payload);
302+
if (isTensor() && rhs.isTensor()) {
303+
std::swap(payload.as_tensor, rhs.payload.as_tensor);
304+
} else if (isTensor()) {
305+
at::Tensor t = std::move(payload.as_tensor);
306+
// As far as I can tell, omitting the usual explicit destructor call
307+
// is not UB in and of itself, and it's a slight perf win. The
308+
// destructor is a no-op, because the moved-from Tensor is
309+
// effectively an intrusive_ptr in the null state, so we don't need
310+
// the behavior for correctness reasons either. Leaving this
311+
// explanatory comment, including commented-out destructor call, to
312+
// make this abundantly clear.
313+
//
314+
// payload.as_tensor.~Tensor();
315+
memcpy(&payload, &rhs.payload, sizeof(payload));
316+
new (&rhs.payload.as_tensor) at::Tensor(std::move(t));
317+
} else if (rhs.isTensor()) {
318+
rhs.swap(*this);
319+
return;
320+
} else {
321+
std::swap(reinterpret_cast<char(&)[sizeof(payload)]>(*&payload), reinterpret_cast<char(&)[sizeof(payload)]>(*&rhs.payload));
322+
}
292323
std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr);
293324
std::swap(tag, rhs.tag);
294325
}
@@ -297,21 +328,16 @@ struct CAFFE2_API IValue final {
297328
// While some of these accessors could be generated through templates,
298329
// we prefer to write them manually for clarity
299330

300-
IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(t.defined()) {
301-
// Note: the undefined tensor is not refcounted, so while it
302-
// is tagged as a tensor, is_intrusive_ptr is set to false.
303-
// This is not an optional optimization: our incref call
304-
// *will not* do the right thing when called on an
305-
// undefined tensor.
306-
payload.as_intrusive_ptr = t.unsafeReleaseTensorImpl();
331+
IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(false) {
332+
new (&payload.as_tensor) at::Tensor(std::move(t));
307333
}
308334
bool isTensor() const {
309335
return Tag::Tensor == tag;
310336
}
311337
at::Tensor toTensor() &&;
312338
at::Tensor toTensor() const&;
313339
at::TensorImpl* unsafeToTensorImpl() const {
314-
return static_cast<at::TensorImpl*>(payload.as_intrusive_ptr);
340+
return payload.as_tensor.unsafeGetTensorImpl();
315341
}
316342

317343
const IValue& toIValue() const {
@@ -565,7 +591,7 @@ struct CAFFE2_API IValue final {
565591
c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() const&;
566592

567593
// None
568-
IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {}
594+
IValue() : tag(Tag::None), is_intrusive_ptr(false) {}
569595
bool isNone() const {
570596
return Tag::None == tag;
571597
}
@@ -815,7 +841,35 @@ struct CAFFE2_API IValue final {
815841
class NullType = c10::detail::intrusive_target_default_null_type<T>>
816842
c10::intrusive_ptr<T, NullType> toIntrusivePtr() const;
817843

818-
void clearToNone() {
844+
void destroy() {
845+
if (isTensor()) {
846+
payload.as_tensor.~Tensor();
847+
} else if (is_intrusive_ptr) {
848+
c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr);
849+
}
850+
}
851+
852+
C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept {
853+
if (rhs.isTensor()) {
854+
new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
855+
// As far as I can tell, omitting the usual explicit destructor call
856+
// is not UB in and of itself, and it's a slight perf win. The
857+
// destructor is a no-op, because the moved-from Tensor is
858+
// effectively an intrusive_ptr in the null state, so we don't need
859+
// the behavior for correctness reasons either. Leaving this
860+
// explanatory comment, including commented-out destructor call, to
861+
// make this abundantly clear.
862+
//
863+
// rhs.payload.as_tensor.~Tensor();
864+
} else {
865+
memcpy(&payload, &rhs.payload, sizeof(payload));
866+
}
867+
tag = rhs.tag;
868+
is_intrusive_ptr = rhs.is_intrusive_ptr;
869+
rhs.clearToNone();
870+
}
871+
872+
void clearToNone() noexcept {
819873
payload.as_int = 0;
820874
tag = Tag::None;
821875
is_intrusive_ptr = false;
@@ -826,13 +880,23 @@ struct CAFFE2_API IValue final {
826880
double as_double;
827881
bool as_bool;
828882
c10::intrusive_ptr_target* as_intrusive_ptr;
883+
at::Tensor as_tensor;
829884
struct {
830885
DeviceType type;
831886
DeviceIndex index;
832887
} as_device;
888+
889+
Payload() : as_int(0) {}
890+
~Payload() {}
833891
};
834892

835-
IValue(Payload p, Tag t, bool i) : payload(p), tag(t), is_intrusive_ptr(i) {}
893+
IValue(const Payload& p, Tag t, bool i) : tag(t), is_intrusive_ptr(i) {
894+
if (isTensor()) {
895+
new (&payload.as_tensor) at::Tensor(p.as_tensor);
896+
} else {
897+
memcpy(&payload, &p, sizeof(payload));
898+
}
899+
}
836900

837901
Payload payload;
838902
Tag tag;
@@ -852,9 +916,14 @@ struct CAFFE2_API WeakIValue final {
852916
}
853917
}
854918
WeakIValue(const IValue& rhs)
855-
: payload(rhs.payload),
856-
tag(rhs.tag),
919+
: tag(rhs.tag),
857920
is_intrusive_ptr(rhs.is_intrusive_ptr) {
921+
if (rhs.isTensor()) {
922+
payload.as_intrusive_ptr = rhs.unsafeToTensorImpl();
923+
} else {
924+
static_assert(sizeof(payload) == sizeof(rhs.payload), "IValue and WeakIValue payload sizes don't match!");
925+
memcpy(&payload, &rhs.payload, sizeof(payload));
926+
}
858927
if (is_intrusive_ptr) {
859928
c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
860929
}
@@ -888,17 +957,28 @@ struct CAFFE2_API WeakIValue final {
888957

889958
IValue lock() const {
890959
if (!is_intrusive_ptr) {
891-
return IValue(payload, tag, false);
960+
IValue::Payload newPayload;
961+
memcpy(&newPayload, &payload, sizeof(newPayload));
962+
return IValue(newPayload, tag, false);
892963
}
893964
auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target>::reclaim(
894965
payload.as_intrusive_ptr);
895-
IValue::Payload pl;
896-
pl.as_intrusive_ptr = temp.lock().release();
897-
temp.release();
898-
if (!pl.as_intrusive_ptr) {
899-
return IValue();
966+
if (IValue::Tag::Tensor == tag) {
967+
auto ip = temp.lock().release();
968+
if (!ip) {
969+
return IValue();
970+
} else {
971+
return IValue(std::move(ip));
972+
}
900973
} else {
901-
return IValue(pl, tag, true);
974+
IValue::Payload pl;
975+
pl.as_intrusive_ptr = temp.lock().release();
976+
temp.release();
977+
if (!pl.as_intrusive_ptr) {
978+
return IValue();
979+
} else {
980+
return IValue(pl, tag, true);
981+
}
902982
}
903983
}
904984

@@ -928,7 +1008,17 @@ struct CAFFE2_API WeakIValue final {
9281008
}
9291009

9301010
private:
931-
IValue::Payload payload;
1011+
union Payload {
1012+
int64_t as_int;
1013+
double as_double;
1014+
bool as_bool;
1015+
c10::intrusive_ptr_target* as_intrusive_ptr;
1016+
struct {
1017+
DeviceType type;
1018+
DeviceIndex index;
1019+
} as_device;
1020+
};
1021+
Payload payload;
9321022
IValue::Tag tag;
9331023
bool is_intrusive_ptr;
9341024
};

aten/src/ATen/core/ivalue_inl.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,22 @@ inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() const& {
130130
}
131131
inline at::Tensor IValue::toTensor() && {
132132
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
133-
return at::Tensor(
134-
moveToIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
133+
auto result = std::move(payload.as_tensor);
134+
// As far as I can tell, omitting the usual explicit destructor call
135+
// is not UB in and of itself, and it's a slight perf win. The
136+
// destructor is a no-op, because the moved-from Tensor is
137+
// effectively an intrusive_ptr in the null state, so we don't need
138+
// the behavior for correctness reasons either. Leaving this
139+
// explanatory comment, including commented-out destructor call, to
140+
// make this abundantly clear.
141+
//
142+
// payload.as_tensor.~Tensor();
143+
clearToNone();
144+
return result;
135145
}
136146
inline at::Tensor IValue::toTensor() const& {
137147
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
138-
return at::Tensor(toIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
148+
return payload.as_tensor;
139149
}
140150
inline c10::Stream IValue::toStream() && {
141151
return c10::Stream::unpack(payload.as_int);

aten/src/ATen/test/ivalue_test.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,72 @@ TEST(IValueTest, Basic) {
5151
ASSERT_EQ(tv.use_count(), 2);
5252
}
5353

54+
TEST(IValueTest, Swap) {
55+
at::Tensor tv1 = at::rand({3, 4});
56+
IValue tensor(tv1), scalar(42);
57+
// swap() has special cases depending on which side is tensor or
58+
// not. Exercise all 4 combinations.
59+
tensor.swap(scalar);
60+
tensor.swap(scalar);
61+
tensor.swap(tensor);
62+
scalar.swap(scalar);
63+
}
64+
65+
TEST(IValueTest, MoveConstruct) {
66+
at::Tensor t1 = at::rand({3, 4});
67+
68+
{
69+
IValue sourceTensor(t1);
70+
IValue target(std::move(sourceTensor));
71+
EXPECT_TRUE(target.toTensor().equal(t1));
72+
EXPECT_TRUE(sourceTensor.isNone());
73+
}
74+
75+
{
76+
IValue sourceScalar(42);
77+
IValue target(std::move(sourceScalar));
78+
EXPECT_EQ(target, IValue(42));
79+
EXPECT_TRUE(sourceScalar.isNone());
80+
}
81+
}
82+
83+
TEST(IValueTest, MoveAssign) {
84+
at::Tensor tv1 = at::rand({3, 4});
85+
at::Tensor tv2 = at::rand({3, 4});
86+
87+
// 1: tensor to tensor
88+
{
89+
IValue targetTensor(tv1), sourceTensor(tv2);
90+
targetTensor = std::move(sourceTensor);
91+
EXPECT_TRUE(targetTensor.toTensor().equal(tv2));
92+
EXPECT_TRUE(sourceTensor.isNone());
93+
}
94+
95+
// 2: tensor to scalar
96+
{
97+
IValue targetScalar(42), sourceTensor(tv1);
98+
targetScalar = std::move(sourceTensor);
99+
EXPECT_TRUE(targetScalar.toTensor().equal(tv1));
100+
EXPECT_TRUE(sourceTensor.isNone());
101+
}
102+
103+
// 3: scalar to tensor
104+
{
105+
IValue targetTensor(tv1), sourceScalar(42);
106+
targetTensor = std::move(sourceScalar);
107+
EXPECT_EQ(targetTensor, 42);
108+
EXPECT_TRUE(sourceScalar.isNone());
109+
}
110+
111+
// 4: scalar to scalar
112+
{
113+
IValue targetScalar(42), sourceScalar(43);
114+
targetScalar = std::move(sourceScalar);
115+
EXPECT_EQ(targetScalar, 43);
116+
EXPECT_TRUE(sourceScalar.isNone());
117+
}
118+
}
119+
54120
TEST(IValueTest, Tuple) {
55121
std::tuple<int64_t, at::Tensor> t = std::make_tuple(123, at::randn({1}));
56122
auto iv = IValue(t);

c10/macros/Macros.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,14 @@ namespace at { namespace cuda { using namespace c10::hip; }}
186186
#define C10_NOINLINE
187187
#endif
188188

189+
#if __has_attribute(always_inline) || defined(__GNUC__)
190+
#define C10_ALWAYS_INLINE __attribute__((__always_inline__)) inline
191+
#elif defined(_MSC_VER)
192+
#define C10_ALWAYS_INLINE __forceinline
193+
#else
194+
#define C10_ALWAYS_INLINE inline
195+
#endif
196+
189197
#include <sstream>
190198
#include <string>
191199

0 commit comments

Comments
 (0)