20
20
21
21
import spu .libspu .link as link
22
22
import spu .psi as psi
23
- from spu .tests .utils import create_clean_folder , create_link_desc , wc_count
23
+ from spu .tests .utils import create_link_desc , wc_count
24
+ from tempfile import TemporaryDirectory
24
25
25
26
26
27
class UnitTests (unittest .TestCase ):
28
+ def setUp (self ) -> None :
29
+ self .tempdir_ = TemporaryDirectory ()
30
+ return super ().setUp ()
31
+
32
+ def tearDown (self ) -> None :
33
+ self .tempdir_ .cleanup ()
34
+ return super ().tearDown ()
35
+
27
36
def test_pir (self ):
28
37
# setup stage
29
-
30
- server_setup_config = '''
31
- {
38
+ server_setup_config = f'''
39
+ {{
32
40
"mode": "MODE_SERVER_SETUP",
33
41
"pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI",
34
- "pir_server_config": {
42
+ "pir_server_config": {{
35
43
"input_path": "spu/tests/data/alice.csv",
36
- "setup_path": "/tmp /spu_test_pir_pir_server_setup",
44
+ "setup_path": "{ self . tempdir_ . name } /spu_test_pir_pir_server_setup",
37
45
"key_columns": [
38
46
"id"
39
47
],
@@ -42,56 +50,56 @@ def test_pir(self):
42
50
],
43
51
"label_max_len": 288,
44
52
"bucket_size": 1000000,
45
- "apsi_server_config": {
46
- "oprf_key_path": "/tmp /spu_test_pir_server_secret_key.bin",
53
+ "apsi_server_config": {{
54
+ "oprf_key_path": "{ self . tempdir_ . name } /spu_test_pir_server_secret_key.bin",
47
55
"num_per_query": 1,
48
56
"compressed": false
49
- }
50
- }
51
- }
57
+ }}
58
+ }}
59
+ }}
52
60
'''
53
61
54
- with open ("/tmp/spu_test_pir_server_secret_key.bin" , 'wb' ) as f :
62
+ with open (
63
+ f"{ self .tempdir_ .name } /spu_test_pir_server_secret_key.bin" , 'wb'
64
+ ) as f :
55
65
f .write (
56
66
bytes .fromhex (
57
67
"000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000"
58
68
)
59
69
)
60
70
61
- create_clean_folder ("/tmp/spu_test_pir_pir_server_setup" )
62
-
63
71
psi .pir (json_format .ParseDict (json .loads (server_setup_config ), psi .PirConfig ()))
64
72
65
- server_online_config = '''
66
- {
73
+ server_online_config = f '''
74
+ {{
67
75
"mode": "MODE_SERVER_ONLINE",
68
76
"pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI",
69
- "pir_server_config": {
70
- "setup_path": "/tmp /spu_test_pir_pir_server_setup"
71
- }
72
- }
77
+ "pir_server_config": {{
78
+ "setup_path": "{ self . tempdir_ . name } /spu_test_pir_pir_server_setup"
79
+ }}
80
+ }}
73
81
'''
74
82
75
- client_online_config = '''
76
- {
83
+ client_online_config = f '''
84
+ {{
77
85
"mode": "MODE_CLIENT",
78
86
"pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI",
79
- "pir_client_config": {
80
- "input_path": "/tmp /spu_test_pir_pir_client.csv",
87
+ "pir_client_config": {{
88
+ "input_path": "{ self . tempdir_ . name } /spu_test_pir_pir_client.csv",
81
89
"key_columns": [
82
90
"id"
83
91
],
84
- "output_path": "/tmp /spu_test_pir_pir_output.csv"
85
- }
86
- }
92
+ "output_path": "{ self . tempdir_ . name } /spu_test_pir_pir_output.csv"
93
+ }}
94
+ }}
87
95
'''
88
96
89
97
pir_client_input_content = '''id
90
98
user808
91
99
xxx
92
100
'''
93
101
94
- with open ("/tmp /spu_test_pir_pir_client.csv" , 'w' ) as f :
102
+ with open (f" { self . tempdir_ . name } /spu_test_pir_pir_client.csv" , 'w' ) as f :
95
103
f .write (pir_client_input_content )
96
104
97
105
configs = [
@@ -118,7 +126,9 @@ def wrap(rank, link_desc, configs):
118
126
self .assertEqual (job .exitcode , 0 )
119
127
120
128
# including title, actual matched item cnt is 1.
121
- self .assertEqual (wc_count ("/tmp/spu_test_pir_pir_output.csv" ), 2 )
129
+ self .assertEqual (
130
+ wc_count (f"{ self .tempdir_ .name } /spu_test_pir_pir_output.csv" ), 2
131
+ )
122
132
123
133
124
134
if __name__ == '__main__' :
0 commit comments