Skip to content

Commit

Permalink
Merge pull request scipy#9455 from ilayn/get_lapack_func_clnup
Browse files Browse the repository at this point in the history
MAINT: Speed up get_(lapack,blas)_func
  • Loading branch information
rgommers authored Dec 26, 2018
2 parents 5539300 + e44e45f commit 78fbaab
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 20 deletions.
38 changes: 38 additions & 0 deletions benchmarks/benchmarks/blas_lapack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import division, absolute_import, print_function

import numpy as np

try:
import scipy.linalg.lapack as la
import scipy.linalg.blas as bla
except ImportError:
pass

from .common import Benchmark


class GetBlasLapackFuncs(Benchmark):
"""
Test the speed of grabbing the correct BLAS/LAPACK routine flavor.
In particular, upon receiving strange dtype arrays the results shouldn't
diverge too much. Hence the results here should be comparable
"""

param_names = ['dtype1', 'dtype2',
'dtype1_ord', 'dtype2_ord',
'size']
params = [
['b', 'G', 'd'],
['d', 'F', '?']
['C', 'F'],
['C', 'F'],
[10, 100, 1000]
]

def setup(self, dtype1, dtype2, dtype1_ord, dtype2_ord, size):
self.arr1 = np.empty(size, dtype=dtype1, order=dtype1_ord)
self.arr2 = np.empty(size, dtype=dtype2, order=dtype2_ord)

def time_find_best_blas_type(self, arr1, arr2):
prefix, dtype, prefer_fortran = bla.find_best_blas_type((arr1, arr2))
61 changes: 41 additions & 20 deletions scipy/linalg/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,25 @@
from scipy.linalg._fblas import *
del empty_module

# 'd' will be default for 'i',..
_type_conv = {'f': 's', 'd': 'd', 'F': 'c', 'D': 'z', 'G': 'z'}
# all numeric dtypes '?bBhHiIlLqQefdgFDGO' that are safe to be converted to

# single precision float : '?bBhH!!!!!!ef!!!!!!'
# double precision float : '?bBhHiIlLqQefdg!!!!'
# single precision complex : '?bBhH!!!!!!ef!!F!!!'
# double precision complex : '?bBhHiIlLqQefdgFDG!'

_type_score = {x: 1 for x in '?bBhHef'}
_type_score.update({x: 2 for x in 'iIlLqQd'})

# Handle float128(g) and complex256(G) separately in case non-windows systems.
# On windows, the values will be rewritten to the same key with the same value.
_type_score.update({'F': 3, 'D': 4, 'g': 2, 'G': 4})

# Final mapping to the actual prefixes and dtypes
_type_conv = {1: ('s', _np.dtype('float32')),
2: ('d', _np.dtype('float64')),
3: ('c', _np.dtype('complex64')),
4: ('z', _np.dtype('complex128'))}

# some convenience alias for complex functions
_blas_alias = {'cnrm2': 'scnrm2', 'znrm2': 'dznrm2',
Expand Down Expand Up @@ -270,26 +287,30 @@ def find_best_blas_type(arrays=(), dtype=None):
"""
dtype = _np.dtype(dtype)
max_score = _type_score.get(dtype.char, 5)
prefer_fortran = False

if arrays:
# use the most generic type in arrays
dtypes = [ar.dtype for ar in arrays]
dtype = _np.find_common_type(dtypes, ())
try:
index = dtypes.index(dtype)
except ValueError:
index = 0
if arrays[index].flags['FORTRAN']:
# prefer Fortran for leading array with column major order
prefer_fortran = True

prefix = _type_conv.get(dtype.char, 'd')
if dtype.char == 'G':
# complex256 -> complex128 (i.e., C long double -> C double)
dtype = _np.dtype('D')
elif dtype.char not in 'fdFD':
dtype = _np.dtype('d')
# In most cases, single element is passed through, quicker route
if len(arrays) == 1:
max_score = _type_score.get(arrays[0].dtype.char, 5)
prefer_fortran = arrays[0].flags['FORTRAN']
else:
# use the most generic type in arrays
scores = [_type_score.get(x.dtype.char, 5) for x in arrays]
max_score = max(scores)
ind_max_score = scores.index(max_score)
# safe upcasting for mix of float64 and complex64 --> prefix 'z'
if max_score == 3 and (2 in scores):
max_score = 4

if arrays[ind_max_score].flags['FORTRAN']:
# prefer Fortran for leading array with column major order
prefer_fortran = True

# Get the LAPACK prefix and the corresponding dtype if not fall back
# to 'd' and double precision float.
prefix, dtype = _type_conv.get(max_score, ('d', _np.dtype('float64')))

return prefix, dtype, prefer_fortran

Expand Down Expand Up @@ -318,7 +339,7 @@ def _get_funcs(names, arrays, dtype,
if prefer_fortran:
module1, module2 = module2, module1

for i, name in enumerate(names):
for name in names:
func_name = prefix + name
func_name = alias.get(func_name, func_name)
func = getattr(module1[0], func_name, None)
Expand Down

0 comments on commit 78fbaab

Please sign in to comment.