I noticed that gcc's div-by-constant optimization is a bit
suboptimal and want to improve it. I submitted it to
bugzilla:
http://gcc.gnu.org/bugzilla/show_bug.cgi?id=28417

but I think "big guys" have no time for such a low-impact thing.
I want to do it myself.

However, I need a little helt in understanding
what "precision" parameter is for in this function.
The comments in the source are unfortunately are too terse for me.
My current understanding of the functions internals are below.
Please read on or jump to [*] marker below if you are impatient.

So the routine tries to find such m so that:

x div d == x * m >> POST_SHIFT

and m is at max N+1 bit.

As far as I can see, at first routine picks a largest possible
POST_SHIFT:

lgup = ceil(log2(d)); POST_SHIFT = N + lgup;

Then it finds mlow and mhigh so that:

x * mlow >> POST_SHIFT
    <= x * m >> POST_SHIFT
        <= x * mhigh >> POST_SHIFT

Using v >> POST_SHIFT == v div (1<<(N+lgup)) we can rewrite it as
["div" denotes an integer division, unlike real division "/"):

x * mlow div (1<<(N+lgup))
    <= x * m div (1<<(N+lgup))
        <= x * mhigh div (1<<(N+lgup))

"Precise" m is equal to (1<<(N+lgup))/d (a real value, not integer).
Substitute it into middle expr. Then obviously
mlow = floor(m) = (1<<(N+lgup)) div d.

For mhigh: for arbitrary unsigned a,b: (a+b-1) div b >= a/b.
So we want to add a value bigger than d-1 to the dividend.
Let's add 1<<N:

x * ((1<<(N+lgup)) div d) div (1<<(N+lgup))
    <=  x * (1<<(N+lgup))/d div (1<<(N+lgup))
        <= x * (((1<<(N+lgup)) + 1<<N) div d) div (1<<(N+lgup))

Then routine checks whether POST_SHIFT can be lowered.

The inefficiency here is that routine does not take into account
_at which value of x_  x * m >> POST_SHIFT will fail.
It will always fail at some x = n*d - 1, and for large d
these values are rather scarce -> we may be lucky and not hit
such a value!

An example:

choose_multiplier(d=1577682821,n=32,precision=32) returns
        *post_shift_ptr=31,multiplier=5846151023
whereas optimal one is
        *post_shift_ptr=27,multiplier=365384439

and it is correct wrt algorithm: mlow=5846151022 < m < mhigh=5846151023.

But the catch is that 5846151024 _too_ will "by chance of d being large"
work for any 32-bit x, and 5846151024/16 = 365384439 will also work
(because of all those zeroes in low-order bits)!

IOW: "mlow < m < mhigh" algorithm is not optimal.
It misses potentially better values.

[*]
I have a better alrorithm. See attachment if you are curious.

I'd like to put it instead of current one, but I don't
understand the role of the "precision" parameter in the current code:
Currently comment says:

   Choose a minimal N + 1 bit approximation to 1/D that can be used to
   replace division by D, and put the least significant N bits of the result
   in *MULTIPLIER_PTR and return the most significant bit.

Ok, this one I understand.

   The width of operations is N (should be <= HOST_BITS_PER_WIDE_INT),

Width of _which_ ops? x * m is N * N = N bits? or N * N = 2N? Or what??

   the needed precision is in PRECISION (should be <= N).

The "needed" precision? Does this mean that that x*m>>shift is allowed
to deviate by +/- 1<<(N-PRECISION) - 1 from true result x/d? Or what??

I'd prefer much more verbose comments here.... HEEEEEELP :)
Maybe an example of the call with N and PRECISION which
are not equal to eaqch other and not equal to 32 will be helpful.

Gory details of choose_multiplier(d=1577682821,n=32,precision=32)
are below:

  lgup = ceil_log2 (d); //// 31
  pow = n + lgup; //// 63
  pow2 = n + lgup - precision; //// 31

  /* mlow = 2^(N + lgup)/d */
 if (pow >= HOST_BITS_PER_WIDE_INT) //// yes
    {
      nh = (HOST_WIDE_INT) 1 << (pow - HOST_BITS_PER_WIDE_INT); //// 1<<31
      nl = 0;
    }
  else
    {
      nh = 0;
      nl = (unsigned HOST_WIDE_INT) 1 << pow;
    }
//// 1<<63 / d: mlow=5846151022
  div_and_round_double (TRUNC_DIV_EXPR, 1, nl, nh, d, (HOST_WIDE_INT) 0,
                        &mlow_lo, &mlow_hi, &dummy1, &dummy2);

  /* mhigh = (2^(N + lgup) + 2^N + lgup - precision)/d */
  if (pow2 >= HOST_BITS_PER_WIDE_INT) //// no
    nh |= (HOST_WIDE_INT) 1 << (pow2 - HOST_BITS_PER_WIDE_INT);
  else
    nl |= (unsigned HOST_WIDE_INT) 1 << pow2; //// 1<<31
//// (1<<63 + 1<<31) / d: mhigh=5846151023 (5846151023.661466 in fp)
  div_and_round_double (TRUNC_DIV_EXPR, 1, nl, nh, d, (HOST_WIDE_INT) 0,
                        &mhigh_lo, &mhigh_hi, &dummy1, &dummy2);

--
vda

/*
[below: 'div' is unsigned integer division]
['/' is real division with infinite precision]
[A,B,C... - integers, a,b,c... - reals]

Theory: we want to calculate A div B, fast.
B is const > 2 and is not a power of 2.

In fp: A/B = A*(L/B)/L. (If L is a large power of 2,
say 2^32, then it could be done really fast).
Let k := L/B, K := L div B + 1.
Then A/B = A*k/L.

Then this is true:

A div B <= A * K div L.

For small enough A: A div B = A * K div L.
Lets find first A for which it is not true.

Lets compare k/L and K/L. K/L is larger by a small value d:

d := K/L - k/L = (L div B + 1) / L - L/B/L =
= (L div B * B + B) / (L*B) - L/(L*B) =
= (L div B * B + B - L) / (L*B)

A*K/L is larger than A*k/L by A*d.

A*k/L is closest to 'overflowing into next integer'
when A = N*B-1. The 'deficit' with such A is exactly 1/B.
If A*d >= 1/B, then A*K/L will 'overflow'.

Thus bad_A >= 1/B / d = (1/B) * (L*B)/(L div B * B + B - L) =
= L/(L div B * B + B - L). Then you need to 'walk up' to next
A representable as N*B-1: bad_A2 = (bad_A div B) * B + B-1
This is the first A which will have wrong result
(i.e. for which (A*K div L) = (A div B)+1, not (A div B).

Practical use.

Suppose that A and B are 32-bit unsigned integers.

Unfortunately, there exist only two B values in 3..2^32-1 range
for which _any_ 32-bit unsigned A can be fast divided using L=2^32
(because bad_A=2^32 and any A is less than that):

B=641 K=6700417 d=1/2753074036736 bad_A=4294967296 A=unlimited
B=6700417 K=641 d=1/28778071884562432 bad_A=4294967296 A=unlimited

We need to use larger powers of 2 for L if we need to handle
many more B's.

The below code prints fastdiv parameters for all B values
in range 3..2^32-1
*/

#include <stdint.h>
#include <stdio.h>
#include <sys/time.h>

int main()
{
    enum { max_unsigned = (unsigned)(~0) };
    enum { max_ull = (uint64_t)(~0ULL) };

    unsigned A, B, K, d_LB, fast, exact, count, mincount, bits;
    uint64_t L, KL, mask, bad_A, bad_A2;

    setbuf(stdout, NULL);

    mincount = max_unsigned;
    //for (B = 3; B; B++) {
    for (B = 1726956429; B; B++) {
//(B & 0xfffff) || printf("B=%u\r", B);
	if (((B-1)&B) == 0) { // B is a power of 2
		printf("B=%u - power of 2, division by shift\n", B);
		continue;
	}
	// We need max L for which max_unsigned * (L/B + 1) <= max_ull
	L = (max_ull/max_unsigned - 1) * B;
	bits = 63;
	mask = 1ULL << 63;
	while( !(L & mask))
		bits--, mask >>= 1;
	L = mask;
	while ( (KL = L/B + 1) > max_unsigned)
		bits--, L >>= 1;
	K = KL;
	while (!(K & 1) && bits > 32)
		bits--, L >>= 1, K = L/B + 1;
	
	d_LB = ((L/B) * B + B - L);
	bad_A = L / d_LB;
	bad_A2 = (bad_A / B) * B + B-1;
	if (bad_A2 > max_unsigned) {
//if(0)
		printf("B=%u L=%llx K=%u d=%u/%llu bad_A=%llu (*) A=unlimited\n",
			B, L, K, d_LB, L*B, bad_A);
		continue;
	}
	A = bad_A2;
//printf("A_start=%u\n", A);
	count = 0;
	while(1) {
		fast = (((uint64_t)A)*K) >> bits;
		exact = A / B;
		if (fast != exact) {
//if(count || !(A & 0xfffe0000))
			printf("B=%u L=%llx K=%u d=%u/%llu"
				" bad_A2=%u fast=%u exact=%u cnt=%u\n",
				B, L, K, d_LB, L*B,
				A, fast, exact, count);
			while(A >= B) {
				// testing: is bad_A formula exact?
				A -= B;
				fast = (((uint64_t)A)*K) >> bits;
				exact = A / B;
				if (fast != exact) {
					// should never happen
					printf("WOW! B=%u L=%llx K=%u d=%u/%llu"
						" bad_A2=%u fast=%u exact=%u\n",
						B, L, K, d_LB, L*B,
						A, fast, exact);
				} else
					break;
			}
			break;
		}
		if (A > max_unsigned - B) {
//printf("A_end=%u\n", A);
//if(0)
			printf("B=%u L=%llx K=%u d=%u/%llu bad_A=%llu"
				" A=unlimited\n",
				B, L, K, d_LB, L*B, bad_A);
			break;
		}
		count++;
		A += B;
	}
	// TODO random test....
	if(mincount > count) mincount = count;
    }
    printf("B=%u (0x%x). done. mincount=%d\n", B-1, B-1, mincount);
    
    return 0;
}

Reply via email to