/**********************************************************************
 * Copyright (c) 2005 Scapa Technologies Limited and others
 * 
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/legal/epl-v10.html
 * 
 * Contributors: 
 * Scapa Technologies Limited - Initial API and implementation
 **********************************************************************/

package org.eclipse.stp.b2j.core.jengine.internal.mutex;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;

import org.eclipse.stp.b2j.core.jengine.internal.message.Message;

/**
 * @author amiguel
 * 
 * A message passing class used in the local engine implementation
 */
public class MultiQueuedBlockingMap {

///////////////////////////////////////////////////////////
//
// internal single mutex implementation
//
	private Object mutex_LOCK = new Object();
	private int mutex_val = 1;
	
	private void mutex_lock() throws InterruptedException {
		synchronized(mutex_LOCK) {
			mutex_val--;
			if (mutex_val < 0) {
				mutex_LOCK.wait();
			}	
		}
	}
	private void mutex_release() {
		synchronized(mutex_LOCK) {
			mutex_val++;
			if (mutex_val > 1) mutex_val = 1; 
			if (mutex_val <= 0) mutex_LOCK.notify();
		}
	}
//
///////////////////////////////////////////////////////////

	
HashMap pool = new HashMap();

	private void removeLock(Object key) {
		pool.remove(key);
//System.out.println("REMOVED "+key+" / "+pool.size());			
	}

	private Lock getLock(Object key) {
		Lock notifier;
		
		notifier = (Lock)pool.get(key);
		if (notifier == null) {
//System.out.println("GETLOCK:created new notifier for key "+key);			
			notifier = new Lock();
			pool.put(key,notifier);
		}
		
		return notifier;
	}
	
	private void removeRecipient(Recipient recipient) {
		Object[] keys = recipient.keys;
		for (int i = 0; i < keys.length; i++) {
			Lock tmp = getLock(keys[i]);

			//remove the recipient from conversation [i]
			tmp.recipients.remove(recipient);
			
			if (tmp.isEmpty()) {
				removeLock(keys[i]);
			}
		}
	}
	
	public Message putAndGet(String key, Message value, String[] getkeys, long timeout) throws InterruptedException {
		
		//XXX LOCK WORLD
		mutex_lock();

		boolean no_timeout = false;
		
		//
		//Put the message
		//
		
		Lock notifier = getLock(key);

		if (notifier.recipients.size() > 0) {
			//something is waiting for a message already
			
			//we know this message will be processed - so never time out
			no_timeout = true;
			
			Recipient recipient = (Recipient)notifier.recipients.getFirst();
			Object[] keys = recipient.keys;
			
			synchronized(recipient) {

				//set the recipient's value field
				recipient.value = value;
				recipient.value.append(key);
			
				//We need to remove it now otherwise another PUT might come along and override us
				removeRecipient(recipient);
				
				//XXX UNLOCK WORLD
				mutex_release();

				recipient.notify();
			}
			
		} else {
			//nothing is waiting for a message, queue it
			notifier.messages.addLast(value);

			//XXX UNLOCK WORLD
			mutex_release();
		}
		
		//do the get
		try {
			Message ret;
			
			if (no_timeout) {
				//Don't timeout if we know the message will be processed
				ret = get(getkeys,0);
			} else {
				//Message may not be processed to we'll check in 'timeout' milliseconds when we time out
				ret = get(getkeys,timeout);
			}
			
			return ret;
		} catch (InterruptedException e) {
			//get() timed out

			//XXX LOCK WORLD
			mutex_lock();

			Iterator it = notifier.messages.iterator();
			while (it.hasNext()) {
				Object o = it.next();
				if (o == value) {
					//message IS NOT being processed - OK to remove and time out
					
					//remove our message from the IN queue...
					notifier.messages.remove(o);
					
					//XXX UNLOCK WORLD
					mutex_release();
					
					//...and time out
					throw e;
				}
			}
			/*
			//check to see if put message is being processed (otherwise its a race condition)
			for (int i = 0; i < notifier.messages.size(); i++) {
				if (notifier.messages.get(i) == value) {
					//message IS NOT being processed - OK to remove and time out
					
					//remove our message from the IN queue...
					notifier.messages.remove(i);
					
					//XXX UNLOCK WORLD
					mutex_release();
					
					//...and time out
					throw e;
				}
			}*/
			
			//couldn't find our message - it therefore IS being processed
			
			//XXX UNLOCK WORLD
			mutex_release();

			//message IS being processed so do a get and don't time out
			Message ret = get(getkeys,0);
			return ret;
			
		}
	}
	
	//non blocking put
	public void put(String key, Message value) throws InterruptedException {

		//XXX LOCK WORLD
		mutex_lock();
		
		Lock notifier = getLock(key);

		if (notifier.recipients.size() > 0) {
			//something is waiting for a message already
			
			Recipient recipient = (Recipient)notifier.recipients.getFirst();
			Object[] keys = recipient.keys;
			
			synchronized(recipient) {

				//set the recipient's value field
				recipient.value = value;
				recipient.value.append(key);
			
				//We need to remove it now otherwise another PUT might come along and override us
				removeRecipient(recipient);
				
				//XXX UNLOCK WORLD
				mutex_release();

				recipient.notify();
			}
			
		} else {
			//nothing is waiting for a message, queue it
			notifier.messages.addLast(value);

			//XXX UNLOCK WORLD
			mutex_release();
		}
		
	}
	
	public Message get(String[] keylist, long timeout) throws InterruptedException {
		
		//TODO could get this from a pool rather than suffering object creation/deletion all the time
		Recipient recipient = new Recipient();
		recipient.keys = keylist;
		recipient.value = null;

		boolean timed_out = false;

		//XXX LOCK WORLD
		mutex_lock();

		//this has the effect of randomly shuffling the key list to prevent starvation of conversations
		//~8 million Math.random calls per second means no loss of performance here
		int keyindex = (int)(Math.random() * keylist.length);
		
		//look for existing messages in any of the conversations
		for (int i = 0; i < keylist.length; i++) {
			
			keyindex++;
			if (keyindex == keylist.length) keyindex = 0;
			
			Lock lock = getLock(keylist[keyindex]);
			
			if (lock.messages.size() > 0) {
				//get (and remove) the first message in this queue
				recipient.value = (Message)lock.messages.removeFirst();
				recipient.value.append(keylist[keyindex]);

				//if the lock is empty remove it from the map
				if (lock.isEmpty()) {
					removeLock(keylist[keyindex]);
				}
				
				//we've got a message so break out
				break;
			}
		}

		if (recipient.value == null) {
			//we found no existing message, we'll have to wait for one
			
			//wait on our notification object (Recipient)
			synchronized(recipient) {
				
				//add ourselves as a listener on each lock
				for (int k = 0; k < keylist.length; k++) {
					Lock lock = getLock(keylist[k]);
					lock.recipients.add(recipient);
				}
			
				//XXX UNLOCK WORLD
				mutex_release();
				recipient.wait(timeout);
			}
			
			//We do this rather than just locking mutex_lock inside the sync(recipient) block
			//because we MUST always lock them in the same order otherwise we could get deadlock

			//XXX LOCK WORLD
			mutex_lock();
				synchronized(recipient) {
					
					//If we timed out then we must remove ourselves from all conversations
					if (recipient.value == null) {
						timed_out = true;
						removeRecipient(recipient);
					}
				}
			//XXX UNLOCK WORLD
			mutex_release();
			
		} else {
			//XXX UNLOCK WORLD
			mutex_release();
		}
		
		if (timed_out) {
			throw new InterruptedException("Timed out ("+timeout+"ms)");
		}

		return recipient.value;		
	}
	
	class Lock {
		LinkedList messages = new LinkedList();
		LinkedList recipients = new LinkedList();
		
		public boolean isEmpty() {
			return messages.size() == 0 && recipients.size() == 0;
		}
	}	

	class Recipient {
		Object[] keys;
		Message value;
	}
	/*
//////////////////////////////////////////////////////
//
// TESTING 
//
	private static final int TEST_PUT=0;
	private static final int TEST_GET=1;

	public static void main(String[] args) {
//		testPickPutTimeout();
		testThroughput();
//		testPick();
//		testQueueing();
	}
	
	private static void testThroughput() {
		
		//
		// T threads get and put on 1 conversation alternately N times
		//
		
		MultiQueuedBlockingMap map = new MultiQueuedBlockingMap();
		
		int t = 4;
		int n = 50000;
		
		ArrayList list = new ArrayList();
		
		for (int i = 0; i < t; i++) {
			Thread tput = new TestThroughputThread(map,i%2,n);
			list.add(tput);
		}
		
		long T = System.currentTimeMillis();
		
		for (int i = 0; i < t; i++) {
			Thread tput = (Thread)list.get(i);
			tput.start();
		}
		
		for (int i = 0; i < t; i++) {
			Thread tput = (Thread)list.get(i);
			try {
				tput.join();
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
		
		T = System.currentTimeMillis()-T;
		
		int count = t * n / 2;
		
		System.out.println(T+"ms to do "+count+" put/get pairs");
		System.out.println((count/(((double)T)/1000.0d)) + "pairs per second");
		
	}
	
	private static class TestThroughputThread extends Thread {
		int swch;
		int count;
		MultiQueuedBlockingMap map;
		public TestThroughputThread(MultiQueuedBlockingMap map, int i, int n) {
			swch = i;
			count = n;
			this.map = map;
			System.out.println("Throughput thread created "+swch);
		}
		public void run() {
			String[] conversations = new String[]{"CONV"};
			
			try {
				for (int i = 0; i < count; i++) {
					if (swch == 1) {
//						System.out.println("GET");
						map.get(conversations,0);
					} else if (swch == 0) {
//						System.out.println("PUT");
						map.put(conversations[0],new Message());
					}
//					swch = 1-swch;
				}
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
	}
	
	private static void testPickPutTimeout() {
		//
		// T1 threads get() and then put() on C conversations
		// T2 threads putAndGet() on C conversations (timeouts are looped)
		//
		
		int t1 = 4;
		int t2 = 1000;
		int c = 8;
		
		ArrayList list = new ArrayList();
		
		MultiQueuedBlockingMap map = new MultiQueuedBlockingMap();
		
		String[] conversations = new String[c];
		for (int i = 0; i < c; i++) {
			conversations[i] = "CONVERSATION_"+i;
		}
		
		for (int i = 0; i < t1; i++) {
			Thread tmp = new TestPickPutThread(map,TEST_GET,t2/t1,conversations);
			list.add(tmp);
		}
		
		for (int i = 0; i < t2; i++) {
			Thread tmp = new TestPickPutThread(map,TEST_PUT,t2,conversations);
			list.add(tmp);
		}
		
		for (int i = 0; i < list.size(); i++) {
			Thread tmp = (Thread)list.get(i);
			tmp.start();
		}

		for (int i = 0; i < list.size(); i++) {
			Thread tmp = (Thread)list.get(i);
			try {
				tmp.join();
			} catch (Exception e) {
				e.printStackTrace();
			}
		}

		long total = 0;
		
		for (int i = 0; i < list.size(); i++) {
			TestPickPutThread tmp = (TestPickPutThread)list.get(i);
			
			total += tmp.total_timeouts;
			
			if (tmp.notes != null) {
				System.out.println(tmp.notes);
			}
		}
		
		System.out.println("Average timeouts = "+(((double)total) / ((double)t2)));

		System.out.println(map.pool.size()+" entries remaining");
		
		System.out.println("FINISHED");
		
	}
	
	private static class TestPickPutThread extends Thread {
		int type = 0;
		int threads;
		String[] conversations;
		MultiQueuedBlockingMap map;
		
		long total_timeouts = 0;
		String notes;
		
		public TestPickPutThread(MultiQueuedBlockingMap map, int type, int threads, String[] conversations) {
			this.type = type;
			this.conversations = conversations;
			this.threads = threads;
			this.map = map;
		}
		public void run() {
			try {
			
				if (type == TEST_GET) {
					for (int i = 0; i < threads * conversations.length; i++) {
						try {
							System.out.println("PICKING "+(i+1)+" of "+(threads * conversations.length)+"...");
							Message m = map.get(conversations,0);
							
							Thread.sleep(100);
							
							String conv = (String)m.pop();
							
							map.put(">>"+conv,m);
//							System.out.println("PICKED "+m);
						} catch (InterruptedException x) {
//							System.out.println("TIMEOUT");
							i--;
							continue;
						}
					}
				} else if (type == TEST_PUT) {
					for (int i = 0; i < conversations.length; i++) {
//						System.out.println("PUT "+i);
						
						try {
							map.putAndGet(conversations[i],new Message(i),new String[]{">>"+conversations[i]},1000);
						} catch (InterruptedException x) {
							total_timeouts++;
							i--;
							continue;
						}
						
//						map.put(conversations[i],new Message(i));
					}
					
					if (total_timeouts > 0) {
						notes = "PutAndGet timed out "+total_timeouts+" times";
					}
				}
				
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
	}
	
		
	private static void testPick() {
		//
		// T1 threads pick on C conversations (including ignored timeouts)
		// T2 threads post on C conversations
		//
		
		int t1 = 4;
		int t2 = 1000;
		int c = 5;
		
		ArrayList list = new ArrayList();
		
		MultiQueuedBlockingMap map = new MultiQueuedBlockingMap();
		
		String[] conversations = new String[c];
		for (int i = 0; i < c; i++) {
			conversations[i] = "CONVERSATION_"+i;
		}
		
		for (int i = 0; i < t1; i++) {
			Thread tmp = new TestPickThread(map,TEST_GET,t2/t1,conversations);
			list.add(tmp);
		}
		
		for (int i = 0; i < t2; i++) {
			Thread tmp = new TestPickThread(map,TEST_PUT,t2,conversations);
			list.add(tmp);
		}
		
		for (int i = 0; i < list.size(); i++) {
			Thread tmp = (Thread)list.get(i);
			tmp.start();
		}

		for (int i = 0; i < list.size(); i++) {
			Thread tmp = (Thread)list.get(i);
			try {
				tmp.join();
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
		
		System.out.println(map.pool.size()+" entries remaining");
		
		System.out.println("FINISHED");
		
	}
	
	private static class TestPickThread extends Thread {
		int type = 0;
		int threads;
		String[] conversations;
		MultiQueuedBlockingMap map;
		public TestPickThread(MultiQueuedBlockingMap map, int type, int threads, String[] conversations) {
			this.type = type;
			this.conversations = conversations;
			this.threads = threads;
			this.map = map;
		}
		public void run() {
			try {
			
				if (type == TEST_GET) {
					for (int i = 0; i < threads * conversations.length; i++) {
						try {
							System.out.println("PICKING "+(i+1)+" of "+(threads * conversations.length)+"...");
							Message m = map.get(conversations,0);
//							System.out.println("PICKED "+m);
						} catch (InterruptedException x) {
//							System.out.println("TIMEOUT");
							i--;
							continue;
						}
					}
				} else if (type == TEST_PUT) {
					for (int i = 0; i < conversations.length; i++) {
//						System.out.println("PUT "+i);
						map.put(conversations[i],new Message(i));
					}
				}
				
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
	}
	
	private static void testQueueing() {
		MultiQueuedBlockingMap timeoutmap = new MultiQueuedBlockingMap();

		MultiQueuedBlockingMap map = new MultiQueuedBlockingMap();
		
		//start N pairs of threads reading and writing
		//run each thread N times
		int NTHREADS = 200;
		int NRUNS = 5000;
		
		for (int i = 0; i < NTHREADS; i++) {
		
//			String key = "key"+(i%20);
			String key = "key"+(i%NTHREADS);
			Message value = new Message();
			value.append("value"+(i%NTHREADS));
//			value.append("value"+(i%20));
			
			try {
				timeoutmap.get(new String[]{"MUMBATO"},100);
				System.out.println("Fetched ok "+i+" (map size = "+timeoutmap.pool.size()+")");
			} catch (Exception e) {
				System.out.println("Timed out ok "+i+" (map size = "+timeoutmap.pool.size()+")");
				try {
					Message m = new Message();
					m.append("MUMBATO");
					timeoutmap.put("MUMBATO",m);
				} catch (Exception x) {
					x.printStackTrace();
				}
			}
		
//			try {
//				Thread.sleep((int)(Math.random()*1));
//			} catch (Exception e) {
//			}
			
			
			Thread t1 = new TestThread(NRUNS,map,TEST_GET,key,value);
			Thread t2 = new TestThread(NRUNS,map,TEST_PUT,key,value);
			
			t1.start();
			t2.start();
		}
	}

	private static class TestThread extends Thread {
		int type;
		String key;
		Message value;
		MultiQueuedBlockingMap map;
		int count;
		public TestThread(int count, MultiQueuedBlockingMap map, int type, String key, Message value) {
			this.type = type;
			this.key = key;
			this.value = value;
			this.map = map;
			this.count = count;
		}
		public void run() {
			for (int i = 0; i < count; i++) {
				try {
//					Thread.sleep((int)(Math.random()*50));
//					Thread.sleep((int)(Math.random()*1));
			
					if (type == TEST_PUT) {
//						System.out.println("PUT "+key+" -> "+value);
						Message m = (Message)value.clone();
						m.setType(i);
						map.put(key,m);
//						System.out.println("PUT RETURN "+key+" -> "+value);
					} else if (type == TEST_GET) {
//						System.out.println("GET "+key+" -> "+value);
						
						Message val = null;
						try {
							val = map.get(new String[]{key},1);
						} catch (Exception x) {
							System.out.println("TIMEOUT");
							i--;
							continue;
						}
						
						val.pop();
						
						value.setType(i);
						
//						System.out.println(map.get(key)+" == "+value);
						if (!val.toString().equals(value.toString())) {
							System.out.println("MISMATCH! "+value+" != "+val);
						}
					}
				} catch (Throwable t) {
					t.printStackTrace();
				}
			}
		}
	}*/
}
