Skip to content

Conversation

@OmarPavel
Copy link

Summary:
Tune max segment length per cta in triton table batched embeddings, and expose the param via cli.
Improves performance by ~2% on B200

We were using hardcoded 1024, tuned for H100.
This change sets the default for B200 to 4096 (after testing smaller and larger values).

It also exposes this flag via command line, if used together with --no-deterministic, you can add --max-cta-segment-length 4096 to change this via cmdline (for whatever cuda device you're building for).

Tested at 512/1024/2048/4096/8192 values, 4096 outperforms at the two tested batch sizes of 256 and 512

Embedding Dim: 256
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 512                 │ 24,651            │ -4.1% slower │
│ 1024 (default)      │ 23,680            │ baseline     │
│ 4096                │ 23,118            │ +2.4% faster │
│ 8192                │ 28,698            │ -21.2% slower│
└─────────────────────┴───────────────────┴──────────────┘

Embedding Dim: 512
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 1024 (default)      │ 42,067            │ baseline     │
│ 2048                │ 41,471            │ +1.4% faster │
│ 4096                │ 41,235            │ +2.0% faster │
│ 8192                │ 41,414            │ +1.6% faster │
└─────────────────────┴───────────────────┴──────────────┘

Reviewed By: stashuk-olek

Differential Revision: D88338051


@meta-cla meta-cla bot added the cla signed label Dec 10, 2025
OmarPavel added a commit to OmarPavel/FBGEMM that referenced this pull request Dec 13, 2025
…nd expose the param via cli (pytorch#5212)

Summary:

X-link: pytorch/pytorch#170113

X-link: facebookresearch/FBGEMM#2209

Tune max segment length per cta in triton table batched embeddings, and expose the param via cli.
Improves performance by ~2% on B200

We were using hardcoded 1024, tuned for H100.
This change sets the default for B200 to 4096 (after testing smaller and larger values).

It also exposes this flag via command line, if used together with --no-deterministic, you can add --max-cta-segment-length 4096 to change this via cmdline (for whatever cuda device you're building for).

Tested at 512/1024/2048/4096/8192 values, 4096 outperforms at the two tested batch sizes of 256 and 512

```
Embedding Dim: 256
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 512                 │ 24,651            │ -4.1% slower │
│ 1024 (default)      │ 23,680            │ baseline     │
│ 4096                │ 23,118            │ +2.4% faster │
│ 8192                │ 28,698            │ -21.2% slower│
└─────────────────────┴───────────────────┴──────────────┘

Embedding Dim: 512
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 1024 (default)      │ 42,067            │ baseline     │
│ 2048                │ 41,471            │ +1.4% faster │
│ 4096                │ 41,235            │ +2.0% faster │
│ 8192                │ 41,414            │ +1.6% faster │
└─────────────────────┴───────────────────┴──────────────┘

Reviewed By: stashuk-olek

Differential Revision: D88338051
OmarPavel added a commit to OmarPavel/pytorch that referenced this pull request Dec 13, 2025
…nd expose the param via cli (pytorch#170113)

Summary:
X-link: pytorch/FBGEMM#5212


X-link: facebookresearch/FBGEMM#2209

Tune max segment length per cta in triton table batched embeddings, and expose the param via cli.
Improves performance by ~2% on B200

We were using hardcoded 1024, tuned for H100.
This change sets the default for B200 to 4096 (after testing smaller and larger values).

It also exposes this flag via command line, if used together with --no-deterministic, you can add --max-cta-segment-length 4096 to change this via cmdline (for whatever cuda device you're building for).

Tested at 512/1024/2048/4096/8192 values, 4096 outperforms at the two tested batch sizes of 256 and 512

```
Embedding Dim: 256
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 512                 │ 24,651            │ -4.1% slower │
│ 1024 (default)      │ 23,680            │ baseline     │
│ 4096                │ 23,118            │ +2.4% faster │
│ 8192                │ 28,698            │ -21.2% slower│
└─────────────────────┴───────────────────┴──────────────┘

Embedding Dim: 512
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 1024 (default)      │ 42,067            │ baseline     │
│ 2048                │ 41,471            │ +1.4% faster │
│ 4096                │ 41,235            │ +2.0% faster │
│ 8192                │ 41,414            │ +1.6% faster │
└─────────────────────┴───────────────────┴──────────────┘

Test Plan:
buck run -m ovr_configtriton:beta -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -c python.package_style=inplace fbcode//deeplearning/fbgemm/fbgemm_gpu/fb/
triton:triton_table_batched_embeddings_bench -- device --alpha=1.15 --batch-size=131072 --embedding-dim=512 --weights-precision=fp32 --iters=5 --no-deterministic --max-cta-segment-length 4096

buck run -m ovr_config//triton:beta -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -c python.package_style=inplace fbcode//deeplearning/fbgemm/fbgemm_gpu/fb/triton:triton_table_batched_embeddings_bench -- device --alpha=1.15 --batch-size=131072 --embedding-dim=512 --weights-precision=fp32 --iters=5 --no-deterministic --max-cta-segment-length 4096

Reviewed By: stashuk-olek

Differential Revision: D88338051
…nd expose the param via cli (pytorch#5212)

Summary:

X-link: pytorch/pytorch#170113

X-link: facebookresearch/FBGEMM#2209

Tune max segment length per cta in triton table batched embeddings, and expose the param via cli.
Improves performance by ~2% on B200

We were using hardcoded 1024, tuned for H100.
This change sets the default for B200 to 4096 (after testing smaller and larger values).

It also exposes this flag via command line, if used together with --no-deterministic, you can add --max-cta-segment-length 4096 to change this via cmdline (for whatever cuda device you're building for).

Tested at 512/1024/2048/4096/8192 values, 4096 outperforms at the two tested batch sizes of 256 and 512

```
Embedding Dim: 256
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 512                 │ 24,651            │ -4.1% slower │
│ 1024 (default)      │ 23,680            │ baseline     │
│ 4096                │ 23,118            │ +2.4% faster │
│ 8192                │ 28,698            │ -21.2% slower│
└─────────────────────┴───────────────────┴──────────────┘

Embedding Dim: 512
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 1024 (default)      │ 42,067            │ baseline     │
│ 2048                │ 41,471            │ +1.4% faster │
│ 4096                │ 41,235            │ +2.0% faster │
│ 8192                │ 41,414            │ +1.6% faster │
└─────────────────────┴───────────────────┴──────────────┘

Reviewed By: stashuk-olek

Differential Revision: D88338051
OmarPavel added a commit to OmarPavel/pytorch that referenced this pull request Dec 15, 2025
…nd expose the param via cli (pytorch#170113)

Summary:
X-link: pytorch/FBGEMM#5212


X-link: facebookresearch/FBGEMM#2209

Tune max segment length per cta in triton table batched embeddings, and expose the param via cli.
Improves performance by ~2% on B200

We were using hardcoded 1024, tuned for H100.
This change sets the default for B200 to 4096 (after testing smaller and larger values).

It also exposes this flag via command line, if used together with --no-deterministic, you can add --max-cta-segment-length 4096 to change this via cmdline (for whatever cuda device you're building for).

Tested at 512/1024/2048/4096/8192 values, 4096 outperforms at the two tested batch sizes of 256 and 512

```
Embedding Dim: 256
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 512                 │ 24,651            │ -4.1% slower │
│ 1024 (default)      │ 23,680            │ baseline     │
│ 4096                │ 23,118            │ +2.4% faster │
│ 8192                │ 28,698            │ -21.2% slower│
└─────────────────────┴───────────────────┴──────────────┘

Embedding Dim: 512
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 1024 (default)      │ 42,067            │ baseline     │
│ 2048                │ 41,471            │ +1.4% faster │
│ 4096                │ 41,235            │ +2.0% faster │
│ 8192                │ 41,414            │ +1.6% faster │
└─────────────────────┴───────────────────┴──────────────┘

Test Plan:
buck run -m ovr_configtriton:beta -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -c python.package_style=inplace fbcode//deeplearning/fbgemm/fbgemm_gpu/fb/
triton:triton_table_batched_embeddings_bench -- device --alpha=1.15 --batch-size=131072 --embedding-dim=512 --weights-precision=fp32 --iters=5 --no-deterministic --max-cta-segment-length 4096

buck run -m ovr_config//triton:beta -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -c python.package_style=inplace fbcode//deeplearning/fbgemm/fbgemm_gpu/fb/triton:triton_table_batched_embeddings_bench -- device --alpha=1.15 --batch-size=131072 --embedding-dim=512 --weights-precision=fp32 --iters=5 --no-deterministic --max-cta-segment-length 4096

Reviewed By: stashuk-olek

Differential Revision: D88338051
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant