-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[PyTorch] Store Tensor explicitly in IValue #48824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
b48bbb4
0198db9
cff661b
805c989
cd3ce5f
bda623d
1b6544b
d774458
48211e9
bad7e5e
fd28d3f
56a6608
3b45f26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,23 +163,61 @@ struct CAFFE2_API IValue final { | |
c10::raw::intrusive_ptr::incref(payload.as_intrusive_ptr); | ||
} | ||
} | ||
IValue(IValue&& rhs) noexcept : IValue() { | ||
swap(rhs); | ||
IValue(IValue&& rhs) noexcept : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) { | ||
if (isTensor()) { | ||
new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); | ||
rhs.payload.as_tensor.~Tensor(); | ||
} else { | ||
memcpy(&payload, &rhs.payload, sizeof(payload)); | ||
} | ||
rhs.tag = Tag::None; | ||
rhs.is_intrusive_ptr = false; | ||
} | ||
|
||
/// @private [doxygen private] | ||
~IValue() { | ||
if (is_intrusive_ptr) { | ||
c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr); | ||
} else if (isTensor()) { | ||
payload.as_tensor.~Tensor(); | ||
} | ||
} | ||
IValue& operator=(IValue&& rhs) & noexcept { | ||
IValue(std::move(rhs)).swap(*this); // this also sets rhs to None | ||
|
||
// Always-inline for performance -- this gets called frequently | ||
// inside the core of the static runtime. | ||
C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept { | ||
if (&rhs == this) { | ||
return *this; | ||
} | ||
|
||
// Tear down our state. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we deduplicate this logic with the logic in the constructor by moving them into a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, but that's more work for the inliner to get right and these code paths are critical. I can try it. |
||
if (isTensor()) { | ||
payload.as_tensor.~Tensor(); | ||
} else if (is_intrusive_ptr) { | ||
c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr); | ||
} | ||
|
||
if (rhs.isTensor()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and deduplicate this logic with the one from the move constructor? |
||
new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); | ||
rhs.payload.as_tensor.~Tensor(); | ||
} else { | ||
// No need to specially handle rhs being an intrusive_ptr -- we | ||
// steal the reference. | ||
memcpy(&payload, &rhs.payload, sizeof(payload)); | ||
} | ||
|
||
tag = rhs.tag; | ||
is_intrusive_ptr = rhs.is_intrusive_ptr; | ||
rhs.tag = Tag::None; | ||
rhs.is_intrusive_ptr = false; | ||
return *this; | ||
} | ||
|
||
IValue& operator=(IValue const& rhs) & { | ||
IValue(rhs).swap(*this); | ||
return *this; | ||
} | ||
|
||
void dump() const; | ||
|
||
/** | ||
|
@@ -288,7 +326,19 @@ struct CAFFE2_API IValue final { | |
|
||
/// @private [doxygen private] | ||
void swap(IValue& rhs) noexcept { | ||
std::swap(payload, rhs.payload); | ||
if (isTensor() && rhs.isTensor()) { | ||
std::swap(payload.as_tensor, rhs.payload.as_tensor); | ||
} else if (isTensor()) { | ||
at::Tensor t = std::move(payload.as_tensor); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ugh this is more involved than I had hoped. I guess it's UB to just relocate the Tensor without destructing and constructing again? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, you have to construct it in rhs.payload using placement new or it's UB. Skiping the destructor call is legit, and I'll probably try that. |
||
payload.as_tensor.~Tensor(); | ||
memcpy(&payload, &rhs.payload, sizeof(payload)); | ||
new (&rhs.payload.as_tensor) at::Tensor(std::move(t)); | ||
} else if (rhs.isTensor()) { | ||
rhs.swap(*this); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is potentially slow because it needs to do the |
||
return; | ||
} else { | ||
std::swap(reinterpret_cast<char(&)[sizeof(payload)]>(*&payload), reinterpret_cast<char(&)[sizeof(payload)]>(*&rhs.payload)); | ||
} | ||
std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr); | ||
std::swap(tag, rhs.tag); | ||
} | ||
|
@@ -297,21 +347,16 @@ struct CAFFE2_API IValue final { | |
// While some of these accessors could be generated through templates, | ||
// we prefer to write them manually for clarity | ||
|
||
IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(t.defined()) { | ||
// Note: the undefined tensor is not refcounted, so while it | ||
// is tagged as a tensor, is_intrusive_ptr is set to false. | ||
// This is not an optional optimization: our incref call | ||
// *will not* do the right thing when called on an | ||
// undefined tensor. | ||
payload.as_intrusive_ptr = t.unsafeReleaseTensorImpl(); | ||
IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(false) { | ||
new (&payload.as_tensor) at::Tensor(std::move(t)); | ||
} | ||
bool isTensor() const { | ||
return Tag::Tensor == tag; | ||
} | ||
at::Tensor toTensor() &&; | ||
at::Tensor toTensor() const&; | ||
at::TensorImpl* unsafeToTensorImpl() const { | ||
return static_cast<at::TensorImpl*>(payload.as_intrusive_ptr); | ||
return payload.as_tensor.unsafeGetTensorImpl(); | ||
} | ||
|
||
const IValue& toIValue() const { | ||
|
@@ -565,7 +610,7 @@ struct CAFFE2_API IValue final { | |
c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() const&; | ||
|
||
// None | ||
IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {} | ||
IValue() : tag(Tag::None), is_intrusive_ptr(false) {} | ||
bool isNone() const { | ||
return Tag::None == tag; | ||
} | ||
|
@@ -826,13 +871,23 @@ struct CAFFE2_API IValue final { | |
double as_double; | ||
bool as_bool; | ||
c10::intrusive_ptr_target* as_intrusive_ptr; | ||
at::Tensor as_tensor; | ||
struct { | ||
DeviceType type; | ||
DeviceIndex index; | ||
} as_device; | ||
|
||
Payload() : as_int(0) {} | ||
~Payload() {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason you're user-defining the destructor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unions with non-POD types in them are a pain. The destructor cannot be defined by default -- do you run ~Tensor() or not? So, we have to define it to do nothing. |
||
}; | ||
|
||
IValue(Payload p, Tag t, bool i) : payload(p), tag(t), is_intrusive_ptr(i) {} | ||
IValue(const Payload& p, Tag t, bool i) : tag(t), is_intrusive_ptr(i) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. even the largest Payload should be 64bit only and Payload has trivial copy/move constructors, so I would assume passing by value is better. Is passing by reference here related to the Itanium ABI thing you posted about? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Not with |
||
if (isTensor()) { | ||
new (&payload.as_tensor) at::Tensor(p.as_tensor); | ||
} else { | ||
memcpy(&payload, &p, sizeof(payload)); | ||
} | ||
} | ||
|
||
Payload payload; | ||
Tag tag; | ||
|
@@ -852,9 +907,14 @@ struct CAFFE2_API WeakIValue final { | |
} | ||
} | ||
WeakIValue(const IValue& rhs) | ||
: payload(rhs.payload), | ||
tag(rhs.tag), | ||
: tag(rhs.tag), | ||
is_intrusive_ptr(rhs.is_intrusive_ptr) { | ||
if (rhs.isTensor()) { | ||
payload.as_intrusive_ptr = rhs.unsafeToTensorImpl(); | ||
} else { | ||
static_assert(sizeof(payload) == sizeof(rhs.payload), "IValue and WeakIValue payload sizes don't match!"); | ||
memcpy(&payload, &rhs.payload, sizeof(payload)); | ||
} | ||
if (is_intrusive_ptr) { | ||
c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); | ||
} | ||
|
@@ -888,17 +948,28 @@ struct CAFFE2_API WeakIValue final { | |
|
||
IValue lock() const { | ||
if (!is_intrusive_ptr) { | ||
return IValue(payload, tag, false); | ||
IValue::Payload newPayload; | ||
memcpy(&newPayload, &payload, sizeof(newPayload)); | ||
return IValue(newPayload, tag, false); | ||
} | ||
auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target>::reclaim( | ||
payload.as_intrusive_ptr); | ||
IValue::Payload pl; | ||
pl.as_intrusive_ptr = temp.lock().release(); | ||
temp.release(); | ||
if (!pl.as_intrusive_ptr) { | ||
return IValue(); | ||
if (IValue::Tag::Tensor == tag) { | ||
auto ip = temp.lock().release(); | ||
if (!ip) { | ||
return IValue(); | ||
} else { | ||
return IValue(std::move(ip)); | ||
} | ||
} else { | ||
return IValue(pl, tag, true); | ||
IValue::Payload pl; | ||
pl.as_intrusive_ptr = temp.lock().release(); | ||
temp.release(); | ||
if (!pl.as_intrusive_ptr) { | ||
return IValue(); | ||
} else { | ||
return IValue(pl, tag, true); | ||
} | ||
} | ||
} | ||
|
||
|
@@ -928,7 +999,17 @@ struct CAFFE2_API WeakIValue final { | |
} | ||
|
||
private: | ||
IValue::Payload payload; | ||
union Payload { | ||
int64_t as_int; | ||
double as_double; | ||
bool as_bool; | ||
c10::intrusive_ptr_target* as_intrusive_ptr; | ||
struct { | ||
DeviceType type; | ||
DeviceIndex index; | ||
} as_device; | ||
}; | ||
Payload payload; | ||
IValue::Tag tag; | ||
bool is_intrusive_ptr; | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we gain something by putting the
isTensor
check first? I would assume thatTensor
objects are much more common than non-Tensorintrusive_ptr
.