/**********************************************************************
 * Copyright (c) 2005 Scapa Technologies Limited and others
 * 
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/legal/epl-v10.html
 * 
 * Contributors: 
 * Scapa Technologies Limited - Initial API and implementation
 **********************	************************************************/

package org.eclipse.stp.b2j.core.jengine.internal.transport;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.ArrayList;
import java.util.HashMap;

import org.eclipse.stp.b2j.core.jengine.internal.message.BPlaneTransactionClient;
import org.eclipse.stp.b2j.core.jengine.internal.message.BPlaneTransactionServer;
import org.eclipse.stp.b2j.core.jengine.internal.message.Message;
import org.eclipse.stp.b2j.core.jengine.internal.message.TransactionListener;
import org.eclipse.stp.b2j.core.jengine.internal.transport.session.Session;
import org.eclipse.stp.b2j.core.jengine.internal.transport.session.SessionFactory;
import org.eclipse.stp.b2j.core.jengine.internal.utils.UIDPool;
import org.eclipse.stp.b2j.core.publicapi.transport.session.SessionAddress;

public class TransportTunnel implements TransactionListener {
	
	private static int UID_SHIFT = 10000;
	UIDPool pool = new UIDPool();
	
	SessionAddress address;
	Session session;
	BPlaneTransactionClient tclient;
	BPlaneTransactionServer tserver;
	
	static final int CREATE_SERVER = 10;
	static final int CREATE_FORWARDER = 20;
	static final int CREATE_FORWARDER_FAILED = 21;
	static final int KILL_FORWARDERS = 30;
	
	Object f_LOCK = new Object();
	HashMap f_to = new HashMap();
	HashMap f_from = new HashMap();
	
	public Message doTransaction(Message m) {
		int type = m.getType();
		if (type == CREATE_SERVER) {
			Integer local_port = (Integer)m.get(0);
			String remote_host = (String)m.get(1);
			Integer remote_port = (Integer)m.get(2);
			
			System.out.println("CREATE SERVER: "+local_port+" -> "+remote_host+":"+remote_port);
			
			//create a server on the local port here,
			Server server = new Server(local_port.intValue(),remote_host,remote_port.intValue());
			server.start();
			
		} else if (type == CREATE_FORWARDER) {
			String remote_host = (String)m.get(0);
			Integer remote_port = (Integer)m.get(1);
			Integer uid = (Integer)m.get(2);
			
			try {
				Socket sock = new Socket(remote_host,remote_port.intValue());
				
				createForwarder(sock,uid.intValue(),false);
				
			} catch (Exception e) {
				m = new Message(CREATE_FORWARDER_FAILED);
				m.append(""+e);
			}
		} else if (type == KILL_FORWARDERS) {
			int uid = ((Integer)m.get(0)).intValue();
			killForwarders(uid);
		}
		
		return m;
	}
	
	public void killForwarders(int uid) {
		synchronized(f_LOCK) {
			Integer n = new Integer(uid);
			Forwarder to = (Forwarder)f_to.get(n);
			Forwarder from = (Forwarder)f_from.get(n);
			
			to.finish = true;
			from.finish = true;
			
			try {
				to.interrupt();
			} catch (Throwable t) {
			}
			try {
				from.interrupt();
			} catch (Throwable t) {
			}
		}
	}
	
	public void createForwarder(Socket sock, int uid, boolean direction) throws IOException {
		synchronized(f_LOCK) {
			Forwarder to = new Forwarder(uid,direction && true,new BufferedInputStream(sock.getInputStream()),session.getOutputStream((short)(uid)));
			Forwarder from = new Forwarder(uid,false,session.getInputStream((short)(uid)),new BufferedOutputStream(sock.getOutputStream()));
			
			from.start();
			to.start();
		
			Integer n = new Integer(uid);
			f_to.put(n,to);
			f_from.put(n,from);
		}
	}
	
	class Forwarder extends Thread {
		InputStream in;
		OutputStream out;
		int uid;
		boolean finish = false;
		boolean killer = false;
		public Forwarder(int uid, boolean killer, InputStream in, OutputStream out) {
			this.uid = uid;
			this.in = in;
			this.out = out;
			this.killer = killer;
		}
		public void run() {
			try {
				if (killer) {
					System.out.println("Forwarder: >> "+(uid));
				} else {
					System.out.println("Forwarder: << "+(uid));
				}
				
				byte[] buf = new byte[4096];
				int n = 0;
				while (!finish) {
					n = in.read(buf,0,4096);
					if (n > 0) {
//System.out.println("READ "+n);						
						out.write(buf,0,n);
						out.flush();
					} else if (n == -1) {
						return;
					}
				}
			} catch (IOException e) {
			} catch (Exception e) {
			}
			if (killer) {
				//we only have one thread on each side do this
				
				try {
					killForwarders(uid);
				} catch (Throwable t) {
				}
					
				try {
					Message m = new Message(KILL_FORWARDERS);
					m.append(uid);
					tclient.doTransaction(m);
				} catch (Throwable t) {
				}
					
				try {
					Thread.sleep(5000);
				} catch (Throwable t) {
				}
				
				//release the UID
				pool.releaseUID(uid-UID_SHIFT);
			}
		}
	}
	
	class Server extends Thread {
		int port;
		
		String rhost;
		int rport;
		
		public Server(int port, String rhost, int rport) {
			this.port = port;
			this.rhost = rhost;
			this.rport = rport;
		}
		
		public void run() {
			
			try {
				ServerSocket server = new ServerSocket(port);

				while (true) {
					Socket sock = server.accept();
					
					try {
						
						//get a UID for this connection
						int uid = pool.getUID();
						
						//tell tclient to listen for info on UID + 100
						Message m = new Message(CREATE_FORWARDER);
						m.append(rhost);
						m.append(rport);
						m.append(UID_SHIFT+uid);
						
						m = tclient.doTransaction(m);

						//this overrides any previous connections on UID + 100
						
						if (m.getType() == CREATE_FORWARDER_FAILED) {
							throw new Exception("Forwarder create failed");
						}
						
						//Then start a socket forwarder + receiver on UID + 100 (other side should do same)
						createForwarder(sock,UID_SHIFT+uid,true);
						
						//when the forwarders die they will release the UID
					} catch (Exception e) {
						System.out.println("Forwarder ended: "+e);
					}
				}
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
	}
	
	
	public static void main(String[] argsarray) {
		new TransportTunnel(argsarray);
	}
	
	private void usage() {
		System.out.println("Usage: TransportTunnel <options>");
		System.out.println("Options:");
		System.out.println("  -h                  show this help");
		System.out.println("  -from PORT          accept incomming tunnel connections on PORT");
		System.out.println("  -to HOST PORT       create a new tunnel connection to HOST on PORT");
		System.out.println("  -ls PORT HOST PORT  create a local server forwarding data from local PORT to HOST:PORT");
		System.out.println("  -rs PORT HOST PORT  create a remote server forwarding data from remote PORT to HOST:PORT");
		System.out.println("  -rc TIMEOUT         automatic reconnect with a specified TIMEOUT in milliseconds");
		System.out.println("");
		System.out.println("E.g. TransportTunnel -from 10000");
		System.out.println("");
		System.out.println("     (then) TransportTunnel -to myhost 10000  -ls 80 www.google.com 80");
		System.exit(1);
	}
	
	public TransportTunnel(String[] argsarray) {
		try {
			ArrayList arglist = new ArrayList();
			for (int i = 0; i < argsarray.length; i++) {
				arglist.add(argsarray[i]);
			}
			
			short line = -1;
			
			for (int i = 0; i < arglist.size(); i++) {
				String arg = (String)arglist.get(i);
				if (arg.equals("-from")) {
					if (address != null) usage();
					
					arglist.remove(i);
					int port = Integer.parseInt((String)arglist.remove(i));
					
					address = new SessionAddress("localhost",SessionAddress.TRANSPORT_PORT_ANY,SessionAddress.TRANSPORT_PORT_ANY,"localhost",port,port);
					address.setRequiresMultipleStreams(true);
					session = SessionFactory.newSession(address,false);
					
					line = 0;
					
					UID_SHIFT = 10000;

					i--;

				} else if (arg.equals("-to")) {
					if (address != null) usage();
					
					arglist.remove(i);
					String host = (String)arglist.remove(i);
					int port = Integer.parseInt((String)arglist.remove(i));

					address = new SessionAddress("localhost",SessionAddress.TRANSPORT_PORT_ANY,SessionAddress.TRANSPORT_PORT_ANY,host,port,port);
					address.setRequiresMultipleStreams(true);
					session = SessionFactory.newSession(address,true);
					
					line = 1;

					UID_SHIFT = 20000;
					
					i--;
					
				} else if (arg.equals("-rc")) {
					arglist.remove(i);
					long timeout = Long.parseLong((String)arglist.remove(i));
					
					address.setRequiresLinkReconnection(true);
					address.setReconnectionFailureAbortTimeout(timeout);
					
					i--;
					
				} else if (arg.equals("-h")) {
					usage();
					
				}

			}
			
			//
			// Start the session
			//
			session.begin();
			
			if (session == null) {
				throw new Exception("You must specify one of '-from' or '-to'");
			}
			
			//
			// We now have our connection
			//
			
			System.out.println("INCOMING ON LINE "+line);
			tserver = new BPlaneTransactionServer(session.getInputStream(line),session.getOutputStream(line),this);
			line = (short)(1-line);
			System.out.println("OUTGOING ON LINE "+line);
			tclient = new BPlaneTransactionClient(session.getInputStream(line),session.getOutputStream(line));
			
			//
			// We now have communications between ourselves and the remote host
			//
			
			System.out.println("CONNECTION ESTABLISHED");
			
			//
			// Now we create any forwarding servers
			//
			
			for (int i = 0; i < arglist.size(); i++) {
				String arg = (String)arglist.get(i);

				if (arg.equals("-h")) {
					usage();
					
				} else if (arg.equals("-ls")) {
				
					arglist.remove(i);
					int lport = Integer.parseInt((String)arglist.remove(i));
					String rhost = (String)arglist.remove(i);
					int rport = Integer.parseInt((String)arglist.remove(i));
					
					System.out.println(lport+" ==> "+rhost+":"+rport);
					
					Message m = new Message(CREATE_SERVER);
					m.append(lport);
					m.append(rhost);
					m.append(rport);
					
					doTransaction(m);
					
					i--;
					
				} else if (arg.equals("-rs")) {

					arglist.remove(i);
					int rport = Integer.parseInt((String)arglist.remove(i));
					String lhost = (String)arglist.remove(i);
					int lport = Integer.parseInt((String)arglist.remove(i));

					System.out.println(lhost+":"+lport+" <== "+rport);
					
					Message m = new Message(CREATE_SERVER);
					m.append(rport);
					m.append(lhost);
					m.append(lport);

					tclient.doTransaction(m);
					
					i--;
					
				} else {
					System.err.println("Unrecognised argument: "+arg);
					System.exit(1);
				}
			}
			
			System.out.println("SETUP COMPLETE");
			
		} catch (NumberFormatException e) {
			System.out.println("Invalid number, "+e);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}
}