import java.util.*;
import java.io.*;
import java.net.*;

import com.sun.mail.util.LineInputStream;

/*
 This class implement a simple TCP proxy that listens on a port
 and forwards connections and data to another port. It creates two threads
 for each connection. The threads die after the connections close.
 */

public class TCPProxy implements Runnable {

    ServerSocket listenSocket;
    String remoteHost;
    int remotePort;

    /* Since binding to a particular port is not guaranteed to work (ports numbers are a
        scarce resource on a host-basis), we allow a range of ports to be specified, and
        keep trying upon a failure until the entire range has been tested. */

    public TCPProxy(String aRemoteHost, int aRemotePort, int firstLocalPortToTry, int lastLocalPortToTry) throws IOException {
        int portToTry;
        remoteHost = aRemoteHost;
        remotePort = aRemotePort;
        for (portToTry=firstLocalPortToTry;portToTry<=lastLocalPortToTry;portToTry++) {
            try {
                listenSocket = new ServerSocket(portToTry);
            } catch (IOException ex) {
                // we failed... oh well, listenSocket will be nil, so we will continue to loop
            }
            if (listenSocket != null) break;
        }
        if (listenSocket == null) throw new IOException("can't create server socket");
        new Thread(this).start();
    }

    class PortForwardingThread implements Runnable {
        InputStream input;
        OutputStream output;
        private boolean	_useProxy;
        private String	_realHost;
        private int		_realPort;
        private byte[]	_realPortBytes;

		/**
			@ param	useProxy	If false, realHost and realPort may be null.
		*/
        public PortForwardingThread(InputStream is, OutputStream os,
        							boolean useProxy, String realHost,
        							int realPort) {
            input = is;
            output = os;
            _useProxy = useProxy;
            _realHost = realHost;
            _realPort = realPort;
			System.out.println("PortForwardingThread: useProxy = " + useProxy + ", realHost = " + realHost + ", realPort = " + realPort);
            if (useProxy)
            {
            	if ( ! input.getClass().getName().equals("com.sun.mail.util.LineInputStream"))
            		input = new LineInputStream(input);
            	try {
            		_realPortBytes
            				= String.valueOf(realPort).getBytes("US-ASCII");
				} catch (UnsupportedEncodingException e) {
					e.printStackTrace();
				}
            }
            new Thread(this).start();
        }
        public void run() {
        	System.out.println(new Date() + "PortForwardingThread.run() 1");
            while (true) {
            	byte bytes[] = new byte[256];
                try {
                	if ( ! _useProxy)
                	{
//System.out.println("simple proxy: about to read");
	                    int readCount = input.read(bytes);
//System.out.println("simple proxy: read input = " + new String(bytes));
			            if (readCount <= 0)
			            {
//System.out.println("simple proxy: readCount <= 0");
	                    	break;
	                    }
	                    output.write(bytes, 0, readCount);
	                }
	                else
	                {
	                	if (sendIncomingToRemoteProxy())
	                		break;
	                }
                } catch (IOException ex) {
                		ex.printStackTrace();
                	break;
                }
            }
        	//System.out.println(new Date() + "PortForwardingThread.run() 2");
            try { input.close(); } catch (IOException ex) { /*ex.printStackTrace();*/ }
        	//System.out.println(new Date() + "PortForwardingThread.run() 3");
            try { output.close(); } catch (IOException ex) { /*ex.printStackTrace();*/ }
        	//System.out.println(new Date() + "PortForwardingThread.run() 4");
        }

		/**
			Returns true if end of stream encountered, false otherwise
		*/
		private boolean sendIncomingToRemoteProxy() throws IOException
		{
//System.out.println("sendIncomingToRemoteProxy");
			boolean foundHostHeader = false;
			boolean	foundRequestLine = false;
			boolean	foundContentLength = false;
			LineInputStream incoming = (LineInputStream)input;
			String hostString = "host: ";
			int hostStringLength = hostString.length();
			String line = null;
			int httpIndex, getIndex = -1, postIndex;
			int contentLength = 0;
			
			while (true)
			{
//System.out.println("About to readLine()");
				line = incoming.readLine();
//System.out.println("line = " + line);
				if (line == null)
					return true;

				if (!foundHostHeader && line.length() > hostStringLength
						&& hostString.equalsIgnoreCase(
										line.substring(0, hostStringLength)))
				{
//System.out.println("Found host header, writing a modified one");
					foundHostHeader = true;
					writeHostHeader();
				}
				else if (line.equals(""))
				{
//System.out.println("Blank line");
					if ( ! foundHostHeader)
					{
//System.out.println("Didn't find host header, writing one");
						writeHostHeader();
					}
            		output.write(END_OF_LINE_BYTES);
            		output.write(END_OF_LINE_BYTES);
            		break;
				}
				else if (!foundRequestLine
							&& (getIndex = line.indexOf("GET")) != -1
							|| line.indexOf("POST") != -1)
				{
//System.out.println("GET or POST line");
					foundRequestLine = true;
					if ((httpIndex = line.indexOf(HTTP_PROTOCOL_START)) != -1)
					{
//System.out.println("Replacing host in \"http://<host>\"");
						int slashIndex
								= line.indexOf("/", httpIndex
													+ HTTP_PROTOCOL_START_LENGTH);
						line = line.substring(0, httpIndex)
								+ "http://" + _realHost + ":" + _realPort
								+ line.substring(slashIndex, line.length());
	            		output.write(line.getBytes("US-ASCII"));
            		}
            		else
            		{
//System.out.println("Putting host in relative URI");
            			int uriStartIndex;
            			if (getIndex != -1)
            				uriStartIndex = getIndex + GET_STRING_LENGTH;
            			else
            			{
 							postIndex = line.indexOf("POST");
           					uriStartIndex = postIndex + POST_STRING_LENGTH;
						}
            			line = line.substring(0, uriStartIndex)
								+ "http://" + _realHost + ":" + _realPort
								+ line.substring(uriStartIndex, line.length());
	            		output.write(line.getBytes("US-ASCII"));
            		}
				}
				else if (!foundContentLength
						&& line.length() > CONTENT_LENGTH_STRING_LENGTH
						&& CONTENT_LENGTH_STRING.equalsIgnoreCase(
								line.substring(0, CONTENT_LENGTH_STRING_LENGTH)))
				{
//System.out.println("Getting content-length");
					String contentLengthValueString
							= line.substring(CONTENT_LENGTH_STRING_LENGTH,
												line.length());
					try {
						contentLength
								= Integer.valueOf(contentLengthValueString).intValue();
					} catch (NumberFormatException e) {
						e.printStackTrace();
					}
//System.out.println("content-length = " + contentLength);
	            	output.write(line.getBytes("US-ASCII"));
				}
				else
				{
//System.out.println("Sending line back out unchanged");
            		output.write(line.getBytes("US-ASCII"));
				}

            	output.write(END_OF_LINE_BYTES);
			}
//System.out.println("Fell through");
			if (contentLength > 0)
				return passThroughContent(contentLength);
//System.out.println("Done with request");
			return false;
		}

		private boolean passThroughContent(int contentLength) throws IOException
		{
			byte bytes[] = new byte[256];
			int totalReadCount = 0;
//System.out.println("Checking available");
			while (true)
			{
//System.out.println("About to read");
	            int readCount = input.read(bytes);
//System.out.println("Read");
	            if (readCount <= 0)
	            {
//System.out.println("proxy proxy: readCount <= 0");
	            	return true;
	            }
//System.out.println("Writing: " + new String(bytes, 0, readCount));
	            output.write(bytes, 0, readCount);
	            totalReadCount += readCount;
	            if (totalReadCount == contentLength)
	            {
//System.out.println("1");
//		            readCount = input.read(bytes);
//System.out.println("2");
//		            output.write(bytes, 0, readCount);
//System.out.println("3");
	            	break;
	            }
            }
            
            return false;
		}

		private void writeHostHeader() throws IOException
		{
//System.out.println("Writing host header, _realHost = " + _realHost);
//System.out.println("Writing host header, _realPort = " + _realPort);
			output.write(HOST_BYTES);
			output.write(_realHost.getBytes("US-ASCII"));
			output.write((byte)':');
			output.write(_realPortBytes);
		}
    }

    public int localListenPort() {
        return listenSocket.getLocalPort();
    }

    public void forwardConnection() throws Exception {
		System.out.println(new Date() + "TCPProxy.forwardConnection() A");
        Socket acceptedPort = listenSocket.accept();
		System.out.println(new Date() + "TCPProxy.forwardConnection() B");
		//
		int forwardRetryCount = 10;
        Socket forwardedPort = null;
        Exception remoteSocketException = null;
        
        InputStream	incomingStream = null;
        boolean		haveProxyInfo = false;
        String 		connectHost = remoteHost;
        int			connectPort = remotePort;
        
		{
			boolean haveProxy = Boolean.getBoolean("proxySet");
			if (haveProxy)
			{
				String proxyHost = System.getProperty("http.proxyHost");
				Integer proxyPortInteger = null;
				try {
					proxyPortInteger = Integer.getInteger("http.proxyPort");
				} catch (NumberFormatException e) {}
				
				if (proxyHost != null && proxyPortInteger != null)
				{
					haveProxyInfo = true;
					connectHost = proxyHost;
					connectPort = proxyPortInteger.intValue();
					System.out.println(">>>connectHost " + connectHost + " connectPort " + connectPort);
				}
			}
		}

		if (haveProxyInfo)
			incomingStream = new LineInputStream(acceptedPort.getInputStream());
		else
			incomingStream = acceptedPort.getInputStream();
        
        while (forwardRetryCount-- > 0) {
            try {
         		System.out.println(new Date() + "TCPProxy.forwardConnection() 1");
				System.out.println("connectPort = " + connectPort);
				forwardedPort = new Socket(connectHost, connectPort);
         		System.out.println(new Date() + "TCPProxy.forwardConnection() 2");
           } catch (IOException ex) {
            	remoteSocketException = ex;
            }
            if (forwardedPort != null) break;
        }
        if (forwardedPort == null)
        {
        	acceptedPort.close();
        	if (remoteSocketException != null)
        		throw remoteSocketException;
        	else
        		throw new IOException("Can't connect to forward socket");
        }
        System.out.println(new Date() + "TCPProxy.forwardConnection() 3");
        
		new TCPProxy.PortForwardingThread(incomingStream,
											forwardedPort.getOutputStream(),
											haveProxyInfo, remoteHost, remotePort);
        System.out.println(new Date() + "TCPProxy.forwardConnection() 4");
        new TCPProxy.PortForwardingThread(forwardedPort.getInputStream(),
        									acceptedPort.getOutputStream(),
        									false, null, 0);
        System.out.println(new Date() + "TCPProxy.forwardConnection() 5");
    }

    public void run() {
        while (true) {
            try {
                forwardConnection();
            } catch (Exception ex) {
                // We got an exception. But who cares? Try again.
                ex.printStackTrace();
            }
        }
    }
/*
	class LineInputStream extends InputStream
	{
		private InputStream _source;
	
		public LineInputStream(InputStream sourceStream)
		{
			_source = sourceStream;
		}
		
		public String readLine()
		{
			
		}
	}
*/
    public static void main (String args[]) {
        System.out.println("Hello World!");
        try {
            TCPProxy proxy = new TCPProxy("www.myhost.com", 80, 8080, 8888);
            System.out.println("proxy port -> " + proxy.localListenPort());
            while (true) {
                Thread.sleep(10000);
            }
        } catch (IOException ex) {}
          catch (InterruptedException ex) {}
        System.out.println("Goodbye World!");
    }
    
	private static final byte[] HOST_BYTES = {(byte)'H',
												(byte)'o',
												(byte)'s',
												(byte)'t',
												(byte)':',
												(byte)' '};
	private static final byte[] END_OF_LINE_BYTES = {(byte)'\r',(byte)'\n'};
	private static final String HTTP_PROTOCOL_START = "http://";
	private static final int HTTP_PROTOCOL_START_LENGTH
			= HTTP_PROTOCOL_START.length();
	private static final String POST_STRING = "POST ";
	private static final int POST_STRING_LENGTH
			= POST_STRING.length();
	private static final String GET_STRING = "GET ";
	private static final int GET_STRING_LENGTH
			= GET_STRING.length();
	private static final String CONTENT_LENGTH_STRING = "Content-Length: ";
	private static final int CONTENT_LENGTH_STRING_LENGTH
			= CONTENT_LENGTH_STRING.length();
}
