Skip to content

Commit 8c8563f

Browse files
swolchokfacebook-github-bot
authored andcommitted
Add compile-time promote_types template
Summary: Now types can be promoted at compile-time. (I had to fix promoteTypes' lack of gating for BFloat16; I believe that would have caused a buffer overflow?) Reviewed By: kimishpatel, manuelcandales Differential Revision: D56643045 fbshipit-source-id: cd522e50f59ce838bba06c796e47a1d16ac55b22
1 parent 62a7a13 commit 8c8563f

File tree

3 files changed

+312
-2
lines changed

3 files changed

+312
-2
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
indexToType = ["U1", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "B1"]
8+
promoteTypesLookup = [
9+
["U1", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "U1"],
10+
["I2", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I1"],
11+
["I2", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I2"],
12+
["I4", "I4", "I4", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I4"],
13+
["I8", "I8", "I8", "I8", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I8"],
14+
["F2", "F2", "F2", "F2", "F2", "F2", "F4", "F8", "C2", "C4", "C8", "F2"],
15+
["F4", "F4", "F4", "F4", "F4", "F4", "F4", "F8", "C4", "C4", "C8", "F4"],
16+
["F8", "F8", "F8", "F8", "F8", "F8", "F8", "F8", "C8", "C8", "C8", "F8"],
17+
["C2", "C2", "C2", "C2", "C2", "C2", "C4", "C8", "C2", "C4", "C8", "C2"],
18+
["C4", "C4", "C4", "C4", "C4", "C4", "C4", "C8", "C4", "C4", "C8", "C4"],
19+
["C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8"],
20+
["U1", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "B1"],
21+
]
22+
for rowIndex, row in enumerate(promoteTypesLookup):
23+
for colIndex, col in enumerate(row):
24+
print(f"TABLE_ENTRY({indexToType[rowIndex]}, {indexToType[colIndex]}, {col});")

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 235 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,14 +387,19 @@ inline bool isComplexType(exec_aten::ScalarType t) {
387387
t == exec_aten::ScalarType::ComplexDouble);
388388
}
389389

390-
inline bool isBitsType(exec_aten::ScalarType t) {
390+
constexpr bool isBitsType(exec_aten::ScalarType t) {
391391
return t == exec_aten::ScalarType::Bits1x8 ||
392392
t == exec_aten::ScalarType::Bits2x4 ||
393393
t == exec_aten::ScalarType::Bits4x2 ||
394394
t == exec_aten::ScalarType::Bits8 || t == exec_aten::ScalarType::Bits16;
395395
}
396396

397-
inline bool isQIntType(exec_aten::ScalarType t) {
397+
template <typename T>
398+
struct is_bits_type
399+
: std::integral_constant<bool, isBitsType(CppTypeToScalarType<T>::value)> {
400+
};
401+
402+
constexpr bool isQIntType(exec_aten::ScalarType t) {
398403
// Don't forget to extend this when adding new QInt types
399404
return t == exec_aten::ScalarType::QInt8 ||
400405
t == exec_aten::ScalarType::QUInt8 ||
@@ -403,6 +408,11 @@ inline bool isQIntType(exec_aten::ScalarType t) {
403408
t == exec_aten::ScalarType::QUInt2x4;
404409
}
405410

411+
template <typename T>
412+
struct is_qint_type
413+
: std::integral_constant<bool, isQIntType(CppTypeToScalarType<T>::value)> {
414+
};
415+
406416
inline exec_aten::ScalarType toQIntType(exec_aten::ScalarType t) {
407417
switch (t) {
408418
case exec_aten::ScalarType::Byte:
@@ -550,6 +560,225 @@ To convert(From val) {
550560
return static_cast<To>(val);
551561
}
552562

563+
namespace internal {
564+
template <typename T1, typename T2>
565+
struct promote_types_lookup;
566+
567+
template <typename T1>
568+
struct promote_types_lookup<T1, T1> {
569+
using type = T1;
570+
};
571+
572+
using U1 = typename ScalarTypeToCppType<exec_aten::ScalarType::Byte>::type;
573+
using I1 = typename ScalarTypeToCppType<exec_aten::ScalarType::Char>::type;
574+
using I2 = typename ScalarTypeToCppType<exec_aten::ScalarType::Short>::type;
575+
using I4 = typename ScalarTypeToCppType<exec_aten::ScalarType::Int>::type;
576+
using I8 = typename ScalarTypeToCppType<exec_aten::ScalarType::Long>::type;
577+
using F2 = typename ScalarTypeToCppType<exec_aten::ScalarType::Half>::type;
578+
using F4 = typename ScalarTypeToCppType<exec_aten::ScalarType::Float>::type;
579+
using F8 = typename ScalarTypeToCppType<exec_aten::ScalarType::Double>::type;
580+
using C2 =
581+
typename ScalarTypeToCppType<exec_aten::ScalarType::ComplexHalf>::type;
582+
using C4 =
583+
typename ScalarTypeToCppType<exec_aten::ScalarType::ComplexFloat>::type;
584+
using C8 =
585+
typename ScalarTypeToCppType<exec_aten::ScalarType::ComplexDouble>::type;
586+
using B1 = typename ScalarTypeToCppType<exec_aten::ScalarType::Bool>::type;
587+
588+
#define TABLE_ENTRY(key1, key2, value) \
589+
template <> \
590+
struct promote_types_lookup<key1, key2> { \
591+
using type = value; \
592+
}
593+
594+
/* promote_types_lookup is a compile-time-accessible version of the
595+
* table in promoteTypes below; we cannot make promoteTypes constexpr
596+
* and use it directly because we are on C++11 and thus don't have
597+
* C++17 relaxed constexpr. The below series of entries is generated
598+
* by genScalarTypeTable.py. */
599+
TABLE_ENTRY(U1, U1, U1);
600+
TABLE_ENTRY(U1, I1, I2);
601+
TABLE_ENTRY(U1, I2, I2);
602+
TABLE_ENTRY(U1, I4, I4);
603+
TABLE_ENTRY(U1, I8, I8);
604+
TABLE_ENTRY(U1, F2, F2);
605+
TABLE_ENTRY(U1, F4, F4);
606+
TABLE_ENTRY(U1, F8, F8);
607+
TABLE_ENTRY(U1, C2, C2);
608+
TABLE_ENTRY(U1, C4, C4);
609+
TABLE_ENTRY(U1, C8, C8);
610+
TABLE_ENTRY(U1, B1, U1);
611+
TABLE_ENTRY(I1, U1, I2);
612+
TABLE_ENTRY(I1, I1, I1);
613+
TABLE_ENTRY(I1, I2, I2);
614+
TABLE_ENTRY(I1, I4, I4);
615+
TABLE_ENTRY(I1, I8, I8);
616+
TABLE_ENTRY(I1, F2, F2);
617+
TABLE_ENTRY(I1, F4, F4);
618+
TABLE_ENTRY(I1, F8, F8);
619+
TABLE_ENTRY(I1, C2, C2);
620+
TABLE_ENTRY(I1, C4, C4);
621+
TABLE_ENTRY(I1, C8, C8);
622+
TABLE_ENTRY(I1, B1, I1);
623+
TABLE_ENTRY(I2, U1, I2);
624+
TABLE_ENTRY(I2, I1, I2);
625+
TABLE_ENTRY(I2, I2, I2);
626+
TABLE_ENTRY(I2, I4, I4);
627+
TABLE_ENTRY(I2, I8, I8);
628+
TABLE_ENTRY(I2, F2, F2);
629+
TABLE_ENTRY(I2, F4, F4);
630+
TABLE_ENTRY(I2, F8, F8);
631+
TABLE_ENTRY(I2, C2, C2);
632+
TABLE_ENTRY(I2, C4, C4);
633+
TABLE_ENTRY(I2, C8, C8);
634+
TABLE_ENTRY(I2, B1, I2);
635+
TABLE_ENTRY(I4, U1, I4);
636+
TABLE_ENTRY(I4, I1, I4);
637+
TABLE_ENTRY(I4, I2, I4);
638+
TABLE_ENTRY(I4, I4, I4);
639+
TABLE_ENTRY(I4, I8, I8);
640+
TABLE_ENTRY(I4, F2, F2);
641+
TABLE_ENTRY(I4, F4, F4);
642+
TABLE_ENTRY(I4, F8, F8);
643+
TABLE_ENTRY(I4, C2, C2);
644+
TABLE_ENTRY(I4, C4, C4);
645+
TABLE_ENTRY(I4, C8, C8);
646+
TABLE_ENTRY(I4, B1, I4);
647+
TABLE_ENTRY(I8, U1, I8);
648+
TABLE_ENTRY(I8, I1, I8);
649+
TABLE_ENTRY(I8, I2, I8);
650+
TABLE_ENTRY(I8, I4, I8);
651+
TABLE_ENTRY(I8, I8, I8);
652+
TABLE_ENTRY(I8, F2, F2);
653+
TABLE_ENTRY(I8, F4, F4);
654+
TABLE_ENTRY(I8, F8, F8);
655+
TABLE_ENTRY(I8, C2, C2);
656+
TABLE_ENTRY(I8, C4, C4);
657+
TABLE_ENTRY(I8, C8, C8);
658+
TABLE_ENTRY(I8, B1, I8);
659+
TABLE_ENTRY(F2, U1, F2);
660+
TABLE_ENTRY(F2, I1, F2);
661+
TABLE_ENTRY(F2, I2, F2);
662+
TABLE_ENTRY(F2, I4, F2);
663+
TABLE_ENTRY(F2, I8, F2);
664+
TABLE_ENTRY(F2, F2, F2);
665+
TABLE_ENTRY(F2, F4, F4);
666+
TABLE_ENTRY(F2, F8, F8);
667+
TABLE_ENTRY(F2, C2, C2);
668+
TABLE_ENTRY(F2, C4, C4);
669+
TABLE_ENTRY(F2, C8, C8);
670+
TABLE_ENTRY(F2, B1, F2);
671+
TABLE_ENTRY(F4, U1, F4);
672+
TABLE_ENTRY(F4, I1, F4);
673+
TABLE_ENTRY(F4, I2, F4);
674+
TABLE_ENTRY(F4, I4, F4);
675+
TABLE_ENTRY(F4, I8, F4);
676+
TABLE_ENTRY(F4, F2, F4);
677+
TABLE_ENTRY(F4, F4, F4);
678+
TABLE_ENTRY(F4, F8, F8);
679+
TABLE_ENTRY(F4, C2, C4);
680+
TABLE_ENTRY(F4, C4, C4);
681+
TABLE_ENTRY(F4, C8, C8);
682+
TABLE_ENTRY(F4, B1, F4);
683+
TABLE_ENTRY(F8, U1, F8);
684+
TABLE_ENTRY(F8, I1, F8);
685+
TABLE_ENTRY(F8, I2, F8);
686+
TABLE_ENTRY(F8, I4, F8);
687+
TABLE_ENTRY(F8, I8, F8);
688+
TABLE_ENTRY(F8, F2, F8);
689+
TABLE_ENTRY(F8, F4, F8);
690+
TABLE_ENTRY(F8, F8, F8);
691+
TABLE_ENTRY(F8, C2, C8);
692+
TABLE_ENTRY(F8, C4, C8);
693+
TABLE_ENTRY(F8, C8, C8);
694+
TABLE_ENTRY(F8, B1, F8);
695+
TABLE_ENTRY(C2, U1, C2);
696+
TABLE_ENTRY(C2, I1, C2);
697+
TABLE_ENTRY(C2, I2, C2);
698+
TABLE_ENTRY(C2, I4, C2);
699+
TABLE_ENTRY(C2, I8, C2);
700+
TABLE_ENTRY(C2, F2, C2);
701+
TABLE_ENTRY(C2, F4, C4);
702+
TABLE_ENTRY(C2, F8, C8);
703+
TABLE_ENTRY(C2, C2, C2);
704+
TABLE_ENTRY(C2, C4, C4);
705+
TABLE_ENTRY(C2, C8, C8);
706+
TABLE_ENTRY(C2, B1, C2);
707+
TABLE_ENTRY(C4, U1, C4);
708+
TABLE_ENTRY(C4, I1, C4);
709+
TABLE_ENTRY(C4, I2, C4);
710+
TABLE_ENTRY(C4, I4, C4);
711+
TABLE_ENTRY(C4, I8, C4);
712+
TABLE_ENTRY(C4, F2, C4);
713+
TABLE_ENTRY(C4, F4, C4);
714+
TABLE_ENTRY(C4, F8, C8);
715+
TABLE_ENTRY(C4, C2, C4);
716+
TABLE_ENTRY(C4, C4, C4);
717+
TABLE_ENTRY(C4, C8, C8);
718+
TABLE_ENTRY(C4, B1, C4);
719+
TABLE_ENTRY(C8, U1, C8);
720+
TABLE_ENTRY(C8, I1, C8);
721+
TABLE_ENTRY(C8, I2, C8);
722+
TABLE_ENTRY(C8, I4, C8);
723+
TABLE_ENTRY(C8, I8, C8);
724+
TABLE_ENTRY(C8, F2, C8);
725+
TABLE_ENTRY(C8, F4, C8);
726+
TABLE_ENTRY(C8, F8, C8);
727+
TABLE_ENTRY(C8, C2, C8);
728+
TABLE_ENTRY(C8, C4, C8);
729+
TABLE_ENTRY(C8, C8, C8);
730+
TABLE_ENTRY(C8, B1, C8);
731+
TABLE_ENTRY(B1, U1, U1);
732+
TABLE_ENTRY(B1, I1, I1);
733+
TABLE_ENTRY(B1, I2, I2);
734+
TABLE_ENTRY(B1, I4, I4);
735+
TABLE_ENTRY(B1, I8, I8);
736+
TABLE_ENTRY(B1, F2, F2);
737+
TABLE_ENTRY(B1, F4, F4);
738+
TABLE_ENTRY(B1, F8, F8);
739+
TABLE_ENTRY(B1, C2, C2);
740+
TABLE_ENTRY(B1, C4, C4);
741+
TABLE_ENTRY(B1, C8, C8);
742+
TABLE_ENTRY(B1, B1, B1);
743+
744+
} // namespace internal
745+
746+
template <typename T1, typename T2, bool half_to_float = false>
747+
struct promote_types {
748+
private:
749+
static_assert(
750+
std::is_same<T1, T2>::value ||
751+
(!is_qint_type<T1>::value && !is_qint_type<T2>::value),
752+
"promote_types not valid for quantized dtypes");
753+
static_assert(
754+
std::is_same<T1, T2>::value ||
755+
(!is_bits_type<T1>::value && !is_bits_type<T2>::value),
756+
"promote_types not valid for bits dtypes");
757+
758+
static_assert(
759+
!std::is_same<
760+
T1,
761+
typename ScalarTypeToCppType<exec_aten::ScalarType::BFloat16>::type>::
762+
value &&
763+
!std::is_same<
764+
T2,
765+
typename ScalarTypeToCppType<
766+
exec_aten::ScalarType::BFloat16>::type>::value,
767+
"promote_types not valid for BFloat16");
768+
using promoted_type_not_respecting_half_to_float =
769+
typename internal::promote_types_lookup<T1, T2>::type;
770+
771+
public:
772+
using type = typename std::conditional<
773+
half_to_float &&
774+
std::is_same<
775+
promoted_type_not_respecting_half_to_float,
776+
typename ScalarTypeToCppType<exec_aten::ScalarType::Half>::type>::
777+
value,
778+
typename ScalarTypeToCppType<exec_aten::ScalarType::Float>::type,
779+
promoted_type_not_respecting_half_to_float>::type;
780+
};
781+
553782
/**
554783
* Implements type promotion rules that are consistent with ATen behaviour,
555784
* which in turn is consistent with NumPy's promote_types.
@@ -589,6 +818,10 @@ inline exec_aten::ScalarType promoteTypes(
589818
ET_CHECK_MSG(false, "promoteTypes not valid for bits dtypes");
590819
}
591820

821+
ET_CHECK_MSG(
822+
a != exec_aten::ScalarType::BFloat16 &&
823+
b != exec_aten::ScalarType::BFloat16,
824+
"promoteTypes not valid for BFloat16");
592825
// 12 types are handled by this function, see the constexpr definitions above
593826
const int NUM_PROMOTE_TYPES = 12;
594827

runtime/core/exec_aten/util/test/scalar_type_util_test.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,56 @@ TEST(ScalarTypeUtilTest, promoteTypesTest) {
162162
promoteTypes(ScalarType::Char, ScalarType::Bool) == ScalarType::Char);
163163
ET_CHECK(promoteTypes(ScalarType::Bool, ScalarType::Int) == ScalarType::Int);
164164
}
165+
166+
template <typename T1, typename T2>
167+
struct promote_types_is_valid
168+
: std::integral_constant<
169+
bool,
170+
!std::is_same<T1, torch::executor::BFloat16>::value &&
171+
!std::is_same<T2, torch::executor::BFloat16>::value &&
172+
(std::is_same<T1, T2>::value ||
173+
(!torch::executor::is_qint_type<T1>::value &&
174+
!torch::executor::is_qint_type<T2>::value &&
175+
!torch::executor::is_bits_type<T1>::value &&
176+
!torch::executor::is_bits_type<T2>::value))> {};
177+
178+
template <typename T1, bool half_to_float>
179+
struct CompileTimePromoteTypesTestCase {
180+
static void testAll() {
181+
#define CALL_TEST_ONE(cpp_type, scalar_type) \
182+
testOne<cpp_type, promote_types_is_valid<T1, cpp_type>::value>();
183+
ET_FORALL_SCALAR_TYPES(CALL_TEST_ONE)
184+
#undef CALL_TEST_ONE
185+
}
186+
187+
template <
188+
typename T2,
189+
bool valid,
190+
typename std::enable_if<valid, bool>::type = true>
191+
static void testOne() {
192+
auto actual = torch::executor::CppTypeToScalarType<
193+
typename torch::executor::promote_types<T1, T2, half_to_float>::type>::
194+
value;
195+
const auto scalarType1 = torch::executor::CppTypeToScalarType<T1>::value;
196+
const auto scalarType2 = torch::executor::CppTypeToScalarType<T2>::value;
197+
auto expected = promoteTypes(scalarType1, scalarType2, half_to_float);
198+
EXPECT_EQ(actual, expected)
199+
<< "promoting " << (int)scalarType1 << " to " << (int)scalarType2;
200+
}
201+
202+
template <
203+
typename T2,
204+
bool valid,
205+
typename std::enable_if<!valid, bool>::type = true>
206+
static void testOne() {
207+
// Skip invalid case
208+
}
209+
};
210+
211+
TEST(ScalarTypeUtilTest, compileTypePromoteTypesTest) {
212+
#define INSTANTIATE_TYPE_TEST(cpp_type, scalar_type) \
213+
CompileTimePromoteTypesTestCase<cpp_type, false>::testAll(); \
214+
CompileTimePromoteTypesTestCase<cpp_type, true>::testAll();
215+
216+
ET_FORALL_SCALAR_TYPES(INSTANTIATE_TYPE_TEST);
217+
}

0 commit comments

Comments
 (0)