Skip to content

Commit 0198db9

Browse files
committed
reduce code duplication on "[PyTorch] Store Tensor explicitly in IValue"
Enables following diff, which will make toTensor() return `const Tensor&` and allow callers to avoid refcounting overhead. Differential Revision: [D25324617](https://our.internmc.facebook.com/intern/diff/D25324617/) [ghstack-poisoned]
2 parents b48bbb4 + bc2352e commit 0198db9

File tree

120 files changed

+3268
-945
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

120 files changed

+3268
-945
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ torch/lib/*.exe*
7676
torch/lib/*.dylib*
7777
torch/lib/*.h
7878
torch/lib/*.lib
79+
torch/lib/*.pdb
7980
torch/lib/*.so*
8081
torch/lib/protobuf*.pc
8182
torch/lib/build

.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ if "%REBUILD%"=="" (
77
7z x -aoa %TMP_DIR_WIN%\mkl.7z -o%TMP_DIR_WIN%\mkl
88
)
99
set CMAKE_INCLUDE_PATH=%TMP_DIR_WIN%\mkl\include
10-
set LIB=%TMP_DIR_WIN%\mkl\lib;%LIB
10+
set LIB=%TMP_DIR_WIN%\mkl\lib;%LIB%

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ _(aten, index_fill) \
383383
_(aten, index_put) \
384384
_(aten, index_select) \
385385
_(aten, indices) \
386+
_(aten, inner) \
386387
_(aten, instance_norm) \
387388
_(aten, inverse) \
388389
_(aten, irfft) \
@@ -496,6 +497,7 @@ _(aten, mode) \
496497
_(aten, mse_loss) \
497498
_(aten, mse_loss_backward) \
498499
_(aten, mse_loss_forward) \
500+
_(aten, msort) \
499501
_(aten, multi_margin_loss) \
500502
_(aten, multi_margin_loss_backward) \
501503
_(aten, multi_margin_loss_forward) \
@@ -546,7 +548,6 @@ _(aten, _euclidean_dist) \
546548
_(aten, pdist) \
547549
_(aten, cdist) \
548550
_(aten, permute) \
549-
_(aten, movedim) \
550551
_(aten, pin_memory) \
551552
_(aten, pinverse) \
552553
_(aten, pixel_shuffle) \

aten/src/ATen/core/interned_strings.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace c10 {
2727
_(prim, Assign) \
2828
_(prim, BroadcastingChunk) \
2929
_(prim, BroadcastSizes) \
30+
_(prim, ReductionSizes) \
3031
_(prim, Constant) \
3132
_(prim, ChunkSizes) \
3233
_(prim, Drop) \
@@ -284,6 +285,8 @@ namespace c10 {
284285
_(aten, swapaxes_) \
285286
_(aten, swapdims) \
286287
_(aten, swapdims_) \
288+
_(aten, movedim) \
289+
_(aten, moveaxis) \
287290
FORALL_ATEN_BASE_SYMBOLS(_) \
288291
_(onnx, Add) \
289292
_(onnx, Concat) \

aten/src/ATen/core/ivalue.h

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -164,23 +164,14 @@ struct CAFFE2_API IValue final {
164164
}
165165
}
166166
IValue(IValue&& rhs) noexcept : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) {
167-
if (isTensor()) {
168-
new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
169-
rhs.payload.as_tensor.~Tensor();
170-
} else {
171-
memcpy(&payload, &rhs.payload, sizeof(payload));
172-
}
167+
moveFrom(std::move(rhs));
173168
rhs.tag = Tag::None;
174169
rhs.is_intrusive_ptr = false;
175170
}
176171

177172
/// @private [doxygen private]
178173
~IValue() {
179-
if (is_intrusive_ptr) {
180-
c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr);
181-
} else if (isTensor()) {
182-
payload.as_tensor.~Tensor();
183-
}
174+
destroy();
184175
}
185176

186177
// Always-inline for performance -- this gets called frequently
@@ -190,26 +181,8 @@ struct CAFFE2_API IValue final {
190181
return *this;
191182
}
192183

193-
// Tear down our state.
194-
if (isTensor()) {
195-
payload.as_tensor.~Tensor();
196-
} else if (is_intrusive_ptr) {
197-
c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr);
198-
}
199-
200-
if (rhs.isTensor()) {
201-
new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
202-
rhs.payload.as_tensor.~Tensor();
203-
} else {
204-
// No need to specially handle rhs being an intrusive_ptr -- we
205-
// steal the reference.
206-
memcpy(&payload, &rhs.payload, sizeof(payload));
207-
}
208-
209-
tag = rhs.tag;
210-
is_intrusive_ptr = rhs.is_intrusive_ptr;
211-
rhs.tag = Tag::None;
212-
rhs.is_intrusive_ptr = false;
184+
destroy();
185+
moveFrom(std::move(rhs));
213186
return *this;
214187
}
215188

@@ -330,7 +303,15 @@ struct CAFFE2_API IValue final {
330303
std::swap(payload.as_tensor, rhs.payload.as_tensor);
331304
} else if (isTensor()) {
332305
at::Tensor t = std::move(payload.as_tensor);
333-
payload.as_tensor.~Tensor();
306+
// As far as I can tell, omitting the usual explicit destructor call
307+
// is not UB in and of itself, and it's a slight perf win. The
308+
// destructor is a no-op, because the moved-from Tensor is
309+
// effectively an intrusive_ptr in the null state, so we don't need
310+
// the behavior for correctness reasons either. Leaving this
311+
// explanatory comment, including commented-out destructor call, to
312+
// make this abundantly clear.
313+
//
314+
// payload.as_tensor.~Tensor();
334315
memcpy(&payload, &rhs.payload, sizeof(payload));
335316
new (&rhs.payload.as_tensor) at::Tensor(std::move(t));
336317
} else if (rhs.isTensor()) {
@@ -860,7 +841,35 @@ struct CAFFE2_API IValue final {
860841
class NullType = c10::detail::intrusive_target_default_null_type<T>>
861842
c10::intrusive_ptr<T, NullType> toIntrusivePtr() const;
862843

863-
void clearToNone() {
844+
void destroy() {
845+
if (isTensor()) {
846+
payload.as_tensor.~Tensor();
847+
} else if (is_intrusive_ptr) {
848+
c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr);
849+
}
850+
}
851+
852+
C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept {
853+
if (rhs.isTensor()) {
854+
new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
855+
// As far as I can tell, omitting the usual explicit destructor call
856+
// is not UB in and of itself, and it's a slight perf win. The
857+
// destructor is a no-op, because the moved-from Tensor is
858+
// effectively an intrusive_ptr in the null state, so we don't need
859+
// the behavior for correctness reasons either. Leaving this
860+
// explanatory comment, including commented-out destructor call, to
861+
// make this abundantly clear.
862+
//
863+
// rhs.payload.as_tensor.~Tensor();
864+
} else {
865+
memcpy(&payload, &rhs.payload, sizeof(payload));
866+
}
867+
tag = rhs.tag;
868+
is_intrusive_ptr = rhs.is_intrusive_ptr;
869+
rhs.clearToNone();
870+
}
871+
872+
void clearToNone() noexcept {
864873
payload.as_int = 0;
865874
tag = Tag::None;
866875
is_intrusive_ptr = false;

aten/src/ATen/core/ivalue_inl.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ inline at::Tensor IValue::toTensor() && {
137137
// effectively an intrusive_ptr in the null state, so we don't need
138138
// the behavior for correctness reasons either. Leaving this
139139
// explanatory comment, including commented-out destructor call, to
140-
// make this abundantly clear. payload.as_tensor.~Tensor();
140+
// make this abundantly clear.
141+
//
142+
// payload.as_tensor.~Tensor();
141143
clearToNone();
142144
return result;
143145
}

aten/src/ATen/cpp_custom_type_hack.h

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,52 @@
1-
// WARNING! WARNING! WARNING!
2-
// This file is a temporary hack to enable development of pytorch quantization
3-
//
4-
// It's a stub for wrapping arbitrary cpp types in TorchScript. Proper
5-
// implementation (under development) is to use TorchScript custom types.
6-
// In the meantime, we abuse ByteTensor with custom deleter for this purpose.
7-
//
8-
// Template argument <T> has to be registered with CAFFE_KNOWN_TYPE mechanism.
1+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
2+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
3+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
4+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
5+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
6+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
7+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
8+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
9+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
10+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
11+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
12+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
13+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
14+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
15+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
16+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
17+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
18+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
19+
20+
// YOU ARE IN THE WRONG PLACE! TURN BACK NOW!
21+
22+
// This code was a temporary hack to enable embedding arbitrary C++ structures
23+
// into Tensors. THIS IS UNSAFE AND IS NOT SUPPORTED. IF YOU USE THIS CODE,
24+
// IT __WILL__ BREAK.
25+
26+
// This code has been superseded by custom classes:
27+
// https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html
28+
29+
// Please use custom classes and **DO NOT ADD MORE CALLSITES TO THINGS DEFINED
30+
// IN THIS FILE**.
31+
32+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
33+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
34+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
35+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
36+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
37+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
38+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
39+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
40+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
41+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
42+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
43+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
44+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
45+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
46+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
47+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
48+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
49+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
950

1051
#include <ATen/ATen.h>
1152
#include <ATen/TracerMode.h>
@@ -14,13 +55,17 @@ namespace at {
1455
namespace cpp_custom_type_hack {
1556

1657
template <typename T>
58+
[[deprecated("Use custom classes instead: "
59+
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]]
1760
bool isa(const Tensor& packed) {
1861
return (packed.scalar_type() == kByte) &&
1962
(packed.storage().data_ptr().get_deleter() ==
2063
caffe2::TypeMeta::Make<T>().deleteFn());
2164
}
2265

2366
template <typename T>
67+
[[deprecated("Use custom classes instead: "
68+
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]]
2469
T& cast(const Tensor& packed) {
2570
TORCH_CHECK(
2671
packed.scalar_type() == kByte, "Expected temporary cpp type wrapper");
@@ -33,6 +78,8 @@ T& cast(const Tensor& packed) {
3378
}
3479

3580
template <typename T>
81+
[[deprecated("Use custom classes instead: "
82+
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]]
3683
Tensor create(std::unique_ptr<T> ptr, TensorOptions options) {
3784
// None of this should trace, so turn off Tracer dispatching
3885
at::AutoNonVariableTypeMode guard; // TODO: remove

aten/src/ATen/native/CompositeRandomAccessorCommon.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class CompositeRandomAccessor {
129129

130130
// Pointer-like operations {
131131
C10_HOST_DEVICE
132-
reference operator*() {
132+
reference operator*() const {
133133
return TupleInfo::tie(*keys, *values);
134134
}
135135

0 commit comments

Comments
 (0)