On Sun, 22 Mar 2020 at 22:16, Tom Lane <t...@sss.pgh.pa.us> wrote:
>
> With resolutions of the XXX items, I think this'd be committable.
>

Thanks for looking at this!

Here is an updated patch with the following updates based on your comments:

* Now uses integer arithmetic to compute res_weight and res_ndigits,
instead of floor() and ceil().

* New comment giving a more detailed explanation of how blen is
chosen, and why it must sometimes examine the first digit of the input
and reduce blen by 1 (which can occur at any step, as shown in the
example given).

* New comment giving a proof that the number of steps required is
guaranteed to be less than 32.

* New comment explaining why the initial integer square root using
Newton's method is guaranteed to converge. I couldn't find a formal
reference for this, but there's a Wikipedia article on it -
https://en.wikipedia.org/wiki/Integer_square_root and I think it's a
well-known result in the field.

Regards,
Dean
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
new file mode 100644
index 10229eb..9986132
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -393,16 +393,6 @@ static const NumericVar const_ten =
 #endif
 
 #if DEC_DIGITS == 4
-static const NumericDigit const_zero_point_five_data[1] = {5000};
-#elif DEC_DIGITS == 2
-static const NumericDigit const_zero_point_five_data[1] = {50};
-#elif DEC_DIGITS == 1
-static const NumericDigit const_zero_point_five_data[1] = {5};
-#endif
-static const NumericVar const_zero_point_five =
-{1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_five_data};
-
-#if DEC_DIGITS == 4
 static const NumericDigit const_zero_point_nine_data[1] = {9000};
 #elif DEC_DIGITS == 2
 static const NumericDigit const_zero_point_nine_data[1] = {90};
@@ -518,6 +508,8 @@ static void div_var_fast(const NumericVa
 static int	select_div_scale(const NumericVar *var1, const NumericVar *var2);
 static void mod_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result);
+static void div_mod_var(const NumericVar *var1, const NumericVar *var2,
+						NumericVar *quot, NumericVar *rem);
 static void ceil_var(const NumericVar *var, NumericVar *result);
 static void floor_var(const NumericVar *var, NumericVar *result);
 
@@ -7712,6 +7704,7 @@ div_var_fast(const NumericVar *var1, con
 			 NumericVar *result, int rscale, bool round)
 {
 	int			div_ndigits;
+	int			load_ndigits;
 	int			res_sign;
 	int			res_weight;
 	int		   *div;
@@ -7766,9 +7759,6 @@ div_var_fast(const NumericVar *var1, con
 	div_ndigits += DIV_GUARD_DIGITS;
 	if (div_ndigits < DIV_GUARD_DIGITS)
 		div_ndigits = DIV_GUARD_DIGITS;
-	/* Must be at least var1ndigits, too, to simplify data-loading loop */
-	if (div_ndigits < var1ndigits)
-		div_ndigits = var1ndigits;
 
 	/*
 	 * We do the arithmetic in an array "div[]" of signed int's.  Since
@@ -7781,9 +7771,16 @@ div_var_fast(const NumericVar *var1, con
 	 * (approximate) quotient digit and stores it into div[], removing one
 	 * position of dividend space.  A final pass of carry propagation takes
 	 * care of any mistaken quotient digits.
+	 *
+	 * Note that div[] doesn't necessarily contain all of the digits from the
+	 * dividend --- the desired precision plus guard digits might be less than
+	 * the dividend's precision.  This happens, for example, in the square
+	 * root algorithm, where we typically divide a 2N-digit number by an
+	 * N-digit number, and only require a result with N digits of precision.
 	 */
 	div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
-	for (i = 0; i < var1ndigits; i++)
+	load_ndigits = Min(div_ndigits, var1ndigits);
+	for (i = 0; i < load_ndigits; i++)
 		div[i + 1] = var1digits[i];
 
 	/*
@@ -7844,9 +7841,15 @@ div_var_fast(const NumericVar *var1, con
 			maxdiv += Abs(qdigit);
 			if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
 			{
-				/* Yes, do it */
+				/*
+				 * Yes, do it.  Note that if var2ndigits is much smaller than
+				 * div_ndigits, we can save a significant amount of effort
+				 * here by noting that we only need to normalise those div[]
+				 * entries touched where prior iterations subtracted multiples
+				 * of the divisor.
+				 */
 				carry = 0;
-				for (i = div_ndigits; i > qi; i--)
+				for (i = Min(qi + var2ndigits - 2, div_ndigits); i > qi; i--)
 				{
 					newdig = div[i] + carry;
 					if (newdig < 0)
@@ -8095,6 +8098,76 @@ mod_var(const NumericVar *var1, const Nu
 
 
 /*
+ * div_mod_var() -
+ *
+ *	Calculate the truncated integer quotient and numeric remainder of two
+ *	numeric variables.  The remainder is precise to var2's dscale.
+ */
+static void
+div_mod_var(const NumericVar *var1, const NumericVar *var2,
+			NumericVar *quot, NumericVar *rem)
+{
+	NumericVar	q;
+	NumericVar	r;
+
+	init_var(&q);
+	init_var(&r);
+
+	/*
+	 * Use div_var_fast() to get an initial estimate for the integer quotient.
+	 * This might be inaccurate (per the warning in div_var_fast's comments),
+	 * but we can correct it below.
+	 */
+	div_var_fast(var1, var2, &q, 0, false);
+
+	/* Compute initial estimate of remainder using the quotient estimate. */
+	mul_var(var2, &q, &r, var2->dscale);
+	sub_var(var1, &r, &r);
+
+	/*
+	 * Adjust the results if necessary --- the remainder should have the same
+	 * sign as var1, and its absolute value should be less than the absolute
+	 * value of var2.
+	 */
+	while (r.ndigits != 0 && r.sign != var1->sign)
+	{
+		/* The absolute value of the quotient is too large */
+		if (var1->sign == var2->sign)
+		{
+			sub_var(&q, &const_one, &q);
+			add_var(&r, var2, &r);
+		}
+		else
+		{
+			add_var(&q, &const_one, &q);
+			sub_var(&r, var2, &r);
+		}
+	}
+
+	while (cmp_abs(&r, var2) >= 0)
+	{
+		/* The absolute value of the quotient is too small */
+		if (var1->sign == var2->sign)
+		{
+			add_var(&q, &const_one, &q);
+			sub_var(&r, var2, &r);
+		}
+		else
+		{
+			sub_var(&q, &const_one, &q);
+			add_var(&r, var2, &r);
+		}
+	}
+
+	set_var_from_var(&q, quot);
+	set_var_from_var(&r, rem);
+
+	free_var(&q);
+	free_var(&r);
+}
+
+
+/*
  * ceil_var() -
  *
  *	Return the smallest integer greater than or equal to the argument
@@ -8213,18 +8286,30 @@ gcd_var(const NumericVar *var1, const Nu
 /*
  * sqrt_var() -
  *
- *	Compute the square root of x using Newton's algorithm
+ *	Compute the square root of x using the Karatsuba Square Root algorithm.
+ *	NOTE: we allow rscale < 0 here, implying rounding before the decimal
+ *	point.
  */
 static void
 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
 {
-	NumericVar	tmp_arg;
-	NumericVar	tmp_val;
-	NumericVar	last_val;
-	int			local_rscale;
 	int			stat;
-
-	local_rscale = rscale + 8;
+	int			res_weight;
+	int			res_ndigits;
+	int			src_ndigits;
+	int			step;
+	int			ndigits[32];
+	int			blen;
+	int64		arg_int64;
+	int			src_idx;
+	int64		s_int64;
+	int64		r_int64;
+	NumericVar	s_var;
+	NumericVar	r_var;
+	NumericVar	a0_var;
+	NumericVar	a1_var;
+	NumericVar	q_var;
+	NumericVar	u_var;
 
 	stat = cmp_var(arg, &const_zero);
 	if (stat == 0)
@@ -8243,43 +8328,440 @@ sqrt_var(const NumericVar *arg, NumericV
 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
 				 errmsg("cannot take square root of a negative number")));
 
-	init_var(&tmp_arg);
-	init_var(&tmp_val);
-	init_var(&last_val);
+	init_var(&s_var);
+	init_var(&r_var);
+	init_var(&a0_var);
+	init_var(&a1_var);
+	init_var(&q_var);
+	init_var(&u_var);
 
-	/* Copy arg in case it is the same var as result */
-	set_var_from_var(arg, &tmp_arg);
+	/*
+	 * The result weight is half the input weight, rounded towards minus
+	 * infinity --- res_weight = floor(arg->weight / 2).
+	 */
+	if (arg->weight >= 0)
+		res_weight = arg->weight / 2;
+	else
+		res_weight = -((-arg->weight - 1) / 2 + 1);
 
 	/*
-	 * Initialize the result to the first guess
+	 * Number of NBASE digits to compute.  To ensure correct rounding, compute
+	 * at least 1 extra decimal digit.  We explicitly allow rscale to be
+	 * negative here, but must always compute at least 1 NBASE digit.  Thus
+	 * res_ndigits = res_weight + 1 + ceil((rscale + 1) / DEC_DIGITS) or 1.
 	 */
-	alloc_var(result, 1);
-	result->digits[0] = tmp_arg.digits[0] / 2;
-	if (result->digits[0] == 0)
-		result->digits[0] = 1;
-	result->weight = tmp_arg.weight / 2;
-	result->sign = NUMERIC_POS;
+	if (rscale + 1 >= 0)
+		res_ndigits = res_weight + 1 + (rscale + DEC_DIGITS) / DEC_DIGITS;
+	else
+		res_ndigits = res_weight + 1 - (-rscale - 1) / DEC_DIGITS;
+	res_ndigits = Max(res_ndigits, 1);
 
-	set_var_from_var(result, &last_val);
+	/*
+	 * Number of source NBASE digits logically required to produce a result
+	 * with this precision --- every digit before the decimal point, plus 2
+	 * for each result digit after the decimal point (or minus 2 for each
+	 * result digit we round before the decimal point).
+	 */
+	src_ndigits = arg->weight + 1 + (res_ndigits - res_weight - 1) * 2;
+	src_ndigits = Max(src_ndigits, 1);
 
-	for (;;)
+	/* ----------
+	 * From this point on, we treat the input and the result as integers and
+	 * compute the integer square root and remainder using the Karatsuba
+	 * Square Root algorithm, which may be written recursively as follows:
+	 *
+	 *	SqrtRem(n = a3*b^3 + a2*b^2 + a1*b + a0):
+	 *		[ for some base b, and coefficients a0,a1,a2,a3 chosen so that
+	 *		  0 <= a0,a1,a2 < b and a3 >= b/4 ]
+	 *		Let (s,r) = SqrtRem(a3*b + a2)
+	 *		Let (q,u) = DivRem(r*b + a1, 2*s)
+	 *		Let s = s*b + q
+	 *		Let r = u*b + a0 - q^2
+	 *		If r < 0 Then
+	 *			Let r = r + s
+	 *			Let s = s - 1
+	 *			Let r = r + s
+	 *		Return (s,r)
+	 *
+	 * See "Karatsuba Square Root", Paul Zimmermann, INRIA Research Report
+	 * RR-3805, November 1999.  At the time of writing this was available
+	 * on the net at <https://hal.inria.fr/inria-00072854>.
+	 *
+	 * The way to read the assumption "n = a3*b^3 + a2*b^2 + a1*b + a0" is
+	 * "choose a base b such that n requires at least four base-b digits to
+	 * express; then those digits are a3,a2,a1,a0, with a3 possibly larger
+	 * than b".  For optimal performance, b should have approximately a
+	 * quarter the number of digits in the input, so that the outer square
+	 * root computes roughly twice as many digits as the inner one.  For
+	 * simplicity, we choose b = NBASE^blen, an integer power of NBASE.
+	 *
+	 * We implement the algorithm iteratively rather than recursively, to
+	 * allow the working variables to be reused.  With this approach, each
+	 * digit of the input is read precisely once --- src_idx tracks the number
+	 * of input digits used so far.
+	 *
+	 * The array ndigits[] holds the number of NBASE digits of the input that
+	 * will have been used at the end of each iteration, which roughly doubles
+	 * each time.  Note that the array elements are stored in reverse order,
+	 * so if the final iteration requires src_ndigits = 37 input digits, the
+	 * array will contain [37,19,11,7,5,3], and we would start by computing
+	 * the square root of the 3 most significant NBASE digits.
+	 *
+	 * In each iteration, we choose blen to be the largest integer for which
+	 * the input number has a3 >= b/4, when written in the form above.  In
+	 * general, this means blen = src_ndigits / 4 (truncated), but if
+	 * src_ndigits is a multiple of 4, that might lead to the coefficient a3
+	 * being less than b/4 (if the first input digit is less than NBASE/4), in
+	 * which case we choose blen = src_ndigits / 4 - 1.  The number of digits
+	 * in the inner square root is then src_ndigits - 2*blen.  So, for
+	 * example, if we have src_ndigits = 26 initially, the array ndigits[]
+	 * will be either [26,14,8,4] or [26,14,8,6,4], depending on the size of
+	 * the first input digit.
+	 *
+	 * Additionally, we can put an upper bound on the number of steps required
+	 * as follows --- suppose that the number of source digits is an n-bit
+	 * number in the range [2^(n-1), 2^n-1], then blen will be in the range
+	 * [2^(n-3)-1, 2^(n-2)-1] and the number of digits in the inner square
+	 * root will be in the range [2^(n-2), 2^(n-1)+1].  In the next step, blen
+	 * will be in the range [2^(n-4)-1, 2^(n-3)] and the number of digits in
+	 * the next inner square root will be in the range [2^(n-3), 2^(n-2)+1].
+	 * This pattern repeats, and in the worst case the array ndigits[] will
+	 * contain [2^n-1, 2^(n-1)+1, 2^(n-2)+1, ... 9, 5, 3], and the computation
+	 * will require n steps.  Therefore, since all digit array sizes are
+	 * signed 32-bit integers, the number of steps required is guaranteed to
+	 * be less than 32.
+	 * ----------
+	 */
+	step = 0;
+	while ((ndigits[step] = src_ndigits) > 4)
 	{
-		div_var_fast(&tmp_arg, result, &tmp_val, local_rscale, true);
+		/* Choose b so that a3 >= b/4, as described above */
+		blen = src_ndigits / 4;
+		if (blen * 4 == src_ndigits && arg->digits[0] < NBASE / 4)
+			blen--;
 
-		add_var(result, &tmp_val, result);
-		mul_var(result, &const_zero_point_five, result, local_rscale);
+		/* Number of digits in the next step (inner square root) */
+		src_ndigits -= 2 * blen;
+		step++;
+	}
 
-		if (cmp_var(&last_val, result) == 0)
-			break;
-		set_var_from_var(result, &last_val);
+	/*
+	 * First iteration (innermost square root and remainder):
+	 *
+	 * Here src_ndigits <= 4, and the input fits in an int64.  Its square root
+	 * has at most 9 decimal digits, so estimate it using double precision
+	 * arithmetic, which will in fact almost certainly return the correct
+	 * result with no further correction required.
+	 */
+	arg_int64 = arg->digits[0];
+	for (src_idx = 1; src_idx < src_ndigits; src_idx++)
+	{
+		arg_int64 *= NBASE;
+		if (src_idx < arg->ndigits)
+			arg_int64 += arg->digits[src_idx];
 	}
 
-	free_var(&last_val);
-	free_var(&tmp_val);
-	free_var(&tmp_arg);
+	s_int64 = (int64) sqrt((double) arg_int64);
+	r_int64 = arg_int64 - s_int64 * s_int64;
 
-	/* Round to requested precision */
+	/*
+	 * Use Newton's method to correct the result, if necessary.
+	 *
+	 * This uses integer division with truncation to compute the truncated
+	 * integer square root by iterating using the formula x -> (x + n/x) / 2.
+	 * This is known to converge to isqrt(n), unless n+1 is a perfect square.
+	 * If n+1 is a perfect square, the sequence will oscillate between the two
+	 * values isqrt(n) and isqrt(n)+1, so we can be assured of convergence by
+	 * checking the remainder.
+	 */
+	while (r_int64 < 0 || r_int64 > 2 * s_int64)
+	{
+		s_int64 = (s_int64 + arg_int64 / s_int64) / 2;
+		r_int64 = arg_int64 - s_int64 * s_int64;
+	}
+
+	/*
+	 * Iterations with src_ndigits <= 8:
+	 *
+	 * The next 1 or 2 iterations compute larger (outer) square roots with
+	 * src_ndigits <= 8, so the result still fits in an int64 (even though the
+	 * input no longer does) and we can continue to compute using int64
+	 * variables to avoid more expensive numeric computations.
+	 *
+	 * It is fairly easy to see that there is no risk of the intermediate
+	 * values below overflowing 64-bit integers.  In the worst case, the
+	 * previous iteration will have computed a 3-digit square root (of a
+	 * 6-digit input less than NBASE^6 / 4), so at the start of this
+	 * iteration, s will be less than NBASE^3 / 2 = 10^12 / 2, and r will be
+	 * less than 10^12.  In this case, blen will be 1, so numer will be less
+	 * than 10^17, and denom will be less than 10^12 (and hence u will also be
+	 * less than 10^12).  Finally, since q^2 = u*b + a0 - r, we can also be
+	 * sure that q^2 < 10^17.  Therefore all these quantities fit comfortably
+	 * in 64-bit integers.
+	 */
+	step--;
+	while (step >= 0 && (src_ndigits = ndigits[step]) <= 8)
+	{
+		int			b;
+		int			a0;
+		int			a1;
+		int			i;
+		int64		numer;
+		int64		denom;
+		int64		q;
+		int64		u;
+
+		blen = (src_ndigits - src_idx) / 2;
+
+		/* Extract a1 and a0, and compute b */
+		a0 = 0;
+		a1 = 0;
+		b = 1;
+
+		for (i = 0; i < blen; i++, src_idx++)
+		{
+			b *= NBASE;
+			a1 *= NBASE;
+			if (src_idx < arg->ndigits)
+				a1 += arg->digits[src_idx];
+		}
+
+		for (i = 0; i < blen; i++, src_idx++)
+		{
+			a0 *= NBASE;
+			if (src_idx < arg->ndigits)
+				a0 += arg->digits[src_idx];
+		}
+
+		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+		numer = r_int64 * b + a1;
+		denom = 2 * s_int64;
+		q = numer / denom;
+		u = numer - q * denom;
+
+		/* Compute s = s*b + q and r = u*b + a0 - q^2 */
+		s_int64 = s_int64 * b + q;
+		r_int64 = u * b + a0 - q * q;
+
+		if (r_int64 < 0)
+		{
+			/* s is too large by 1; set r += s, s--, r += s */
+			r_int64 += s_int64;
+			s_int64--;
+			r_int64 += s_int64;
+		}
+
+		Assert(src_idx == src_ndigits); /* All input digits consumed */
+		step--;
+	}
+
+	/*
+	 * On platforms with 128-bit integer support, we can further delay the
+	 * need to use numeric variables.
+	 */
+#ifdef HAVE_INT128
+	if (step >= 0)
+	{
+		int128		s_int128;
+		int128		r_int128;
+
+		s_int128 = s_int64;
+		r_int128 = r_int64;
+
+		/*
+		 * Iterations with src_ndigits <= 16:
+		 *
+		 * The result fits in an int128 (even though the input doesn't) so we
+		 * use int128 variables to avoid more expensive numeric computations.
+		 */
+		while (step >= 0 && (src_ndigits = ndigits[step]) <= 16)
+		{
+			int64		b;
+			int64		a0;
+			int64		a1;
+			int64		i;
+			int128		numer;
+			int128		denom;
+			int128		q;
+			int128		u;
+
+			blen = (src_ndigits - src_idx) / 2;
+
+			/* Extract a1 and a0, and compute b */
+			a0 = 0;
+			a1 = 0;
+			b = 1;
+
+			for (i = 0; i < blen; i++, src_idx++)
+			{
+				b *= NBASE;
+				a1 *= NBASE;
+				if (src_idx < arg->ndigits)
+					a1 += arg->digits[src_idx];
+			}
+
+			for (i = 0; i < blen; i++, src_idx++)
+			{
+				a0 *= NBASE;
+				if (src_idx < arg->ndigits)
+					a0 += arg->digits[src_idx];
+			}
+
+			/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+			numer = r_int128 * b + a1;
+			denom = 2 * s_int128;
+			q = numer / denom;
+			u = numer - q * denom;
+
+			/* Compute s = s*b + q and r = u*b + a0 - q^2 */
+			s_int128 = s_int128 * b + q;
+			r_int128 = u * b + a0 - q * q;
+
+			if (r_int128 < 0)
+			{
+				/* s is too large by 1; set r += s, s--, r += s */
+				r_int128 += s_int128;
+				s_int128--;
+				r_int128 += s_int128;
+			}
+
+			Assert(src_idx == src_ndigits); /* All input digits consumed */
+			step--;
+		}
+
+		/*
+		 * All remaining iterations require numeric variables.  Convert the
+		 * integer values to NumericVar and continue.  Note that in the final
+		 * iteration we don't need the remainder, so we can save a few cycles
+		 * there by not fully computing it.
+		 */
+		int128_to_numericvar(s_int128, &s_var);
+		if (step >= 0)
+			int128_to_numericvar(r_int128, &r_var);
+	}
+	else
+	{
+		int64_to_numericvar(s_int64, &s_var);
+		/* step < 0, so we certainly don't need r */
+	}
+#else							/* !HAVE_INT128 */
+	int64_to_numericvar(s_int64, &s_var);
+	if (step >= 0)
+		int64_to_numericvar(r_int64, &r_var);
+#endif							/* HAVE_INT128 */
+
+	/*
+	 * The remaining iterations with src_ndigits > 8 (or 16, if have int128)
+	 * use numeric variables.
+	 */
+	while (step >= 0)
+	{
+		int			tmp_len;
+
+		src_ndigits = ndigits[step];
+		blen = (src_ndigits - src_idx) / 2;
+
+		/* Extract a1 and a0 */
+		if (src_idx < arg->ndigits)
+		{
+			tmp_len = Min(blen, arg->ndigits - src_idx);
+			alloc_var(&a1_var, tmp_len);
+			memcpy(a1_var.digits, arg->digits + src_idx,
+				   tmp_len * sizeof(NumericDigit));
+			a1_var.weight = blen - 1;
+			a1_var.sign = NUMERIC_POS;
+			a1_var.dscale = 0;
+			strip_var(&a1_var);
+		}
+		else
+		{
+			zero_var(&a1_var);
+			a1_var.dscale = 0;
+		}
+		src_idx += blen;
+
+		if (src_idx < arg->ndigits)
+		{
+			tmp_len = Min(blen, arg->ndigits - src_idx);
+			alloc_var(&a0_var, tmp_len);
+			memcpy(a0_var.digits, arg->digits + src_idx,
+				   tmp_len * sizeof(NumericDigit));
+			a0_var.weight = blen - 1;
+			a0_var.sign = NUMERIC_POS;
+			a0_var.dscale = 0;
+			strip_var(&a0_var);
+		}
+		else
+		{
+			zero_var(&a0_var);
+			a0_var.dscale = 0;
+		}
+		src_idx += blen;
+
+		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+		set_var_from_var(&r_var, &q_var);
+		q_var.weight += blen;
+		add_var(&q_var, &a1_var, &q_var);
+		add_var(&s_var, &s_var, &u_var);
+		div_mod_var(&q_var, &u_var, &q_var, &u_var);
+
+		/* Compute s = s*b + q */
+		s_var.weight += blen;
+		add_var(&s_var, &q_var, &s_var);
+
+		/*
+		 * Compute r = u*b + a0 - q^2.
+		 *
+		 * In the final iteration, we don't actually need r; we just need to
+		 * know whether it is negative, so that we know whether to adjust s.
+		 * So instead of the final subtraction we can just compare.
+		 */
+		u_var.weight += blen;
+		add_var(&u_var, &a0_var, &u_var);
+		mul_var(&q_var, &q_var, &q_var, 0);
+
+		if (step > 0)
+		{
+			/* Need r for later iterations */
+			sub_var(&u_var, &q_var, &r_var);
+			if (r_var.sign == NUMERIC_NEG)
+			{
+				/* s is too large by 1; set r += s, s--, r += s */
+				add_var(&r_var, &s_var, &r_var);
+				sub_var(&s_var, &const_one, &s_var);
+				add_var(&r_var, &s_var, &r_var);
+			}
+		}
+		else
+		{
+			/* Don't need r anymore, except to test if s is too large by 1 */
+			if (cmp_var(&u_var, &q_var) < 0)
+				sub_var(&s_var, &const_one, &s_var);
+		}
+
+		Assert(src_idx == src_ndigits); /* All input digits consumed */
+		step--;
+	}
+
+	/*
+	 * Construct the final result, rounding it to the requested precision.
+	 */
+	set_var_from_var(&s_var, result);
+	result->weight = res_weight;
+	result->sign = NUMERIC_POS;
+
+	/* Round to target rscale (and set result->dscale) */
 	round_var(result, rscale);
+
+	/* Strip leading and trailing zeroes */
+	strip_var(result);
+
+	free_var(&s_var);
+	free_var(&r_var);
+	free_var(&a0_var);
+	free_var(&a1_var);
+	free_var(&q_var);
+	free_var(&u_var);
 }
 
 
@@ -8530,12 +9012,18 @@ ln_var(const NumericVar *arg, NumericVar
 	 * Each sqrt() will roughly halve the weight of x, so adjust the local
 	 * rscale as we work so that we keep this many significant digits at each
 	 * step (plus a few more for good measure).
+	 *
+	 * Note that we allow local_rscale < 0 during this input reduction
+	 * process, which implies rounding before the decimal point.  sqrt_var()
+	 * explicitly supports this, and it significantly reduces the work
+	 * required to reduce very large inputs to the required range.  Once the
+	 * input reduction is complete, x.weight will be 0 and its display scale
+	 * will be non-negative again.
 	 */
 	nsqrt = 0;
 	while (cmp_var(&x, &const_zero_point_nine) <= 0)
 	{
 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
 		sqrt_var(&x, &x, local_rscale);
 		mul_var(&fact, &const_two, &fact, 0);
 		nsqrt++;
@@ -8543,7 +9031,6 @@ ln_var(const NumericVar *arg, NumericVar
 	while (cmp_var(&x, &const_one_point_one) >= 0)
 	{
 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
 		sqrt_var(&x, &x, local_rscale);
 		mul_var(&fact, &const_two, &fact, 0);
 		nsqrt++;
diff --git a/src/test/regress/expected/numeric.out b/src/test/regress/expected/numeric.out
new file mode 100644
index 23a4c6d..c7fe63d
--- a/src/test/regress/expected/numeric.out
+++ b/src/test/regress/expected/numeric.out
@@ -1580,6 +1580,57 @@ select div(12345678901234567890, 123) *
 (1 row)
 
 --
+-- Test some corner cases for square root
+--
+select sqrt(1.000000000000003::numeric);
+       sqrt        
+-------------------
+ 1.000000000000001
+(1 row)
+
+select sqrt(1.000000000000004::numeric);
+       sqrt        
+-------------------
+ 1.000000000000002
+(1 row)
+
+select sqrt(96627521408608.56340355805::numeric);
+        sqrt         
+---------------------
+ 9829929.87811248648
+(1 row)
+
+select sqrt(96627521408608.56340355806::numeric);
+        sqrt         
+---------------------
+ 9829929.87811248649
+(1 row)
+
+select sqrt(515549506212297735.073688290367::numeric);
+          sqrt          
+------------------------
+ 718017761.766585921184
+(1 row)
+
+select sqrt(515549506212297735.073688290368::numeric);
+          sqrt          
+------------------------
+ 718017761.766585921185
+(1 row)
+
+select sqrt(8015491789940783531003294973900306::numeric);
+       sqrt        
+-------------------
+ 89529278953540017
+(1 row)
+
+select sqrt(8015491789940783531003294973900307::numeric);
+       sqrt        
+-------------------
+ 89529278953540018
+(1 row)
+
+--
 -- Test code path for raising to integer powers
 --
 select 10.0 ^ -2147483648 as rounds_to_zero;
diff --git a/src/test/regress/sql/numeric.sql b/src/test/regress/sql/numeric.sql
new file mode 100644
index c5c8d76..41475a9
--- a/src/test/regress/sql/numeric.sql
+++ b/src/test/regress/sql/numeric.sql
@@ -883,6 +883,19 @@ select div(12345678901234567890, 123);
 select div(12345678901234567890, 123) * 123 + 12345678901234567890 % 123;
 
 --
+-- Test some corner cases for square root
+--
+
+select sqrt(1.000000000000003::numeric);
+select sqrt(1.000000000000004::numeric);
+select sqrt(96627521408608.56340355805::numeric);
+select sqrt(96627521408608.56340355806::numeric);
+select sqrt(515549506212297735.073688290367::numeric);
+select sqrt(515549506212297735.073688290368::numeric);
+select sqrt(8015491789940783531003294973900306::numeric);
+select sqrt(8015491789940783531003294973900307::numeric);
+
+--
 -- Test code path for raising to integer powers
 --
 

Reply via email to