#!/usr/bin/env python
'''
Uncoded matrix-vector multiplication
'''

from mpi4py import MPI
import numpy as np
import random
import threading
import time

# Change to True for more accurate timing, sacrificing performance
barrier = True

##################### Parameters ########################
#Conditions that need to be met for the parameters: 
#k|q^(k-1)
#kq|m
#(kq choose k)|q^(k-1)m
#k|n

#Use one master and N workers. Note that N should be k*q. 
#SOS You need to check these values if you load the matrices.###############################################################################
k = 2
q = 2
N = k*q

#Number of jobs i.e. number of matrix-vector multiplications c = A*b for different A, b. 
#SOS You need to check these values if you load the matrices.###############################################################################
J = q**(k-1)

#Set to 0 to load pregenerated list of A,b from current directory or 1 to generate them. Note that parameters of this script and of pregenerateAb_mat_vec.py should match.
#Also, matrices have to be in a folder "pregenerateAb_mat_vec" in parent directory
loadAb = 1

#Bound of the elements of the input matrices i.e. those should be in [0,...,B-1]
B = 3

#Input matrix size - A: m by n, b: n by 1. 
#SOS You need to check these values if you load the matrices.###############################################################################
m = 4
n = 4

#########################################################

comm = MPI.COMM_WORLD

#Check for wrong number of MPI processes
if comm.size != N+1:
	print("The number of MPI processes mismatches the number of workers.")
	comm.Abort(1)
	
if comm.rank == 0:
	# Master
	print "Running with %d processes:" % comm.Get_size()

	print "CAMR, N=%d workers, k=%d, q=%d, J=%d, m=%d, n=%d, B=%d" % (N, k, q, J, m, n, B)
	
	bp_start = time.time()
	
	#Create random matrices or load them from files. Now it doesn't make sense to use np.int64.
	A = []
	b = []
	if loadAb == 0:
	    for i in range(J):
			A.append(np.matrix(np.random.randint(0,B,(m,n))))
			b.append(np.matrix(np.random.randint(0,B,(n,1))))
		
		#test		
		# np.savez('A_list', *A)
		# np.savez('b_list',*b)
		
	elif loadAb == 1:
		A_file = np.load('../pregenerateAb_mat_vec/A_list.npz')
		b_file = np.load('../pregenerateAb_mat_vec/b_list.npz')
		for i in range(J):
			A.append(A_file['arr_%d'%i])
			b.append(b_file['arr_%d'%i])

	#test
	for i in range (J):
		print('A[%d] is: ' % i)
		print(A[i])
		print('b[%d] is: ' % i)
		print(b[i])
	
	
	#Split horizontally 
	Ah = [] 
	for i in range(J):
		Ah.append(np.split(A[i], q))
		
	#Split vertically 
	Ahv = []
	bhv = []
	for i in range(J):
		Ahv_tmp = []
		for j in range(q):
			Ahv_tmp.append(np.split(Ah[i][j], k, axis=1))
			
		Ahv.append(Ahv_tmp)
		bhv.append(np.split(b[i], k, axis=0))
	
	#test
	# print(Ahv)
	# print(bhv)
	
	# Initialize return dictionary
	Crtn = []
	for i in range(J):
		tmp_list = []
		
		#Only q reducers are used per job (one for each block row)
		for j in range(q):
			tmp_list.append(np.zeros((m//q, 1), dtype=np.int_))
			
		Crtn.append(tmp_list)  

	# Start requests to send and receive
	reqA = [None] * J * N
	reqb = [None] * J * N
	reqc = [None] * J * q

	bp_start = time.time()

	for i in range(J):
		for j in range(N):
			# print("Sending...")
			# print("i:",i)
			# print("j:",j)
			# print(Ahv[i][j//k][j%k])
			# print(bhv[i][j%k])
			reqA[i*N+j] = comm.Isend([np.ascontiguousarray(Ahv[i][j//k][j%k]), MPI.INT], dest=j+1, tag=15)
			reqb[i*N+j] = comm.Isend([np.ascontiguousarray(bhv[i][j%k]), MPI.INT], dest=j+1, tag=29)
			
		#Only q reducers are used per job (one for each block row)
		#In our convention, the reducers of the i-th job are the machines i%k+0*k, i%k+1*k, ... , i%k+(q-1)*k, if indexing starts from zero. The +1 is due to the MPI-assigned ranks.
		for j in range(q):
			reqc[i*q+j] = comm.Irecv([Crtn[i][j], MPI.INT], source=i%k+j*k+1, tag=42)

	MPI.Request.Waitall(reqA)
	MPI.Request.Waitall(reqb)

	# Optionally wait for all workers to receive their submatrices, for more accurate timing
	if barrier:
		comm.Barrier()

	bp_sent = time.time()
	print "Time spent sending all messages is: %f" % (bp_sent - bp_start)

	MPI.Request.Waitall(reqc)
	bp_received = time.time()
	# print(type(Crtn[0]))
	print "Time spent waiting for all workers is: %f" % (bp_received - bp_sent)
	
	# print(Crtn[0])
	
	#For each job, concatenate results of the reducers. save returned product to file.
	c = []
	for i in range(J):
		cur_col = np.empty((0,1), int)

		#construct column
		for j in range(q):
			#Test
			# print("Crtn", Crtn[i][j])
			cur_col = np.append(cur_col, Crtn[i][j], axis=0)
		
		#concatenate column
		c.append(cur_col)

	#Test
	# print("c_list", c)
	np.savez('c_list', *c)


else:
	
	#For each job, create matrices (m/N)x(n) and (n)x(1). This allocation is just for speedup
	A = []
	b = []
	rA = [None] * J
	rb = [None] * J
	for i in range(J):
		A.append(np.empty_like(np.matrix([[0]*(n//k) for j in range(m//q)])))
		b.append(np.empty_like(np.matrix([[0]*(1) for j in range(n//k)])))
		rA[i] = comm.Irecv(A[i], source=0, tag=15)
		rb[i] = comm.Irecv(b[i], source=0, tag=29)
		
	MPI.Request.Waitall(rA)
	MPI.Request.Waitall(rb)

	#test
	# for i in range(J):
		# print "For job %d, worker %d received splits of A, b" % (i, comm.Get_rank()-1)
		# print(A[i])
		# print(b[i])
	
	if barrier:
		comm.Barrier()
	wbp_received = time.time()

	#Test
	# print("I started multiplying. ", comm.rank)
	
	c = []
	for i in range(J):
		c.append(A[i]*b[i])
		# c.append(np.matmul(A[i], b[i]))
	
	#Test
	# print("I completed. ", comm.rank)
	
	wbp_done = time.time()
	print "Worker %d computing takes: %f\n" % (comm.Get_rank()-1, wbp_done - wbp_received)

	#Initialize reduction dictionary for the number of jobs that I am a reducer
	Crtn = []
	for i in range(J*q//N):
		tmp_list = []
		
		#k-1 computations will be received for each reduction
		for j in range(k-1):
			tmp_list.append(np.zeros((m//q, 1), dtype=np.int_))
			
		Crtn.append(tmp_list)  

	#After the local computation, each reducer is receiving k-1 local computations from other workers
	red_so_far = 0
	reduced_blk = []
	for i in range(J):
		if (comm.rank-1)%k == i%k:
			#I am a mapper/reducer
			mappers = [(comm.rank-1)+j for j in range(-(i%k),k-(i%k))]
			mappers.remove(comm.rank-1)
			
			# print "Worker %d reduces job %d whose other mappers are %s" % (comm.Get_rank()-1, i, mappers)
			
			#Receive computations from the remaining mappers of the block row
			red_req = [None] * (k-1)
			for j in range(k-1):
				red_req[j] = comm.Irecv([Crtn[red_so_far][j], MPI.INT], source=mappers[j]+1, tag=5)
			MPI.Request.Waitall(red_req)
			reduced_blk.append(np.sum(Crtn[red_so_far], axis=0) + c[i])
			red_so_far = red_so_far + 1
			
			# print "Worker %d reduced job %d" % (comm.Get_rank()-1, i)
			# print(reduced_blk[red_so_far-1])
			
		else:
		
			#I am a mapper
			
			# print "Worker %d maps job %d" % (comm.Get_rank()-1, i)
			
			#Determine the appropriate reducer's rank
			reducer_c = i%k
			reducer_r = (comm.rank-1)//k
			map_req = comm.Isend(c[i], dest=reducer_r*k + reducer_c + 1, tag=5)
			MPI.Request.Wait(map_req)
			# print "Worker %d on row %d, mapped job %d and sent the data to reducer %d" % (comm.Get_rank()-1, reducer_r, i, reducer_r*k + reducer_c) 
			# print(c[i])
	
	

	print "Worker %d finished MapReduce" % (comm.Get_rank()-1)
	
	#Return all reductions to master for concatenation
	master_req = [None] * (J*q//N)
	for i in range(J*q//N):
		master_req[i] = comm.Isend(reduced_blk[i], dest=0, tag=42)
	MPI.Request.Waitall(master_req)
