#!/usr/bin/env python
#
# Copyright (c) 2005 The Regents of the University of California. 
# This material was produced under U.S. Government contract W-7405-ENG-36 
# for Los Alamos National Laboratory, which is operated by the University 
# of California for the U.S. Department of Energy. The U.S. Government has 
# rights to use, reproduce, and distribute this software. NEITHER THE 
# GOVERNMENT NOR THE UNIVERSITY MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR 
# ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is modified 
# to produce derivative works, such modified software should be clearly marked, 
# so as not to confuse it with the version available from LANL. LA-CC 04-115
# 
# Additionally, this program and the accompanying materials 
# are made available under the terms of the Eclipse Public License v1.0
# which accompanies this distribution, and is available at
# http://www.eclipse.org/legal/epl-v10.html
# 

"""
usage: ptp_mpich2_proxy [[--host=hostname] --port=port]
"""

from time import ctime
__author__ = "Greg Watson"
__date__ = ctime()
__version__ = "2.0"
__credits__ = ""

import sys, os
import signal, socket, binascii, pwd, grp
import imp

from  sets     import  Set
from  re       import  sub
from  urllib   import  unquote

from  ptplib   import  PTPProxy, ptp_print, \
                       CMD_QUIT, CMD_INIT, CMD_MODELDEF, CMD_STARTEVENTS, \
                       CMD_STOPEVENTS, CMD_SUBMITJOB, CMD_TERMINATEJOB, \
                       ELEMENT_ID_ATTR, ELEMENT_NAME_ATTR, \
                       MACHINE_STATE_UP, \
                       NODE_NUMBER_ATTR, NODE_STATE_ATTR, NODE_STATE_UP, \
                       QUEUE_STATE_NORMAL, \
                       JOB_STATE_INIT, JOB_STATE_RUNNING, JOB_STATE_ERROR, \
                       JOB_STATE_TERMINATED, JOB_SUB_ID_ATTR, JOB_EXEC_NAME_ATTR, \
                       JOB_EXEC_PATH_ATTR, JOB_NUM_PROCS_ATTR, JOB_WORKING_DIR_ATTR, \
                       JOB_PROG_ARGS_ATTR, JOB_ENV_ATTR, JOB_DEBUG_EXEC_NAME_ATTR, \
                       JOB_DEBUG_EXEC_PATH_ATTR, JOB_DEBUG_ARGS_ATTR, JOB_DEBUG_FLAG_ATTR, \
                       JOB_ID_ATTR, \
                       PROC_STATE_STARTING, PROC_STATE_RUNNING, PROC_STATE_ERROR, \
                       PROTOCOL_VERSION_ATTR, PROTOCOL_VERSION, \
                       MSG_LEVEL_FATAL, MSG_LEVEL_ERROR, MSG_LEVEL_WARNING, \
                       BASE_ID_ATTR

                       
ptp_print(os.environ['PATH'])
#
# Try and locate MPICH2 installation
#
try:
    fp, path, desc = imp.find_module('mpdlib', os.environ['PATH'].split(':'))
    imp.load_module('mpdlib', fp, path, desc)
except ImportError:
    ptp_print('Could not locate MPICH2 installation. Please check your PATH.')
    sys.exit(1)

from  mpdlib   import  mpd_set_my_id, mpd_get_my_id, mpd_uncaught_except_tb, \
                       mpd_handle_signal, mpd_get_my_username, mpd_version, \
                       MPDSock, MPDParmDB, MPDListenSock, \
                       MPDStreamHandler, mpd_get_my_username, \
                       mpd_get_groups_for_username
                     
global exit, parmdb, myAddr, myIP, cliMode, jobProcCount, ptpJobIDToJobID
global sigOccurred, streamHandler, manSocks, proxy, debug

global lastID, rmID, machineID, queueID, nodeIDs, gTransID, procIDs

recvTimeout = 20  # const
jobProcCount = {}
ptpJobIDToJobID = {}
procIDs = {}
nodeIDs = {}
proxy = None
debug = 0
lastID = 0

def ptp_mpd_proxy():
    global exit, parmdb, myAddr, myIP, streamHandler
    global manSocks, jobProcCount, proxy, debug

    sys.excepthook = mpd_uncaught_except_tb

    ptpHost = 'localhost'
    ptpPort = 0
    sigOccurred = 0
    manSocks = []
    exit = 0
    debug = 0

    ptpProxyCmds = {
	CMD_QUIT         : finish,
	CMD_INIT         : initialize,
	CMD_MODELDEF     : model_def,
    CMD_SUBMITJOB    : submit_job,
	CMD_TERMINATEJOB : terminate_job,
    CMD_STARTEVENTS  : start_events,
    CMD_STOPEVENTS   : stop_events,
    }

    if len(sys.argv) > 1:
        if (sys.argv[1] == '-h' or sys.argv[1] == '--help'):
            usage()
      
    arg = 1
    while arg < len(sys.argv):
        args = sys.argv[arg].split('=')
        if args[0] == '--host':
        	if len(args) < 2:
        	    usage()
        	ptpHost = args[1]
        elif args[0] == '--port':
            if len(args) < 2 or not args[1].isdigit():
                usage()
            ptpPort = int(args[1])
            ptp_print('ptpport = %d' % ptpPort)
        elif args[0] == '--debug':
            debug = 1
        elif args[0] == '--proxy':
            pass
        else:
        	usage()
        arg += 1

    streamHandler = MPDStreamHandler()

    #
    # Connect to PTP
    #
    if ptpPort != 0:
        ptp_print('about to connect')
        proxy = PTPProxy(ptpHost, ptpPort, ptpProxyCmds, debug=debug)
    	try:
    	    proxy.connect()
    	except:
    	    ptp_print('ptp_mpich2_proxy: could not connect to %s:%d' % (ptpHost,ptpPort))
    	    sys.exit(-1)
    	streamHandler.set_handler(proxy.getsocket(),proxy.read_command,args=())

    #
    # Set up MPD stuff
    #
    if hasattr(signal,'SIGINT'):
        signal.signal(signal.SIGINT, sig_handler)
    if hasattr(signal,'SIGALRM'):
        signal.signal(signal.SIGALRM,sig_handler)

    mpd_set_my_id(myid='ptp_mpich2_proxy')

    myAddr = socket.gethostname();
    try:
    	hostinfo = socket.gethostbyname_ex(myAddr)
    except:
    	ptp_print('ptp_mpich2_proxy failed: gethostbyname_ex failed for %s' % (myAddr))
    	sys.exit(-1)
    myIP = hostinfo[2][0]

    parmdb = MPDParmDB(orderedSources=['cmdline','xml','env','rcfile','thispgm'])
    parmsToOverride = {
                        'MPD_USE_ROOT_MPD'            :  0,
                        'MPD_SECRETWORD'              :  '',
                      }
    for (k,v) in parmsToOverride.items():
        parmdb[('thispgm',k)] = v
    parmdb.get_parms_from_env(parmsToOverride)
    parmdb.get_parms_from_rcfile(parmsToOverride)

    parmdb[('thispgm','mship')] = ''
    parmdb[('thispgm','rship')] = ''
    parmdb[('thispgm','userpgm')] = ''
    parmdb[('thispgm','nprocs')] = 0
    parmdb[('thispgm','ecfn_format')] = ''
    parmdb[('thispgm','gdb_attach_jobid')] = ''
    parmdb[('thispgm','singinitpid')] = 0
    parmdb[('thispgm','singinitport')] = 0
    parmdb[('thispgm','ignore_rcfile')] = 0
    parmdb[('thispgm','ignore_environ')] = 0
    parmdb[('thispgm','inXmlFilename')] = ''
    parmdb[('thispgm','print_parmdb_all')] = 0
    parmdb[('thispgm','print_parmdb_def')] = 0

    if ptpPort == 0:
    	sys.stdout.write('>>> ')
    	sys.stdout.flush()
    	streamHandler.set_handler(sys.stdin,do_input,args=())

    #
    # Main Loop
    #

    while not exit:
    	if sigOccurred:
    	    for sock in manSocks:
        		handle_sig_occurred(sock)
    	rv = streamHandler.handle_active_streams(timeout=0.1)
    	if rv[0] < 0:  # will handle some sigs at top of next loop
    	    pass       # may have to handle some err conditions here
        if proxy.getResult() < 0:
            proxy.close()
            exit = 1

def do_input(fd):
    global exit
    line = fd.readline()
    if line == '':
        exit = 1
        return
    if line == '\n':
        sys.stdout.write('>>> ')
        sys.stdout.flush()
        return
    res = ()
    line = line.rstrip(' \n').split(' ')
    if line[0] == 'exit' or line[0] == 'quit' or line[0] == 'q':
    	exit = 1
    elif line[0] == 'getnodes':
    	print get_nodes()
    elif line[0] == 'listjobs':
    	res = listjobs([])
    elif line[0] == 'listprocs':
    	if len(line) < 2:
    	    ptp_print('listprocs <jobid>')
    	else:
    	    res = listprocs([line[1]])
    elif line[0] == 'procattrs':
    	if len(line) < 2:
    	    ptp_print('procattrs <jobid>')
    	else:
    	    res = procattrs([line[1]])
    elif line[0] == 'kill':
    	if len(line) < 2:
    	    ptp_print('kill <jobid>')
    	else:
    	    res = kill([line[1]])
    elif line[0] == 'run':
    	if len(line) < 3:
    	    ptp_print('run <nprocs> <cmd> [<args>]')
    	args = []
    	args.append('execName')
    	args.append(line[2])
    	args.append('numOfProcs')
    	args.append(line[1])
    	args.append('workingDir')
    	args.append(os.path.abspath(os.getcwd()))
    	pos = 3
    	while pos < len(line):
    	    args.append('progArg')
    	    args.append(line[pos])
    	    pos += 1
    	res = run(args)
    else:
    	ptp_print('Unknown command: %s' % line[0])
    
    sys.stdout.write('>>> ')
    sys.stdout.flush()

#
# Open a connection to the mpd console
#
def open_mpd_console():
    global parmdb
    if (hasattr(os,'getuid')  and  os.getuid() == 0)  or  parmdb['MPD_USE_ROOT_MPD']:
        fullDirName = os.path.abspath(os.path.split(sys.argv[0])[0])  # normalize
        mpdroot = os.path.join(fullDirName,'mpdroot')
        conSock = PTPConClientSock(mpdroot=mpdroot,secretword=parmdb['MPD_SECRETWORD'])
    else:
        conSock = PTPConClientSock(secretword=parmdb['MPD_SECRETWORD'])
    if conSock.connect():
        return conSock
    return 0

#
# Create a jobid string
#
def make_jobid(str):
    smjobid = str.split('  ')  # jobnum, mpdid, and alias (if present)
    return smjobid[0] + '@' + smjobid[1]

#
# Generate a new ID
#
def generate_id():
    global lastID
    lastID += 1
    return lastID 
#
# Query mpd for node information
#
def get_nodes(proxy):
    global nodeIDs
    conSock = open_mpd_console()
    if conSock == 0:
        proxy.send_message_event(tid, MSG_LEVEL_ERROR, 0, 'Could not connect to MPD')
    	return []
    msgToSend = { 'cmd' : 'mpdtrace' }
    conSock.send_dict_msg(msgToSend)
    attrs = []
    node_number = 0
    done = 0
    while not done:
    	msg = conSock.recv_dict_msg(timeout=5.0)
    	if not msg:    # also get this on ^C
            proxy.send_message_event(tid, MSG_LEVEL_ERROR, 0, 'Did not receive message from MPD')
    	    conSock.close()
    	    return []
    	elif not msg.has_key('cmd'):
            proxy.send_message_event(tid, MSG_LEVEL_ERROR, 0, 'Invalid message from MPD')
    	    conSock.close()
            return []
    	if msg['cmd'] == 'mpdtrace_info':
    	    pos = msg['id'].rfind('_')
            node_name = str(msg['id'][:pos])
            node_id = generate_id()
            nodeIDs[node_name] = node_id
            node_attrs = [node_id, 3]
            node_attrs.append('%s=%s' % (ELEMENT_NAME_ATTR, node_name))
            node_attrs.append('%s=%d' % (NODE_NUMBER_ATTR, node_number))
            node_attrs.append('%s=%s' % (NODE_STATE_ATTR, NODE_STATE_UP))
            node_number += 1
            attrs += node_attrs
    	elif msg['cmd'] == 'mpdtrace_trailer':
    	    done = 1
    conSock.close()
    return [node_number] + attrs

#
# Terminate proxy
#
def finish(tid, attrs, proxy):
    global exit, gTransID
    exit = 1
    if gTransID >= 0:
        proxy.send_ok_event(gTransID) 
    proxy.send_shutdown_event(tid)
    proxy.close()
    return 0

#
# CMD_INIT
#
def initialize(tid, attrs, proxy):
    global rmID, lastID, gTransID
    if PROTOCOL_VERSION_ATTR not in attrs:
        proxy.send_error_event(tid, 0, 'no protocol version supplied')
        return -1
    version = attrs[PROTOCOL_VERSION_ATTR]
    if version != PROTOCOL_VERSION:
        proxy.send_error_event(tid, 0, 'wire protocol \'%s\' not supported' % version)
        return -1
    if BASE_ID_ATTR not in attrs:
        proxy.send_error_event(tid, 0, 'no base ID supplied')
        return -1
    rmID = int(attrs[BASE_ID_ATTR])
    lastID = int(rmID)
    gTransID = -1
    #
    # Test if mpd is running
    #
    conSock = open_mpd_console()
    if conSock == 0:
        proxy.send_error_event(tid, 0, 'Could not connect to MPD. Check it is running.')
        return -1
    proxy.send_ok_event(tid)
    return 0

#
# CMD_MODELDEF
#
def model_def(tid, attrs, proxy):
    proxy.send_ok_event(tid)
    return 0

#
# CMD_STARTEVENTS
#
def start_events(tid, attrs, proxy):
    global machineID, queueID, rmID, myAddr, gTransID
    gTransID = tid
    machineID = generate_id()
    proxy.send_new_machine_event(tid, rmID, machineID, myAddr, MACHINE_STATE_UP)
    attrs = get_nodes(proxy)
    if len(attrs) > 0:
        proxy.send_new_node_event(tid, machineID, attrs)
    queueID = generate_id()
    proxy.send_new_queue_event(tid, rmID, queueID, "default", QUEUE_STATE_NORMAL)
    return 0

#
# CMD_STOPEVENTS
#
def stop_events(tid, attrs, proxy):
    global gTransID
    if gTransID >= 0:
        proxy.send_ok_event(gTransID)
        gTransID = -1
    proxy.send_ok_event(tid)
    return 0;

#
# Deprecated
#
def listjobs(args):
    conSock = open_mpd_console()
    if conSock == 0:
    	return (RTEV_ERROR_BASE, 'Could not connect to MPD')
    msgToSend = { 'cmd' : 'mpdlistjobs' }
    conSock.send_dict_msg(msgToSend)
    msg = conSock.recv_dict_msg(timeout=5.0)
    if not msg:
    	conSock.close()
    	return (RTEV_ERROR_BASE, 'No message recvd from MPD before timeout')
    if msg['cmd'] != 'local_mpdid':     # get full id of local mpd for filters later
    	conSock.close()
    	return (RTEV_ERROR_BASE, 'Did not recv local_mpdid msg from local mpd; instead, recvd: %s' % msg)
    jobids = []
    done = 0
    while not done:
        msg = conSock.recv_dict_msg()
        if not msg.has_key('cmd'):
    	    conSock.close()
            return (RTEV_ERROR_BASE, 'Invalid message: %s:' % (msg))
        if msg['cmd'] == 'mpdlistjobs_info':
    	    jid = make_jobid(msg['jobid'])
            if jid not in jobids:
                jobids.append(jid)
        else:  # mpdlistjobs_trailer
            done = 1
    conSock.close()
    return (RTEV_OK, jobids)

#
# Deprecated
#
def listprocs(args):
    if len(args) != 1:
    	return (RTEV_ERROR_PROCS, 'Invalid arguments')
    jid = int(args[0])
    if jid >= len(ptpJobIDToJobID):
    	return (RTEV_ERROR_PROCS, 'No such job')
    jobid = ptpJobIDToJobID(jid)
    conSock = open_mpd_console()
    if conSock == 0:
    	return (RTEV_ERROR_PROCS, 'Could not connect to MPD')
    msgToSend = { 'cmd' : 'mpdlistjobs' }
    conSock.send_dict_msg(msgToSend)
    msg = conSock.recv_dict_msg(timeout=5.0)
    if not msg:
    	conSock.close()
    	return (RTEV_ERROR_PROCS, 'No message recvd from MPD before timeout')
    if msg['cmd'] != 'local_mpdid':     # get full id of local mpd for filters later
    	conSock.close()
    	return (RTEV_ERROR_PROCS, 'Did not recv local_mpdid msg from local mpd; instead, recvd: %s' % msg)
    procs = []
    done = 0
    while not done:
        msg = conSock.recv_dict_msg()
        if not msg.has_key('cmd'):
    	    conSock.close()
            return (RTEV_ERROR_PROCS, 'Invalid message: %s:' % (msg))
        if msg['cmd'] == 'mpdlistjobs_info':
    	    if make_jobid(msg['jobid']) == jobid:
        		procs.append(str(msg['rank']))
        else:  # mpdlistjobs_trailer
            done = 1
    conSock.close()
    return (RTEV_PROCS, procs)

#
# Query mpd for to obtain process information
# for the given job
#
def get_procs(jobid, proxy):
    global ptpJobIDToJobID, procIDs, gTransID
    conSock = open_mpd_console()
    if conSock == 0:
    	proxy.send_message_event(gTransID, MSG_LEVEL_ERROR, 0, 'Could not connect to MPD')
        return -1
    msgToSend = { 'cmd' : 'mpdlistjobs' }
    conSock.send_dict_msg(msgToSend)
    msg = conSock.recv_dict_msg(timeout=5.0)
    if not msg:
    	conSock.close()
    	proxy.send_message_event(gTransID, MSG_LEVEL_ERROR, 0, 'No message recvd from MPD before timeout')
        return -1
    if msg['cmd'] != 'local_mpdid':     # get full id of local mpd for filters later
    	conSock.close()
    	proxy.send_message_event(gTransID, MSG_LEVEL_ERROR, 0, 'Did not recv local_mpdid msg from local mpd; instead, recvd: %s' % msg)
        return -1
    procs = {}
    done = 0
    mpich_jobid = ptpJobIDToJobID[jobid]
    while not done:
        msg = conSock.recv_dict_msg()
        if not msg.has_key('cmd'):
    	    conSock.close()
            proxy.send_message_event(gTransID, MSG_LEVEL_ERROR, 0, 'Invalid message: %s:' % (msg))
            return -1
        if msg['cmd'] == 'mpdlistjobs_info':
    	    if make_jobid(msg['jobid']) == mpich_jobid:
                pid = generate_id()
                rank = int(msg['rank'])
                procs[rank] = pid
                attrs = []
                attrs.append('%s=%d' % (ELEMENT_NAME_ATTR, rank))
                proxy.send_new_process_event(gTransID, jobid, pid, rank, nodeIDs[msg['host']], rank, int(msg['clipid']))
        else:  # mpdlistjobs_trailer
            done = 1
    conSock.close()
    procIDs[jobid] = procs
    return 0

def terminate_job(tid, attrs, proxy):
    global ptpJobIDToJobID
    if JOB_ID_ATTR not in attrs:
        proxy.send_terminatejob_error_event(tid, 0, 'Invalid arguments to terminate job command')
        return 0
    ptp_jobid = int(attrs[JOB_ID_ATTR])
    if ptp_jobid not in ptpJobIDToJobID:
        proxy.send_terminatejob_error_event(tid, 0, 'No such job: %s' % ptp_jobid)
        return 0
    jobid = ptpJobIDToJobID[ptp_jobid]
    mpdid = ''
    sjobid = jobid.split('@')
    jobnum = sjobid[0]
    if len(sjobid) > 1:
    	mpdid = sjobid[1]

    conSock = open_mpd_console()
    if conSock == 0:
        proxy.send_message_event(gTransID, MSG_LEVEL_ERROR, 0, 'Could not connect to MPD')
        proxy.send_terminatejob_error_event(tid, 0, 'Could not contact runtime')
        return 0
    msgToSend = { 'cmd':'mpdkilljob', 'jobnum' : jobnum, 'mpdid' : mpdid,
                  'jobalias' : '', 'username' : mpd_get_my_username() }
    conSock.send_dict_msg(msgToSend)
    msg = conSock.recv_dict_msg(timeout=5.0)
    if not msg:
    	conSock.close()
        proxy.send_message_event(gTransID, MSG_LEVEL_ERROR, 0, 'No message recvd from MPD before timeout')
        proxy.send_terminatejob_error_event(tid, 0, 'Could not contact runtime')
        return 0
    if msg['cmd'] != 'mpdkilljob_ack':
        if msg['cmd'] == 'already_have_a_console':
    	    err_msg = 'Someone already connected to the MPD console'
        else:
            err_msg = 'Unexpected message from mpd: %s' % (msg)
    	conSock.close()
        proxy.send_message_event(gTransID, MSG_LEVEL_ERROR, 0, err_msg)
        proxy.send_terminatejob_error_event(tid, 0, 'Could not contact runtime')
        return 0
    conSock.close()
    if not msg['handled']:
        proxy.send_message_event(gTransID, MSG_LEVEL_WARNING, 0, 'job not found')
    proxy.send_ok_event(tid)
    return 0

def submit_job(tid, attrs, proxy):
    global parmdb, sigOccurred, streamHandler, manSocks
    global tmpJobID, jobProcCount, ptpJobIDToJobID, ringInfo, gTransID
    currumask = os.umask(0); os.umask(currumask)

    jobsub_id = ''
    pgm_name = ''
    pgm_args = []
    pgm_env = {}
    exec_path = ''
    nprocs = 0
    cwd = ''
    debug = 0
    debug_exec_name = ''
    debug_exec_path = ''
    debug_args = []

    if JOB_SUB_ID_ATTR in attrs:
        jobsub_id = attrs[JOB_SUB_ID_ATTR]
    if JOB_EXEC_NAME_ATTR in attrs:
	    pgm_name = attrs[JOB_EXEC_NAME_ATTR]
    if JOB_EXEC_PATH_ATTR in attrs:
        exec_path = attrs[JOB_EXEC_PATH_ATTR]
    if JOB_NUM_PROCS_ATTR in attrs:
        nprocs = int(attrs[JOB_NUM_PROCS_ATTR])
    if JOB_WORKING_DIR_ATTR in attrs:
        cwd = attrs[JOB_WORKING_DIR_ATTR]
    if JOB_PROG_ARGS_ATTR in attrs:
        pgm_args = attrs[JOB_PROG_ARGS_ATTR]
    if JOB_ENV_ATTR in attrs:
        for kv in attrs[JOB_ENV_ATTR]:
            env = kv.split('=')
            if len(env) == 2:
        		pgm_env[env[0]] = env[1]
    if JOB_DEBUG_EXEC_NAME_ATTR in attrs:
        debug_exec_name = attrs[JOB_DEBUG_EXEC_NAME_ATTR]
    if JOB_DEBUG_EXEC_PATH_ATTR in attrs:
        debug_exec_path = attrs[JOB_DEBUG_EXEC_PATH_ATTR]
    if JOB_DEBUG_ARGS_ATTR in attrs:
        debug_args = attrs[JOB_DEBUG_ARGS_ATTR]
    if JOB_DEBUG_FLAG_ATTR in attrs:
        debug = 1

    if jobsub_id == '':
        proxy.send_submitjob_error_event(tid, jobsub_id, 0, 'missing ID on job submission')
        return 0
        
    if pgm_name == '':
    	proxy.send_submitjob_error_event(tid, jobsub_id, 0, 'Must specify a program name')
        return 0

    if exec_path != '':
        pgm_path = exec_path + '/' + pgm_name
    else:
        pgm_path = pgm_name
        
    if nprocs <= 0:
    	proxy.send_submitjob_error_event(tid, jobsub_id, 0, 'Invalid number of processes')
        return 0

    if cwd == '':
    	proxy.send_submitjob_error_event(tid, jobsub_id, 0, 'Must specify a working directory')
        return 0

    conSock = open_mpd_console()
    if conSock == 0:
    	proxy.send_message_event(tid, MSG_LEVEL_ERROR, 0, 'Could not connect to MPD')
        proxy.send_submitjob_error_event(tid, jobsub_id, 0, 'Could not submit job to runtime')
        return 0

    listenSock = MPDListenSock('',0,name='socket_to_listen_for_man')
    listenPort = listenSock.getsockname()[1]

    parmdb[('thispgm','nprocs')] = nprocs

    msgToMPD = { 'cmd'            : 'mpdrun',
                 'conhost'        : myAddr,
                 'conip'	      : myIP,
                 'spawned'        : 0,
                 'nstarted'       : 0,
                 'hosts'          : {},
                 'execs'          : {},
                 'users'          : {},
                 'cwds'           : {},
                 'umasks'         : {},
                 'paths'          : {},
                 'args'           : {},
                 'limits'         : {},
                 'envvars'        : {},
                 'ifhns'          : {},
               }

    msgToMPD['nprocs'] = nprocs
    msgToMPD['conport'] = listenPort
    msgToMPD['limits'][(0,nprocs-1)]  = {}
    msgToMPD['users'][(0,nprocs-1)]   = mpd_get_my_username()
    msgToMPD['execs'][(0,nprocs-1)]   = pgm_path
    msgToMPD['args'][(0,nprocs-1)]    = pgm_args
    msgToMPD['paths'][(0,nprocs-1)]   = os.environ['PATH']
    msgToMPD['cwds'][(0,nprocs-1)]    = cwd
    msgToMPD['umasks'][(0,nprocs-1)]  = str(currumask)
    msgToMPD['hosts'][(0,nprocs-1)]   = '_any_'
    msgToMPD['envvars'][(0,nprocs-1)] = pgm_env
    msgToMPD['conifhn'] = ''
    msgToMPD['jobalias'] = ''
    msgToMPD['try_1st_locally'] = 1
    msgToMPD['line_labels'] = '%r'
    msgToMPD['stdin_dest'] = '0'
    msgToMPD['gdb'] = 0
    msgToMPD['gdba'] = ''
    msgToMPD['totalview'] = 0
    msgToMPD['singinitpid'] = parmdb['singinitpid']
    msgToMPD['singinitport'] = parmdb['singinitport']
    msgToMPD['host_spec_pool'] = []

    msgToSend = { 'cmd' : 'get_mpdrun_values' }
    conSock.send_dict_msg(msgToSend)
    msg = conSock.recv_dict_msg(timeout=recvTimeout)
    if not msg:
        conSock.close()
        listenSock.close()
    	proxy.send_message_event(tid, MSG_LEVEL_ERROR, 0, 'Communication with MPD failed')
        proxy.send_submitjob_error_event(tid, jobsub_id, 0, 'Could not submit job to runtime')
        return 0
    elif msg['cmd'] != 'response_get_mpdrun_values':
        conSock.close()
        listenSock.close()
    	proxy.send_message_event(tid, MSG_LEVEL_ERROR, 0, 'Unexpected msg from MPD :%s:' % (msg))
        proxy.send_submitjob_error_event(tid, jobsub_id, 0, 'Could not submit job to runtime')
        return 0
    if msg['mpd_version'] != mpd_version():
        conSock.close()
        listenSock.close()
    	proxy.send_message_event(tid, MSG_LEVEL_ERROR, 0, 'MPD version %s does not match mpiexec version %s' % \
                      (msg['mpd_version'],mpd_version()))
        proxy.send_submitjob_error_event(tid, jobsub_id, 0, 'Could not submit job to runtime')
        return 0

    # make sure to do this after nprocs has its value
    linesPerRank = {}  # keep this a dict instead of a list
    for i in range(msgToMPD['nprocs']):
        linesPerRank[i] = []

    conSock.send_dict_msg(msgToMPD)
    msg = conSock.recv_dict_msg(timeout=recvTimeout)
    if not msg:
        conSock.close()
        listenSock.close()
    	proxy.send_message_event(tid, MSG_LEVEL_ERROR, 0, 'No msg recvd from MPD when expecting ack of request')
        proxy.send_submitjob_error_event(tid, jobsub_id, 0, 'Could not submit job to runtime')
        return 0
    elif msg['cmd'] == 'mpdrun_ack':
        currRingSize = msg['ringsize']
        currRingNCPUs = msg['ring_ncpus']
    else:
        if msg['cmd'] == 'already_have_a_console':
    	    err_msg = 'Someone already connected to the MPD console'
        elif msg['cmd'] == 'job_failed':
            if  msg['reason'] == 'some_procs_not_started':
                err_msg = 'Unable to start all procs; may have invalid machine names'
            elif  msg['reason'] == 'invalid_username':
                err_msg =  'Invalid username %s at host %s' % \
                      (msg['username'],msg['host'])
            else:
                err_msg = 'Job failed; reason=:%s:' % (msg['reason'])
        else:
            err_msg = 'Unexpected message from mpd: %s' % (msg)
    	conSock.close()
    	listenSock.close()
        proxy.send_message_event(tid, MSG_LEVEL_ERROR, 0, err_msg)
        proxy.send_submitjob_error_event(tid, jobsub_id, 0, 'Could not submit job to runtime')
        return 0

    conSock.close()

    #
    # Acknowledge job submission
    #
    proxy.send_ok_event(tid)

    (manSock,addr) = listenSock.accept()
    if not manSock:
    	listenSock.close()
    	proxy.send_message_event(gTransID, MSG_LEVEL_ERROR, 0, 'Failed to obtain sock from MPD manager')
        return 0
        
    ptp_jobid = generate_id()
    jobProcCount[ptp_jobid] = nprocs
    streamHandler.set_handler(manSock,handle_man_input,args=(streamHandler,ptp_jobid,proxy))

    # first, do handshaking with man 
    msg = manSock.recv_dict_msg()  
    if (not msg  or  not msg.has_key('cmd') or msg['cmd'] != 'man_checking_in'):
    	proxy.send_message_event(gTransID, MSG_LEVEL_ERROR, 0, 'Invalid handshake msg: %s' % (msg))
        streamHandler.del_handler(manSock)
        return 0

    msgToSend = { 'cmd' : 'ringsize', 'ring_ncpus' : currRingNCPUs,
                  'ringsize' : currRingSize }
    manSock.send_dict_msg(msgToSend)
    msg = manSock.recv_dict_msg()
    if (not msg  or  not msg.has_key('cmd')):
    	proxy.send_message_event(gTransID, MSG_LEVEL_ERROR, 0, 'Invalid reply to ringsize msg: %s' % (msg))
        streamHandler.del_handler(manSock)
        return 0
    if (msg['cmd'] == 'job_started'):
        mpich_jobid = make_jobid(msg['jobid'])
    	ptpJobIDToJobID[ptp_jobid] = mpich_jobid
    	debug_print('ptp_mpich2_proxy: job %s started' % mpich_jobid)
    else:
    	proxy.send_message_event(gTransID, MSG_LEVEL_ERROR, 0, 'Unknown msg: %s' % (msg))
        streamHandler.del_handler(manSock)
        return 0

    proxy.send_new_job_event(gTransID, queueID, ptp_jobid, mpich_jobid, JOB_STATE_RUNNING, jobsub_id)
    if get_procs(ptp_jobid, proxy) < 0:
        proxy.send_job_state_change_event(gTransID, ptp_jobid, JOB_STATE_ERROR)
        streamHandler.del_handler(manSock)
        return 0
    
    (manCliStdoutSock,addr) = listenSock.accept(name='stdout_sock')
    streamHandler.set_handler(manCliStdoutSock,
			  handle_cli_stdout,
			  args=(streamHandler,ptp_jobid,proxy))
    (manCliStderrSock,addr) = listenSock.accept(name='stderr_sock')
    streamHandler.set_handler(manCliStderrSock,
			  handle_cli_stderr,
			  args=(streamHandler,ptp_jobid,proxy))

    manSocks.append(manSock)

    return 0

#
# Find lowest and highest element IDs and
# return an ID range
#
def get_proc_id_range(jobid):
    global procIDs, jobProcCount
    if jobid not in procIDs or jobid not in jobProcCount:
        ptp_print('fatal error: could not find job %s', jobid)
        return '0'
    procs = procIDs[jobid]
    min = procs[0]
    max = procs[0]
    for proc in procs.values():
        if proc > max:
            max = proc
        if proc < min:
            min = proc
    return '%s-%s' % (min, max)
    
def handle_man_input(sock,streamHandler,jobid,proxy):
    global manSocks, jobIDToPTPJobId, gTransID, procIDs, jobProcCount
    msg = sock.recv_dict_msg()
    if not msg:
        # The the job is terminated, we will enter here. This doesn't
        # seem right, but it's what appears to happen.
        proxy.send_process_signalled_event(gTransID, get_proc_id_range(jobid), signal.SIGKILL)
        proxy.send_job_state_change_event(gTransID, jobid, JOB_STATE_TERMINATED)
        streamHandler.del_handler(sock)
    elif not msg.has_key('cmd'):
        ptp_print('ptp_mpich2_proxy: from man, invalid msg=:%s:' % (msg))
    elif msg['cmd'] == 'startup_status':
        if msg['rc'] != 0:
            ptp_print('rank %d (%s) in job %s failed to find executable %s' % \
                  ( msg['rank'], msg['src'], msg['jobid'], msg['exec'] ))
            host = msg['src'].split('_')[0]
            reason = unquote(msg['reason'])
            ptp_print('problem with execution of %s  on  %s:  %s ' % \
                  (msg['exec'],host,reason))
            # keep going until all man's finish
            proxy.send_process_state_change_event(gTransID, get_proc_id_range(jobid), PROC_STATE_ERROR)
            proxy.send_job_state_change_event(gTransID, jobid, JOB_STATE_ERROR)
    	return
    elif msg['cmd'] == 'job_aborted_early':
        ptp_print('rank %d in job %s caused collective abort of all ranks' % \
              ( msg['rank'], msg['jobid'] ))
        status = msg['exit_status']
        if hasattr(os,'WIFSIGNALED')  and  os.WIFSIGNALED(status):
            killed_status = status & 0x007f  # AND off core flag
            ptp_print('  exit status of rank %d: killed by signal %d ' % \
                  (msg['rank'],killed_status))
            proxy.send_process_signalled_event(gTransID, procIDs[jobid][int(msg['rank'])], killed_status)
        elif hasattr(os,'WEXITSTATUS'):
            exit_status = os.WEXITSTATUS(status)
            ptp_print('  exit status of rank %d: return code %d ' % \
                  (msg['rank'],exit_status))
            proxy.send_process_exited_event(gTransID, procIDs[jobid][int(msg['rank'])], exit_status)
    	job_state = JOB_STATE_ERROR
        jobProcCount[jobid] -= 1
    elif msg['cmd'] == 'job_aborted':
        ptp_print('job aborted; reason = %s' % (msg['reason']))
        proxy.send_process_state_change_event(gTransID, get_proc_id_range(jobid), PROC_STATE_ERROR)
        job_state = JOB_STATE_ERROR
        jobProcCount[jobid] = 0
    elif msg['cmd'] == 'client_exit_status':
        ptp_print("exit info: rank=%d  host=%s  pid=%d  status=%d" % \
              (msg['cli_rank'],msg['cli_host'],
               msg['cli_pid'],msg['cli_status']))
        status = msg['cli_status']
        if hasattr(os,'WIFSIGNALED')  and  os.WIFSIGNALED(status):
            killed_status = status & 0x007f  # AND off core flag
            ptp_print('exit status of rank %d: killed by signal %d ' % \
                   (msg['cli_rank'],killed_status))
            proxy.send_process_signalled_event(gTransID, procIDs[jobid][int(msg['cli_rank'])], killed_status)
        elif hasattr(os,'WEXITSTATUS'):
            exit_status = os.WEXITSTATUS(status)
            ptp_print('exit status of rank %d: return code %d ' % \
                  (msg['cli_rank'],exit_status))
            proxy.send_process_exited_event(gTransID, procIDs[jobid][int(msg['cli_rank'])], exit_status)
        job_state = JOB_STATE_TERMINATED
        jobProcCount[jobid] -= 1
    else:
        ptp_print('unrecognized msg from manager :%s:' % msg)
        return

    if jobProcCount[jobid] <= 0:
        proxy.send_job_state_change_event(gTransID, jobid, job_state)
        del ptpJobIDToJobID[jobid]
        del procIDs[jobid]
    	streamHandler.del_handler(sock)
    	manSocks.remove(sock)
    	sock.close()

def handle_cli_stdout(sock,streamHandler,jobid,proxy):
    global procIDs
    msg = sock.recv_one_line()
    if not msg:
    	streamHandler.del_handler(sock)
    else:
    	try:
    	    (rank,rest) = msg.rstrip().split(':',1)
            rank = int(rank)
    	except:
    	    rest = msg.rstrip()
    	    rank = 0
        if jobid in procIDs:
            procs = procIDs[jobid]
            if rank in procs:
                proxy.send_process_output_event(gTransID, procs[rank], rest)

def handle_cli_stderr(sock,streamHandler,jobid,proxy):
    msg = sock.recv(1024)
    if not msg:
        streamHandler.del_handler(sock)
    else:
        sys.stderr.write(msg)
        sys.stderr.flush()

def handle_sig_occurred(manSock):
    global sigOccurred, exit
    if sigOccurred == signal.SIGINT:
        if manSock:
            msgToSend = { 'cmd' : 'signal', 'signo' : 'SIGINT' }
            manSock.send_dict_msg(msgToSend)
            manSock.close()
        exit = 1
    elif sigOccurred == signal.SIGALRM:
        if manSock:
            msgToSend = { 'cmd' : 'signal', 'signo' : 'SIGKILL' }
            manSock.send_dict_msg(msgToSend)
            manSock.close()
        ptp_print('job ending due to env var MPIEXEC_TIMEOUT=%s' % \
                  os.environ['MPIEXEC_TIMEOUT'])
        exit = 1

def sig_handler(signum,frame):
    global sigOccurred
    sigOccurred = signum
    mpd_handle_signal(signum,frame)

def debug_print(str):
    global debug
    if debug:
    	ptp_print(str)

#
# Replacement for MPDConClientSock() that doesn't
# call sys.exit()
#
class PTPConClientSock(MPDSock):
    def __init__(self,name='console_to_mpd',mpdroot='',secretword='',**kargs):
        MPDSock.__init__(self)
        self.sock = 0
        if os.environ.has_key('MPD_CON_EXT'):
            self.conExt = '_'  + os.environ['MPD_CON_EXT']
        else:       
            self.conExt = ''
        self.secretword = secretword
        self.mpdroot = mpdroot
        self.name = name
        
    def connect(self):
        if self.mpdroot: 
            self.conFilename = '/tmp/mpd2.console_root' + self.conExt
            self.sock = MPDSock(family=socket.AF_UNIX,name=self.name)
            rootpid = os.fork()
            if rootpid == 0:
                os.execvpe(self.mpdroot,[self.mpdroot,self.conFilename,str(self.sock.fileno())],{})           
                ptp_print('failed to exec mpdroot (%s)' % self.mpdroot ) 
                return 0    
            else:   
                (pid,status) = os.waitpid(rootpid,0)
                if os.WIFSIGNALED(status):
                    status = status & 0x007f  # AND off core flag
                else:
                    status = os.WEXITSTATUS(status)
                if status != 0:
                    ptp_print('forked process failed; status=' % status)
                    return 0
        else:
            self.conFilename = '/tmp/mpd2.console_' + mpd_get_my_username() + self.conExt
            if hasattr(socket,'AF_UNIX'):
                sockFamily = socket.AF_UNIX
            else:
                sockFamily = socket.AF_INET
            if os.environ.has_key('MPD_CON_INET_HOST_PORT'):
                sockFamily = socket.AF_INET    # override above-assigned value
                (conHost,conPort) = os.environ['MPD_CON_INET_HOST_PORT'].split(':')
                conPort = int(conPort)
            else:
                (conHost,conPort) = ('',0)
            self.sock = MPDSock(family=sockFamily,socktype=socket.SOCK_STREAM,name=self.name)
            if hasattr(socket,'AF_UNIX')  and  sockFamily == socket.AF_UNIX:
                if hasattr(signal,'alarm'):
                    oldAlarmTime = signal.alarm(8)
                else:    # assume python2.3 or later
                    oldTimeout = socket.getdefaulttimeout()
                    socket.setdefaulttimeout(8)
                try:
                    self.sock.connect(self.conFilename)
                except Exception, errmsg:
                    self.sock.close()
                    self.sock = 0
                if hasattr(signal,'alarm'):
                    signal.alarm(oldAlarmTime)
                else:    # assume python2.3 or later
                    socket.setdefaulttimeout(oldTimeout)
                if self.sock:
                    # this is done by mpdroot otherwise
                    msgToSend = 'realusername=%s secretword=UNUSED\n' % \
                                mpd_get_my_username()
                    self.sock.send_char_msg(msgToSend)
            else:
                if not conPort:
                    conFile = open(self.conFilename)
                    for line in conFile:
                        line = line.strip()
                        (k,v) = line.split('=')
                        if k == 'port':
                            conPort = int(v)
                    conFile.close()
                if conHost:
                    conIfhn = socket.gethostbyname_ex(conHost)[2][0]
                else:
                    conIfhn = 'localhost'
                self.sock = MPDSock(name=self.name)
                if hasattr(signal,'alarm'):
                    oldAlarmTime = signal.alarm(8)
                else:    # assume python2.3 or later
                    oldTimeout = socket.getdefaulttimeout()
                    socket.setdefaulttimeout(8)
                try:
                    self.sock.connect((conIfhn,conPort))
                except Exception, errmsg:
                    ptp_print("failed to connect to host %s port %d" % \
                              (conIfhn,conPort) )
                    self.sock.close()
                    self.sock = 0
                if hasattr(signal,'alarm'):
                    signal.alarm(oldAlarmTime)
                else:    # assume python2.3 or later
                    socket.setdefaulttimeout(oldTimeout)
                if not self.sock:
                    ptp_print('%s: cannot connect to local mpd (%s); possible causes:' % \
                          (mpd_get_my_id(),self.conFilename))
                    ptp_print('  1. no mpd is running on this host')
                    ptp_print('  2. an mpd is running but was started without a "console" (-n option)')
                    return 0
                msgToSend = { 'cmd' : 'con_init' }
                self.sock.send_dict_msg(msgToSend)
                msg = self.sock.recv_dict_msg()
                if not msg:
                    ptp_print('expected con_challenge from mpd; got eof')
                    return 0
                if msg['cmd'] != 'con_challenge':
                    ptp_print('expected con_challenge from mpd; got msg=:%s:' % (msg) )
                    return 0
                randVal = self.secretword + str(msg['randnum'])
                response = md5new(randVal).digest()
                msgToSend = { 'cmd' : 'con_challenge_response', 'response' : response,
                              'realusername' : mpd_get_my_username() }
                self.sock.send_dict_msg(msgToSend)
                msg = self.sock.recv_dict_msg()
                if not msg  or  msg['cmd'] != 'valid_response':
                    ptp_print('expected valid_response from mpd; got msg=:%s:' % (msg) )
                    return 0
        if not self.sock:
            ptp_print('%s: cannot connect to local mpd (%s); possible causes:' % \
                  (mpd_get_my_id(),self.conFilename))
            ptp_print('  1. no mpd is running on this host')
            ptp_print('  2. an mpd is running but was started without a "console" (-n option)')
            return 0
        return 1

def usage():
    ptp_print(__doc__)
    sys.exit(-1)

if __name__ == '__main__':
    ptp_mpd_proxy()
    sys.exit(0)
