Skip to content

Commit a1f9186

Browse files
authored
Merge pull request Bears-R-Us#125 from mhmerrill/set-ops-feature
Set ops feature
2 parents 1917e96 + 279c9b2 commit a1f9186

File tree

14 files changed

+713
-15
lines changed

14 files changed

+713
-15
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
arkouda_server
22
arkouda_server_real
3+
arkouda_server_llvm
4+
arkouda_server_llvm_real
5+
*_real
36
#*#
47
.#*
58
*.~*

arkouda.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,11 @@ def __neg__(self):
339339

340340
# overload unary~ for pdarray implemented as pdarray^(~0)
341341
def __invert__(self):
342-
return self.binop(~0, "^")
342+
if self.dtype == np.int64:
343+
return self.binop(~0, "^")
344+
if self.dtype == np.bool:
345+
return self.binop(True, "^")
346+
return NotImplemented
343347

344348
# op= operators
345349
def opeq(self, other, op):
@@ -674,12 +678,12 @@ def histogram(pda, bins=10):
674678
else:
675679
raise TypeError("must be pdarray {} and bins must be an int {}".format(pda,bins))
676680

677-
def in1d(pda1, pda2):
681+
def in1d(pda1, pda2, invert=False):
678682
if isinstance(pda1, pdarray) and isinstance(pda2, pdarray):
679-
repMsg = generic_msg("in1d {} {}".format(pda1.name, pda2.name))
683+
repMsg = generic_msg("in1d {} {} {}".format(pda1.name, pda2.name, invert))
680684
return create_pdarray(repMsg)
681685
else:
682-
raise TypeError("must be pdarray {} and bins must be an int {}".format(pda,bins))
686+
raise TypeError("must be pdarray {} or {}".format(pda1,pda2))
683687

684688
def unique(pda, return_counts=False):
685689
if isinstance(pda, pdarray):
@@ -737,6 +741,85 @@ def coargsort(arrays):
737741
repMsg = generic_msg("coargsort {} {}".format(len(arrays), ' '.join([a.name for a in arrays])))
738742
return create_pdarray(repMsg)
739743

744+
def concatenate(arrays):
745+
size = 0
746+
dtype = None
747+
for a in arrays:
748+
if not isinstance(a, pdarray):
749+
raise ValueError("Argument must be an iterable of pdarrays")
750+
if dtype == None:
751+
dtype = a.dtype
752+
elif dtype != a.dtype:
753+
raise ValueError("All pdarrays must have same dtype")
754+
size += a.size
755+
if size == 0:
756+
return zeros(0, dtype=int64)
757+
repMsg = generic_msg("concatenate {} {}".format(len(arrays), ' '.join([a.name for a in arrays])))
758+
return create_pdarray(repMsg)
759+
760+
# (A1 | A2) Set Union: elements are in one or the other or both
761+
def union1d(pda1, pda2):
762+
if isinstance(pda1, pdarray) and isinstance(pda2, pdarray):
763+
if pda1.size == 0:
764+
return pda2 # union is pda2
765+
if pda2.size == 0:
766+
return pda1 # union is pda1
767+
return unique(concatenate((unique(pda1), unique(pda2))))
768+
else:
769+
raise TypeError("must be pdarray {} or {}".format(pda1,pda2))
770+
771+
# (A1 & A2) Set Intersection: elements have to be in both arrays
772+
def intersect1d(pda1, pda2, assume_unique=False):
773+
if isinstance(pda1, pdarray) and isinstance(pda2, pdarray):
774+
if pda1.size == 0:
775+
return pda1 # nothing in the intersection
776+
if pda2.size == 0:
777+
return pda2 # nothing in the intersection
778+
if not assume_unique:
779+
pda1 = unique(pda1)
780+
pda2 = unique(pda2)
781+
aux = concatenate((pda1, pda2))
782+
aux_sort_indices = argsort(aux)
783+
aux = aux[aux_sort_indices]
784+
mask = aux[1:] == aux[:-1]
785+
int1d = aux[:-1][mask]
786+
return int1d
787+
else:
788+
raise TypeError("must be pdarray {} or {}".format(pda1,pda2))
789+
790+
# (A1 - A2) Set Difference: elements have to be in first array but not second
791+
def setdiff1d(pda1, pda2, assume_unique=False):
792+
if isinstance(pda1, pdarray) and isinstance(pda2, pdarray):
793+
if pda1.size == 0:
794+
return pda1 # return a zero length pdarray
795+
if pda2.size == 0:
796+
return pda1 # subtracting nothing return orig pdarray
797+
if not assume_unique:
798+
pda1 = unique(pda1)
799+
pda2 = unique(pda2)
800+
return pda1[in1d(pda1, pda2, invert=True)]
801+
else:
802+
raise TypeError("must be pdarray {} or {}".format(pda1,pda2))
803+
804+
# (A1 ^ A2) Set Symmetric Difference: elements are not in the intersection
805+
def setxor1d(pda1, pda2, assume_unique=False):
806+
if isinstance(pda1, pdarray) and isinstance(pda2, pdarray):
807+
if pda1.size == 0:
808+
return pda2 # return other pdarray if pda1 is empty
809+
if pda2.size == 0:
810+
return pda1 # return other pdarray if pda2 is empty
811+
if not assume_unique:
812+
pda1 = unique(pda1)
813+
pda2 = unique(pda2)
814+
aux = concatenate((pda1, pda2))
815+
aux_sort_indices = argsort(aux)
816+
aux = aux[aux_sort_indices]
817+
flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True])))
818+
return aux[flag[1:] & flag[:-1]]
819+
else:
820+
raise TypeError("must be pdarray {} or {}".format(pda1,pda2))
821+
822+
740823
def local_argsort(pda):
741824
if isinstance(pda, pdarray):
742825
if pda.size == 0:

benchmarks/bench_mac_llvm.log

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
4.3.1
2+
psp = tcp://localhost:5555
3+
connected to tcp://localhost:5555
4+
array size = 1,000,000
5+
number of trials = 6
6+
>>> arkouda argsort
7+
numLocales = 1, N = 1,000,000
8+
Average time = 0.0469 sec
9+
Average rate = 0.1588 GiB/sec
10+
>>> numpy argsort
11+
N = 1,000,000
12+
Average time = 0.0978 sec
13+
Average rate = 0.0762 GiB/sec
14+
4.3.1
15+
psp = tcp://localhost:5555
16+
connected to tcp://localhost:5555
17+
size of index array = 100,000,000
18+
size of values array = 100,000,000
19+
number of trials = 6
20+
>>> arkouda gather
21+
numLocales = 1, num_indices = 100,000,000 ; num_values = 100,000,000
22+
Average time = 1.3009 sec
23+
Average rate = 1.72 GiB/sec
24+
>>> numpy gather
25+
num_indices = 100,000,000 ; num_values = 100,000,000
26+
Average time = 2.6057 sec
27+
Average rate = 0.86 GiB/sec
28+
4.3.1
29+
psp = tcp://localhost:5555
30+
connected to tcp://localhost:5555
31+
array size = 100,000,000
32+
number of trials = 6
33+
>>> arkouda reduce
34+
numLocales = 1, N = 100,000,000
35+
sum = 4999999950000000
36+
Average time = 0.0391 sec
37+
Average rate = 19.04 GiB/sec
38+
prod = 0.0
39+
Average time = 0.0421 sec
40+
Average rate = 17.72 GiB/sec
41+
min = 0
42+
Average time = 0.0389 sec
43+
Average rate = 19.13 GiB/sec
44+
max = 99999999
45+
Average time = 0.0398 sec
46+
Average rate = 18.73 GiB/sec
47+
>>> numpy reduce
48+
N = 100,000,000
49+
sum = 4999999950000000
50+
Average time = 0.0557 sec
51+
Average rate = 13.38 GiB/sec
52+
prod = 0
53+
Average time = 0.0900 sec
54+
Average rate = 8.28 GiB/sec
55+
min = 0
56+
Average time = 0.0897 sec
57+
Average rate = 8.31 GiB/sec
58+
max = 99999999
59+
Average time = 0.0885 sec
60+
Average rate = 8.42 GiB/sec
61+
4.3.1
62+
psp = tcp://localhost:5555
63+
connected to tcp://localhost:5555
64+
array size = 100,000,000
65+
number of trials = 6
66+
>>> arkouda scan
67+
numLocales = 1, N = 100,000,000
68+
cumsum, final value = 4999999950000000
69+
Average time = 0.7731 sec
70+
Average rate = 1.93 GiB/sec
71+
cumprod, final value = 0
72+
Average time = 0.6849 sec
73+
Average rate = 2.18 GiB/sec
74+
>>> numpy scan
75+
N = 100,000,000
76+
cumsum, final value = 4999999950000000
77+
Average time = 0.6315 sec
78+
Average rate = 2.36 GiB/sec
79+
cumprod, final value = 0
80+
Average time = 0.7011 sec
81+
Average rate = 2.13 GiB/sec
82+
4.3.1
83+
psp = tcp://localhost:5555
84+
connected to tcp://localhost:5555
85+
size of index array = 100,000,000
86+
size of values array = 100,000,000
87+
number of trials = 6
88+
>>> arkouda scatter
89+
numLocales = 1, num_indices = 100,000,000 ; num_values = 100,000,000
90+
Average time = 0.9536 sec
91+
Average rate = 2.34 GiB/sec
92+
>>> numpy scatter
93+
num_indices = 100,000,000 ; num_values = 100,000,000
94+
Average time = 2.0967 sec
95+
Average rate = 1.07 GiB/sec
96+
4.3.1
97+
psp = tcp://localhost:5555
98+
connected to tcp://localhost:5555
99+
array size = 100,000,000
100+
number of trials = 6
101+
>>> arkouda stream
102+
numLocales = 1, N = 100,000,000
103+
Average time = 0.6236 sec
104+
Average rate = 3.58 GiB/sec
105+
>>> numpy stream
106+
N = 100,000,000
107+
Average time = 0.8965 sec
108+
Average rate = 2.49 GiB/sec

benchmarks/run_all.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash
2+
3+
echo ---- argsort ----
4+
./argsort.py -n 10000000 localhost 5555
5+
echo ---- gather ----
6+
./gather.py localhost 5555
7+
echo ---- reduce ----
8+
./reduce.py -t 10 localhost 5555
9+
echo ---- scan ----
10+
./scan.py localhost 5555
11+
echo ---- scatter ----
12+
./scatter.py localhost 5555
13+
echo ---- stream ----
14+
./stream.py localhost 5555
15+

src/ConcatenateMsg.chpl

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
module ConcatenateMsg
2+
{
3+
use ServerConfig;
4+
5+
use Time only;
6+
use Math only;
7+
8+
use MultiTypeSymbolTable;
9+
use MultiTypeSymEntry;
10+
use ServerErrorStrings;
11+
12+
use AryUtil;
13+
14+
/* Concatenate a list of arrays together
15+
to form one array
16+
*/
17+
proc concatenateMsg(reqMsg: string, st: borrowed SymTab) {
18+
var pn = "concatenate";
19+
var repMsg: string;
20+
var fields = reqMsg.split();
21+
var cmd = fields[1];
22+
var n = try! fields[2]:int; // number of arrays to sort
23+
var names = fields[3..];
24+
// Check that fields contains the stated number of arrays
25+
if (n != names.size) { return try! incompatibleArgumentsError(pn, "Expected %i arrays but got %i".format(n, names.size)); }
26+
/* var arrays: [0..#n] borrowed GenSymEntry; */
27+
var size: int = 0;
28+
var dtype: DType;
29+
// Check that all arrays exist in the symbol table and have the same size
30+
for (name, i) in zip(names, 1..) {
31+
// arrays[i] = st.lookup(name): borrowed GenSymEntry;
32+
var g: borrowed GenSymEntry = st.lookup(name);
33+
if (g == nil) { return unknownSymbolError(pn, name); }
34+
if (i == 1) {dtype = g.dtype;}
35+
else {
36+
if (dtype != g.dtype) {
37+
return try! incompatibleArgumentsError(pn, "Expected %s dtype but got %s dtype".format(dtype2str(dtype), dtype2str(g.dtype)));
38+
}
39+
}
40+
// accumulate size from each array size
41+
size += g.size;
42+
}
43+
// allocate a new array in the symboltable
44+
// and copy in arrays
45+
var rname = st.nextName();
46+
select (dtype) {
47+
when DType.Int64 {
48+
// create array to copy into
49+
var e = st.addEntry(rname, size, int);
50+
var start: int;
51+
var end: int;
52+
start = 0;
53+
for (name, i) in zip(names, 1..) {
54+
// lookup and cast operand to copy from
55+
var o = toSymEntry(st.lookup(name), int);
56+
// calculate end which is inclusive
57+
end = start + o.size - 1;
58+
// copy array into concatenation array
59+
e.a[{start..end}] = o.a;
60+
// update new start for next array copy
61+
start += o.size;
62+
}
63+
}
64+
when DType.Float64 {
65+
// create array to copy into
66+
var e = st.addEntry(rname, size, real);
67+
var start: int;
68+
var end: int;
69+
start = 0;
70+
for (name, i) in zip(names, 1..) {
71+
// lookup and cast operand to copy from
72+
var o = toSymEntry(st.lookup(name), real);
73+
// calculate end which is inclusive
74+
end = start + o.size - 1;
75+
// copy array into concatenation array
76+
e.a[{start..end}] = o.a;
77+
// update new start for next array copy
78+
start += o.size;
79+
}
80+
}
81+
when DType.Bool {
82+
// create array to copy into
83+
var e = st.addEntry(rname, size, bool);
84+
var start: int;
85+
var end: int;
86+
start = 0;
87+
for (name, i) in zip(names, 1..) {
88+
// lookup and cast operand to copy from
89+
var o = toSymEntry(st.lookup(name), bool);
90+
// calculate end which is inclusive
91+
end = start + o.size - 1;
92+
// copy array into concatenation array
93+
e.a[{start..end}] = o.a;
94+
// update new start for next array copy
95+
start += o.size;
96+
}
97+
}
98+
otherwise {return notImplementedError("concatenate",dtype);}
99+
}
100+
101+
return try! "created " + st.attrib(rname);
102+
}
103+
104+
}

0 commit comments

Comments
 (0)