Skip to content

Enable global weight decay to TBE (Backend) (#2498) #2516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

spcyppt
Copy link
Contributor

@spcyppt spcyppt commented Apr 18, 2024

Summary:

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets local but not global weight decay.


Usage:
set

optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL

e.g.,

tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56285676

Copy link

netlify bot commented Apr 18, 2024

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit a7125bc
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/663346376ab1fe0008df8e31
😎 Deploy Preview https://deploy-preview-2516--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Apr 19, 2024
Summary:


With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
---
**Usage:**
set 
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56285676

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Apr 23, 2024
Summary:


With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
---
**Usage:**
set 
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56285676

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Apr 23, 2024
Summary:


With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
---
**Usage:**
set 
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56285676

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Apr 24, 2024
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Apr 30, 2024
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Apr 30, 2024
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Apr 30, 2024
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Apr 30, 2024
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request May 1, 2024
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
@spcyppt spcyppt force-pushed the export-D56285676 branch from 9051b29 to 0297a63 Compare May 1, 2024 21:08
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56285676

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request May 2, 2024
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay=0.01,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
@spcyppt spcyppt force-pushed the export-D56285676 branch from 0579a56 to 42653e9 Compare May 2, 2024 04:00
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56285676

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request May 2, 2024
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay=0.01,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
@spcyppt spcyppt force-pushed the export-D56285676 branch from 42653e9 to ad7e8f9 Compare May 2, 2024 05:24
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56285676

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request May 2, 2024
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay=0.01,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
@spcyppt spcyppt force-pushed the export-D56285676 branch from ad7e8f9 to d5d8101 Compare May 2, 2024 05:29
spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request May 2, 2024
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay=0.01,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
@spcyppt spcyppt force-pushed the export-D56285676 branch from d5d8101 to 0d16a83 Compare May 2, 2024 06:11
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56285676

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request May 2, 2024
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay=0.01,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
@spcyppt spcyppt force-pushed the export-D56285676 branch from 0d16a83 to 8ea229b Compare May 2, 2024 06:37
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56285676

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request May 2, 2024
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay=0.01,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
@spcyppt spcyppt force-pushed the export-D56285676 branch from 8ea229b to c3b63ef Compare May 2, 2024 07:46
Summary:
Pull Request resolved: pytorch#2516

Pull Request resolved: pytorch#2498

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay=0.01,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Differential Revision: D56285676
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in c1f7a66.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants