20
20
#include < c10/core/TensorImpl.h>
21
21
#include < c10/core/UndefinedTensorImpl.h>
22
22
#include < c10/util/intrusive_ptr.h>
23
+ #include < c10/util/irange.h>
23
24
#include < c10/util/hash.h>
24
25
25
26
namespace torch {
@@ -250,9 +251,252 @@ struct TORCH_API ConstantString final : c10::intrusive_ptr_target {
250
251
251
252
struct Future ;
252
253
254
+ struct TORCH_API TupleElements {
255
+ private:
256
+ size_t inlineSize_;
257
+ // We represent TupleElements this way to save doing a heap
258
+ // allocation in the common (at least for unpickling) case where we
259
+ // have only 3 elements. We have our own union instead of
260
+ // c10::SmallVector<IValue> because c10::SmallVector<IValue> always
261
+ // stores the begin/end/capacity pointers, which would be a waste of
262
+ // space in our use case.
263
+ union {
264
+ std::vector<IValue> elementsVector_;
265
+ // Don't want to declare a std::array because the convenient
266
+ // iteration and size members are a footgun in this case -- the
267
+ // actual size of the array may be smaller than 3!
268
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
269
+ IValue elementsInline_[3 ];
270
+ };
271
+
272
+ void destroyInline () {
273
+ for (const auto ii : c10::irange (inlineSize_)) {
274
+ elementsInline_[ii].~IValue ();
275
+ }
276
+ }
277
+ public:
278
+
279
+ using iterator = IValue*;
280
+ using const_iterator = const IValue*;
281
+
282
+ TupleElements () : inlineSize_(0 ) {
283
+ new (&elementsVector_) std::vector<IValue>();
284
+ }
285
+
286
+ explicit TupleElements (std::vector<IValue> elements)
287
+ : inlineSize_(0 ), elementsVector_(std::move(elements)) {}
288
+
289
+ explicit TupleElements (IValue&& e1 )
290
+ : inlineSize_(1 ) {
291
+ new (&elementsInline_[0 ]) IValue (std::move (e1 ));
292
+ }
293
+
294
+ explicit TupleElements (IValue&& e1 , IValue&& e2 )
295
+ : inlineSize_(2 ) {
296
+ new (&elementsInline_[0 ]) IValue (std::move (e1 ));
297
+ new (&elementsInline_[1 ]) IValue (std::move (e2 ));
298
+ }
299
+
300
+ explicit TupleElements (IValue&& e1 , IValue&& e2 , IValue&& e3 )
301
+ : inlineSize_(3 ) {
302
+ new (&elementsInline_[0 ]) IValue (std::move (e1 ));
303
+ new (&elementsInline_[1 ]) IValue (std::move (e2 ));
304
+ new (&elementsInline_[2 ]) IValue (std::move (e3 ));
305
+ }
306
+
307
+ ~TupleElements () {
308
+ if (inlineSize_) {
309
+ destroyInline ();
310
+ } else {
311
+ elementsVector_.~vector ();
312
+ }
313
+ }
314
+
315
+ // Simply not implemented; no particular reason not to implement in
316
+ // the future except that it seems unnecessary.
317
+ TupleElements (const TupleElements&) = delete ;
318
+ TupleElements& operator =(const TupleElements&) = delete ;
319
+
320
+ TupleElements (TupleElements&& rhs) noexcept
321
+ : inlineSize_(rhs.inlineSize_) {
322
+ if (inlineSize_) {
323
+ for (const auto ii : c10::irange (inlineSize_)) {
324
+ new (&elementsInline_[ii]) IValue (std::move (rhs.elementsInline_ [ii]));
325
+ }
326
+ } else {
327
+ new (&elementsVector_) std::vector<IValue>(std::move (rhs.elementsVector_ ));
328
+ }
329
+ }
330
+
331
+ TupleElements& operator =(TupleElements&& rhs) noexcept {
332
+ if (inlineSize_) {
333
+ if (rhs.inlineSize_ ) {
334
+ for (const auto ii : c10::irange (std::min (inlineSize_, rhs.inlineSize_ ))) {
335
+ elementsInline_[ii] = std::move (rhs.elementsInline_ [ii]);
336
+ }
337
+ if (rhs.inlineSize_ > inlineSize_) {
338
+ for (const auto ii : c10::irange (inlineSize_, rhs.inlineSize_ )) {
339
+ new (&elementsInline_[ii]) IValue (std::move (rhs.elementsInline_ [ii]));
340
+ }
341
+ } else {
342
+ for (const auto ii : c10::irange (rhs.inlineSize_ , inlineSize_)) {
343
+ elementsInline_[ii].~IValue ();
344
+ }
345
+ }
346
+ } else {
347
+ destroyInline ();
348
+ new (&elementsVector_) std::vector<IValue>(std::move (rhs.elementsVector_ ));
349
+ }
350
+ } else {
351
+ if (rhs.inlineSize_ ) {
352
+ elementsVector_.~vector ();
353
+ for (const auto ii : c10::irange (rhs.inlineSize_ )) {
354
+ new (&elementsInline_[ii]) IValue (std::move (rhs.elementsInline_ [ii]));
355
+ }
356
+ } else {
357
+ elementsVector_ = std::move (rhs.elementsVector_ );
358
+ }
359
+ }
360
+ inlineSize_ = rhs.inlineSize_ ;
361
+ return *this ;
362
+ }
363
+
364
+ C10_NODISCARD c10::ArrayRef<IValue> asArrayRef () const {
365
+ if (inlineSize_) {
366
+ return c10::ArrayRef<IValue>(elementsInline_, inlineSize_);
367
+ } else {
368
+ return elementsVector_;
369
+ }
370
+ }
371
+
372
+ // Mimic implicit conversion from std::vector to ArrayRef.
373
+ operator c10::ArrayRef<IValue>() const {
374
+ return asArrayRef ();
375
+ }
376
+
377
+ static size_t hash (const TupleElements& v) {
378
+ return c10::hash<c10::ArrayRef<IValue>>()(v.asArrayRef ());
379
+ }
380
+
381
+ void setContents (std::vector<IValue>&& contents) {
382
+ if (inlineSize_) {
383
+ destroyInline ();
384
+ new (&elementsVector_) std::vector<IValue>(std::move (contents));
385
+ inlineSize_ = 0 ;
386
+ } else {
387
+ elementsVector_ = std::move (contents);
388
+ }
389
+ }
390
+
391
+ C10_NODISCARD bool empty () const {
392
+ return inlineSize_ ? false : elementsVector_.empty ();
393
+ }
394
+
395
+ C10_NODISCARD size_t size () const {
396
+ return inlineSize_ ? inlineSize_ : elementsVector_.size ();
397
+ }
398
+
399
+ C10_NODISCARD IValue& operator [](size_t idx) {
400
+ if (inlineSize_) {
401
+ return elementsInline_[idx];
402
+ } else {
403
+ return elementsVector_[idx];
404
+ }
405
+ }
406
+
407
+ C10_NODISCARD const IValue& operator [](size_t idx) const {
408
+ if (inlineSize_) {
409
+ return elementsInline_[idx];
410
+ } else {
411
+ return elementsVector_[idx];
412
+ }
413
+ }
414
+
415
+ C10_NODISCARD IValue& at (size_t idx) {
416
+ if (inlineSize_) {
417
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (inlineSize_ <= 3 );
418
+ TORCH_CHECK (idx < inlineSize_, " TupleElements: invalid index Index = " , idx, " ; Length = " , inlineSize_);
419
+ return elementsInline_[idx];
420
+ } else {
421
+ return elementsVector_.at (idx);
422
+ }
423
+ }
424
+
425
+ C10_NODISCARD const IValue& at (size_t idx) const {
426
+ if (inlineSize_) {
427
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (inlineSize_ <= 3 );
428
+ TORCH_CHECK (idx < inlineSize_, " TupleElements: invalid index Index = " , idx, " ; Length = " , inlineSize_);
429
+ return elementsInline_[idx];
430
+ } else {
431
+ return elementsVector_.at (idx);
432
+ }
433
+ }
434
+
435
+ C10_NODISCARD iterator begin () {
436
+ if (inlineSize_) {
437
+ return elementsInline_;
438
+ } else {
439
+ return elementsVector_.data ();
440
+ }
441
+ }
442
+
443
+ C10_NODISCARD iterator end () {
444
+ if (inlineSize_) {
445
+ return elementsInline_ + inlineSize_;
446
+ } else {
447
+ return elementsVector_.data () + elementsVector_.size ();
448
+ }
449
+ }
450
+
451
+ C10_NODISCARD const_iterator begin () const {
452
+ if (inlineSize_) {
453
+ return elementsInline_;
454
+ } else {
455
+ return elementsVector_.data ();
456
+ }
457
+ }
458
+
459
+ C10_NODISCARD const_iterator end () const {
460
+ if (inlineSize_) {
461
+ return elementsInline_ + inlineSize_;
462
+ } else {
463
+ return elementsVector_.data () + elementsVector_.size ();
464
+ }
465
+ }
466
+
467
+ C10_NODISCARD const_iterator cbegin () const {
468
+ return begin ();
469
+ }
470
+
471
+ C10_NODISCARD const_iterator cend () const {
472
+ return end ();
473
+ }
474
+
475
+ C10_NODISCARD std::vector<IValue> vec () const & {
476
+ return asArrayRef ().vec ();
477
+ }
478
+
479
+ C10_NODISCARD IValue& back () {
480
+ return *(end () - 1 );
481
+ }
482
+
483
+ C10_NODISCARD const IValue& back () const {
484
+ return *(end () - 1 );
485
+ }
486
+
487
+ C10_NODISCARD std::vector<IValue> vec () && {
488
+ std::vector<IValue> result;
489
+ result.reserve (size ());
490
+ for (auto && iv : *this ) {
491
+ result.push_back (std::move (iv));
492
+ }
493
+ return result;
494
+ }
495
+ };
496
+
253
497
struct TORCH_API Tuple : c10::intrusive_ptr_target {
254
498
private:
255
- std::vector<IValue> elements_;
499
+ TupleElements elements_;
256
500
mutable std::shared_ptr<TupleType>
257
501
type_; // lazily computed for unnamed tuples
258
502
@@ -264,25 +508,60 @@ struct TORCH_API Tuple : c10::intrusive_ptr_target {
264
508
std::shared_ptr<TupleType> type_) {
265
509
return c10::make_intrusive<Tuple>(std::move (elements_), type_);
266
510
}
511
+
512
+ static c10::intrusive_ptr<Tuple> createNamed (
513
+ TupleElements elements_,
514
+ std::shared_ptr<TupleType> type_) {
515
+ return c10::make_intrusive<Tuple>(std::move (elements_), type_);
516
+ }
517
+
518
+ // MSVC apparently can't disambiguate the other two overloads of
519
+ // create when passed an initializer_list without this.
520
+ static c10::intrusive_ptr<Tuple> create (std::initializer_list<IValue> elements_) {
521
+ return create (std::vector<IValue>(elements_));
522
+ }
523
+
267
524
static c10::intrusive_ptr<Tuple> create (std::vector<IValue> elements_) {
268
525
return c10::make_intrusive<Tuple>(std::move (elements_));
269
526
}
270
527
528
+ static c10::intrusive_ptr<Tuple> create (TupleElements elements_) {
529
+ return c10::make_intrusive<Tuple>(std::move (elements_));
530
+ }
531
+
532
+ static c10::intrusive_ptr<Tuple> create (IValue e1 ) {
533
+ return c10::make_intrusive<Tuple>(std::move (e1 ));
534
+ }
535
+
536
+ static c10::intrusive_ptr<Tuple> create (IValue e1 , IValue e2 ) {
537
+ return c10::make_intrusive<Tuple>(std::move (e1 ), std::move (e2 ));
538
+ }
539
+
540
+ static c10::intrusive_ptr<Tuple> create (IValue e1 , IValue e2 , IValue e3 ) {
541
+ return c10::make_intrusive<Tuple>(std::move (e1 ), std::move (e2 ), std::move (e3 ));
542
+ }
543
+
271
544
template <typename ... Args>
272
545
static c10::intrusive_ptr<Tuple> create (Args&&... elements_) {
273
546
return c10::make_intrusive<Tuple>(
274
547
std::vector<IValue>{IValue (std::forward<Args>(elements_))...});
275
548
}
276
549
277
- const std::vector<IValue>& elements () const & {
550
+ Tuple (const Tuple& rhs) = delete ;
551
+
552
+ const TupleElements& elements () const & {
278
553
return elements_;
279
554
}
280
555
281
- std::vector<IValue> elements () && {
556
+ TupleElements elements () && {
282
557
return std::move (elements_);
283
558
}
284
559
285
560
void setElements (std::vector<IValue>&& elements) {
561
+ elements_.setContents (std::move (elements));
562
+ }
563
+
564
+ void setElements (TupleElements&& elements) {
286
565
elements_ = std::move (elements);
287
566
}
288
567
@@ -294,6 +573,10 @@ struct TORCH_API Tuple : c10::intrusive_ptr_target {
294
573
elements_[idx] = std::move (element);
295
574
}
296
575
576
+ size_t size () const {
577
+ return elements_.size ();
578
+ }
579
+
297
580
std::shared_ptr<TupleType> type () const ;
298
581
299
582
static size_t hash (const Tuple& t) {
@@ -305,8 +588,20 @@ struct TORCH_API Tuple : c10::intrusive_ptr_target {
305
588
const ivalue::Tuple& rhs);
306
589
307
590
private:
308
- Tuple (std::vector<IValue> elements, std::shared_ptr<TupleType> type = nullptr )
309
- : elements_(std::move(elements)), type_(std::move(type)) {}
591
+ explicit Tuple (std::vector<IValue> elements, std::shared_ptr<TupleType> type = nullptr )
592
+ : elements_(std::move(elements)), type_(std::move(type)) {}
593
+
594
+ explicit Tuple (TupleElements&& elements, std::shared_ptr<TupleType> type = nullptr )
595
+ : elements_(std::move(elements)), type_(std::move(type)) {}
596
+
597
+ explicit Tuple (IValue&& e1 , std::shared_ptr<TupleType> type = nullptr )
598
+ : elements_(std::move(e1 )), type_(std::move(type)) {}
599
+
600
+ explicit Tuple (IValue&& e1 , IValue&& e2 , std::shared_ptr<TupleType> type = nullptr )
601
+ : elements_(std::move(e1 ), std::move(e2 )), type_(std::move(type)) {}
602
+
603
+ explicit Tuple (IValue&& e1 , IValue&& e2 , IValue&& e3 , std::shared_ptr<TupleType> type = nullptr )
604
+ : elements_(std::move(e1 ), std::move(e2 ), std::move(e3 )), type_(std::move(type)) {}
310
605
311
606
friend class c10 ::intrusive_ptr<Tuple>;
312
607
};
@@ -1299,7 +1594,7 @@ c10::optional<T> generic_to(IValue ivalue, _fake_type<c10::optional<T>>) {
1299
1594
namespace detail {
1300
1595
template <typename Tuple, std::size_t ... INDEX>
1301
1596
Tuple generic_to_tuple_impl (
1302
- const std::vector<IValue> & t,
1597
+ const ivalue::TupleElements & t,
1303
1598
std::index_sequence<INDEX...>) {
1304
1599
return std::make_tuple (
1305
1600
t[INDEX].to <typename std::tuple_element<INDEX, Tuple>::type>()...);
0 commit comments