My attempt uses a different approach: create two sorted arrays, n^2
elements each; and then iterate over them looking for matching
elements (only one pass is required). I managed to get 58,2250612857 s
on my 1,7 MHz machine. It requires numpy for decent performance,
though.

import numpy
import time

def parse_input():
        al, bl, cl, dl = [], [], [], []
        for i in xrange(int(raw_input())):
                a, b, c, d = map(int, raw_input().split())
                al.append(a)
                bl.append(b)
                cl.append(c)
                dl.append(d)
        return al, bl, cl, dl

def count_zero_sums(al, bl, cl, dl):
        n = len(al) # Assume others are equal

        # Construct al extended (every element is repeated n times)
        ale = numpy.array(al).repeat(n)
        del al
        # Construct bl extended (whole array is repeated n times)
        ble = numpy.zeros((n*n,), int)
        for i in xrange(n): ble[i*n:(i+1)*n] = bl
        del bl
        # Construct abl - sorted list of all sums of a, b for a, b in al, bl
        abl = numpy.sort(ale + ble)
        del ale, ble

        # Construct cl extended (every element is repeated n times)
        cle = numpy.array(cl).repeat(n)
        del cl
        # Construct dl extended (whole array is repeated n times)
        dle = numpy.zeros((n*n,), int)
        for i in xrange(n): dle[i*n:(i+1)*n] = dl
        del dl
        # Construct cdl - sorted list of all negated sums of a, b for a, b in
cl, dl
        cdl = numpy.sort(-(cle + dle))
        del cle, dle

        # Iterate over arrays, count matching elements
        result = 0
        i, j = 0, 0
        n = n*n
        try:
                while True:
                        while abl[i] < cdl[j]:
                                i += 1
                        while abl[i] > cdl[j]:
                                j += 1
                        if abl[i] == cdl[j]:
                                # Found matching sequences
                                ii = i + 1
                                while ii < n and abl[ii] == abl[i]: ii += 1
                                jj = j + 1
                                while jj < n and cdl[jj] == cdl[j]: jj += 1
                                result += (ii - i)*(jj - j)
                                i, j = ii, jj
        except IndexError:
                pass

        return result

t = time.clock()
print count_zero_sums(*parse_input())
print time.clock() - t

-- 
http://mail.python.org/mailman/listinfo/python-list

Reply via email to