|
3 | 3 | /* LibTomMath, multiple-precision integer library -- Tom St Denis */
|
4 | 4 | /* SPDX-License-Identifier: Unlicense */
|
5 | 5 |
|
6 |
| -/* find the n'th root of an integer |
| 6 | +/* |
| 7 | + * Find the n'th root of an integer. |
7 | 8 | *
|
8 | 9 | * Result found such that (c)**b <= a and (c+1)**b > a
|
9 | 10 | *
|
|
12 | 13 | * which will find the root in log(N) time where
|
13 | 14 | * each step involves a fair bit.
|
14 | 15 | */
|
| 16 | + |
| 17 | +#ifdef LTM_USE_SMALLER_NTH_ROOT |
15 | 18 | mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c)
|
16 | 19 | {
|
17 | 20 | mp_int t1, t2, t3, a_;
|
18 | 21 | mp_ord cmp;
|
19 |
| - int ilog2; |
| 22 | + int ilog2; |
20 | 23 | mp_err err;
|
21 | 24 |
|
22 | 25 | /* input must be positive if b is even */
|
@@ -75,6 +78,7 @@ mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c)
|
75 | 78 | if ((err = mp_2expt(&t2,ilog2)) != MP_OKAY) {
|
76 | 79 | goto LBL_ERR;
|
77 | 80 | }
|
| 81 | + |
78 | 82 | do {
|
79 | 83 | /* t1 = t2 */
|
80 | 84 | if ((err = mp_copy(&t2, &t1)) != MP_OKAY) {
|
@@ -167,4 +171,361 @@ mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c)
|
167 | 171 | return err;
|
168 | 172 | }
|
169 | 173 |
|
| 174 | +#else /* LTM_USE_SMALLER_NTH_ROOT */ |
| 175 | + |
| 176 | +/* |
| 177 | + On a system with Gnu LibC > 4 you can use |
| 178 | + __builtin_clz() or the assembler command BSR |
| 179 | + (Intel) but let me assure you: the function |
| 180 | + s_floor_log2() will not be the bottleneck here. |
| 181 | +*/ |
| 182 | +static int s_floor_log2(mp_digit value) |
| 183 | +{ |
| 184 | + int r = 0; |
| 185 | + while ((value >>= 1) != 0) { |
| 186 | + r++; |
| 187 | + } |
| 188 | + return r; |
| 189 | +} |
| 190 | +/* |
| 191 | + Extra version for int needed because mp_digit is |
| 192 | + a) unsigned |
| 193 | + b) can be any size between 8 and 64 bits |
| 194 | + Two version with the same code, just different |
| 195 | + input types seems silly but all the ways known to |
| 196 | + me than can work around that are either complicated |
| 197 | + or dependent on compiler specifics or are ugly or |
| 198 | + all of the above. |
| 199 | + An example for "all of the above": |
| 200 | + #define FLOOR_ILOG2(T) \ |
| 201 | + int s_floor_ilog2_##T(T value) { \ |
| 202 | + int r = 0; \ |
| 203 | + while ((value >>= 1) != 0) { \ |
| 204 | + r++; \ |
| 205 | + } \ |
| 206 | + return r; \ |
| 207 | + } |
| 208 | + FLOOR_ILOG2(int) |
| 209 | + FLOOR_ILOG2(mp_digit) |
| 210 | +*/ |
| 211 | + |
| 212 | +/* |
| 213 | + Here, "value" will not be negative, so it is, in theory, |
| 214 | + possible to use the function above by casting "int" to |
| 215 | + "mp_digit" but "mp_digit" can be smaller than "int", much |
| 216 | + smaller. |
| 217 | + */ |
| 218 | +static int s_floor_ilog2(int value) |
| 219 | +{ |
| 220 | + int r = 0; |
| 221 | + while ((value >>= 1) != 0) { |
| 222 | + r++; |
| 223 | + } |
| 224 | + return r; |
| 225 | +} |
| 226 | +/* |
| 227 | + The cut-off between Newton's method and bisection is at |
| 228 | + about ln(x)/(ln ln (x)) * 1.2. |
| 229 | + Floating point methods are not available, so a rough |
| 230 | + approximation must do. |
| 231 | + By taking the bitcount of the number as floor(log_2(x)) |
| 232 | + and, together with ln(x) ~ floor(log2(x)) * log(2) |
| 233 | + implemented as 69/100 * floor(log2(x)), we can get |
| 234 | + a sufficiently good approximation. |
| 235 | + This snippet assumes "int" is at least 16 bit wide. |
| 236 | + TODO: check if it is possible to use mp_word instead |
| 237 | + which is guaranteed to be at least 16 bit wide |
| 238 | +*/ |
| 239 | +#include <limits.h> |
| 240 | +static int s_recurrence_bisection_cutoff(int value) |
| 241 | +{ |
| 242 | + int lnx, lnlnx; |
| 243 | + |
| 244 | + /* |
| 245 | + such small values should have been handled by a nth-root |
| 246 | + implementation with native integers |
| 247 | + */ |
| 248 | + if (value < 8) { |
| 249 | + return 1; |
| 250 | + } |
| 251 | + |
| 252 | + /* ln(x) ~ floor(log2(x)) * log(2) */ |
| 253 | + if (value > ((INT_MAX / 69))) { |
| 254 | + /* |
| 255 | + if "value" is so big that a multiplication |
| 256 | + with 69 overflows we can safely spend |
| 257 | + two digits of accuracy for a better sleep. |
| 258 | + */ |
| 259 | + lnx = (value / 100) * 69; |
| 260 | + } else { |
| 261 | + lnx = ((69 * value) / 100); |
| 262 | + } |
| 263 | + /* ln ln x */ |
| 264 | + lnlnx = s_floor_ilog2(lnx); |
| 265 | + /* cannot overflow anymore here */ |
| 266 | + lnlnx = ((69 * lnlnx) / 100); |
| 267 | + |
| 268 | + lnx = lnx / lnlnx; |
| 269 | + /* floor(ln(x)/(ln ln (x))) < floor(fln2(x)/(fln2 fln2 (x))) + 1 for x >= 8 */ |
| 270 | + lnx += 1; |
| 271 | + /* apply twiddle factor */ |
| 272 | + /* cannot overflow */ |
| 273 | + lnx = ((12 * lnx) / 10); |
| 274 | + return lnx; |
| 275 | +} |
| 276 | + |
| 277 | +/* |
| 278 | + Compute log_2(b) bits of a^(1/b) or all of them with a binary search method |
| 279 | +*/ |
| 280 | +static mp_err s_bisection(mp_int *a, mp_digit b, mp_int *c, int cutoff, int rootsize) |
| 281 | +{ |
| 282 | + mp_int low, high, mid, midpow; |
| 283 | + mp_err err; |
| 284 | + int comp, i = 0; |
| 285 | + |
| 286 | + /* force at least one run */ |
| 287 | + if (cutoff == 0) { |
| 288 | + cutoff = 1; |
| 289 | + } |
| 290 | + |
| 291 | + if ((err = mp_init_multi(&low, &high, &mid, &midpow, NULL)) != MP_OKAY) { |
| 292 | + return err; |
| 293 | + } |
| 294 | + if ((err = mp_2expt(&high, rootsize)) != MP_OKAY) { |
| 295 | + goto LTM_ERR; |
| 296 | + } |
| 297 | + if ((err = mp_2expt(&low, rootsize - 2)) != MP_OKAY) { |
| 298 | + goto LTM_ERR; |
| 299 | + } |
| 300 | + while (mp_cmp(&low, &high) == MP_LT) { |
| 301 | + if (i++ == cutoff) { |
| 302 | + mp_exch(&high, c); |
| 303 | + goto LTM_ERR; |
| 304 | + } |
| 305 | + if ((err = mp_add(&low, &high, &mid)) != MP_OKAY) { |
| 306 | + goto LTM_ERR; |
| 307 | + } |
| 308 | + if ((err = mp_div_2(&mid, &mid)) != MP_OKAY) { |
| 309 | + goto LTM_ERR; |
| 310 | + } |
| 311 | + if ((err = mp_expt_d(&mid, b, &midpow)) != MP_OKAY) { |
| 312 | + goto LTM_ERR; |
| 313 | + } |
| 314 | + comp = mp_cmp(&midpow, a); |
| 315 | + if (mp_cmp(&low, &mid) == MP_LT && comp == MP_LT) { |
| 316 | + mp_exch(&low, &mid); |
| 317 | + } else if (mp_cmp(&high, &mid) == MP_GT && comp == MP_GT) { |
| 318 | + mp_exch(&high, &mid); |
| 319 | + } else { |
| 320 | + mp_exch(&mid, c); |
| 321 | + goto LTM_ERR; |
| 322 | + } |
| 323 | + } |
| 324 | + if ((err = mp_add_d(&mid, 1, &mid)) != MP_OKAY) { |
| 325 | + goto LTM_ERR; |
| 326 | + } |
| 327 | + mp_exch(&mid, c); |
| 328 | +LTM_ERR: |
| 329 | + mp_clear_multi(&low, &high, &mid, &midpow, NULL); |
| 330 | + return err; |
| 331 | +} |
| 332 | + |
| 333 | +static mp_err s_newton(mp_int *a, mp_digit b, mp_int *c, int cutoff, int rootsize) |
| 334 | +{ |
| 335 | + mp_int xi, t1, t2; |
| 336 | + mp_err err = MP_OKAY; |
| 337 | + |
| 338 | + if ((err = mp_init_multi(&xi, &t1, &t2, NULL)) != MP_OKAY) { |
| 339 | + return err; |
| 340 | + } |
| 341 | + if ((err = s_bisection(a, b, &t1, cutoff, rootsize)) != MP_OKAY) { |
| 342 | + goto LTM_ERR; |
| 343 | + } |
| 344 | + if ((err = mp_add_d(&t1, 1, &xi)) != MP_OKAY) { |
| 345 | + goto LTM_ERR; |
| 346 | + } |
| 347 | + while (mp_cmp(&t1, &xi) == MP_LT) { |
| 348 | + if ((rootsize--) == 0) { |
| 349 | + break; |
| 350 | + } |
| 351 | + if ((err = mp_copy(&t1, &xi)) != MP_OKAY) { |
| 352 | + goto LTM_ERR; |
| 353 | + } |
| 354 | + if ((err = mp_mul_d(&xi, b - 1, &t2)) != MP_OKAY) { |
| 355 | + goto LTM_ERR; |
| 356 | + } |
| 357 | + if ((err = mp_expt_d(&xi, b - 1, &t1)) != MP_OKAY) { |
| 358 | + goto LTM_ERR; |
| 359 | + } |
| 360 | + if ((err = mp_div(a, &t1, &t1, NULL)) != MP_OKAY) { |
| 361 | + goto LTM_ERR; |
| 362 | + } |
| 363 | + if ((err = mp_add(&t1, &t2, &t1)) != MP_OKAY) { |
| 364 | + goto LTM_ERR; |
| 365 | + } |
| 366 | + if ((err = mp_div_d(&t1, b, &t1, NULL)) != MP_OKAY) { |
| 367 | + goto LTM_ERR; |
| 368 | + } |
| 369 | + } |
| 370 | + mp_exch(&xi, c); |
| 371 | +LTM_ERR: |
| 372 | + mp_clear_multi(&xi, &t1, &t2, NULL); |
| 373 | + return err; |
| 374 | +} |
| 375 | + |
| 376 | +mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c) |
| 377 | +{ |
| 378 | + mp_int A; |
| 379 | + mp_int t1; |
| 380 | + int cmp; |
| 381 | + mp_err err = MP_OKAY; |
| 382 | + int ilog2, rootsize, cutoff, even_faster; |
| 383 | + mp_sign neg; |
| 384 | + |
| 385 | + /* |
| 386 | + * Checks, balances and shortcuts |
| 387 | + * |
| 388 | + * if b = 0 -> MP_VAL division by zero |
| 389 | + * if b even and a neg. -> MP_VAL non-real result |
| 390 | + * if a = 0 and b > 0 -> 0 |
| 391 | + * if a = 0 and b < 0 -> n/a b is unsigned |
| 392 | + * if a = 1 -> 1 |
| 393 | + * if a > 0 and b < 0 -> n/a b is unsigned |
| 394 | + * if b > log_2(a) -> 1 |
| 395 | + */ |
| 396 | + |
| 397 | + if (b == 0) { |
| 398 | + return MP_VAL; |
| 399 | + } |
| 400 | + |
| 401 | + if (b == 1) { |
| 402 | + if ((err = mp_copy(a, c)) != MP_OKAY) { |
| 403 | + return err; |
| 404 | + } |
| 405 | + return MP_OKAY; |
| 406 | + } |
| 407 | + if (b == 2) { |
| 408 | + return mp_sqrt(a, c); |
| 409 | + } |
| 410 | + |
| 411 | + /* TODO: check if an exception for unity is sensible */ |
| 412 | + if ((a->used == 1) && (a->dp[0] == 1)) { |
| 413 | + mp_set(c, 1); |
| 414 | + if (a->sign == MP_NEG && ((b & 1) == 0)) { |
| 415 | + c->sign = MP_NEG; |
| 416 | + } |
| 417 | + return MP_OKAY; |
| 418 | + } |
| 419 | + |
| 420 | + if ((a->sign == MP_NEG) && ((b & 1) == 0)) { |
| 421 | + return MP_VAL; |
| 422 | + } |
| 423 | +#if ( !(defined MP_8BIT) && !(defined MP_16BIT) ) |
| 424 | + /* The type "mp_digit" can be bigger than int */ |
| 425 | + if (sizeof(mp_digit) > sizeof(int) && b > INT_MAX) { |
| 426 | + /* In that case "b" is bigger than log_2(x), hence floor(x^(1/b)) = 1 */ |
| 427 | + mp_set(c, 1); |
| 428 | + c->sign = a->sign; |
| 429 | + return MP_OKAY; |
| 430 | + } |
| 431 | +#endif |
| 432 | + if (mp_iszero(a)) { |
| 433 | + mp_zero(c); |
| 434 | + return MP_OKAY; |
| 435 | + } |
| 436 | + |
| 437 | +#ifdef LTM_USE_SMALL_NTH_ROOT |
| 438 | + if (a->used == 1) { |
| 439 | + ilog2 = s_small_nthroot(a->dp[0], b); |
| 440 | + mp_set(c,ilog2); |
| 441 | + return MP_OKAY; |
| 442 | + } |
| 443 | +#endif |
| 444 | + if ((err = mp_init(&A)) != MP_OKAY) { |
| 445 | + return err; |
| 446 | + } |
| 447 | + if ((err = mp_copy(a, &A)) != MP_OKAY) { |
| 448 | + goto LTM_ERR_2; |
| 449 | + } |
| 450 | + neg = a->sign; |
| 451 | + A.sign = MP_ZPOS; |
| 452 | + |
| 453 | + ilog2 = mp_count_bits(a); |
| 454 | + |
| 455 | + if (ilog2 < (int)(b)) { |
| 456 | + mp_set(c, 1uL); |
| 457 | + c->sign = neg; |
| 458 | + goto LTM_ERR_2; |
| 459 | + } |
| 460 | + |
| 461 | + rootsize = (ilog2/(int)(b)) + 1; |
| 462 | + cutoff = s_floor_log2(b); |
| 463 | + |
| 464 | + even_faster = s_recurrence_bisection_cutoff(ilog2); |
| 465 | + if (b < (mp_digit)even_faster) { |
| 466 | + if ((err = s_newton(&A, b, c, cutoff, rootsize)) != MP_OKAY) { |
| 467 | + goto LTM_ERR_2; |
| 468 | + } |
| 469 | + } else { |
| 470 | + if ((err = s_bisection(&A, b, c, -1, rootsize)) != MP_OKAY) { |
| 471 | + goto LTM_ERR_2; |
| 472 | + } |
| 473 | + } |
| 474 | + |
| 475 | + if ((err = mp_init(&t1)) != MP_OKAY) { |
| 476 | + goto LTM_ERR_2; |
| 477 | + } |
| 478 | + if ((err = mp_expt_d(c, b, &t1)) != MP_OKAY) { |
| 479 | + goto LTM_ERR_1; |
| 480 | + } |
| 481 | + cmp = mp_cmp(&t1, &A); |
| 482 | + if (cmp == MP_GT) { |
| 483 | + if ((err = mp_sub_d(c, 1u, c)) != MP_OKAY) { |
| 484 | + goto LTM_ERR_1; |
| 485 | + } |
| 486 | + for (;;) { |
| 487 | + if ((err = mp_expt_d(c, b, &t1)) != MP_OKAY) { |
| 488 | + goto LTM_ERR_1; |
| 489 | + } |
| 490 | + cmp = mp_cmp(&t1, &A); |
| 491 | + if (cmp != MP_GT) { |
| 492 | + break; |
| 493 | + } |
| 494 | + if ((err = mp_sub_d(c, 1u, c)) != MP_OKAY) { |
| 495 | + goto LTM_ERR_1; |
| 496 | + } |
| 497 | + } |
| 498 | + } else if (cmp == MP_LT) { |
| 499 | + if ((err = mp_add_d(c, 1u, c)) != MP_OKAY) { |
| 500 | + goto LTM_ERR_1; |
| 501 | + } |
| 502 | + for (;;) { |
| 503 | + if ((err = mp_expt_d(c, b, &t1)) != MP_OKAY) { |
| 504 | + goto LTM_ERR_1; |
| 505 | + } |
| 506 | + cmp = mp_cmp(&t1, &A); |
| 507 | + if (cmp != MP_LT) { |
| 508 | + break; |
| 509 | + } |
| 510 | + if ((err = mp_add_d(c, 1u, c)) != MP_OKAY) { |
| 511 | + goto LTM_ERR_1; |
| 512 | + } |
| 513 | + } |
| 514 | + /* Does overshoot in contrast to the other branch above */ |
| 515 | + if (cmp != MP_EQ) { |
| 516 | + if ((err = mp_sub_d(c, 1u, c)) != MP_OKAY) { |
| 517 | + goto LTM_ERR_1; |
| 518 | + } |
| 519 | + } |
| 520 | + } |
| 521 | + |
| 522 | +LTM_ERR_1: |
| 523 | + mp_clear(&t1); |
| 524 | +LTM_ERR_2: |
| 525 | + mp_clear(&A); |
| 526 | + c->sign = a->sign; |
| 527 | + return err; |
| 528 | +} |
| 529 | +#endif |
| 530 | + |
170 | 531 | #endif
|
0 commit comments