Skip to content

Commit cb69a32

Browse files
authored
Repo sync (#586)
# Pull Request ## What problem does this PR solve? Issue Number: Fixed # ## Possible side effects? - Performance: - Backward compatibility:
1 parent dd968fa commit cb69a32

14 files changed

+178
-129
lines changed

bazel/repositories.bzl

+3-3
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ def _libpsi():
5151
http_archive,
5252
name = "psi",
5353
urls = [
54-
"https://github.com/secretflow/psi/archive/refs/tags/v0.3.0.dev240222.tar.gz",
54+
"https://github.com/secretflow/psi/archive/refs/tags/v0.3.0.dev240304.tar.gz",
5555
],
56-
strip_prefix = "psi-0.3.0.dev240222",
57-
sha256 = "a7319040510a1581741f05ac4b31e3d887ba8ba4766154736f96d76970d00de5",
56+
strip_prefix = "psi-0.3.0.dev240304",
57+
sha256 = "6e56dceaffbe67f7d17fbb32a5486ec31c6f17156aadb9ac57f47e4c7fe6b384",
5858
)
5959

6060
def _rules_proto_grpc():

libspu/core/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ spu_cc_library(
7777
deps = [
7878
"//libspu:spu_cc_proto",
7979
"//libspu/core:prelude",
80+
"@yacl//yacl/utils:parallel",
8081
],
8182
)
8283

libspu/core/config.cc

+5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "libspu/core/config.h"
1616

1717
#include "prelude.h"
18+
#include "yacl/utils/parallel.h"
1819

1920
namespace spu {
2021
namespace {
@@ -43,6 +44,10 @@ void populateRuntimeConfig(RuntimeConfig& cfg) {
4344
SPU_ENFORCE(cfg.protocol() != ProtocolKind::PROT_INVALID);
4445
SPU_ENFORCE(cfg.field() != FieldType::FT_INVALID);
4546

47+
if (cfg.max_concurrency() == 0) {
48+
cfg.set_max_concurrency(yacl::get_num_threads());
49+
}
50+
4651
//
4752
if (cfg.fxp_fraction_bits() == 0) {
4853
cfg.set_fxp_fraction_bits(defaultFxpBits(cfg.field()));

libspu/core/context.cc

+23-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
#include "libspu/core/context.h"
1616

17+
#include "yacl/link/algorithm/allgather.h"
18+
#include "yacl/utils/parallel.h"
19+
1720
#include "libspu/core/trace.h"
1821

1922
namespace spu {
@@ -35,7 +38,26 @@ SPUContext::SPUContext(const RuntimeConfig& config,
3538
const std::shared_ptr<yacl::link::Context>& lctx)
3639
: config_(config),
3740
prot_(std::make_unique<Object>(genRootObjectId(lctx))),
38-
lctx_(lctx) {}
41+
lctx_(lctx),
42+
max_cluster_level_concurrency_(yacl::get_num_threads()) {
43+
// Limit number of threads
44+
if (config.max_concurrency() > 0) {
45+
yacl::set_num_threads(config.max_concurrency());
46+
max_cluster_level_concurrency_ = std::min<int32_t>(
47+
max_cluster_level_concurrency_, config.max_concurrency());
48+
}
49+
50+
if (lctx_) {
51+
auto other_max = yacl::link::AllGather(
52+
lctx, {&max_cluster_level_concurrency_, sizeof(int32_t)}, "num_cores");
53+
54+
// Comupte min
55+
for (const auto& o : other_max) {
56+
max_cluster_level_concurrency_ = std::min<int32_t>(
57+
max_cluster_level_concurrency_, o.data<int32_t>()[0]);
58+
}
59+
}
60+
}
3961

4062
std::unique_ptr<SPUContext> SPUContext::fork() const {
4163
std::shared_ptr<yacl::link::Context> new_lctx =

libspu/core/context.h

+9
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ class SPUContext final {
3838
// TODO(jint): do we really need a link here? how about a FHE context.
3939
std::shared_ptr<yacl::link::Context> lctx_;
4040

41+
// Min number of cores in SPU cluster
42+
int32_t max_cluster_level_concurrency_;
43+
4144
public:
4245
explicit SPUContext(const RuntimeConfig& config,
4346
const std::shared_ptr<yacl::link::Context>& lctx);
@@ -81,6 +84,12 @@ class SPUContext final {
8184
StateT* getState() {
8285
return prot_->template getState<StateT>();
8386
}
87+
88+
// If any task assumes same level of parallelism across all instances,
89+
// this is the max number of tasks to launch at the same time.
90+
int32_t getClusterLevelMaxConcurrency() const {
91+
return max_cluster_level_concurrency_;
92+
}
8493
};
8594

8695
class KernelEvalContext final {

libspu/spu.proto

+3
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ message RuntimeConfig {
174174
// 0(default) indicates implementation defined.
175175
int64 fxp_fraction_bits = 3;
176176

177+
// Max number of cores
178+
int32 max_concurrency = 4;
179+
177180
///////////////////////////////////////
178181
// Advanced
179182
///////////////////////////////////////

spu/libpsi.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ void BindLibs(py::module& m) {
6868
psi::BucketPsiConfig config;
6969
YACL_ENFORCE(config.ParseFromString(config_pb));
7070

71-
psi::BucketPsi psi(config, lctx, ic_mode);
72-
auto r = psi.Run(std::move(progress_callbacks), callbacks_interval_ms);
71+
auto r = psi::RunLegacyPsi(config, lctx, std::move(progress_callbacks),
72+
callbacks_interval_ms, ic_mode);
7373
return r.SerializeAsString();
7474
},
7575
py::arg("link_context"), py::arg("psi_config"),

spu/tests/BUILD.bazel

+2
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ py_test(
224224
name = "link_test",
225225
srcs = ["link_test.py"],
226226
deps = [
227+
":utils",
227228
"//spu:api",
228229
],
229230
)
@@ -308,6 +309,7 @@ py_test(
308309
"exclusive-if-local",
309310
],
310311
deps = [
312+
":utils",
311313
"//spu/utils:distributed",
312314
],
313315
)

spu/tests/distributed_test.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,13 @@
2727

2828
import spu.utils.distributed as ppd
2929
from spu import spu_pb2
30-
31-
32-
def unused_tcp_port() -> int:
33-
"""Return an unused port"""
34-
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
35-
sock.bind(("localhost", 0))
36-
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
37-
return cast(int, sock.getsockname()[1])
30+
from spu.tests.utils import get_free_port
3831

3932

4033
TEST_NODES_DEF = {
41-
"node:0": f"127.0.0.1:{unused_tcp_port()}",
42-
"node:1": f"127.0.0.1:{unused_tcp_port()}",
43-
"node:2": f"127.0.0.1:{unused_tcp_port()}",
34+
"node:0": f"127.0.0.1:{get_free_port()}",
35+
"node:1": f"127.0.0.1:{get_free_port()}",
36+
"node:2": f"127.0.0.1:{get_free_port()}",
4437
}
4538

4639

@@ -50,9 +43,9 @@ def unused_tcp_port() -> int:
5043
"config": {
5144
"node_ids": ["node:0", "node:1", "node:2"],
5245
"spu_internal_addrs": [
53-
f"127.0.0.1:{unused_tcp_port()}",
54-
f"127.0.0.1:{unused_tcp_port()}",
55-
f"127.0.0.1:{unused_tcp_port()}",
46+
f"127.0.0.1:{get_free_port()}",
47+
f"127.0.0.1:{get_free_port()}",
48+
f"127.0.0.1:{get_free_port()}",
5649
],
5750
"runtime_config": {
5851
"protocol": "ABY3",

spu/tests/link_test.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,14 @@
2323

2424
import spu.libspu.link as link
2525
from socket import socket
26-
27-
28-
def _rand_port():
29-
with socket() as s:
30-
s.bind(("localhost", 0))
31-
return s.getsockname()[1]
26+
from spu.tests.utils import get_free_port
3227

3328

3429
class UnitTests(unittest.TestCase):
3530
def test_link_brpc(self):
3631
desc = link.Desc()
37-
desc.add_party("alice", f"127.0.0.1:{_rand_port()}")
38-
desc.add_party("bob", f"127.0.0.1:{_rand_port()}")
32+
desc.add_party("alice", f"127.0.0.1:{get_free_port()}")
33+
desc.add_party("bob", f"127.0.0.1:{get_free_port()}")
3934

4035
def proc(rank):
4136
data = "hello" if rank == 0 else "world"
@@ -90,8 +85,8 @@ def thread(rank):
9085

9186
def test_link_send_recv(self):
9287
desc = link.Desc()
93-
desc.add_party("alice", f"127.0.0.1:{_rand_port()}")
94-
desc.add_party("bob", f"127.0.0.1:{_rand_port()}")
88+
desc.add_party("alice", f"127.0.0.1:{get_free_port()}")
89+
desc.add_party("bob", f"127.0.0.1:{get_free_port()}")
9590

9691
def proc(rank):
9792
lctx = link.create_brpc(desc, rank)
@@ -116,8 +111,8 @@ def proc(rank):
116111

117112
def test_link_send_async(self):
118113
desc = link.Desc()
119-
desc.add_party("alice", f"127.0.0.1:{_rand_port()}")
120-
desc.add_party("bob", f"127.0.0.1:{_rand_port()}")
114+
desc.add_party("alice", f"127.0.0.1:{get_free_port()}")
115+
desc.add_party("bob", f"127.0.0.1:{get_free_port()}")
121116

122117
def proc(rank):
123118
lctx = link.create_brpc(desc, rank)
@@ -140,8 +135,8 @@ def proc(rank):
140135

141136
def test_link_next_rank(self):
142137
desc = link.Desc()
143-
desc.add_party("alice", f"127.0.0.1:{_rand_port()}")
144-
desc.add_party("bob", f"127.0.0.1:{_rand_port()}")
138+
desc.add_party("alice", f"127.0.0.1:{get_free_port()}")
139+
desc.add_party("bob", f"127.0.0.1:{get_free_port()}")
145140

146141
def proc(rank):
147142
lctx = link.create_brpc(desc, rank)

spu/tests/pir_test.py

+39-29
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,28 @@
2020

2121
import spu.libspu.link as link
2222
import spu.psi as psi
23-
from spu.tests.utils import create_clean_folder, create_link_desc, wc_count
23+
from spu.tests.utils import create_link_desc, wc_count
24+
from tempfile import TemporaryDirectory
2425

2526

2627
class UnitTests(unittest.TestCase):
28+
def setUp(self) -> None:
29+
self.tempdir_ = TemporaryDirectory()
30+
return super().setUp()
31+
32+
def tearDown(self) -> None:
33+
self.tempdir_.cleanup()
34+
return super().tearDown()
35+
2736
def test_pir(self):
2837
# setup stage
29-
30-
server_setup_config = '''
31-
{
38+
server_setup_config = f'''
39+
{{
3240
"mode": "MODE_SERVER_SETUP",
3341
"pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI",
34-
"pir_server_config": {
42+
"pir_server_config": {{
3543
"input_path": "spu/tests/data/alice.csv",
36-
"setup_path": "/tmp/spu_test_pir_pir_server_setup",
44+
"setup_path": "{self.tempdir_.name}/spu_test_pir_pir_server_setup",
3745
"key_columns": [
3846
"id"
3947
],
@@ -42,56 +50,56 @@ def test_pir(self):
4250
],
4351
"label_max_len": 288,
4452
"bucket_size": 1000000,
45-
"apsi_server_config": {
46-
"oprf_key_path": "/tmp/spu_test_pir_server_secret_key.bin",
53+
"apsi_server_config": {{
54+
"oprf_key_path": "{self.tempdir_.name}/spu_test_pir_server_secret_key.bin",
4755
"num_per_query": 1,
4856
"compressed": false
49-
}
50-
}
51-
}
57+
}}
58+
}}
59+
}}
5260
'''
5361

54-
with open("/tmp/spu_test_pir_server_secret_key.bin", 'wb') as f:
62+
with open(
63+
f"{self.tempdir_.name}/spu_test_pir_server_secret_key.bin", 'wb'
64+
) as f:
5565
f.write(
5666
bytes.fromhex(
5767
"000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000"
5868
)
5969
)
6070

61-
create_clean_folder("/tmp/spu_test_pir_pir_server_setup")
62-
6371
psi.pir(json_format.ParseDict(json.loads(server_setup_config), psi.PirConfig()))
6472

65-
server_online_config = '''
66-
{
73+
server_online_config = f'''
74+
{{
6775
"mode": "MODE_SERVER_ONLINE",
6876
"pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI",
69-
"pir_server_config": {
70-
"setup_path": "/tmp/spu_test_pir_pir_server_setup"
71-
}
72-
}
77+
"pir_server_config": {{
78+
"setup_path": "{self.tempdir_.name}/spu_test_pir_pir_server_setup"
79+
}}
80+
}}
7381
'''
7482

75-
client_online_config = '''
76-
{
83+
client_online_config = f'''
84+
{{
7785
"mode": "MODE_CLIENT",
7886
"pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI",
79-
"pir_client_config": {
80-
"input_path": "/tmp/spu_test_pir_pir_client.csv",
87+
"pir_client_config": {{
88+
"input_path": "{self.tempdir_.name}/spu_test_pir_pir_client.csv",
8189
"key_columns": [
8290
"id"
8391
],
84-
"output_path": "/tmp/spu_test_pir_pir_output.csv"
85-
}
86-
}
92+
"output_path": "{self.tempdir_.name}/spu_test_pir_pir_output.csv"
93+
}}
94+
}}
8795
'''
8896

8997
pir_client_input_content = '''id
9098
user808
9199
xxx
92100
'''
93101

94-
with open("/tmp/spu_test_pir_pir_client.csv", 'w') as f:
102+
with open(f"{self.tempdir_.name}/spu_test_pir_pir_client.csv", 'w') as f:
95103
f.write(pir_client_input_content)
96104

97105
configs = [
@@ -118,7 +126,9 @@ def wrap(rank, link_desc, configs):
118126
self.assertEqual(job.exitcode, 0)
119127

120128
# including title, actual matched item cnt is 1.
121-
self.assertEqual(wc_count("/tmp/spu_test_pir_pir_output.csv"), 2)
129+
self.assertEqual(
130+
wc_count(f"{self.tempdir_.name}/spu_test_pir_pir_output.csv"), 2
131+
)
122132

123133

124134
if __name__ == '__main__':

0 commit comments

Comments
 (0)