@@ -387,14 +387,19 @@ inline bool isComplexType(exec_aten::ScalarType t) {
387
387
t == exec_aten::ScalarType::ComplexDouble);
388
388
}
389
389
390
- inline bool isBitsType (exec_aten::ScalarType t) {
390
+ constexpr bool isBitsType (exec_aten::ScalarType t) {
391
391
return t == exec_aten::ScalarType::Bits1x8 ||
392
392
t == exec_aten::ScalarType::Bits2x4 ||
393
393
t == exec_aten::ScalarType::Bits4x2 ||
394
394
t == exec_aten::ScalarType::Bits8 || t == exec_aten::ScalarType::Bits16;
395
395
}
396
396
397
- inline bool isQIntType (exec_aten::ScalarType t) {
397
+ template <typename T>
398
+ struct is_bits_type
399
+ : std::integral_constant<bool , isBitsType(CppTypeToScalarType<T>::value)> {
400
+ };
401
+
402
+ constexpr bool isQIntType (exec_aten::ScalarType t) {
398
403
// Don't forget to extend this when adding new QInt types
399
404
return t == exec_aten::ScalarType::QInt8 ||
400
405
t == exec_aten::ScalarType::QUInt8 ||
@@ -403,6 +408,11 @@ inline bool isQIntType(exec_aten::ScalarType t) {
403
408
t == exec_aten::ScalarType::QUInt2x4;
404
409
}
405
410
411
+ template <typename T>
412
+ struct is_qint_type
413
+ : std::integral_constant<bool , isQIntType(CppTypeToScalarType<T>::value)> {
414
+ };
415
+
406
416
inline exec_aten::ScalarType toQIntType (exec_aten::ScalarType t) {
407
417
switch (t) {
408
418
case exec_aten::ScalarType::Byte:
@@ -550,6 +560,225 @@ To convert(From val) {
550
560
return static_cast <To>(val);
551
561
}
552
562
563
+ namespace internal {
564
+ template <typename T1, typename T2>
565
+ struct promote_types_lookup ;
566
+
567
+ template <typename T1>
568
+ struct promote_types_lookup <T1, T1> {
569
+ using type = T1;
570
+ };
571
+
572
+ using U1 = typename ScalarTypeToCppType<exec_aten::ScalarType::Byte>::type;
573
+ using I1 = typename ScalarTypeToCppType<exec_aten::ScalarType::Char>::type;
574
+ using I2 = typename ScalarTypeToCppType<exec_aten::ScalarType::Short>::type;
575
+ using I4 = typename ScalarTypeToCppType<exec_aten::ScalarType::Int>::type;
576
+ using I8 = typename ScalarTypeToCppType<exec_aten::ScalarType::Long>::type;
577
+ using F2 = typename ScalarTypeToCppType<exec_aten::ScalarType::Half>::type;
578
+ using F4 = typename ScalarTypeToCppType<exec_aten::ScalarType::Float>::type;
579
+ using F8 = typename ScalarTypeToCppType<exec_aten::ScalarType::Double>::type;
580
+ using C2 =
581
+ typename ScalarTypeToCppType<exec_aten::ScalarType::ComplexHalf>::type;
582
+ using C4 =
583
+ typename ScalarTypeToCppType<exec_aten::ScalarType::ComplexFloat>::type;
584
+ using C8 =
585
+ typename ScalarTypeToCppType<exec_aten::ScalarType::ComplexDouble>::type;
586
+ using B1 = typename ScalarTypeToCppType<exec_aten::ScalarType::Bool>::type;
587
+
588
+ #define TABLE_ENTRY (key1, key2, value ) \
589
+ template <> \
590
+ struct promote_types_lookup <key1, key2> { \
591
+ using type = value; \
592
+ }
593
+
594
+ /* promote_types_lookup is a compile-time-accessible version of the
595
+ * table in promoteTypes below; we cannot make promoteTypes constexpr
596
+ * and use it directly because we are on C++11 and thus don't have
597
+ * C++17 relaxed constexpr. The below series of entries is generated
598
+ * by genScalarTypeTable.py. */
599
+ TABLE_ENTRY (U1, U1, U1);
600
+ TABLE_ENTRY (U1, I1, I2);
601
+ TABLE_ENTRY (U1, I2, I2);
602
+ TABLE_ENTRY (U1, I4, I4);
603
+ TABLE_ENTRY (U1, I8, I8);
604
+ TABLE_ENTRY (U1, F2, F2);
605
+ TABLE_ENTRY (U1, F4, F4);
606
+ TABLE_ENTRY (U1, F8, F8);
607
+ TABLE_ENTRY (U1, C2, C2);
608
+ TABLE_ENTRY (U1, C4, C4);
609
+ TABLE_ENTRY (U1, C8, C8);
610
+ TABLE_ENTRY (U1, B1, U1);
611
+ TABLE_ENTRY (I1, U1, I2);
612
+ TABLE_ENTRY (I1, I1, I1);
613
+ TABLE_ENTRY (I1, I2, I2);
614
+ TABLE_ENTRY (I1, I4, I4);
615
+ TABLE_ENTRY (I1, I8, I8);
616
+ TABLE_ENTRY (I1, F2, F2);
617
+ TABLE_ENTRY (I1, F4, F4);
618
+ TABLE_ENTRY (I1, F8, F8);
619
+ TABLE_ENTRY (I1, C2, C2);
620
+ TABLE_ENTRY (I1, C4, C4);
621
+ TABLE_ENTRY (I1, C8, C8);
622
+ TABLE_ENTRY (I1, B1, I1);
623
+ TABLE_ENTRY (I2, U1, I2);
624
+ TABLE_ENTRY (I2, I1, I2);
625
+ TABLE_ENTRY (I2, I2, I2);
626
+ TABLE_ENTRY (I2, I4, I4);
627
+ TABLE_ENTRY (I2, I8, I8);
628
+ TABLE_ENTRY (I2, F2, F2);
629
+ TABLE_ENTRY (I2, F4, F4);
630
+ TABLE_ENTRY (I2, F8, F8);
631
+ TABLE_ENTRY (I2, C2, C2);
632
+ TABLE_ENTRY (I2, C4, C4);
633
+ TABLE_ENTRY (I2, C8, C8);
634
+ TABLE_ENTRY (I2, B1, I2);
635
+ TABLE_ENTRY (I4, U1, I4);
636
+ TABLE_ENTRY (I4, I1, I4);
637
+ TABLE_ENTRY (I4, I2, I4);
638
+ TABLE_ENTRY (I4, I4, I4);
639
+ TABLE_ENTRY (I4, I8, I8);
640
+ TABLE_ENTRY (I4, F2, F2);
641
+ TABLE_ENTRY (I4, F4, F4);
642
+ TABLE_ENTRY (I4, F8, F8);
643
+ TABLE_ENTRY (I4, C2, C2);
644
+ TABLE_ENTRY (I4, C4, C4);
645
+ TABLE_ENTRY (I4, C8, C8);
646
+ TABLE_ENTRY (I4, B1, I4);
647
+ TABLE_ENTRY (I8, U1, I8);
648
+ TABLE_ENTRY (I8, I1, I8);
649
+ TABLE_ENTRY (I8, I2, I8);
650
+ TABLE_ENTRY (I8, I4, I8);
651
+ TABLE_ENTRY (I8, I8, I8);
652
+ TABLE_ENTRY (I8, F2, F2);
653
+ TABLE_ENTRY (I8, F4, F4);
654
+ TABLE_ENTRY (I8, F8, F8);
655
+ TABLE_ENTRY (I8, C2, C2);
656
+ TABLE_ENTRY (I8, C4, C4);
657
+ TABLE_ENTRY (I8, C8, C8);
658
+ TABLE_ENTRY (I8, B1, I8);
659
+ TABLE_ENTRY (F2, U1, F2);
660
+ TABLE_ENTRY (F2, I1, F2);
661
+ TABLE_ENTRY (F2, I2, F2);
662
+ TABLE_ENTRY (F2, I4, F2);
663
+ TABLE_ENTRY (F2, I8, F2);
664
+ TABLE_ENTRY (F2, F2, F2);
665
+ TABLE_ENTRY (F2, F4, F4);
666
+ TABLE_ENTRY (F2, F8, F8);
667
+ TABLE_ENTRY (F2, C2, C2);
668
+ TABLE_ENTRY (F2, C4, C4);
669
+ TABLE_ENTRY (F2, C8, C8);
670
+ TABLE_ENTRY (F2, B1, F2);
671
+ TABLE_ENTRY (F4, U1, F4);
672
+ TABLE_ENTRY (F4, I1, F4);
673
+ TABLE_ENTRY (F4, I2, F4);
674
+ TABLE_ENTRY (F4, I4, F4);
675
+ TABLE_ENTRY (F4, I8, F4);
676
+ TABLE_ENTRY (F4, F2, F4);
677
+ TABLE_ENTRY (F4, F4, F4);
678
+ TABLE_ENTRY (F4, F8, F8);
679
+ TABLE_ENTRY (F4, C2, C4);
680
+ TABLE_ENTRY (F4, C4, C4);
681
+ TABLE_ENTRY (F4, C8, C8);
682
+ TABLE_ENTRY (F4, B1, F4);
683
+ TABLE_ENTRY (F8, U1, F8);
684
+ TABLE_ENTRY (F8, I1, F8);
685
+ TABLE_ENTRY (F8, I2, F8);
686
+ TABLE_ENTRY (F8, I4, F8);
687
+ TABLE_ENTRY (F8, I8, F8);
688
+ TABLE_ENTRY (F8, F2, F8);
689
+ TABLE_ENTRY (F8, F4, F8);
690
+ TABLE_ENTRY (F8, F8, F8);
691
+ TABLE_ENTRY (F8, C2, C8);
692
+ TABLE_ENTRY (F8, C4, C8);
693
+ TABLE_ENTRY (F8, C8, C8);
694
+ TABLE_ENTRY (F8, B1, F8);
695
+ TABLE_ENTRY (C2, U1, C2);
696
+ TABLE_ENTRY (C2, I1, C2);
697
+ TABLE_ENTRY (C2, I2, C2);
698
+ TABLE_ENTRY (C2, I4, C2);
699
+ TABLE_ENTRY (C2, I8, C2);
700
+ TABLE_ENTRY (C2, F2, C2);
701
+ TABLE_ENTRY (C2, F4, C4);
702
+ TABLE_ENTRY (C2, F8, C8);
703
+ TABLE_ENTRY (C2, C2, C2);
704
+ TABLE_ENTRY (C2, C4, C4);
705
+ TABLE_ENTRY (C2, C8, C8);
706
+ TABLE_ENTRY (C2, B1, C2);
707
+ TABLE_ENTRY (C4, U1, C4);
708
+ TABLE_ENTRY (C4, I1, C4);
709
+ TABLE_ENTRY (C4, I2, C4);
710
+ TABLE_ENTRY (C4, I4, C4);
711
+ TABLE_ENTRY (C4, I8, C4);
712
+ TABLE_ENTRY (C4, F2, C4);
713
+ TABLE_ENTRY (C4, F4, C4);
714
+ TABLE_ENTRY (C4, F8, C8);
715
+ TABLE_ENTRY (C4, C2, C4);
716
+ TABLE_ENTRY (C4, C4, C4);
717
+ TABLE_ENTRY (C4, C8, C8);
718
+ TABLE_ENTRY (C4, B1, C4);
719
+ TABLE_ENTRY (C8, U1, C8);
720
+ TABLE_ENTRY (C8, I1, C8);
721
+ TABLE_ENTRY (C8, I2, C8);
722
+ TABLE_ENTRY (C8, I4, C8);
723
+ TABLE_ENTRY (C8, I8, C8);
724
+ TABLE_ENTRY (C8, F2, C8);
725
+ TABLE_ENTRY (C8, F4, C8);
726
+ TABLE_ENTRY (C8, F8, C8);
727
+ TABLE_ENTRY (C8, C2, C8);
728
+ TABLE_ENTRY (C8, C4, C8);
729
+ TABLE_ENTRY (C8, C8, C8);
730
+ TABLE_ENTRY (C8, B1, C8);
731
+ TABLE_ENTRY (B1, U1, U1);
732
+ TABLE_ENTRY (B1, I1, I1);
733
+ TABLE_ENTRY (B1, I2, I2);
734
+ TABLE_ENTRY (B1, I4, I4);
735
+ TABLE_ENTRY (B1, I8, I8);
736
+ TABLE_ENTRY (B1, F2, F2);
737
+ TABLE_ENTRY (B1, F4, F4);
738
+ TABLE_ENTRY (B1, F8, F8);
739
+ TABLE_ENTRY (B1, C2, C2);
740
+ TABLE_ENTRY (B1, C4, C4);
741
+ TABLE_ENTRY (B1, C8, C8);
742
+ TABLE_ENTRY (B1, B1, B1);
743
+
744
+ } // namespace internal
745
+
746
+ template <typename T1, typename T2, bool half_to_float = false >
747
+ struct promote_types {
748
+ private:
749
+ static_assert (
750
+ std::is_same<T1, T2>::value ||
751
+ (!is_qint_type<T1>::value && !is_qint_type<T2>::value),
752
+ " promote_types not valid for quantized dtypes" );
753
+ static_assert (
754
+ std::is_same<T1, T2>::value ||
755
+ (!is_bits_type<T1>::value && !is_bits_type<T2>::value),
756
+ " promote_types not valid for bits dtypes" );
757
+
758
+ static_assert (
759
+ !std::is_same<
760
+ T1,
761
+ typename ScalarTypeToCppType<exec_aten::ScalarType::BFloat16>::type>::
762
+ value &&
763
+ !std::is_same<
764
+ T2,
765
+ typename ScalarTypeToCppType<
766
+ exec_aten::ScalarType::BFloat16>::type>::value,
767
+ " promote_types not valid for BFloat16" );
768
+ using promoted_type_not_respecting_half_to_float =
769
+ typename internal::promote_types_lookup<T1, T2>::type;
770
+
771
+ public:
772
+ using type = typename std::conditional<
773
+ half_to_float &&
774
+ std::is_same<
775
+ promoted_type_not_respecting_half_to_float,
776
+ typename ScalarTypeToCppType<exec_aten::ScalarType::Half>::type>::
777
+ value,
778
+ typename ScalarTypeToCppType<exec_aten::ScalarType::Float>::type,
779
+ promoted_type_not_respecting_half_to_float>::type;
780
+ };
781
+
553
782
/* *
554
783
* Implements type promotion rules that are consistent with ATen behaviour,
555
784
* which in turn is consistent with NumPy's promote_types.
@@ -589,6 +818,10 @@ inline exec_aten::ScalarType promoteTypes(
589
818
ET_CHECK_MSG (false , " promoteTypes not valid for bits dtypes" );
590
819
}
591
820
821
+ ET_CHECK_MSG (
822
+ a != exec_aten::ScalarType::BFloat16 &&
823
+ b != exec_aten::ScalarType::BFloat16,
824
+ " promoteTypes not valid for BFloat16" );
592
825
// 12 types are handled by this function, see the constexpr definitions above
593
826
const int NUM_PROMOTE_TYPES = 12 ;
594
827
0 commit comments