Skip to content

Commit ef5a961

Browse files
committed
[PyTorch] Store Tensor explicitly in IValue
Enables following diff, which will make toTensor() return `const Tensor&` and allow callers to avoid refcounting overhead. Differential Revision: [D25324617](https://our.internmc.facebook.com/intern/diff/D25324617/) ghstack-source-id: 117841198 Pull Request resolved: #48824
1 parent 93973ee commit ef5a961

File tree

4 files changed

+191
-28
lines changed

4 files changed

+191
-28
lines changed

aten/src/ATen/core/ivalue.h

Lines changed: 106 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -163,23 +163,61 @@ struct CAFFE2_API IValue final {
163163
c10::raw::intrusive_ptr::incref(payload.as_intrusive_ptr);
164164
}
165165
}
166-
IValue(IValue&& rhs) noexcept : IValue() {
167-
swap(rhs);
166+
IValue(IValue&& rhs) noexcept : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) {
167+
if (isTensor()) {
168+
new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
169+
rhs.payload.as_tensor.~Tensor();
170+
} else {
171+
memcpy(&payload, &rhs.payload, sizeof(payload));
172+
}
173+
rhs.tag = Tag::None;
174+
rhs.is_intrusive_ptr = false;
168175
}
176+
169177
/// @private [doxygen private]
170178
~IValue() {
171179
if (is_intrusive_ptr) {
172180
c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr);
181+
} else if (isTensor()) {
182+
payload.as_tensor.~Tensor();
173183
}
174184
}
175-
IValue& operator=(IValue&& rhs) & noexcept {
176-
IValue(std::move(rhs)).swap(*this); // this also sets rhs to None
185+
186+
// Always-inline for performance -- this gets called frequently
187+
// inside the core of the static runtime.
188+
C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept {
189+
if (&rhs == this) {
190+
return *this;
191+
}
192+
193+
// Tear down our state.
194+
if (isTensor()) {
195+
payload.as_tensor.~Tensor();
196+
} else if (is_intrusive_ptr) {
197+
c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr);
198+
}
199+
200+
if (rhs.isTensor()) {
201+
new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
202+
rhs.payload.as_tensor.~Tensor();
203+
} else {
204+
// No need to specially handle rhs being an intrusive_ptr -- we
205+
// steal the reference.
206+
memcpy(&payload, &rhs.payload, sizeof(payload));
207+
}
208+
209+
tag = rhs.tag;
210+
is_intrusive_ptr = rhs.is_intrusive_ptr;
211+
rhs.tag = Tag::None;
212+
rhs.is_intrusive_ptr = false;
177213
return *this;
178214
}
215+
179216
IValue& operator=(IValue const& rhs) & {
180217
IValue(rhs).swap(*this);
181218
return *this;
182219
}
220+
183221
void dump() const;
184222

185223
/**
@@ -288,7 +326,19 @@ struct CAFFE2_API IValue final {
288326

289327
/// @private [doxygen private]
290328
void swap(IValue& rhs) noexcept {
291-
std::swap(payload, rhs.payload);
329+
if (isTensor() && rhs.isTensor()) {
330+
std::swap(payload.as_tensor, rhs.payload.as_tensor);
331+
} else if (isTensor()) {
332+
at::Tensor t = std::move(payload.as_tensor);
333+
payload.as_tensor.~Tensor();
334+
memcpy(&payload, &rhs.payload, sizeof(payload));
335+
new (&rhs.payload.as_tensor) at::Tensor(std::move(t));
336+
} else if (rhs.isTensor()) {
337+
rhs.swap(*this);
338+
return;
339+
} else {
340+
std::swap(reinterpret_cast<char(&)[sizeof(payload)]>(*&payload), reinterpret_cast<char(&)[sizeof(payload)]>(*&rhs.payload));
341+
}
292342
std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr);
293343
std::swap(tag, rhs.tag);
294344
}
@@ -297,21 +347,16 @@ struct CAFFE2_API IValue final {
297347
// While some of these accessors could be generated through templates,
298348
// we prefer to write them manually for clarity
299349

300-
IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(t.defined()) {
301-
// Note: the undefined tensor is not refcounted, so while it
302-
// is tagged as a tensor, is_intrusive_ptr is set to false.
303-
// This is not an optional optimization: our incref call
304-
// *will not* do the right thing when called on an
305-
// undefined tensor.
306-
payload.as_intrusive_ptr = t.unsafeReleaseTensorImpl();
350+
IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(false) {
351+
new (&payload.as_tensor) at::Tensor(std::move(t));
307352
}
308353
bool isTensor() const {
309354
return Tag::Tensor == tag;
310355
}
311356
at::Tensor toTensor() &&;
312357
at::Tensor toTensor() const&;
313358
at::TensorImpl* unsafeToTensorImpl() const {
314-
return static_cast<at::TensorImpl*>(payload.as_intrusive_ptr);
359+
return payload.as_tensor.unsafeGetTensorImpl();
315360
}
316361

317362
const IValue& toIValue() const {
@@ -565,7 +610,7 @@ struct CAFFE2_API IValue final {
565610
c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() const&;
566611

567612
// None
568-
IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {}
613+
IValue() : tag(Tag::None), is_intrusive_ptr(false) {}
569614
bool isNone() const {
570615
return Tag::None == tag;
571616
}
@@ -826,13 +871,23 @@ struct CAFFE2_API IValue final {
826871
double as_double;
827872
bool as_bool;
828873
c10::intrusive_ptr_target* as_intrusive_ptr;
874+
at::Tensor as_tensor;
829875
struct {
830876
DeviceType type;
831877
DeviceIndex index;
832878
} as_device;
879+
880+
Payload() : as_int(0) {}
881+
~Payload() {}
833882
};
834883

835-
IValue(Payload p, Tag t, bool i) : payload(p), tag(t), is_intrusive_ptr(i) {}
884+
IValue(const Payload& p, Tag t, bool i) : tag(t), is_intrusive_ptr(i) {
885+
if (isTensor()) {
886+
new (&payload.as_tensor) at::Tensor(p.as_tensor);
887+
} else {
888+
memcpy(&payload, &p, sizeof(payload));
889+
}
890+
}
836891

837892
Payload payload;
838893
Tag tag;
@@ -852,9 +907,14 @@ struct CAFFE2_API WeakIValue final {
852907
}
853908
}
854909
WeakIValue(const IValue& rhs)
855-
: payload(rhs.payload),
856-
tag(rhs.tag),
910+
: tag(rhs.tag),
857911
is_intrusive_ptr(rhs.is_intrusive_ptr) {
912+
if (rhs.isTensor()) {
913+
payload.as_intrusive_ptr = rhs.unsafeToTensorImpl();
914+
} else {
915+
static_assert(sizeof(payload) == sizeof(rhs.payload), "IValue and WeakIValue payload sizes don't match!");
916+
memcpy(&payload, &rhs.payload, sizeof(payload));
917+
}
858918
if (is_intrusive_ptr) {
859919
c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
860920
}
@@ -888,17 +948,28 @@ struct CAFFE2_API WeakIValue final {
888948

889949
IValue lock() const {
890950
if (!is_intrusive_ptr) {
891-
return IValue(payload, tag, false);
951+
IValue::Payload newPayload;
952+
memcpy(&newPayload, &payload, sizeof(newPayload));
953+
return IValue(newPayload, tag, false);
892954
}
893955
auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target>::reclaim(
894956
payload.as_intrusive_ptr);
895-
IValue::Payload pl;
896-
pl.as_intrusive_ptr = temp.lock().release();
897-
temp.release();
898-
if (!pl.as_intrusive_ptr) {
899-
return IValue();
957+
if (IValue::Tag::Tensor == tag) {
958+
auto ip = temp.lock().release();
959+
if (!ip) {
960+
return IValue();
961+
} else {
962+
return IValue(std::move(ip));
963+
}
900964
} else {
901-
return IValue(pl, tag, true);
965+
IValue::Payload pl;
966+
pl.as_intrusive_ptr = temp.lock().release();
967+
temp.release();
968+
if (!pl.as_intrusive_ptr) {
969+
return IValue();
970+
} else {
971+
return IValue(pl, tag, true);
972+
}
902973
}
903974
}
904975

@@ -928,7 +999,17 @@ struct CAFFE2_API WeakIValue final {
928999
}
9291000

9301001
private:
931-
IValue::Payload payload;
1002+
union Payload {
1003+
int64_t as_int;
1004+
double as_double;
1005+
bool as_bool;
1006+
c10::intrusive_ptr_target* as_intrusive_ptr;
1007+
struct {
1008+
DeviceType type;
1009+
DeviceIndex index;
1010+
} as_device;
1011+
};
1012+
Payload payload;
9321013
IValue::Tag tag;
9331014
bool is_intrusive_ptr;
9341015
};

aten/src/ATen/core/ivalue_inl.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,20 @@ inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() const& {
130130
}
131131
inline at::Tensor IValue::toTensor() && {
132132
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
133-
return at::Tensor(
134-
moveToIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
133+
auto result = std::move(payload.as_tensor);
134+
// As far as I can tell, omitting the usual explicit destructor call
135+
// is not UB in and of itself, and it's a slight perf win. The
136+
// destructor is a no-op, because the moved-from Tensor is
137+
// effectively an intrusive_ptr in the null state, so we don't need
138+
// the behavior for correctness reasons either. Leaving this
139+
// explanatory comment, including commented-out destructor call, to
140+
// make this abundantly clear. payload.as_tensor.~Tensor();
141+
clearToNone();
142+
return result;
135143
}
136144
inline at::Tensor IValue::toTensor() const& {
137145
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
138-
return at::Tensor(toIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
146+
return payload.as_tensor;
139147
}
140148
inline c10::Stream IValue::toStream() && {
141149
return c10::Stream::unpack(payload.as_int);

aten/src/ATen/test/ivalue_test.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,72 @@ TEST(IValueTest, Basic) {
5151
ASSERT_EQ(tv.use_count(), 2);
5252
}
5353

54+
TEST(IValueTest, Swap) {
55+
at::Tensor tv1 = at::rand({3, 4});
56+
IValue tensor(tv1), scalar(42);
57+
// swap() has special cases depending on which side is tensor or
58+
// not. Exercise all 4 combinations.
59+
tensor.swap(scalar);
60+
tensor.swap(scalar);
61+
tensor.swap(tensor);
62+
scalar.swap(scalar);
63+
}
64+
65+
TEST(IValueTest, MoveConstruct) {
66+
at::Tensor t1 = at::rand({3, 4});
67+
68+
{
69+
IValue sourceTensor(t1);
70+
IValue target(std::move(sourceTensor));
71+
EXPECT_TRUE(target.toTensor().equal(t1));
72+
EXPECT_TRUE(sourceTensor.isNone());
73+
}
74+
75+
{
76+
IValue sourceScalar(42);
77+
IValue target(std::move(sourceScalar));
78+
EXPECT_EQ(target, IValue(42));
79+
EXPECT_TRUE(sourceScalar.isNone());
80+
}
81+
}
82+
83+
TEST(IValueTest, MoveAssign) {
84+
at::Tensor tv1 = at::rand({3, 4});
85+
at::Tensor tv2 = at::rand({3, 4});
86+
87+
// 1: tensor to tensor
88+
{
89+
IValue targetTensor(tv1), sourceTensor(tv2);
90+
targetTensor = std::move(sourceTensor);
91+
EXPECT_TRUE(targetTensor.toTensor().equal(tv2));
92+
EXPECT_TRUE(sourceTensor.isNone());
93+
}
94+
95+
// 2: tensor to scalar
96+
{
97+
IValue targetScalar(42), sourceTensor(tv1);
98+
targetScalar = std::move(sourceTensor);
99+
EXPECT_TRUE(targetScalar.toTensor().equal(tv1));
100+
EXPECT_TRUE(sourceTensor.isNone());
101+
}
102+
103+
// 3: scalar to tensor
104+
{
105+
IValue targetTensor(tv1), sourceScalar(42);
106+
targetTensor = std::move(sourceScalar);
107+
EXPECT_EQ(targetTensor, 42);
108+
EXPECT_TRUE(sourceScalar.isNone());
109+
}
110+
111+
// 4: scalar to scalar
112+
{
113+
IValue targetScalar(42), sourceScalar(43);
114+
targetScalar = std::move(sourceScalar);
115+
EXPECT_EQ(targetScalar, 43);
116+
EXPECT_TRUE(sourceScalar.isNone());
117+
}
118+
}
119+
54120
TEST(IValueTest, Tuple) {
55121
std::tuple<int64_t, at::Tensor> t = std::make_tuple(123, at::randn({1}));
56122
auto iv = IValue(t);

c10/macros/Macros.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,14 @@ namespace at { namespace cuda { using namespace c10::hip; }}
186186
#define C10_NOINLINE
187187
#endif
188188

189+
#if __has_attribute(always_inline) || defined(__GNUC__)
190+
#define C10_ALWAYS_INLINE __attribute__((__always_inline__)) inline
191+
#elif defined(_MSC_VER)
192+
#define C10_ALWAYS_INLINE __forceinline
193+
#else
194+
#define C10_ALWAYS_INLINE inline
195+
#endif
196+
189197
#include <sstream>
190198
#include <string>
191199

0 commit comments

Comments
 (0)