Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repo sync #821

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,19 @@ If you think SPU is helpful for your research or development, please consider ci

```text
@inproceedings{ditto,
title = {Ditto: Quantization-aware Secure Inference of Transformers upon {MPC}},
author = {Wu, Haoqi and Fang, Wenjing and Zheng, Yancheng and Ma, Junming and Tan, Jin and Wang, Lei},
booktitle = {Proceedings of the 41st International Conference on Machine Learning},
pages = {53346--53365},
year = {2024},
editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
volume = {235},
series = {Proceedings of Machine Learning Research},
month = {21--27 Jul},
publisher = {PMLR},
pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/wu24d/wu24d.pdf},
url = {https://proceedings.mlr.press/v235/wu24d.html},
abstract = {Due to the rising privacy concerns on sensitive client data and trained models like Transformers, secure multi-party computation (MPC) techniques are employed to enable secure inference despite attendant overhead. Existing works attempt to reduce the overhead using more MPC-friendly non-linear function approximations. However, the integration of quantization widely used in plaintext inference into the MPC domain remains unclear. To bridge this gap, we propose the framework named Ditto to enable more efficient quantization-aware secure Transformer inference. Concretely, we first incorporate an MPC-friendly quantization into Transformer inference and employ a quantization-aware distillation procedure to maintain the model utility. Then, we propose novel MPC primitives to support the type conversions that are essential in quantization and implement the quantization-aware MPC execution of secure quantized inference. This approach significantly decreases both computation and communication overhead, leading to improvements in overall efficiency. We conduct extensive experiments on Bert and GPT2 models to evaluate the performance of Ditto. The results demonstrate that Ditto is about $3.14\sim 4.40\times$ faster than MPCFormer (ICLR 2023) and $1.44\sim 2.35\times$ faster than the state-of-the-art work PUMA with negligible utility degradation.}
title = {Ditto: Quantization-aware Secure Inference of Transformers upon {MPC}},
author = {Wu, Haoqi and Fang, Wenjing and Zheng, Yancheng and Ma, Junming and Tan, Jin and Wang, Lei},
booktitle = {Proceedings of the 41st International Conference on Machine Learning},
pages = {53346--53365},
year = {2024},
editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
volume = {235},
series = {Proceedings of Machine Learning Research},
month = {21--27 Jul},
publisher = {PMLR},
pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/wu24d/wu24d.pdf},
url = {https://proceedings.mlr.press/v235/wu24d.html},
abstract = {Due to the rising privacy concerns on sensitive client data and trained models like Transformers, secure multi-party computation (MPC) techniques are employed to enable secure inference despite attendant overhead. Existing works attempt to reduce the overhead using more MPC-friendly non-linear function approximations. However, the integration of quantization widely used in plaintext inference into the MPC domain remains unclear. To bridge this gap, we propose the framework named Ditto to enable more efficient quantization-aware secure Transformer inference. Concretely, we first incorporate an MPC-friendly quantization into Transformer inference and employ a quantization-aware distillation procedure to maintain the model utility. Then, we propose novel MPC primitives to support the type conversions that are essential in quantization and implement the quantization-aware MPC execution of secure quantized inference. This approach significantly decreases both computation and communication overhead, leading to improvements in overall efficiency. We conduct extensive experiments on Bert and GPT2 models to evaluate the performance of Ditto. The results demonstrate that Ditto is about $3.14\sim 4.40\times$ faster than MPCFormer (ICLR 2023) and $1.44\sim 2.35\times$ faster than the state-of-the-art work PUMA with negligible utility degradation.}
}
```

Expand Down
10 changes: 5 additions & 5 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def _libpsi():
http_archive,
name = "psi",
urls = [
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.0.dev240801.tar.gz",
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.0.dev240814.tar.gz",
],
strip_prefix = "psi-0.4.0.dev240801",
sha256 = "541ad74de0cd9e6bffe348c3bc97e659fccb1f1811e612f9d8e6b1debdd7c2a0",
strip_prefix = "psi-0.4.0.dev240814",
sha256 = "2a16a5751d1b7051f01edd11f1fcf01b67ff4d67ec136e7bc6d1d729d7f22634",
)

def _rules_proto_grpc():
Expand Down Expand Up @@ -135,8 +135,8 @@ def _bazel_skylib():
)

def _com_github_openxla_xla():
OPENXLA_COMMIT = "04f2bfe797408c9efe742b89e2e4db6cf526ebb7"
OPENXLA_SHA256 = "7e1d24737815be7607eed5f02fe7f81d97ffe358dfb7b4876f97bce8f48b3b3e"
OPENXLA_COMMIT = "64bdcc53a1b24abf19b1fe598e6f9b0fe6454470"
OPENXLA_SHA256 = "60918b3a0391fe9e0bd506c9b90170b7b5fa64d06de7ec1f4f0e351a303a88fa"

# We need openxla to handle xla/mhlo/stablehlo
maybe(
Expand Down
2 changes: 1 addition & 1 deletion libspu/compiler/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("//bazel:spu.bzl", "spu_cc_binary", "spu_cc_library")
load("//bazel:spu.bzl", "spu_cc_library")

package(
default_visibility = ["//visibility:public"],
Expand Down
2 changes: 1 addition & 1 deletion libspu/compiler/front_end/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ spu_cc_library(
"@xla//xla/service:while_loop_constant_sinking",
"@xla//xla/service:while_loop_simplifier",
"@xla//xla/service:zero_sized_hlo_elimination",
"@xla//xla/service/gpu:dot_dimension_sorter",
"@xla//xla/service/gpu/transforms:dot_dimension_sorter",
"@xla//xla/translate/hlo_to_mhlo:hlo_module_importer",
],
)
Expand Down
2 changes: 1 addition & 1 deletion libspu/compiler/front_end/hlo_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include "xla/service/float_support.h"
#include "xla/service/gather_expander.h"
#include "xla/service/gather_simplifier.h"
#include "xla/service/gpu/dot_dimension_sorter.h"
#include "xla/service/gpu/transforms/dot_dimension_sorter.h"
#include "xla/service/hlo_constant_folding.h"
#include "xla/service/hlo_cse.h"
#include "xla/service/hlo_dce.h"
Expand Down
4 changes: 4 additions & 0 deletions libspu/device/utils/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

load("//bazel:spu.bzl", "spu_cc_binary", "spu_cc_library")

package(
default_visibility = ["//visibility:public"],
)

spu_cc_library(
name = "debug_dump_constant",
srcs = [
Expand Down
1 change: 1 addition & 0 deletions libspu/dialect/utils/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ spu_cc_library(
hdrs = glob([
"*.h",
]),
visibility = ["//visibility:public"],
deps = [
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down
Loading