Skip to content

Commit b58e12f

Browse files
committed
Update on "[PyTorch] RFC: Add tuple inline storage"
I noticed a bunch of time being spent heap-allocating Tuples in the unpickler. 1-, 2-, and 3-element Tuples are apparently common enough that they get their own bytecode instructions, so I decided to try also giving them their own representation. We store up to 3 IValues inline in `Tuple` rather than doing a second heap allocation for a `std::vector<IValue>`. Differential Revision: [D30592622](https://our.internmc.facebook.com/intern/diff/D30592622/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D30592622/)! cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse @SciPioneer @H-Huang gcramer23 [ghstack-poisoned]
2 parents 4d1ef84 + cd35ed5 commit b58e12f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1209
-857
lines changed

BUILD.bazel

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ header_template_rule(
525525
substitutions = {
526526
"@AT_CUDNN_ENABLED@": "1",
527527
"@AT_ROCM_ENABLED@": "0",
528+
"@AT_MAGMA_ENABLED@": "0",
528529
"@NVCC_FLAGS_EXTRA@": "",
529530
},
530531
)
@@ -537,15 +538,6 @@ header_template_rule(
537538
},
538539
)
539540

540-
header_template_rule(
541-
name = "aten_src_THC_THCGeneral",
542-
src = "aten/src/THC/THCGeneral.h.in",
543-
out = "aten/src/THC/THCGeneral.h",
544-
substitutions = {
545-
"#cmakedefine USE_MAGMA": "",
546-
},
547-
)
548-
549541
cc_library(
550542
name = "aten_headers",
551543
hdrs = [
@@ -572,7 +564,6 @@ cc_library(
572564
deps = [
573565
":c10_headers",
574566
":aten_src_TH_THGeneral",
575-
":aten_src_THC_THCGeneral",
576567
],
577568
)
578569

aten/src/ATen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ set_bool(AT_BUILD_WITH_BLAS USE_BLAS)
3030
set_bool(AT_BUILD_WITH_LAPACK USE_LAPACK)
3131
set_bool(AT_BLAS_F2C BLAS_F2C)
3232
set_bool(AT_BLAS_USE_CBLAS_DOT BLAS_USE_CBLAS_DOT)
33+
set_bool(AT_MAGMA_ENABLED USE_MAGMA)
3334
set_bool(CAFFE2_STATIC_LINK_CUDA_INT CAFFE2_STATIC_LINK_CUDA)
3435

3536
configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")

aten/src/ATen/core/ivalue_inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ struct TORCH_API Tuple : c10::intrusive_ptr_target {
579579
static c10::intrusive_ptr<Tuple> createNamed(
580580
std::initializer_list<IValue> elements_,
581581
std::shared_ptr<TupleType> type_) {
582-
return create(std::vector<IValue>(elements_));
582+
return createNamed(std::vector<IValue>(elements_), std::move(type_));
583583
}
584584

585585
// MSVC apparently can't disambiguate the other two overloads of

aten/src/ATen/cuda/ApplyGridUtils.cuh

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
3+
#include <cuda_runtime.h>
4+
5+
namespace at { namespace cuda {
6+
7+
/**
8+
Computes ceil(a / b)
9+
*/
10+
template <typename T>
11+
__host__ __device__ __forceinline__ T ATenCeilDiv(T a, T b) {
12+
return (a + b - 1) / b;
13+
}
14+
15+
namespace {
16+
17+
// Threads per block for our apply kernel
18+
// FIXME: use occupancy calculator instead
19+
constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512;
20+
constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;
21+
22+
template <int step = 1>
23+
inline bool getApplyGrid(uint64_t totalElements, dim3& grid, int64_t curDevice, int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
24+
if (curDevice == -1) return false;
25+
uint64_t numel_per_thread = static_cast<uint64_t>(max_threads_per_block) * static_cast<uint64_t>(step);
26+
uint64_t numBlocks = ATenCeilDiv(totalElements, numel_per_thread);
27+
uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0];
28+
if (numBlocks > maxGridX)
29+
numBlocks = maxGridX;
30+
grid = dim3(numBlocks);
31+
return true;
32+
}
33+
34+
constexpr int getApplyBlocksPerSM() {
35+
return AT_APPLY_BLOCKS_PER_SM;
36+
}
37+
38+
constexpr int getApplyBlockSize() {
39+
return AT_APPLY_THREADS_PER_BLOCK;
40+
}
41+
42+
inline dim3 getApplyBlock(int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
43+
return dim3(max_threads_per_block);
44+
}
45+
46+
}
47+
}} // namespace at::cuda

aten/src/ATen/cuda/CUDAApplyUtils.cuh

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <ATen/cuda/ApplyGridUtils.cuh>
34
#include <ATen/cuda/detail/IndexUtils.cuh>
45
#include <ATen/TensorUtils.h>
56
#include <ATen/ceil_div.h>
@@ -199,11 +200,6 @@ inline void rearrangeDims(detail::TensorInfo<T1, IndexType>* aInfo,
199200
}
200201
}
201202

202-
// Threads per block for our apply kernel
203-
// FIXME: use occupancy calculator instead
204-
constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512;
205-
constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;
206-
207203
// The `remaining_steps` argument is used to support Op that operates on
208204
// multiple elements at the same time. Generally, the strategy of ApplyOpN is to
209205
// 1. Initialize `remaining_steps = step`, where `step` is the template arg of
@@ -379,40 +375,6 @@ kernelPointwiseApply2(detail::TensorInfo<scalar1, IndexType> a,
379375

380376
} // namespace
381377

382-
/**
383-
Computes ceil(a / b)
384-
*/
385-
template <typename T>
386-
C10_DEPRECATED_MESSAGE("at::cuda::ATenCeilDiv is deprecated. Instead use at::ceil_div in <ATen/ceil_div.h>.")
387-
__host__ __device__ __forceinline__ T ATenCeilDiv(T a, T b) {
388-
// TODO: Delete when torchvision stops using this function
389-
return at::ceil_div(a, b);
390-
}
391-
392-
template <int step = 1>
393-
inline bool getApplyGrid(uint64_t totalElements, dim3& grid, int64_t curDevice, int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
394-
if (curDevice == -1) return false;
395-
uint64_t numel_per_thread = static_cast<uint64_t>(max_threads_per_block) * static_cast<uint64_t>(step);
396-
uint64_t numBlocks = ceil_div(totalElements, numel_per_thread);
397-
uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0];
398-
if (numBlocks > maxGridX)
399-
numBlocks = maxGridX;
400-
grid = dim3(numBlocks);
401-
return true;
402-
}
403-
404-
constexpr int getApplyBlocksPerSM() {
405-
return AT_APPLY_BLOCKS_PER_SM;
406-
}
407-
408-
constexpr int getApplyBlockSize() {
409-
return AT_APPLY_THREADS_PER_BLOCK;
410-
}
411-
412-
inline dim3 getApplyBlock(int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
413-
return dim3(max_threads_per_block);
414-
}
415-
416378
template <typename scalar1, typename scalar2, int step, typename Op,
417379
int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
418380
int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>

aten/src/ATen/cuda/CUDAConfig.h.in

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,11 @@
99

1010
#define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@
1111
#define AT_ROCM_ENABLED() @AT_ROCM_ENABLED@
12+
#define AT_MAGMA_ENABLED() @AT_MAGMA_ENABLED@
13+
14+
// Needed for hipMAGMA to correctly identify implementation
15+
#if (AT_ROCM_ENABLED() && AT_MAGMA_ENABLED())
16+
#define HAVE_HIP 1
17+
#endif
1218

1319
#define NVCC_FLAGS_EXTRA "@NVCC_FLAGS_EXTRA@"

aten/src/ATen/cuda/detail/CUDAHooks.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include <ATen/cudnn/cudnn-wrapper.h>
2222
#endif
2323

24-
#ifdef USE_MAGMA
24+
#if AT_MAGMA_ENABLED()
2525
#include <magma_v2.h>
2626
#endif
2727

@@ -118,7 +118,7 @@ bool CUDAHooks::hasCUDA() const {
118118
}
119119

120120
bool CUDAHooks::hasMAGMA() const {
121-
#ifdef USE_MAGMA
121+
#if AT_MAGMA_ENABLED()
122122
return true;
123123
#else
124124
return false;
@@ -337,7 +337,7 @@ std::string CUDAHooks::showConfig() const {
337337
oss << " - MIOpen " << MIOPEN_VERSION_MAJOR << "." << MIOPEN_VERSION_MINOR << "." << MIOPEN_VERSION_PATCH << "\n";
338338
#endif
339339

340-
#ifdef USE_MAGMA
340+
#if AT_MAGMA_ENABLED()
341341
oss << " - Magma " << MAGMA_VERSION_MAJOR << "." << MAGMA_VERSION_MINOR << "." << MAGMA_VERSION_MICRO << "\n";
342342
#endif
343343

aten/src/ATen/native/Activation.h

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,40 @@
11
#pragma once
22

3-
#include <ATen/ATen.h>
43
#include <ATen/native/DispatchStub.h>
5-
#include <c10/core/Scalar.h>
64

7-
namespace at {
5+
namespace c10 {
6+
class Scalar;
7+
}
88

9+
namespace at {
910
struct TensorIterator;
11+
struct TensorIteratorBase;
12+
class TensorBase;
13+
}
1014

11-
namespace native {
15+
namespace at { namespace native {
1216

1317
using structured_activation_fn = void (*)(TensorIteratorBase&);
1418
using structured_activation_backward_fn = void (*)(TensorIteratorBase&);
1519

1620
using activation_fn = void (*)(TensorIterator&);
1721
using activation_backward_fn = void (*)(TensorIterator&);
18-
using softplus_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&);
19-
using softplus_backward_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&);
20-
using threshold_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&);
21-
using hardtanh_backward_fn = void (*)(TensorIterator&, const Scalar&, const Scalar&);
22+
using softplus_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
23+
using softplus_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
24+
using threshold_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
25+
using hardtanh_backward_fn = void (*)(TensorIterator&, const c10::Scalar&, const c10::Scalar&);
2226
using hardsigmoid_fn = void(*)(TensorIteratorBase&);
2327
using hardsigmoid_backward_fn = void(*)(TensorIteratorBase&);
2428
using hardswish_fn = void(*)(TensorIterator&);
2529
using hardswish_backward_fn = void(*)(TensorIterator&);
26-
using shrink_fn = void (*)(TensorIteratorBase&, const Scalar&);
27-
using softshrink_fn = void (*)(TensorIteratorBase&, const Scalar&);
28-
using shrink_backward_fn = void (*)(TensorIteratorBase&, const Scalar&);
29-
using elu_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&, const Scalar&);
30-
using elu_backward_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&, const Scalar&, bool);
31-
using leaky_relu_fn = void (*)(TensorIteratorBase&, const Scalar&);
32-
using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const Scalar&);
33-
using log_sigmoid_cpu_fn = void (*)(Tensor& , Tensor&, const Tensor& );
30+
using shrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
31+
using softshrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
32+
using shrink_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
33+
using elu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&);
34+
using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&, bool);
35+
using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
36+
using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
37+
using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
3438

3539
DECLARE_DISPATCH(elu_fn, elu_stub);
3640
DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);

aten/src/ATen/native/cpu/Activation.cpp

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#define TORCH_ASSERT_NO_OPERATORS
12
#ifndef _USE_MATH_DEFINES
23
#define _USE_MATH_DEFINES
34
#endif
@@ -7,24 +8,22 @@
78
#include <cmath>
89
#include <functional>
910

10-
#include <ATen/ATen.h>
11-
#include <ATen/Config.h>
11+
#include <ATen/Dispatch.h>
12+
#include <ATen/core/TensorBase.h>
1213
#include <ATen/cpu/vec/vec.h>
1314
#include <ATen/native/TensorIterator.h>
1415
#include <ATen/native/cpu/Loops.h>
1516
#include <ATen/Parallel.h>
1617

17-
#if AT_MKL_ENABLED()
18-
#include <mkl.h>
19-
#endif // AT_MKL_ENABLED()
18+
#include <c10/core/Scalar.h>
2019

2120
namespace at {
2221
namespace native {
2322

2423
namespace {
2524

2625
template <typename scalar_t>
27-
inline void _vec_log_sigmoid(Tensor& output, Tensor& buffer, const Tensor& input) {
26+
inline void _vec_log_sigmoid(TensorBase &output, TensorBase &buffer, const TensorBase &input) {
2827
using Vec = Vectorized<scalar_t>;
2928
scalar_t* output_data = output.data_ptr<scalar_t>();
3029
scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
@@ -34,24 +33,25 @@ inline void _vec_log_sigmoid(Tensor& output, Tensor& buffer, const Tensor& input
3433
int64_t d = 0;
3534
for (; d < size - (size % Vec::size()); d += Vec::size()) {
3635
Vec data_vec = Vec::loadu(input_data + begin+ d);
37-
Vec max_vec = vec::maximum(data_vec.neg(), Vec(scalar_t(0)));
38-
Vec buffer_vec = max_vec.neg().exp() + (data_vec.neg() - max_vec).exp();
39-
Vec output_vec = (max_vec + buffer_vec.log()).neg();
36+
Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
37+
Vec buffer_vec = data_vec.abs().neg().exp();
38+
Vec output_vec = min_vec - buffer_vec.log1p();
4039
buffer_vec.store(buffer_data + begin + d);
4140
output_vec.store(output_data + begin + d);
4241
}
4342
if (size - d > 0) {
4443
Vec data_vec = Vec::loadu(input_data + begin + d, size - d);
45-
Vec max_vec = vec::maximum(data_vec.neg(), Vec(scalar_t(0)));
46-
Vec buffer_vec = max_vec.neg().exp() + (data_vec.neg() - max_vec).exp();
47-
Vec output_vec = (max_vec + buffer_vec.log()).neg();
44+
Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
45+
Vec buffer_vec = data_vec.abs().neg().exp();
46+
Vec output_vec = min_vec - buffer_vec.log1p();
4847
buffer_vec.store(buffer_data + begin + d, size - d);
4948
output_vec.store(output_data + begin + d, size - d);
5049
}
5150
});
5251
}
5352

54-
static void log_sigmoid_cpu_kernel(Tensor& output, Tensor& buffer, const Tensor& input) {
53+
static void log_sigmoid_cpu_kernel(
54+
TensorBase &output, TensorBase &buffer, const TensorBase &input) {
5555
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_sigmoid_cpu", [&] {
5656
_vec_log_sigmoid<scalar_t>(output, buffer, input);
5757
});
@@ -66,19 +66,16 @@ static void log_sigmoid_backward_cpu_kernel(TensorIterator& iter) {
6666
auto one_vec = Vec(one_val);
6767
cpu_kernel_vec(iter,
6868
[=](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
69-
auto max_deriv_val = zero_val;
70-
auto sign_val = -one_val;
71-
if (a < zero_val) {
72-
max_deriv_val = -one_val;
73-
sign_val = one_val;
74-
}
75-
return (-max_deriv_val - sign_val * ((b - one_val) / b)) * c;
69+
auto in_negative = a < scalar_t(0);
70+
auto max_deriv = in_negative ? scalar_t(1) : scalar_t(0);
71+
auto sign = in_negative ? scalar_t(1) : -scalar_t(1);
72+
return (max_deriv - sign * (b / (scalar_t(1) + b))) * c;
7673
},
7774
[=](Vec a, Vec b, Vec c) -> Vec {
7875
auto mask = a < zero_vec;
79-
auto max_deriv_vec = Vec::blendv(zero_vec, one_vec.neg(), mask);
76+
auto max_deriv_vec = Vec::blendv(zero_vec, one_vec, mask);
8077
auto sign_vec = Vec::blendv(one_vec.neg(), one_vec, mask);
81-
return (max_deriv_vec + sign_vec * ((b - one_vec) / b)).neg() * c;
78+
return (max_deriv_vec - sign_vec * (b / (one_vec + b))) * c;
8279
});
8380
});
8481
}

0 commit comments

Comments
 (0)