@@ -163,23 +163,61 @@ struct CAFFE2_API IValue final {
163
163
c10::raw::intrusive_ptr::incref (payload.as_intrusive_ptr );
164
164
}
165
165
}
166
- IValue (IValue&& rhs) noexcept : IValue() {
167
- swap (rhs);
166
+ IValue (IValue&& rhs) noexcept : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) {
167
+ if (isTensor ()) {
168
+ new (&payload.as_tensor ) at::Tensor (std::move (rhs.payload .as_tensor ));
169
+ rhs.payload .as_tensor .~Tensor ();
170
+ } else {
171
+ memcpy (&payload, &rhs.payload , sizeof (payload));
172
+ }
173
+ rhs.tag = Tag::None;
174
+ rhs.is_intrusive_ptr = false ;
168
175
}
176
+
169
177
// / @private [doxygen private]
170
178
~IValue () {
171
179
if (is_intrusive_ptr) {
172
180
c10::raw::intrusive_ptr::decref (payload.as_intrusive_ptr );
181
+ } else if (isTensor ()) {
182
+ payload.as_tensor .~Tensor ();
173
183
}
174
184
}
175
- IValue& operator =(IValue&& rhs) & noexcept {
176
- IValue (std::move (rhs)).swap (*this ); // this also sets rhs to None
185
+
186
+ // Always-inline for performance -- this gets called frequently
187
+ // inside the core of the static runtime.
188
+ C10_ALWAYS_INLINE IValue& operator =(IValue&& rhs) & noexcept {
189
+ if (&rhs == this ) {
190
+ return *this ;
191
+ }
192
+
193
+ // Tear down our state.
194
+ if (isTensor ()) {
195
+ payload.as_tensor .~Tensor ();
196
+ } else if (is_intrusive_ptr) {
197
+ c10::raw::intrusive_ptr::decref (payload.as_intrusive_ptr );
198
+ }
199
+
200
+ if (rhs.isTensor ()) {
201
+ new (&payload.as_tensor ) at::Tensor (std::move (rhs.payload .as_tensor ));
202
+ rhs.payload .as_tensor .~Tensor ();
203
+ } else {
204
+ // No need to specially handle rhs being an intrusive_ptr -- we
205
+ // steal the reference.
206
+ memcpy (&payload, &rhs.payload , sizeof (payload));
207
+ }
208
+
209
+ tag = rhs.tag ;
210
+ is_intrusive_ptr = rhs.is_intrusive_ptr ;
211
+ rhs.tag = Tag::None;
212
+ rhs.is_intrusive_ptr = false ;
177
213
return *this ;
178
214
}
215
+
179
216
IValue& operator =(IValue const & rhs) & {
180
217
IValue (rhs).swap (*this );
181
218
return *this ;
182
219
}
220
+
183
221
void dump () const ;
184
222
185
223
/* *
@@ -288,7 +326,19 @@ struct CAFFE2_API IValue final {
288
326
289
327
// / @private [doxygen private]
290
328
void swap (IValue& rhs) noexcept {
291
- std::swap (payload, rhs.payload );
329
+ if (isTensor () && rhs.isTensor ()) {
330
+ std::swap (payload.as_tensor , rhs.payload .as_tensor );
331
+ } else if (isTensor ()) {
332
+ at::Tensor t = std::move (payload.as_tensor );
333
+ payload.as_tensor .~Tensor ();
334
+ memcpy (&payload, &rhs.payload , sizeof (payload));
335
+ new (&rhs.payload .as_tensor ) at::Tensor (std::move (t));
336
+ } else if (rhs.isTensor ()) {
337
+ rhs.swap (*this );
338
+ return ;
339
+ } else {
340
+ std::swap (reinterpret_cast <char (&)[sizeof (payload)]>(*&payload), reinterpret_cast <char (&)[sizeof (payload)]>(*&rhs.payload ));
341
+ }
292
342
std::swap (is_intrusive_ptr, rhs.is_intrusive_ptr );
293
343
std::swap (tag, rhs.tag );
294
344
}
@@ -297,21 +347,16 @@ struct CAFFE2_API IValue final {
297
347
// While some of these accessors could be generated through templates,
298
348
// we prefer to write them manually for clarity
299
349
300
- IValue (at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(t.defined()) {
301
- // Note: the undefined tensor is not refcounted, so while it
302
- // is tagged as a tensor, is_intrusive_ptr is set to false.
303
- // This is not an optional optimization: our incref call
304
- // *will not* do the right thing when called on an
305
- // undefined tensor.
306
- payload.as_intrusive_ptr = t.unsafeReleaseTensorImpl ();
350
+ IValue (at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(false ) {
351
+ new (&payload.as_tensor ) at::Tensor (std::move (t));
307
352
}
308
353
bool isTensor () const {
309
354
return Tag::Tensor == tag;
310
355
}
311
356
at::Tensor toTensor () &&;
312
357
at::Tensor toTensor () const &;
313
358
at::TensorImpl* unsafeToTensorImpl () const {
314
- return static_cast <at::TensorImpl*>( payload.as_intrusive_ptr );
359
+ return payload.as_tensor . unsafeGetTensorImpl ( );
315
360
}
316
361
317
362
const IValue& toIValue () const {
@@ -565,7 +610,7 @@ struct CAFFE2_API IValue final {
565
610
c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder () const &;
566
611
567
612
// None
568
- IValue () : payload{ 0 }, tag(Tag::None), is_intrusive_ptr(false ) {}
613
+ IValue () : tag(Tag::None), is_intrusive_ptr(false ) {}
569
614
bool isNone () const {
570
615
return Tag::None == tag;
571
616
}
@@ -826,13 +871,23 @@ struct CAFFE2_API IValue final {
826
871
double as_double;
827
872
bool as_bool;
828
873
c10::intrusive_ptr_target* as_intrusive_ptr;
874
+ at::Tensor as_tensor;
829
875
struct {
830
876
DeviceType type;
831
877
DeviceIndex index;
832
878
} as_device;
879
+
880
+ Payload () : as_int (0 ) {}
881
+ ~Payload () {}
833
882
};
834
883
835
- IValue (Payload p, Tag t, bool i) : payload(p), tag(t), is_intrusive_ptr(i) {}
884
+ IValue (const Payload& p, Tag t, bool i) : tag(t), is_intrusive_ptr(i) {
885
+ if (isTensor ()) {
886
+ new (&payload.as_tensor ) at::Tensor (p.as_tensor );
887
+ } else {
888
+ memcpy (&payload, &p, sizeof (payload));
889
+ }
890
+ }
836
891
837
892
Payload payload;
838
893
Tag tag;
@@ -852,9 +907,14 @@ struct CAFFE2_API WeakIValue final {
852
907
}
853
908
}
854
909
WeakIValue (const IValue& rhs)
855
- : payload(rhs.payload),
856
- tag(rhs.tag),
910
+ : tag(rhs.tag),
857
911
is_intrusive_ptr(rhs.is_intrusive_ptr) {
912
+ if (rhs.isTensor ()) {
913
+ payload.as_intrusive_ptr = rhs.unsafeToTensorImpl ();
914
+ } else {
915
+ static_assert (sizeof (payload) == sizeof (rhs.payload ), " IValue and WeakIValue payload sizes don't match!" );
916
+ memcpy (&payload, &rhs.payload , sizeof (payload));
917
+ }
858
918
if (is_intrusive_ptr) {
859
919
c10::raw::weak_intrusive_ptr::incref (payload.as_intrusive_ptr );
860
920
}
@@ -888,17 +948,28 @@ struct CAFFE2_API WeakIValue final {
888
948
889
949
IValue lock () const {
890
950
if (!is_intrusive_ptr) {
891
- return IValue (payload, tag, false );
951
+ IValue::Payload newPayload;
952
+ memcpy (&newPayload, &payload, sizeof (newPayload));
953
+ return IValue (newPayload, tag, false );
892
954
}
893
955
auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target>::reclaim (
894
956
payload.as_intrusive_ptr );
895
- IValue::Payload pl;
896
- pl.as_intrusive_ptr = temp.lock ().release ();
897
- temp.release ();
898
- if (!pl.as_intrusive_ptr ) {
899
- return IValue ();
957
+ if (IValue::Tag::Tensor == tag) {
958
+ auto ip = temp.lock ().release ();
959
+ if (!ip) {
960
+ return IValue ();
961
+ } else {
962
+ return IValue (std::move (ip));
963
+ }
900
964
} else {
901
- return IValue (pl, tag, true );
965
+ IValue::Payload pl;
966
+ pl.as_intrusive_ptr = temp.lock ().release ();
967
+ temp.release ();
968
+ if (!pl.as_intrusive_ptr ) {
969
+ return IValue ();
970
+ } else {
971
+ return IValue (pl, tag, true );
972
+ }
902
973
}
903
974
}
904
975
@@ -928,7 +999,17 @@ struct CAFFE2_API WeakIValue final {
928
999
}
929
1000
930
1001
private:
931
- IValue::Payload payload;
1002
+ union Payload {
1003
+ int64_t as_int;
1004
+ double as_double;
1005
+ bool as_bool;
1006
+ c10::intrusive_ptr_target* as_intrusive_ptr;
1007
+ struct {
1008
+ DeviceType type;
1009
+ DeviceIndex index;
1010
+ } as_device;
1011
+ };
1012
+ Payload payload;
932
1013
IValue::Tag tag;
933
1014
bool is_intrusive_ptr;
934
1015
};
0 commit comments