Skip to content

Commit 56816fe

Browse files
committed
Changing recursive calls to loop (small speedup)
1 parent a47f6da commit 56816fe

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

string_kernel.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,23 @@
22

33
# Kernel defined by Lodhi et al. (2002)
44
def ssk(s, t, n, lbda, accum=False):
5-
dynamic = {}
5+
lens, lent = len(s), len(t)
6+
#dynamic = (-1)*np.ones( (n+1, lens, lent) )
7+
k_prim = np.zeros( (n, lens, lent) )
8+
indices = { x : [i for i, e in enumerate(t) if e == x] for x in set(s) }
69

7-
def k_prim(sj, tk, i):
8-
# print( "k_prim({},{},{})".format(s, t, i) )
9-
if i == 0:
10-
# print( "k_prim({},{},{}) => 1".format(s, t, i) )
11-
return 1.
12-
if min(sj, tk) < i:
13-
# print( "k_prim({},{},{}) => 0".format(s, t, i) )
14-
return 0.
15-
if (sj,tk,i) in dynamic:
16-
return dynamic[(sj,tk,i)]
10+
k_prim[0,:,:] = 1
1711

18-
x = s[sj-1]
19-
indices = [i for i in range(tk) if t[i] == x]
20-
toret = lbda * k_prim(sj-1, tk, i) \
21-
+ sum( k_prim(sj-1, k, i-1) * (lbda**(tk-k+1)) for k in indices )
22-
# print( "k_prim({},{},{}) => {}".format(s, t, i, toret) )
23-
dynamic[(sj,tk,i)] = toret
24-
return toret
12+
for i in range(1,n):
13+
for sj in range(i,lens):
14+
for tk in range(i,lent):
15+
x = s[sj-1]
16+
toret = lbda * k_prim[i, sj-1, tk]
17+
for k_ in indices[x]:
18+
if k_ >= tk:
19+
break
20+
toret += k_prim[i-1, sj-1, k_] * (lbda**(tk-k_+1))
21+
k_prim[i,sj,tk] = toret
2522

2623
def k(sj, tk, n):
2724
# print( "k({},{},{})".format(s, t, n) )
@@ -31,18 +28,20 @@ def k(sj, tk, n):
3128
# print( "k({},{},{}) => 0".format(s, t, n) )
3229
return 0.
3330
x = s[sj-1]
34-
indices = [i for i in range(tk) if t[i] == x]
35-
toret = k(sj-1, tk, n) \
36-
+ lbda**2 * sum( k_prim(sj-1, k, n-1) for k in indices )
31+
toret = k(sj-1, tk, n)
32+
for k_ in indices[x]:
33+
if k_ >= tk:
34+
break
35+
toret += lbda**2 * k_prim[n-1, sj-1, k_]
3736
# print( "k({},{},{}) => {}".format(s, t, n, toret) )
3837
return toret
3938

4039
if accum:
41-
toret = sum( k(len(s), len(t), i) for i in range(1, min(n,len(s),len(t))+1) )
40+
toret = sum( k(lens, lent, i) for i in range(1, min(n,lens,lent)+1) )
4241
else:
43-
toret = k(len(s), len(t), n)
42+
toret = k(lens, lent, n)
4443

45-
# print( len(dynamic) )
44+
# print( [len(list(i for (sj,tk,i) in k_prim if i==m-1)) for m in range(n)] )
4645
return toret
4746

4847
def string_kernel(xs, ys, n, lbda):

0 commit comments

Comments
 (0)