Skip to content

Commit d582b6f

Browse files
committed
[PyTorch] Store Tensor explicitly in IValue
Pull Request resolved: #48824 Enables following diff, which will make toTensor() return `const Tensor&` and allow callers to avoid refcounting overhead. ghstack-source-id: 118955313 Differential Revision: [D25324617](https://our.internmc.facebook.com/intern/diff/D25324617/)
1 parent c9e0521 commit d582b6f

File tree

6 files changed

+343
-115
lines changed

6 files changed

+343
-115
lines changed

aten/src/ATen/core/ivalue.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ bool IValue::ptrEqual(const IValue& lhs, const IValue& rhs) {
247247
TORCH_INTERNAL_ASSERT(lhs.is_intrusive_ptr);
248248
TORCH_INTERNAL_ASSERT(rhs.is_intrusive_ptr);
249249
return lhs.tag == rhs.tag &&
250-
lhs.payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
250+
lhs.payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
251251
}
252252

253253
IValue IValue::equals(const IValue& rhs) const {
@@ -307,17 +307,17 @@ size_t IValue::hash(const IValue& v) {
307307
case Tag::None:
308308
return 0;
309309
case Tag::Bool:
310-
return c10::get_hash(v.payload.as_bool);
310+
return c10::get_hash(v.payload.u.as_bool);
311311
case Tag::Double:
312-
return c10::get_hash(v.payload.as_double);
312+
return c10::get_hash(v.payload.u.as_double);
313313
case Tag::Tensor:
314314
// Tensor __hash__ is equivalent to `id()`, so take the pointer value of
315315
// the tensor to emulate it
316-
return c10::get_hash(v.payload.as_int);
316+
return c10::get_hash(v.payload.as_tensor.unsafeGetTensorImpl());
317317
case Tag::Storage:
318-
return c10::get_hash(v.payload.as_int);
318+
return c10::get_hash(v.payload.u.as_int);
319319
case Tag::Int:
320-
return c10::get_hash(v.payload.as_int);
320+
return c10::get_hash(v.payload.u.as_int);
321321
case Tag::String:
322322
return c10::get_hash(v.toStringRef());
323323
case Tag::Tuple:

0 commit comments

Comments
 (0)