Skip to content

Commit c60251d

Browse files
oyamadmmcky
authored andcommitted
FIX: support_enumeration: Use _numba_linalg_solve (#311)
* support_enumeration: Remove fallback for Numba < 0.28 * support_enumeration: Add a test "LinAlgError: Matrix is singular to machine precision.” raised * FIX: support_enumeration: Use `_numba_linalg_solve` Remove `is_singular` by svd * util: Add `_numba_linalg_solve` For use in a jitted function in nopython mode * Call directly Numba internal `numba_xgesv` * Return nonzero int if input matrix is singular, allowing alternative to try-except np.linalg.LinAlgError * support_enumeration: Remove `any()` Allow `cache=True`, close #285
1 parent 34b628b commit c60251d

File tree

6 files changed

+197
-83
lines changed

6 files changed

+197
-83
lines changed

docs/source/util.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ Utilities
77
util/array
88
util/common_messages
99
util/notebooks
10+
util/numba
1011
util/random
1112
util/timing

docs/source/util/numba.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
numba
2+
=====
3+
4+
.. automodule:: quantecon.util.numba
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:

quantecon/game_theory/support_enumeration.py

Lines changed: 28 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,9 @@
1111
Tardos, and V. Vazirani eds., Algorithmic Game Theory, 2007.
1212
1313
"""
14-
from distutils.version import LooseVersion
1514
import numpy as np
16-
import numba
1715
from numba import jit
18-
19-
20-
least_numba_version = LooseVersion('0.28')
21-
is_numba_required_installed = True
22-
if LooseVersion(numba.__version__) < least_numba_version:
23-
is_numba_required_installed = False
24-
nopython = is_numba_required_installed
25-
26-
EPS = np.finfo(float).eps
16+
from ..util.numba import _numba_linalg_solve
2717

2818

2919
def support_enumeration(g):
@@ -46,11 +36,6 @@ def support_enumeration(g):
4636
list(tuple(ndarray(float, ndim=1)))
4737
List containing tuples of Nash equilibrium mixed actions.
4838
49-
Notes
50-
-----
51-
This routine is jit-complied if Numba version 0.28 or above is
52-
installed.
53-
5439
"""
5540
return list(support_enumeration_gen(g))
5641

@@ -80,7 +65,7 @@ def support_enumeration_gen(g):
8065
g.players[1].payoff_array)
8166

8267

83-
@jit(nopython=nopython) # cache=True raises _pickle.PicklingError
68+
@jit(nopython=True) # cache=True raises _pickle.PicklingError
8469
def _support_enumeration_gen(payoff_matrix0, payoff_matrix1):
8570
"""
8671
Main body of `support_enumeration_gen`.
@@ -105,32 +90,28 @@ def _support_enumeration_gen(payoff_matrix0, payoff_matrix1):
10590

10691
for k in range(1, n_min+1):
10792
supps = (np.arange(k), np.empty(k, np.int_))
108-
actions = (np.empty(k), np.empty(k))
93+
actions = (np.empty(k+1), np.empty(k+1))
10994
A = np.empty((k+1, k+1))
110-
A[:-1, -1] = -1
111-
A[-1, :-1] = 1
112-
A[-1, -1] = 0
113-
b = np.zeros(k+1)
114-
b[-1] = 1
95+
11596
while supps[0][-1] < nums_actions[0]:
11697
supps[1][:] = np.arange(k)
11798
while supps[1][-1] < nums_actions[1]:
11899
if _indiff_mixed_action(payoff_matrix0, supps[0], supps[1],
119-
A, b, actions[1]):
100+
A, actions[1]):
120101
if _indiff_mixed_action(payoff_matrix1, supps[1], supps[0],
121-
A, b, actions[0]):
102+
A, actions[0]):
122103
out = (np.zeros(nums_actions[0]),
123104
np.zeros(nums_actions[1]))
124105
for p, (supp, action) in enumerate(zip(supps,
125106
actions)):
126-
out[p][supp] = action
107+
out[p][supp] = action[:-1]
127108
yield out
128109
_next_k_array(supps[1])
129110
_next_k_array(supps[0])
130111

131112

132-
@jit(nopython=nopython)
133-
def _indiff_mixed_action(payoff_matrix, own_supp, opp_supp, A, b, out):
113+
@jit(nopython=True, cache=True)
114+
def _indiff_mixed_action(payoff_matrix, own_supp, opp_supp, A, out):
134115
"""
135116
Given a player's payoff matrix `payoff_matrix`, an array `own_supp`
136117
of this player's actions, and an array `opp_supp` of the opponent's
@@ -139,8 +120,7 @@ def _indiff_mixed_action(payoff_matrix, own_supp, opp_supp, A, b, out):
139120
among the actions in `own_supp`, if any such exists. Return `True`
140121
if such a mixed action exists and actions in `own_supp` are indeed
141122
best responses to it, in which case the outcome is stored in `out`;
142-
`False` otherwise. Arrays `A` and `b` are used in intermediate
143-
steps.
123+
`False` otherwise. Array `A` is used in intermediate steps.
144124
145125
Parameters
146126
----------
@@ -154,17 +134,11 @@ def _indiff_mixed_action(payoff_matrix, own_supp, opp_supp, A, b, out):
154134
Array containing the opponent's action indices, of length k.
155135
156136
A : ndarray(float, ndim=2)
157-
Array used in intermediate steps, of shape (k+1, k+1). The
158-
following values must be assigned in advance: `A[:-1, -1] = -1`,
159-
`A[-1, :-1] = 1`, and `A[-1, -1] = 0`.
160-
161-
b : ndarray(float, ndim=1)
162-
Array used in intermediate steps, of shape (k+1,). The following
163-
values must be assigned in advance `b[:-1] = 0` and `b[-1] = 1`.
137+
Array used in intermediate steps, of shape (k+1, k+1).
164138
165139
out : ndarray(float, ndim=1)
166-
Array of length k to store the k nonzero values of the desired
167-
mixed action.
140+
Array of length k+1 to store the k nonzero values of the desired
141+
mixed action in `out[:-1]` (and the payoff value in `out[-1]`.)
168142
169143
Returns
170144
-------
@@ -175,15 +149,22 @@ def _indiff_mixed_action(payoff_matrix, own_supp, opp_supp, A, b, out):
175149
m = payoff_matrix.shape[0]
176150
k = len(own_supp)
177151

178-
A[:-1, :-1] = payoff_matrix[own_supp, :][:, opp_supp]
179-
if _is_singular(A):
180-
return False
181-
182-
sol = np.linalg.solve(A, b)
183-
if (sol[:-1] <= 0).any():
152+
for i in range(k):
153+
for j in range(k):
154+
A[j, i] = payoff_matrix[own_supp[i], opp_supp[j]] # transpose
155+
A[:-1, -1] = 1
156+
A[-1, :-1] = -1
157+
A[-1, -1] = 0
158+
out[:-1] = 0
159+
out[-1] = 1
160+
161+
r = _numba_linalg_solve(A, out)
162+
if r != 0: # A: singular
184163
return False
185-
out[:] = sol[:-1]
186-
val = sol[-1]
164+
for i in range(k):
165+
if out[i] <= 0:
166+
return False
167+
val = out[-1]
187168

188169
if k == m:
189170
return True
@@ -280,39 +261,3 @@ def _next_k_array(a):
280261
pos += 1
281262

282263
return a
283-
284-
285-
if is_numba_required_installed:
286-
@jit(nopython=True, cache=True)
287-
def _is_singular(a):
288-
s = numba.targets.linalg._compute_singular_values(a)
289-
if s[-1] <= s[0] * EPS:
290-
return True
291-
else:
292-
return False
293-
else:
294-
def _is_singular(a):
295-
s = np.linalg.svd(a, compute_uv=False)
296-
if s[-1] <= s[0] * EPS:
297-
return True
298-
else:
299-
return False
300-
301-
_is_singular_docstr = \
302-
"""
303-
Determine whether matrix `a` is numerically singular, by checking
304-
its singular values.
305-
306-
Parameters
307-
----------
308-
a : ndarray(float, ndim=2)
309-
2-dimensional array of floats.
310-
311-
Returns
312-
-------
313-
bool
314-
Whether `a` is numerically singular.
315-
316-
"""
317-
318-
_is_singular.__doc__ = _is_singular_docstr

quantecon/game_theory/tests/test_support_enumeration.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,33 @@
44
Tests for support_enumeration.py
55
66
"""
7+
import numpy as np
78
from numpy.testing import assert_allclose
9+
from nose.tools import eq_
10+
from quantecon.util import check_random_state
811
from quantecon.game_theory import Player, NormalFormGame, support_enumeration
912

1013

14+
def random_skew_sym(n, m=None, random_state=None):
15+
"""
16+
Generate a random skew symmetric zero-sum NormalFormGame of the form
17+
O B
18+
-B.T O
19+
where B is an n x m matrix.
20+
21+
"""
22+
if m is None:
23+
m = n
24+
random_state = check_random_state(random_state)
25+
B = random_state.random_sample((n, m))
26+
A = np.empty((n+m, n+m))
27+
A[:n, :n] = 0
28+
A[n:, n:] = 0
29+
A[:n, n:] = B
30+
A[n:, :n] = -B.T
31+
return NormalFormGame([Player(A) for i in range(2)])
32+
33+
1134
class TestSupportEnumeration():
1235
def setUp(self):
1336
self.game_dicts = []
@@ -35,10 +58,18 @@ def setUp(self):
3558
def test_support_enumeration(self):
3659
for d in self.game_dicts:
3760
NEs_computed = support_enumeration(d['g'])
61+
eq_(len(NEs_computed), len(d['NEs']))
3862
for actions_computed, actions in zip(NEs_computed, d['NEs']):
3963
for action_computed, action in zip(actions_computed, actions):
4064
assert_allclose(action_computed, action)
4165

66+
def test_no_error_skew_sym(self):
67+
# Test no LinAlgError is raised.
68+
n, m = 3, 2
69+
seed = 7028
70+
g = random_skew_sym(n, m, random_state=seed)
71+
NEs = support_enumeration(g)
72+
4273

4374
if __name__ == '__main__':
4475
import sys

quantecon/util/numba.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""
2+
Utilities to support Numba jitted functions
3+
4+
"""
5+
import numpy as np
6+
from numba import generated_jit, types
7+
from numba.targets.linalg import _LAPACK
8+
9+
10+
# BLAS kinds as letters
11+
_blas_kinds = {
12+
types.float32: 's',
13+
types.float64: 'd',
14+
types.complex64: 'c',
15+
types.complex128: 'z',
16+
}
17+
18+
19+
@generated_jit(nopython=True, cache=True)
20+
def _numba_linalg_solve(a, b):
21+
"""
22+
Solve the linear equation ax = b directly calling a Numba internal
23+
function. The data in `a` and `b` are interpreted in Fortran order,
24+
and dtype of `a` and `b` must be the same, one of {float32, float64,
25+
complex64, complex128}. `a` and `b` are modified in place, and the
26+
solution is stored in `b`. *No error check is made for the inputs.*
27+
28+
Parameters
29+
----------
30+
a : ndarray(ndim=2)
31+
2-dimensional ndarray of shape (n, n).
32+
33+
b : ndarray(ndim=1 or 2)
34+
1-dimensional ndarray of shape (n,) or 2-dimensional ndarray of
35+
shape (n, nrhs).
36+
37+
Returns
38+
-------
39+
r : scalar(int)
40+
r = 0 if successful.
41+
42+
Notes
43+
-----
44+
From github.com/numba/numba/blob/master/numba/targets/linalg.py
45+
46+
"""
47+
numba_xgesv = _LAPACK().numba_xgesv(a.dtype)
48+
kind = ord(_blas_kinds[a.dtype])
49+
50+
def _numba_linalg_solve_impl(a, b): # pragma: no cover
51+
n = a.shape[-1]
52+
if b.ndim == 1:
53+
nrhs = 1
54+
else: # b.ndim == 2
55+
nrhs = b.shape[-1]
56+
F_INT_nptype = np.int32
57+
ipiv = np.empty(n, dtype=F_INT_nptype)
58+
59+
r = numba_xgesv(
60+
kind, # kind
61+
n, # n
62+
nrhs, # nhrs
63+
a.ctypes, # a
64+
n, # lda
65+
ipiv.ctypes, # ipiv
66+
b.ctypes, # b
67+
n # ldb
68+
)
69+
return r
70+
71+
return _numba_linalg_solve_impl

quantecon/util/tests/test_numba.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""
2+
Tests for Numba support utilities
3+
4+
"""
5+
import numpy as np
6+
from numpy.testing import assert_array_equal
7+
from numba import jit
8+
from nose.tools import eq_, ok_
9+
from quantecon.util.numba import _numba_linalg_solve
10+
11+
12+
@jit(nopython=True)
13+
def numba_linalg_solve_orig(a, b):
14+
return np.linalg.solve(a, b)
15+
16+
17+
class TestNumbaLinalgSolve:
18+
def setUp(self):
19+
self.dtypes = [np.float32, np.float64]
20+
self.a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
21+
self.b_1dim = np.array([2, 4, -1])
22+
self.b_2dim = np.array([[2, 3], [4, 1], [-1, 0]])
23+
self.a_singular = np.array([[0, 1, 2], [3, 4, 5], [3, 5, 7]])
24+
25+
def test_b_1dim(self):
26+
for dtype in self.dtypes:
27+
a = np.asfortranarray(self.a, dtype=dtype)
28+
b = np.asfortranarray(self.b_1dim, dtype=dtype)
29+
sol_orig = numba_linalg_solve_orig(a, b)
30+
r = _numba_linalg_solve(a, b)
31+
eq_(r, 0)
32+
assert_array_equal(b, sol_orig)
33+
34+
def test_b_2dim(self):
35+
for dtype in self.dtypes:
36+
a = np.asfortranarray(self.a, dtype=dtype)
37+
b = np.asfortranarray(self.b_2dim, dtype=dtype)
38+
sol_orig = numba_linalg_solve_orig(a, b)
39+
r = _numba_linalg_solve(a, b)
40+
eq_(r, 0)
41+
assert_array_equal(b, sol_orig)
42+
43+
def test_singular_a(self):
44+
for b in [self.b_1dim, self.b_2dim]:
45+
for dtype in self.dtypes:
46+
a = np.asfortranarray(self.a_singular, dtype=dtype)
47+
b = np.asfortranarray(b, dtype=dtype)
48+
r = _numba_linalg_solve(a, b)
49+
ok_(r != 0)
50+
51+
52+
if __name__ == '__main__':
53+
import sys
54+
import nose
55+
56+
argv = sys.argv[:]
57+
argv.append('--verbose')
58+
argv.append('--nocapture')
59+
nose.main(argv=argv, defaultTest=__file__)

0 commit comments

Comments
 (0)