Skip to content

Commit f422cdd

Browse files
swolchokhwangdeyu
authored andcommitted
[PyTorch] Store Tensor explicitly in IValue (pytorch#48824)
Summary: Pull Request resolved: pytorch#48824 Enables following diff, which will make toTensor() return `const Tensor&` and allow callers to avoid refcounting overhead. ghstack-source-id: 119327370 Test Plan: ivalue_test Internal benchmark to ensure perf parity. Some interesting steps during the debugging process: - First version was about a 5% regression - Directly implementing move construction instead of using swap lowered the regression to 2-3% - Directly implementing move assign was maybe an 0.5% improvement - Adding C10_ALWAYS_INLINE on move assign got our regression to negligible - Fixing toTensor() to actually be correct regressed us again, but omitting the explicit dtor call as exhaustively spelled out in a comment fixed it. Reviewed By: bwasti Differential Revision: D25324617 fbshipit-source-id: 7518c1c67f6f2661f151b43310aaddf4fb6e511a
1 parent 458e2c2 commit f422cdd

File tree

5 files changed

+275
-123
lines changed

5 files changed

+275
-123
lines changed

aten/src/ATen/core/ivalue.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ bool IValue::ptrEqual(const IValue& lhs, const IValue& rhs) {
265265
TORCH_INTERNAL_ASSERT(lhs.is_intrusive_ptr);
266266
TORCH_INTERNAL_ASSERT(rhs.is_intrusive_ptr);
267267
return lhs.tag == rhs.tag &&
268-
lhs.payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
268+
lhs.payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
269269
}
270270

271271
IValue IValue::equals(const IValue& rhs) const {
@@ -325,17 +325,17 @@ size_t IValue::hash(const IValue& v) {
325325
case Tag::None:
326326
return 0;
327327
case Tag::Bool:
328-
return c10::get_hash(v.payload.as_bool);
328+
return c10::get_hash(v.payload.u.as_bool);
329329
case Tag::Double:
330-
return c10::get_hash(v.payload.as_double);
330+
return c10::get_hash(v.payload.u.as_double);
331331
case Tag::Tensor:
332332
// Tensor __hash__ is equivalent to `id()`, so take the pointer value of
333333
// the tensor to emulate it
334-
return c10::get_hash(v.payload.as_int);
334+
return c10::get_hash(v.payload.as_tensor.unsafeGetTensorImpl());
335335
case Tag::Storage:
336-
return c10::get_hash(v.payload.as_int);
336+
return c10::get_hash(v.payload.u.as_int);
337337
case Tag::Int:
338-
return c10::get_hash(v.payload.as_int);
338+
return c10::get_hash(v.payload.u.as_int);
339339
case Tag::String:
340340
return c10::get_hash(v.toStringRef());
341341
case Tag::Tuple:

0 commit comments

Comments
 (0)