Skip to content

[PyTorch] RFC: Add tuple inline storage #64066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 50 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
5f06e2f
[PyTorch] RFC: Add tuple inline storage
swolchok Aug 26, 2021
2e3fb4b
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Aug 27, 2021
8496eb9
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Aug 27, 2021
a2726bc
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Aug 27, 2021
2451f9f
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Aug 27, 2021
9dcf2ae
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Aug 30, 2021
a7d23fd
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Aug 31, 2021
ca08c5f
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Aug 31, 2021
8fce745
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 3, 2021
031bc4a
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 7, 2021
8a7ef90
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 8, 2021
e03440d
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 9, 2021
757281a
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 9, 2021
9a1cc85
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 9, 2021
27eca8b
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 9, 2021
b4c2fd4
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 13, 2021
9e66764
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 14, 2021
a5e5bfe
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 15, 2021
1475ad1
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 16, 2021
436ccc5
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 20, 2021
305c8a1
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 20, 2021
edad395
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 20, 2021
a010999
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 21, 2021
0719493
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 21, 2021
5df6b5b
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 22, 2021
8c71367
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 22, 2021
23264d8
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 22, 2021
51f9a65
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 22, 2021
9232e2e
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 23, 2021
5f4f9f8
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 23, 2021
dcf699e
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 23, 2021
121a192
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 24, 2021
d897b89
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 27, 2021
d82cd75
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 29, 2021
09ee5d8
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 30, 2021
4fddc8d
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 30, 2021
0ce2d6d
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Sep 30, 2021
59e3375
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Oct 5, 2021
7e17dd6
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Oct 6, 2021
cf1d072
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Oct 8, 2021
a66f9a9
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Oct 8, 2021
adfdb14
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Oct 8, 2021
161a240
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Oct 11, 2021
236cf76
give up on making TupleElements noncopyable on "[PyTorch] RFC: Add tu…
swolchok Oct 11, 2021
be384f7
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Oct 12, 2021
5b55485
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Oct 12, 2021
88f8ce9
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Oct 13, 2021
eadc443
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Oct 14, 2021
4d1ef84
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Oct 14, 2021
b58e12f
Update on "[PyTorch] RFC: Add tuple inline storage"
swolchok Oct 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update on "[PyTorch] RFC: Add tuple inline storage"
I noticed a bunch of time being spent heap-allocating Tuples
in the unpickler. 1-, 2-, and 3-element Tuples are apparently common
enough that they get their own bytecode instructions, so I decided to
try also giving them their own representation. We store up to 3
IValues inline in `Tuple` rather than doing a second heap allocation
for a `std::vector<IValue>`.

Differential Revision: [D30592622](https://our.internmc.facebook.com/intern/diff/D30592622/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D30592622/)!

[ghstack-poisoned]
  • Loading branch information
swolchok committed Sep 13, 2021
commit b4c2fd4f4f41c0fee3ef4fc18b2c21674eb35be4
4 changes: 4 additions & 0 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ struct TORCH_API TupleElements {
}
}

C10_NODISCARD bool empty() const {
return inlineSize_ ? false : elementsVector_.empty();
}

C10_NODISCARD size_t size() const {
return inlineSize_ ? inlineSize_ : elementsVector_.size();
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/test/ivalue_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ using ivalue::TupleElements;

namespace {
void validateTupleElements(TupleElements& te, c10::ArrayRef<IValue> contents) {
EXPECT_EQ(te.empty(), contents.empty());
EXPECT_EQ(te.size(), contents.size());
for (const auto idx: c10::irange(contents.size())) {
EXPECT_IVALUE_EQ(te[idx], contents[idx]);
Expand Down
150 changes: 32 additions & 118 deletions torch/csrc/jit/mobile/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,6 @@ using caffe2::serialize::ReadAdapterInterface;

OpCode parseOpCode(const char* str);

IValue expect_field(
c10::ivalue::TupleElements& elements,
const std::string& expected_name,
size_t entry) {
auto row = std::move(elements.at(entry)).toTuple();
TORCH_INTERNAL_ASSERT(
row->elements().at(0).toStringRef() == expected_name,
"Expected ",
expected_name,
" found ",
row->elements().at(0).toStringRef());
return std::move(row->elements().at(1));
}

std::string operator_str(
const std::string& name,
const std::string& overloadname) {
Expand Down Expand Up @@ -237,22 +223,6 @@ class BytecodeDeserializer final {
IValue* schemaTable,
const int64_t& model_version,
mobile::Function* function);
/**
* Loads operators by looking them up in the Dispatcher and returns
* the set of operator names (with overload) that are not supported
* by the current runtime.
*
* Accepts an operator_cache, which allows you to cache operator
* functions for the entire model. This is keyed on
* c10::OperatorName. The value may not be what you're looking for
* even if the key is the same. You need to call has_same_arg_num()
* on the value to ensure that the number of arguments are the same.
*/
std::unordered_set<std::string> load_and_find_unsupported_operator_names(
c10::ivalue::TupleElements&& ops_list,
mobile::Function* function,
int64_t model_version,
mobile::Function::OperatorCacheType& operator_cache) const;
std::shared_ptr<CompilationUnit> compilation_unit_;
std::unordered_set<std::string> imported_libs_;
std::unique_ptr<PyTorchStreamReader> reader_{};
Expand All @@ -267,12 +237,22 @@ BytecodeDeserializer::BytecodeDeserializer(
reader_(std::move(reader)),
module_load_options_(module_load_options) {}

std::unordered_set<std::string> BytecodeDeserializer::
load_and_find_unsupported_operator_names(
c10::ivalue::TupleElements&& ops_list,
mobile::Function* function,
int64_t model_version,
mobile::Function::OperatorCacheType& operator_cache) const {
/**
* Loads operators by looking them up in the Dispatcher and returns
* the set of operator names (with overload) that are not supported
* by the current runtime.
*
* Accepts an operator_cache, which allows you to cache operator
* functions for the entire model. This is keyed on
* c10::OperatorName. The value may not be what you're looking for
* even if the key is the same. You need to call has_same_arg_num()
* on the value to ensure that the number of arguments are the same.
*/
std::unordered_set<std::string> load_and_find_unsupported_operator_names(
c10::ivalue::TupleElements&& ops_list,
mobile::Function* function,
int64_t model_version,
mobile::Function::OperatorCacheType& operator_cache) {
std::unordered_set<std::string> unsupported_op_names;
// ops_list is the list of operator names that were read in from
// bytecode.plk for the method that is currently being processed.
Expand Down Expand Up @@ -315,7 +295,7 @@ void BytecodeDeserializer::parseFunctionSchema(
mobile::Function* function) {
// function schema
if (schemaTable) { // (schema is optional for back compat)
auto parseArgList = [this](std::vector<IValue>&& argTables) {
auto parseArgList = [this](c10::ivalue::TupleElements&& argTables) {
std::vector<c10::Argument> args;
for (auto&& argTable : std::move(argTables)) {
auto argTableElements =
Expand All @@ -341,14 +321,14 @@ void BytecodeDeserializer::parseFunctionSchema(
};
auto schemaTableElements =
std::move(*std::move(*schemaTable).toTuple()).elements();
std::vector<IValue> arg_list =
auto arg_list =
std::move(*expect_field(
schemaTableElements,
"arguments",
BYTECODE_INDEX_SCHEMA_ARGUMENTS)
.toTuple())
.elements();
std::vector<IValue> ret_list =
auto ret_list =
std::move(
*expect_field(
schemaTableElements, "returns", BYTECODE_INDEX_SCHEMA_RETURNS)
Expand All @@ -366,14 +346,14 @@ void BytecodeDeserializer::parseFunctionSchema(
}

void parseOperators(
const std::vector<IValue>& ops_list,
c10::ivalue::TupleElements&& ops_list,
const int64_t& model_version,
const uint64_t& module_load_options,
mobile::Function* function,
mobile::Function::OperatorCacheType& operator_cache) {
std::unordered_set<std::string> unsupported_op_names =
load_and_find_unsupported_operator_names(
ops_list, function, model_version, operator_cache);
std::move(ops_list), function, model_version, operator_cache);
if ((module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK) &&
!unsupported_op_names.empty()) {
print_unsupported_ops_and_throw(unsupported_op_names);
Expand Down Expand Up @@ -455,96 +435,30 @@ void BytecodeDeserializer::parseMethods(
codeTableElements, "register_size", BYTECODE_INDEX_REGISTER_SIZE)
.toInt();

std::vector<IValue> debug_handles_m_tuple;
c10::ivalue::TupleElements debug_handles_m_tuple;
if (debug_handles) {
debug_handles_m_tuple =
std::move(*std::move((*debug_handles)[i]).toTuple()).elements();
}

OpCodeCache opCodeCache;
for (const auto j : c10::irange(ins_list.size())) {
// Can't remove this, need to keep Tuple alive!
auto ins_tuple = std::move(ins_list[j]).toTuple();
c10::ArrayRef<IValue> ins_item = ins_tuple->elements();
TORCH_CHECK(
ins_item.size() == 3,
"There should be three parts in an instruction. The function name is ",
function_name);
OpCode op_code = opCodeCache.parse(*ins_item[0].toString());
int X = ins_item[1].toInt();
int N = ins_item[2].toInt();
if (debug_handles) {
int64_t debug_handle = debug_handles_list[j];
function->append_instruction(op_code, X, N, debug_handle);
} else {
function->append_instruction(op_code, X, N);
}
}
parseInstructions(
function_name, ins_list, debug_handles_m_tuple, function.get());

std::unordered_set<std::string> unsupported_op_names =
load_and_find_unsupported_operator_names(
std::move(ops_list), function.get(), model_version, operator_cache);
if ((module_load_options_ & MobileModuleLoadOptions::OPERATOR_CHECK) &&
!unsupported_op_names.empty()) {
print_unsupported_ops_and_throw(unsupported_op_names);
}
parseOperators(
std::move(ops_list),
model_version,
module_load_options_,
function.get(),
operator_cache);

parseConstants(consts_list, function.get());

parseTypes(types_list, function.get());

function->set_register_size(register_size);

// function schema
if (schemaTable) { // (schema is optional for back compat)
auto parseArgList = [this](c10::ivalue::TupleElements&& argTables) {
std::vector<c10::Argument> args;
for (auto&& argTable : std::move(argTables)) {
auto argTableElements =
std::move(*std::move(argTable).toTuple()).elements();
auto name =
expect_field(
argTableElements, "name", BYTECODE_INDEX_ARGUMENT_NAME)
.toStringRef();
c10::TypePtr type = resolveTypeName(
(expect_field(
argTableElements, "type", BYTECODE_INDEX_ARGUMENT_TYPE))
.toStringRef());
IValue default_value = expect_field(
argTableElements,
"default_value",
BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE);
args.emplace_back(
name,
std::move(type),
c10::nullopt /*N*/,
std::move(default_value));
}
return args;
};
auto schemaTableElements =
std::move(*std::move(*schemaTable).toTuple()).elements();
auto arg_list = std::move(*expect_field(
schemaTableElements,
"arguments",
BYTECODE_INDEX_SCHEMA_ARGUMENTS)
.toTuple())
.elements();
auto ret_list = std::move(*expect_field(
schemaTableElements,
"returns",
BYTECODE_INDEX_SCHEMA_RETURNS)
.toTuple())
.elements();
c10::FunctionSchema schema(
function_name,
"" /*overload_name*/,
parseArgList(std::move(arg_list)),
parseArgList(std::move(ret_list)),
false /*is_varargs*/,
false /*is_varret*/);
function->setSchema(std::move(schema));
}
parseFunctionSchema(
function_name, schemaTable, model_version, function.get());

mcu.register_function(std::move(function));
}
Expand Down
18 changes: 9 additions & 9 deletions torch/csrc/jit/mobile/parse_bytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ OpCode parseOpCode(const char* str);
using c10::IValue;

IValue expect_field(
std::vector<IValue>& elements,
c10::ivalue::TupleElements& elements,
const std::string& expected_name,
size_t entry) {
auto row = std::move(elements.at(entry)).toTuple();
Expand Down Expand Up @@ -68,8 +68,8 @@ class OpCodeCache {

void parseInstructions(
const std::string& function_name,
const std::vector<IValue>& ins_list,
std::vector<IValue>& debug_handles_m_tuple,
const c10::ivalue::TupleElements& ins_list,
c10::ivalue::TupleElements& debug_handles_m_tuple,
mobile::Function* function) {
c10::List<int64_t> debug_handles_list;
if (!debug_handles_m_tuple.empty()) {
Expand All @@ -79,10 +79,10 @@ void parseInstructions(
debug_info_function_name == function_name,
"The function names in the bytecode table and the debug info table do not match.");
IValue& debug_handles_table = debug_handles_m_tuple[1];
auto debugHandlesElements = std::move(*std::move(debug_handles_table).toTuple()).elements();
auto debugHandlesTableElements = std::move(*std::move(debug_handles_table).toTuple()).elements();
debug_handles_list =
(expect_field(
debugHandlesElements,
debugHandlesTableElements,
"function_debug_handles",
BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
.toTuple()
Expand All @@ -99,8 +99,8 @@ void parseInstructions(
// becomes an important use case.
OpCodeCache opCodeCache;
for (const auto j : c10::irange(ins_list.size())) {
std::vector<IValue> ins_item =
std::move(*std::move(ins_list[j]).toTuple()).elements();
auto ins_tuple = std::move(ins_list[j]).toTuple();
c10::ArrayRef<IValue> ins_item = ins_tuple->elements();
TORCH_CHECK(
ins_item.size() == 3,
"There should be three parts in an instruction. The function name is ",
Expand All @@ -118,15 +118,15 @@ void parseInstructions(
}

void parseConstants(
const std::vector<IValue>& consts_list,
const c10::ivalue::TupleElements& consts_list,
mobile::Function* function) {
for (const auto& constant : consts_list) {
function->append_constant(constant);
}
}

void parseTypes(
const std::vector<IValue>& types_list,
const c10::ivalue::TupleElements& types_list,
mobile::Function* function) {
static const c10::QualifiedName classPrefix = "__torch__.torch.classes";
for (const auto& t : types_list) {
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/mobile/parse_bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ namespace mobile {
using c10::IValue;
TORCH_API void parseInstructions(
const std::string& function_name,
const std::vector<IValue>& ins_list,
std::vector<IValue>& debug_handles_m_tuple,
const c10::ivalue::TupleElements& ins_list,
c10::ivalue::TupleElements& debug_handles_m_tuple,
mobile::Function* function);
TORCH_API void parseConstants(
const std::vector<IValue>& consts_list,
const c10::ivalue::TupleElements& consts_list,
mobile::Function* function);
TORCH_API void parseTypes(
const std::vector<IValue>& types_list,
const c10::ivalue::TupleElements& types_list,
mobile::Function* function);
TORCH_API void parseRegisterSize(size_t rsize, mobile::Function* function);
} // namespace mobile
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/serialization/import_export_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace torch {
namespace jit {
using c10::IValue;
IValue expect_field(
std::vector<IValue>& elements,
c10::ivalue::TupleElements& elements,
const std::string& expected_name,
size_t entry);
std::string operator_str(
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.