package org.eclipse.stp.b2j.core.jengine.internal.portgen;

import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;

public class WsdlGenerator implements Comparator {
	
	ClassLoader cl;
	Class clazz;
	
	ByteArrayOutputStream bout;
	PrintStream pout;

	ByteArrayOutputStream bxout;
	PrintStream xout;
	
	ArrayList nss = new ArrayList();
	
	HashMap xsdMap = new HashMap();
	ArrayList xsdTypes = new ArrayList();

	String wsdl;
	
	int nsInsert;
	
	public WsdlGenerator(String classname, ArrayList allMethods) throws ClassNotFoundException {
		cl = ClassLoader.getSystemClassLoader();
		init(classname,allMethods);
	}
	public WsdlGenerator(String classname, ArrayList allMethods, ClassLoader cl) throws ClassNotFoundException {
		this.cl = cl;
		init(classname,allMethods);
	}
	public WsdlGenerator(String classname, ArrayList allMethods, URL[] jarurls) throws ClassNotFoundException {
		cl = new URLClassLoader(jarurls,Thread.currentThread().getContextClassLoader());
		init(classname,allMethods);
	}
	
	public String getWsdl() {
		return wsdl;
	}
	
	public void init(String classname, ArrayList allMethods) throws ClassNotFoundException {
		try {
			clazz = WsdlGenerator.class.getClassLoader().loadClass(classname);
		} catch (Exception e) {
			clazz = cl.loadClass(classname);
		}

		String namespace = "http://"+clazz.getPackage().getName().replace('.', '_');

		bout = new ByteArrayOutputStream();
		pout = new PrintStream(bout);

		bxout = new ByteArrayOutputStream();
		xout = new PrintStream(bxout);
		
		pout.println("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
		pout.println("<definitions xmlns=\"http://schemas.xmlsoap.org/wsdl/\""); 
		pout.println("			targetNamespace=\""+namespace+"\""); 
		pout.println("			xmlns:tns=\""+namespace+"\""); 
		pout.println("			xmlns:xsd=\"http://www.w3.org/2001/XMLSchema\""); 
		pout.println("			xmlns:wsdl=\"http://schemas.xmlsoap.org/wsdl/\"");
		pout.println("			xmlns:engine=\"http://www.eclipse.org/stp/b2j/2006/02\""); 
		pout.println("			xmlns:format=\"http://schemas.xmlsoap.org/wsdl/formatbinding/\""); 
		pout.println("			xmlns:java=\"http://schemas.xmlsoap.org/wsdl/java/\"");
		nsInsert = bout.size();
		pout.println(">");

		{
			//HEAD the XSD file
			xout.println();
			xout.println("<!-- XSD types -->");
			xout.println("<types>");
			xout.println("	<schema xmlns=\"http://www.w3.org/2001/XMLSchema\"");
			xout.println("			targetNamespace=\""+namespace+"\""); 
			xout.println("			xmlns:tns=\""+namespace+"\""); 
			xout.println("			xmlns:xsd=\"http://www.w3.org/2001/XMLSchema\""); 
			xout.println("			xmlns:wsdl=\"http://schemas.xmlsoap.org/wsdl/\"");
			xout.println("			xmlns:engine=\"http://www.eclipse.org/stp/b2j/2006/02\""); 
			xout.println("			xmlns:format=\"http://schemas.xmlsoap.org/wsdl/formatbinding/\""); 
			xout.println("			xmlns:java=\"http://schemas.xmlsoap.org/wsdl/java/\"");
			xout.println("	>");
		}
		
		String portname = clazz.getName();//.replace('.', '_');
		portname = portname.substring(portname.lastIndexOf('.')+1);
		
		Method[] methods = clazz.getDeclaredMethods();
		
		ArrayList opnames = new ArrayList();
		
		methods = filterMethods(methods,allMethods);
		
		for (int i = 0; i < methods.length; i++) {
			Method m = methods[i];
			
			String opName = m.getName();
			
			opnames.add(opName);
			
			String reqName = opName+"Req";
			String resName = opName+"Res";
			
			pout.println();
			pout.println("<!-- "+portname+"."+opName+" -->");
			pout.println("<message name=\""+reqName+"\">");
			Class[] params = m.getParameterTypes();

			//shame this isnt in < 1.5
			//			Annotation[][] anns = m.getParameterAnnotations();
			
			//TODO could make part names based on classes?
			for (int k = 0; k < params.length; k++) {
				String partName = "part"+k;
				String partType = writeXsd(params[k]);
				pout.println("	<part name=\""+partName+"\" type=\""+partType+"\" />");
			}
			pout.println("</message>");
			
			pout.println("<message name=\""+resName+"\">");
			Class ret = m.getReturnType();
			{
				if (!ret.equals(void.class)) {
					String partName = "ret";
					String partType = writeXsd(ret);
					pout.println("	<part name=\""+partName+"\" type=\""+partType+"\" />");
				}
			}
			pout.println("</message>");
		}

		{
			//TAIL the XSD file
			xout.println("	</schema>");
			xout.println("</types>");
		}
		
		pout.println(new String(bxout.toByteArray()));

		pout.println("<portType name=\""+portname+"PortType\">");
		for (int i = 0; i < opnames.size(); i++) {
			String opname = (String)opnames.get(i);
			pout.println("	<operation name=\""+opname+"\">");
			pout.println("		<input message=\"tns:"+opname+"Req\"/>");
			pout.println("		<output message=\"tns:"+opname+"Res\"/>");
 			pout.println("	</operation>");
		}
		pout.println("</portType>");
		
		pout.println("");
		pout.println("<binding name=\""+portname+"JavaBinding\" type=\"tns:"+portname+"PortType\">");
		pout.println("	<java:binding/>");
		pout.println("	<format:typeMapping encoding=\"Java\" style=\"Java\">");
		pout.println("		<format:typeMap typeName=\"xsd:string\" formatType=\"java.lang.String\"/>");
		pout.println("		<format:typeMap typeName=\"xsd:dateTime\" formatType=\"java.util.GregorianCalendar\"/>");
		pout.println("		<format:typeMap typeName=\"xsd:double\" formatType=\"double\"/>");
		pout.println("		<format:typeMap typeName=\"xsd:float\" formatType=\"float\"/>");
		pout.println("		<format:typeMap typeName=\"xsd:long\" formatType=\"long\"/>");
		pout.println("		<format:typeMap typeName=\"xsd:int\" formatType=\"int\"/>");
		pout.println("		<format:typeMap typeName=\"xsd:boolean\" formatType=\"java.lang.Boolean\"/>");
		for (int i = 0; i < xsdTypes.size(); i++) {
			pout.println((String)xsdTypes.get(i));
		}
		pout.println("	</format:typeMapping>");
		
		for (int i = 0; i < methods.length; i++) {
			String name = methods[i].getName();
			int partCount = methods[i].getParameterTypes().length;
			boolean hasReturn = !methods[i].getReturnType().equals(void.class);
			
			pout.println("		<operation name=\""+name+"\">");
			pout.print("			<java:operation methodName=\""+name+"\"");

			if (partCount > 0) {
				pout.print(" parameterOrder=\"");
				for (int k = 0; k < partCount; k++) {
					if (k > 0) {
						pout.print(" ");
					}
					pout.print("part"+k);
				}
				pout.print("\"");
			}
			
			if (hasReturn) {
				pout.print(" returnPart=\"ret\"");
			}
			pout.println("/>");
			
			pout.println("		</operation>");
		}
		
		pout.println("</binding>");
		
		pout.println("");
		pout.println("<service>");
		pout.println("	<port name=\""+portname+"JavaService\" binding=\"tns:"+portname+"JavaBinding\">");
		pout.println("		<java:address className=\""+classname+"\"/>");
		pout.println("	</port>");
		pout.println("</service>");
    
		pout.println("");
		pout.println("<partnerLinkType name=\""+portname+"PartnerLink\">");
		pout.println("	<role name=\""+portname+"JavaService\" portType=\"tns:"+portname+"PortType\"/>");
		pout.println("</partnerLinkType>");
		
		pout.println("");
		pout.println("</definitions>");
		
		wsdl = new String(bout.toByteArray());
		
		StringBuffer sb = new StringBuffer();
		sb.append(wsdl.substring(0,nsInsert));
		for (int i = 0; i < nss.size(); i++) {
			sb.append(nss.get(i));
		}
		sb.append(wsdl.substring(nsInsert));
		
		wsdl = sb.toString();
	}
	
	private Method[] filterMethods(Method[] methods, ArrayList all) {
		//ignore overloaded methods - choose the one with the most parameters
		ArrayList list = new ArrayList();

		for (int i = 0; i < methods.length; i++) {
			if (all.size() == 0 || all.contains(methods[i].getName())) {
				list.add(methods[i]);
			}
		}
		Collections.sort(list,this);
		
		for (int i = 0; i < list.size(); i++) {
			if (i > 0) {
				Method m1 = (Method)list.get(i-1);
				Method m2 = (Method)list.get(i);
				if (m1.getName().equals(m2.getName())) {
					list.remove(i--);
				}
			}
		}
		
		Method[] tmp = new Method[list.size()];
		list.toArray(tmp);
		
		return tmp;
	}
	
	HashMap namespaces = new HashMap();
	int NS = 0;
	
	private String writeNamespace(String ns) {
		Integer i = (Integer)namespaces.get(ns);
		if (i == null) {
			i = new Integer(NS++);
			namespaces.put(ns,i);
			
			nss.add("			xmlns:ns"+i+"=\""+ns+"\"\n");
			
		}
		return "ns"+i;
	}

	private String getRefFor(Class c) {
		String fullclassname = c.getName();
		String classname = c.getName();
		classname = classname.substring(classname.lastIndexOf('.')+1);

		//existing XSD mappings
		if (fullclassname.equals("int")
			|| fullclassname.equals("boolean")
			|| fullclassname.equals("float")
			|| fullclassname.equals("double")
			|| fullclassname.equals("long")) {
			return "xsd:"+classname;
		} else if (fullclassname.equals("java.lang.String")) {
			return "xsd:string";
		} else if (fullclassname.equals("java.util.GregorianCalendar")) {
			return "xsd:dateTime";
		}
		
		String namespace;
		if (c.getPackage() == null) {
			namespace = "http://java_empty_namespace";
		} else {
			namespace = "http://"+c.getPackage().getName().replace('.','_');
		}
		String nsref = writeNamespace(namespace);
		
		String ref = nsref+":"+classname;
		
		return ref;
	}
	
	private String writeXsd(Class c) {
		String fullclassname = c.getName();
		String classname = c.getName();
		classname = classname.substring(classname.lastIndexOf('.')+1);

		//existing XSD mappings
		if (fullclassname.equals("int")
			|| fullclassname.equals("boolean")
			|| fullclassname.equals("float")
			|| fullclassname.equals("double")
			|| fullclassname.equals("long")) {
			return "xsd:"+classname;
		} else if (fullclassname.equals("java.lang.String")) {
			return "xsd:string";
		} else if (fullclassname.equals("java.util.GregorianCalendar")) {
			return "xsd:dateTime";
		}
		
		String namespace;
		if (c.getPackage() == null) {
			namespace = "http://java_empty_namespace";
		} else {
			namespace = "http://"+c.getPackage().getName().replace('.','_');
		}
		String nsref = writeNamespace(namespace);
		
		String ref = nsref+":"+classname;
		
		//dont convert types already converted earlier
		if (xsdMap.get(fullclassname) == null) {
			//new type, write it

			xsdMap.put(fullclassname,fullclassname);
			xsdTypes.add("		<format:typeMap typeName=\""+ref+"\" formatType=\""+fullclassname+"\"/>");
			
			
			xout.println();
			xout.println("		<!-- conversion of class "+c.getName()+" -->");

			xout.println("		<complexType name=\""+classname+"\">");
			xout.println("			<sequence>");
			
			//add fields
			Field[] fields = c.getDeclaredFields();
			for (int i = 0; i < fields.length; i++) {
				Field field = fields[i];
				
				if (isInstanceField(field)) {
					
					String fieldName = field.getName();
					Class fieldClazz = field.getType();
	
					String fieldType = getRefFor(fieldClazz);
					
					xout.println("				<element name=\""+fieldName+"\" type=\""+fieldType+"\"/>");
				}
			}

			xout.println("			</sequence>");
			xout.println("		</complexType>");
			
			for (int i = 0; i < fields.length; i++) {
				Field field = fields[i];

				if (isInstanceField(field)) {
					Class fieldClazz = field.getType();
					writeXsd(fieldClazz);
				}
			}
		}
		
		return ref;
	}
	
	private static boolean isInstanceField(Field field) {
		
		try {
			field.setAccessible(true);
			field.get(null);
		} catch (NullPointerException e) {
			return true;
		} catch (Exception e) {
			e.printStackTrace();
		}
		return false;
	}

	public static void main(String[] args) throws Exception {
		if (args.length < 1) {
			System.out.println("Usage: WsdlGenerator <classname> [methodfile]");
			System.out.println("Note: the JVM classpath should include the class you wish to generate WSDL for");
			System.out.println("    methodfile      a file containing a list of methods to convert");
			System.out.println("                    no methodfile means convert all public methods");
		} else {
			WsdlGenerator gen;
			ArrayList allMethods = new ArrayList();
			if (args.length > 1) {
				FileInputStream fin = new FileInputStream(args[1]);
				BufferedReader bread = new BufferedReader(new InputStreamReader(fin));
				
				String line = bread.readLine();
				while (line != null) {
					line = line.trim();
					if (line.length() > 0) {
						allMethods.add(line);
					}
					
					line = bread.readLine();
				}
			}
			gen = new WsdlGenerator(args[0],allMethods);
			System.out.println(gen.getWsdl());
		}
	}
	public int compare(Object arg0, Object arg1) {
		Method m1 = (Method)arg0;
		Method m2 = (Method)arg1;
		return m1.getParameterTypes().length - m2.getParameterTypes().length;
	}
}