Skip to content

Commit fa01b79

Browse files
committed
added alternative algorithm to mp_n_root
1 parent 431ea33 commit fa01b79

File tree

5 files changed

+377
-7
lines changed

5 files changed

+377
-7
lines changed

bn_mp_n_root.c

Lines changed: 363 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
/* LibTomMath, multiple-precision integer library -- Tom St Denis */
44
/* SPDX-License-Identifier: Unlicense */
55

6-
/* find the n'th root of an integer
6+
/*
7+
* Find the n'th root of an integer.
78
*
89
* Result found such that (c)**b <= a and (c+1)**b > a
910
*
@@ -12,11 +13,13 @@
1213
* which will find the root in log(N) time where
1314
* each step involves a fair bit.
1415
*/
16+
17+
#ifdef LTM_USE_SMALLER_NTH_ROOT
1518
mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c)
1619
{
1720
mp_int t1, t2, t3, a_;
1821
mp_ord cmp;
19-
int ilog2;
22+
int ilog2;
2023
mp_err err;
2124

2225
/* input must be positive if b is even */
@@ -75,6 +78,7 @@ mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c)
7578
if ((err = mp_2expt(&t2,ilog2)) != MP_OKAY) {
7679
goto LBL_ERR;
7780
}
81+
7882
do {
7983
/* t1 = t2 */
8084
if ((err = mp_copy(&t2, &t1)) != MP_OKAY) {
@@ -167,4 +171,361 @@ mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c)
167171
return err;
168172
}
169173

174+
#else /* LTM_USE_SMALLER_NTH_ROOT */
175+
176+
/*
177+
On a system with Gnu LibC > 4 you can use
178+
__builtin_clz() or the assembler command BSR
179+
(Intel) but let me assure you: the function
180+
s_floor_log2() will not be the bottleneck here.
181+
*/
182+
static int s_floor_log2(mp_digit value)
183+
{
184+
int r = 0;
185+
while ((value >>= 1) != 0) {
186+
r++;
187+
}
188+
return r;
189+
}
190+
/*
191+
Extra version for int needed because mp_digit is
192+
a) unsigned
193+
b) can be any size between 8 and 64 bits
194+
Two version with the same code, just different
195+
input types seems silly but all the ways known to
196+
me than can work around that are either complicated
197+
or dependent on compiler specifics or are ugly or
198+
all of the above.
199+
An example for "all of the above":
200+
#define FLOOR_ILOG2(T) \
201+
int s_floor_ilog2_##T(T value) { \
202+
int r = 0; \
203+
while ((value >>= 1) != 0) { \
204+
r++; \
205+
} \
206+
return r; \
207+
}
208+
FLOOR_ILOG2(int)
209+
FLOOR_ILOG2(mp_digit)
210+
*/
211+
212+
/*
213+
Here, "value" will not be negative, so it is, in theory,
214+
possible to use the function above by casting "int" to
215+
"mp_digit" but "mp_digit" can be smaller than "int", much
216+
smaller.
217+
*/
218+
static int s_floor_ilog2(int value)
219+
{
220+
int r = 0;
221+
while ((value >>= 1) != 0) {
222+
r++;
223+
}
224+
return r;
225+
}
226+
/*
227+
The cut-off between Newton's method and bisection is at
228+
about ln(x)/(ln ln (x)) * 1.2.
229+
Floating point methods are not available, so a rough
230+
approximation must do.
231+
By taking the bitcount of the number as floor(log_2(x))
232+
and, together with ln(x) ~ floor(log2(x)) * log(2)
233+
implemented as 69/100 * floor(log2(x)), we can get
234+
a sufficiently good approximation.
235+
This snippet assumes "int" is at least 16 bit wide.
236+
TODO: check if it is possible to use mp_word instead
237+
which is guaranteed to be at least 16 bit wide
238+
*/
239+
#include <limits.h>
240+
static int s_recurrence_bisection_cutoff(int value)
241+
{
242+
int lnx, lnlnx;
243+
244+
/*
245+
such small values should have been handled by a nth-root
246+
implementation with native integers
247+
*/
248+
if (value < 8) {
249+
return 1;
250+
}
251+
252+
/* ln(x) ~ floor(log2(x)) * log(2) */
253+
if (value > ((INT_MAX / 69))) {
254+
/*
255+
if "value" is so big that a multiplication
256+
with 69 overflows we can safely spend
257+
two digits of accuracy for a better sleep.
258+
*/
259+
lnx = (value / 100) * 69;
260+
} else {
261+
lnx = ((69 * value) / 100);
262+
}
263+
/* ln ln x */
264+
lnlnx = s_floor_ilog2(lnx);
265+
/* cannot overflow anymore here */
266+
lnlnx = ((69 * lnlnx) / 100);
267+
268+
lnx = lnx / lnlnx;
269+
/* floor(ln(x)/(ln ln (x))) < floor(fln2(x)/(fln2 fln2 (x))) + 1 for x >= 8 */
270+
lnx += 1;
271+
/* apply twiddle factor */
272+
/* cannot overflow */
273+
lnx = ((12 * lnx) / 10);
274+
return lnx;
275+
}
276+
277+
/*
278+
Compute log_2(b) bits of a^(1/b) or all of them with a binary search method
279+
*/
280+
static mp_err s_bisection(mp_int *a, mp_digit b, mp_int *c, int cutoff, int rootsize)
281+
{
282+
mp_int low, high, mid, midpow;
283+
mp_err err;
284+
int comp, i = 0;
285+
286+
/* force at least one run */
287+
if (cutoff == 0) {
288+
cutoff = 1;
289+
}
290+
291+
if ((err = mp_init_multi(&low, &high, &mid, &midpow, NULL)) != MP_OKAY) {
292+
return err;
293+
}
294+
if ((err = mp_2expt(&high, rootsize)) != MP_OKAY) {
295+
goto LTM_ERR;
296+
}
297+
if ((err = mp_2expt(&low, rootsize - 2)) != MP_OKAY) {
298+
goto LTM_ERR;
299+
}
300+
while (mp_cmp(&low, &high) == MP_LT) {
301+
if (i++ == cutoff) {
302+
mp_exch(&high, c);
303+
goto LTM_ERR;
304+
}
305+
if ((err = mp_add(&low, &high, &mid)) != MP_OKAY) {
306+
goto LTM_ERR;
307+
}
308+
if ((err = mp_div_2(&mid, &mid)) != MP_OKAY) {
309+
goto LTM_ERR;
310+
}
311+
if ((err = mp_expt_d(&mid, b, &midpow)) != MP_OKAY) {
312+
goto LTM_ERR;
313+
}
314+
comp = mp_cmp(&midpow, a);
315+
if (mp_cmp(&low, &mid) == MP_LT && comp == MP_LT) {
316+
mp_exch(&low, &mid);
317+
} else if (mp_cmp(&high, &mid) == MP_GT && comp == MP_GT) {
318+
mp_exch(&high, &mid);
319+
} else {
320+
mp_exch(&mid, c);
321+
goto LTM_ERR;
322+
}
323+
}
324+
if ((err = mp_add_d(&mid, 1, &mid)) != MP_OKAY) {
325+
goto LTM_ERR;
326+
}
327+
mp_exch(&mid, c);
328+
LTM_ERR:
329+
mp_clear_multi(&low, &high, &mid, &midpow, NULL);
330+
return err;
331+
}
332+
333+
static mp_err s_newton(mp_int *a, mp_digit b, mp_int *c, int cutoff, int rootsize)
334+
{
335+
mp_int xi, t1, t2;
336+
mp_err err = MP_OKAY;
337+
338+
if ((err = mp_init_multi(&xi, &t1, &t2, NULL)) != MP_OKAY) {
339+
return err;
340+
}
341+
if ((err = s_bisection(a, b, &t1, cutoff, rootsize)) != MP_OKAY) {
342+
goto LTM_ERR;
343+
}
344+
if ((err = mp_add_d(&t1, 1, &xi)) != MP_OKAY) {
345+
goto LTM_ERR;
346+
}
347+
while (mp_cmp(&t1, &xi) == MP_LT) {
348+
if ((rootsize--) == 0) {
349+
break;
350+
}
351+
if ((err = mp_copy(&t1, &xi)) != MP_OKAY) {
352+
goto LTM_ERR;
353+
}
354+
if ((err = mp_mul_d(&xi, b - 1, &t2)) != MP_OKAY) {
355+
goto LTM_ERR;
356+
}
357+
if ((err = mp_expt_d(&xi, b - 1, &t1)) != MP_OKAY) {
358+
goto LTM_ERR;
359+
}
360+
if ((err = mp_div(a, &t1, &t1, NULL)) != MP_OKAY) {
361+
goto LTM_ERR;
362+
}
363+
if ((err = mp_add(&t1, &t2, &t1)) != MP_OKAY) {
364+
goto LTM_ERR;
365+
}
366+
if ((err = mp_div_d(&t1, b, &t1, NULL)) != MP_OKAY) {
367+
goto LTM_ERR;
368+
}
369+
}
370+
mp_exch(&xi, c);
371+
LTM_ERR:
372+
mp_clear_multi(&xi, &t1, &t2, NULL);
373+
return err;
374+
}
375+
376+
mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c)
377+
{
378+
mp_int A;
379+
mp_int t1;
380+
int cmp;
381+
mp_err err = MP_OKAY;
382+
int ilog2, rootsize, cutoff, even_faster;
383+
mp_sign neg;
384+
385+
/*
386+
* Checks, balances and shortcuts
387+
*
388+
* if b = 0 -> MP_VAL division by zero
389+
* if b even and a neg. -> MP_VAL non-real result
390+
* if a = 0 and b > 0 -> 0
391+
* if a = 0 and b < 0 -> n/a b is unsigned
392+
* if a = 1 -> 1
393+
* if a > 0 and b < 0 -> n/a b is unsigned
394+
* if b > log_2(a) -> 1
395+
*/
396+
397+
if (b == 0) {
398+
return MP_VAL;
399+
}
400+
401+
if (b == 1) {
402+
if ((err = mp_copy(a, c)) != MP_OKAY) {
403+
return err;
404+
}
405+
return MP_OKAY;
406+
}
407+
if (b == 2) {
408+
return mp_sqrt(a, c);
409+
}
410+
411+
/* TODO: check if an exception for unity is sensible */
412+
if ((a->used == 1) && (a->dp[0] == 1)) {
413+
mp_set(c, 1);
414+
if (a->sign == MP_NEG && ((b & 1) == 0)) {
415+
c->sign = MP_NEG;
416+
}
417+
return MP_OKAY;
418+
}
419+
420+
if ((a->sign == MP_NEG) && ((b & 1) == 0)) {
421+
return MP_VAL;
422+
}
423+
#if ( !(defined MP_8BIT) && !(defined MP_16BIT) )
424+
/* The type "mp_digit" can be bigger than int */
425+
if (sizeof(mp_digit) > sizeof(int) && b > INT_MAX) {
426+
/* In that case "b" is bigger than log_2(x), hence floor(x^(1/b)) = 1 */
427+
mp_set(c, 1);
428+
c->sign = a->sign;
429+
return MP_OKAY;
430+
}
431+
#endif
432+
if (mp_iszero(a)) {
433+
mp_zero(c);
434+
return MP_OKAY;
435+
}
436+
437+
#ifdef LTM_USE_SMALL_NTH_ROOT
438+
if (a->used == 1) {
439+
ilog2 = s_small_nthroot(a->dp[0], b);
440+
mp_set(c,ilog2);
441+
return MP_OKAY;
442+
}
443+
#endif
444+
if ((err = mp_init(&A)) != MP_OKAY) {
445+
return err;
446+
}
447+
if ((err = mp_copy(a, &A)) != MP_OKAY) {
448+
goto LTM_ERR_2;
449+
}
450+
neg = a->sign;
451+
A.sign = MP_ZPOS;
452+
453+
ilog2 = mp_count_bits(a);
454+
455+
if (ilog2 < (int)(b)) {
456+
mp_set(c, 1uL);
457+
c->sign = neg;
458+
goto LTM_ERR_2;
459+
}
460+
461+
rootsize = (ilog2/(int)(b)) + 1;
462+
cutoff = s_floor_log2(b);
463+
464+
even_faster = s_recurrence_bisection_cutoff(ilog2);
465+
if (b < (mp_digit)even_faster) {
466+
if ((err = s_newton(&A, b, c, cutoff, rootsize)) != MP_OKAY) {
467+
goto LTM_ERR_2;
468+
}
469+
} else {
470+
if ((err = s_bisection(&A, b, c, -1, rootsize)) != MP_OKAY) {
471+
goto LTM_ERR_2;
472+
}
473+
}
474+
475+
if ((err = mp_init(&t1)) != MP_OKAY) {
476+
goto LTM_ERR_2;
477+
}
478+
if ((err = mp_expt_d(c, b, &t1)) != MP_OKAY) {
479+
goto LTM_ERR_1;
480+
}
481+
cmp = mp_cmp(&t1, &A);
482+
if (cmp == MP_GT) {
483+
if ((err = mp_sub_d(c, 1u, c)) != MP_OKAY) {
484+
goto LTM_ERR_1;
485+
}
486+
for (;;) {
487+
if ((err = mp_expt_d(c, b, &t1)) != MP_OKAY) {
488+
goto LTM_ERR_1;
489+
}
490+
cmp = mp_cmp(&t1, &A);
491+
if (cmp != MP_GT) {
492+
break;
493+
}
494+
if ((err = mp_sub_d(c, 1u, c)) != MP_OKAY) {
495+
goto LTM_ERR_1;
496+
}
497+
}
498+
} else if (cmp == MP_LT) {
499+
if ((err = mp_add_d(c, 1u, c)) != MP_OKAY) {
500+
goto LTM_ERR_1;
501+
}
502+
for (;;) {
503+
if ((err = mp_expt_d(c, b, &t1)) != MP_OKAY) {
504+
goto LTM_ERR_1;
505+
}
506+
cmp = mp_cmp(&t1, &A);
507+
if (cmp != MP_LT) {
508+
break;
509+
}
510+
if ((err = mp_add_d(c, 1u, c)) != MP_OKAY) {
511+
goto LTM_ERR_1;
512+
}
513+
}
514+
/* Does overshoot in contrast to the other branch above */
515+
if (cmp != MP_EQ) {
516+
if ((err = mp_sub_d(c, 1u, c)) != MP_OKAY) {
517+
goto LTM_ERR_1;
518+
}
519+
}
520+
}
521+
522+
LTM_ERR_1:
523+
mp_clear(&t1);
524+
LTM_ERR_2:
525+
mp_clear(&A);
526+
c->sign = a->sign;
527+
return err;
528+
}
529+
#endif
530+
170531
#endif

0 commit comments

Comments
 (0)