Mailing List Archive

bpo-37295: Optimize math.comb() and math.perm() (GH-29090)
https://github.com/python/cpython/commit/60c320c38e4e95877cde0b1d8562ebd6bc02ac61
commit: 60c320c38e4e95877cde0b1d8562ebd6bc02ac61
branch: main
author: Serhiy Storchaka <storchaka@gmail.com>
committer: serhiy-storchaka <storchaka@gmail.com>
date: 2021-12-05T22:26:10+02:00
summary:

bpo-37295: Optimize math.comb() and math.perm() (GH-29090)

For very large numbers use divide-and-conquer algorithm for getting
benefit of Karatsuba multiplication of large numbers.

Do calculations completely in C unsigned long long instead of Python
integers if possible.

files:
A Misc/NEWS.d/next/Library/2021-10-18-16-08-55.bpo-37295.wBEWH2.rst
M Doc/whatsnew/3.11.rst
M Modules/mathmodule.c

diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst
index 10dc30939414c..b06d8d4215033 100644
--- a/Doc/whatsnew/3.11.rst
+++ b/Doc/whatsnew/3.11.rst
@@ -351,6 +351,11 @@ Optimizations
* Pure ASCII strings are now normalized in constant time by :func:`unicodedata.normalize`.
(Contributed by Dong-hee Na in :issue:`44987`.)

+* :mod:`math` functions :func:`~math.comb` and :func:`~math.perm` are now up
+ to 10 times or more faster for large arguments (the speed up is larger for
+ larger *k*).
+ (Contributed by Serhiy Storchaka in :issue:`37295`.)
+

CPython bytecode changes
========================
diff --git a/Misc/NEWS.d/next/Library/2021-10-18-16-08-55.bpo-37295.wBEWH2.rst b/Misc/NEWS.d/next/Library/2021-10-18-16-08-55.bpo-37295.wBEWH2.rst
new file mode 100644
index 0000000000000..634f0c453884e
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-10-18-16-08-55.bpo-37295.wBEWH2.rst
@@ -0,0 +1 @@
+Optimize :func:`math.comb` and :func:`math.perm`.
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index 64ce4e6a13fd5..84b5b954b1051 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -3221,6 +3221,138 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
}


+/* Number of permutations and combinations.
+ * P(n, k) = n! / (n-k)!
+ * C(n, k) = P(n, k) / k!
+ */
+
+/* Calculate C(n, k) for n in the 63-bit range. */
+static PyObject *
+perm_comb_small(unsigned long long n, unsigned long long k, int iscomb)
+{
+ /* long long is at least 64 bit */
+ static const unsigned long long fast_comb_limits[] = {
+ 0, ULLONG_MAX, 4294967296ULL, 3329022, 102570, 13467, 3612, 1449, // 0-7
+ 746, 453, 308, 227, 178, 147, 125, 110, // 8-15
+ 99, 90, 84, 79, 75, 72, 69, 68, // 16-23
+ 66, 65, 64, 63, 63, 62, 62, 62, // 24-31
+ };
+ static const unsigned long long fast_perm_limits[] = {
+ 0, ULLONG_MAX, 4294967296ULL, 2642246, 65537, 7133, 1627, 568, // 0-7
+ 259, 142, 88, 61, 45, 36, 30, // 8-14
+ };
+
+ if (k == 0) {
+ return PyLong_FromLong(1);
+ }
+
+ /* For small enough n and k the result fits in the 64-bit range and can
+ * be calculated without allocating intermediate PyLong objects. */
+ if (iscomb
+ ? (k < Py_ARRAY_LENGTH(fast_comb_limits)
+ && n <= fast_comb_limits[k])
+ : (k < Py_ARRAY_LENGTH(fast_perm_limits)
+ && n <= fast_perm_limits[k]))
+ {
+ unsigned long long result = n;
+ if (iscomb) {
+ for (unsigned long long i = 1; i < k;) {
+ result *= --n;
+ result /= ++i;
+ }
+ }
+ else {
+ for (unsigned long long i = 1; i < k;) {
+ result *= --n;
+ ++i;
+ }
+ }
+ return PyLong_FromUnsignedLongLong(result);
+ }
+
+ /* For larger n use recursive formula. */
+ /* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
+ unsigned long long j = k / 2;
+ PyObject *a, *b;
+ a = perm_comb_small(n, j, iscomb);
+ if (a == NULL) {
+ return NULL;
+ }
+ b = perm_comb_small(n - j, k - j, iscomb);
+ if (b == NULL) {
+ goto error;
+ }
+ Py_SETREF(a, PyNumber_Multiply(a, b));
+ Py_DECREF(b);
+ if (iscomb && a != NULL) {
+ b = perm_comb_small(k, j, 1);
+ if (b == NULL) {
+ goto error;
+ }
+ Py_SETREF(a, PyNumber_FloorDivide(a, b));
+ Py_DECREF(b);
+ }
+ return a;
+
+error:
+ Py_DECREF(a);
+ return NULL;
+}
+
+/* Calculate P(n, k) or C(n, k) using recursive formulas.
+ * It is more efficient than sequential multiplication thanks to
+ * Karatsuba multiplication.
+ */
+static PyObject *
+perm_comb(PyObject *n, unsigned long long k, int iscomb)
+{
+ if (k == 0) {
+ return PyLong_FromLong(1);
+ }
+ if (k == 1) {
+ Py_INCREF(n);
+ return n;
+ }
+
+ /* P(n, k) = P(n, j) * P(n-j, k-j) */
+ /* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
+ unsigned long long j = k / 2;
+ PyObject *a, *b;
+ a = perm_comb(n, j, iscomb);
+ if (a == NULL) {
+ return NULL;
+ }
+ PyObject *t = PyLong_FromUnsignedLongLong(j);
+ if (t == NULL) {
+ goto error;
+ }
+ n = PyNumber_Subtract(n, t);
+ Py_DECREF(t);
+ if (n == NULL) {
+ goto error;
+ }
+ b = perm_comb(n, k - j, iscomb);
+ Py_DECREF(n);
+ if (b == NULL) {
+ goto error;
+ }
+ Py_SETREF(a, PyNumber_Multiply(a, b));
+ Py_DECREF(b);
+ if (iscomb && a != NULL) {
+ b = perm_comb_small(k, j, 1);
+ if (b == NULL) {
+ goto error;
+ }
+ Py_SETREF(a, PyNumber_FloorDivide(a, b));
+ Py_DECREF(b);
+ }
+ return a;
+
+error:
+ Py_DECREF(a);
+ return NULL;
+}
+
/*[clinic input]
math.perm

@@ -3244,9 +3376,9 @@ static PyObject *
math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
/*[clinic end generated code: output=e021a25469653e23 input=5311c5a00f359b53]*/
{
- PyObject *result = NULL, *factor = NULL;
+ PyObject *result = NULL;
int overflow, cmp;
- long long i, factors;
+ long long ki, ni;

if (k == Py_None) {
return math_factorial(module, n);
@@ -3260,6 +3392,7 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
Py_DECREF(n);
return NULL;
}
+ assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));

if (Py_SIZE(n) < 0) {
PyErr_SetString(PyExc_ValueError,
@@ -3281,42 +3414,26 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
goto error;
}

- factors = PyLong_AsLongLongAndOverflow(k, &overflow);
+ ki = PyLong_AsLongLongAndOverflow(k, &overflow);
+ assert(overflow >= 0 && !PyErr_Occurred());
if (overflow > 0) {
PyErr_Format(PyExc_OverflowError,
"k must not exceed %lld",
LLONG_MAX);
goto error;
}
- else if (factors == -1) {
- /* k is nonnegative, so a return value of -1 can only indicate error */
- goto error;
- }
+ assert(ki >= 0);

- if (factors == 0) {
- result = PyLong_FromLong(1);
- goto done;
+ ni = PyLong_AsLongLongAndOverflow(n, &overflow);
+ assert(overflow >= 0 && !PyErr_Occurred());
+ if (!overflow && ki > 1) {
+ assert(ni >= 0);
+ result = perm_comb_small((unsigned long long)ni,
+ (unsigned long long)ki, 0);
}
-
- result = n;
- Py_INCREF(result);
- if (factors == 1) {
- goto done;
- }
-
- factor = Py_NewRef(n);
- PyObject *one = _PyLong_GetOne(); // borrowed ref
- for (i = 1; i < factors; ++i) {
- Py_SETREF(factor, PyNumber_Subtract(factor, one));
- if (factor == NULL) {
- goto error;
- }
- Py_SETREF(result, PyNumber_Multiply(result, factor));
- if (result == NULL) {
- goto error;
- }
+ else {
+ result = perm_comb(n, (unsigned long long)ki, 0);
}
- Py_DECREF(factor);

done:
Py_DECREF(n);
@@ -3324,14 +3441,11 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
return result;

error:
- Py_XDECREF(factor);
- Py_XDECREF(result);
Py_DECREF(n);
Py_DECREF(k);
return NULL;
}

-
/*[clinic input]
math.comb

@@ -3357,9 +3471,9 @@ static PyObject *
math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
/*[clinic end generated code: output=bd2cec8d854f3493 input=9a05315af2518709]*/
{
- PyObject *result = NULL, *factor = NULL, *temp;
+ PyObject *result = NULL, *temp;
int overflow, cmp;
- long long i, factors;
+ long long ki, ni;

n = PyNumber_Index(n);
if (n == NULL) {
@@ -3370,6 +3484,7 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
Py_DECREF(n);
return NULL;
}
+ assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));

if (Py_SIZE(n) < 0) {
PyErr_SetString(PyExc_ValueError,
@@ -3382,73 +3497,59 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
goto error;
}

- /* k = min(k, n - k) */
- temp = PyNumber_Subtract(n, k);
- if (temp == NULL) {
- goto error;
- }
- if (Py_SIZE(temp) < 0) {
- Py_DECREF(temp);
- result = PyLong_FromLong(0);
- goto done;
- }
- cmp = PyObject_RichCompareBool(temp, k, Py_LT);
- if (cmp > 0) {
- Py_SETREF(k, temp);
+ ni = PyLong_AsLongLongAndOverflow(n, &overflow);
+ assert(overflow >= 0 && !PyErr_Occurred());
+ if (!overflow) {
+ assert(ni >= 0);
+ ki = PyLong_AsLongLongAndOverflow(k, &overflow);
+ assert(overflow >= 0 && !PyErr_Occurred());
+ if (overflow || ki > ni) {
+ result = PyLong_FromLong(0);
+ goto done;
+ }
+ assert(ki >= 0);
+ ki = Py_MIN(ki, ni - ki);
+ if (ki > 1) {
+ result = perm_comb_small((unsigned long long)ni,
+ (unsigned long long)ki, 1);
+ goto done;
+ }
+ /* For k == 1 just return the original n in perm_comb(). */
}
else {
- Py_DECREF(temp);
- if (cmp < 0) {
+ /* k = min(k, n - k) */
+ temp = PyNumber_Subtract(n, k);
+ if (temp == NULL) {
goto error;
}
- }
-
- factors = PyLong_AsLongLongAndOverflow(k, &overflow);
- if (overflow > 0) {
- PyErr_Format(PyExc_OverflowError,
- "min(n - k, k) must not exceed %lld",
- LLONG_MAX);
- goto error;
- }
- if (factors == -1) {
- /* k is nonnegative, so a return value of -1 can only indicate error */
- goto error;
- }
-
- if (factors == 0) {
- result = PyLong_FromLong(1);
- goto done;
- }
-
- result = n;
- Py_INCREF(result);
- if (factors == 1) {
- goto done;
- }
-
- factor = Py_NewRef(n);
- PyObject *one = _PyLong_GetOne(); // borrowed ref
- for (i = 1; i < factors; ++i) {
- Py_SETREF(factor, PyNumber_Subtract(factor, one));
- if (factor == NULL) {
- goto error;
+ if (Py_SIZE(temp) < 0) {
+ Py_DECREF(temp);
+ result = PyLong_FromLong(0);
+ goto done;
}
- Py_SETREF(result, PyNumber_Multiply(result, factor));
- if (result == NULL) {
- goto error;
+ cmp = PyObject_RichCompareBool(temp, k, Py_LT);
+ if (cmp > 0) {
+ Py_SETREF(k, temp);
}
-
- temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
- if (temp == NULL) {
- goto error;
+ else {
+ Py_DECREF(temp);
+ if (cmp < 0) {
+ goto error;
+ }
}
- Py_SETREF(result, PyNumber_FloorDivide(result, temp));
- Py_DECREF(temp);
- if (result == NULL) {
+
+ ki = PyLong_AsLongLongAndOverflow(k, &overflow);
+ assert(overflow >= 0 && !PyErr_Occurred());
+ if (overflow) {
+ PyErr_Format(PyExc_OverflowError,
+ "min(n - k, k) must not exceed %lld",
+ LLONG_MAX);
goto error;
}
+ assert(ki >= 0);
}
- Py_DECREF(factor);
+
+ result = perm_comb(n, (unsigned long long)ki, 1);

done:
Py_DECREF(n);
@@ -3456,8 +3557,6 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
return result;

error:
- Py_XDECREF(factor);
- Py_XDECREF(result);
Py_DECREF(n);
Py_DECREF(k);
return NULL;

_______________________________________________
Python-checkins mailing list
Python-checkins@python.org
https://mail.python.org/mailman/listinfo/python-checkins