#!/usr/bin/env python2.7
# condenseMatrix
"""
Condense an expression matrix from transcript rows to human readable gene rows. The transcripts within
a gene are totalled for each gene. 

The environment variable CIRM must be set to the base cirm directory, on hgwdev this is /hive/groups/cirm
The script looks for annotation/ensToHugo/trscrptToGene.hg38/mm10 within this directory. 
"""
import os, sys, collections, argparse

# import the UCSC kent python library
sys.path.append(os.path.join(os.path.dirname(__file__), 'pyLib'))
import common
import tempfile
import subprocess


# Ugly global to help with things. 
matrixHeader = "" 


def parseArgs(args):
    """
    Parse the command line arguments.
    """
    parser= argparse.ArgumentParser(description = __doc__)
    parser.add_argument ("inputFile",
    help = " The input file. ",
    type = argparse.FileType("r"))
    parser.add_argument ("outputFile",
    help = " The output file. ",
    type =argparse.FileType("w"))
    parser.add_argument ("--verbose",
    help = " Spit out messages during runtime. ",
    action = "store_true")

    parser.set_defaults(verbose = False)
    options = parser.parse_args()
    return options

def convertNames(inputFile, intermediateFile, humanMap, mouseMap, verbose):
    """
    Input:
        inputFile - an opened file (r).
        intermediateFile - a namedTemporaryFile.
        ensembleExonToGene - a dict, string keys map to string values. 
    Takes in an input matrix and converts the transcript names into gene
    names.  The new matrix is written to intermediateFile.  The conversion
    is accomplished with the ensembleExonToGene table. 
    """
    global matrixHeader
    firstLine = True
    expCol = 0
    count = 0 
    for line in inputFile:
        if firstLine: # Skip the first line
            firstLine = False
            intermediateFile.write(line)
            expCol = len(line.split())
            count +=1
            matrixHeader = line
            continue 
        splitLine = line.split()
        if (len(splitLine)!=expCol):
            print ("The matrix has been corrupted on line %i. Aborting"%(count))
            exit(1)
        count +=1
        if (humanMap.get(splitLine[0])): 
            count = 0 
            for item in splitLine:
                if (count == 0):
                    intermediateFile.write(humanMap[splitLine[0]] + "\t")
                    count += 1
                    continue
                intermediateFile.write(splitLine[count] + "\t")
                count +=1 
            intermediateFile.write("\n")
        elif (mouseMap.get(splitLine[0])): 
            count = 0 
            for item in splitLine:
                if (count == 0):
                    intermediateFile.write(mouseMap[splitLine[0]] + "\t")
                    count += 1
                    continue
                intermediateFile.write(splitLine[count] + "\t")
                count +=1 
            intermediateFile.write("\n")
        else: # The transcript wasn't found 
            if (verbose): ("Warning! The transcript, %s, was not found in the hash table. "%(splitLine[0]))
            count = 0 
            for item in splitLine:
                if (count == 0):
                    intermediateFile.write("unknown"+splitLine[0] + "\t")
                    count += 1
                    continue
                intermediateFile.write(splitLine[count] + "\t")
                count +=1 
            intermediateFile.write("\n")
        splitLine = None

def collapseRows(intermediateFile, outputFile):
    """
    Input:
        intermediateFile - a namedTemporaryFile. 
        outputFile - an opened file (w). 
    Goes through the intermediateFile and collapses the rows. This 
    function assumes that the intermediate file has been sorted. The collapsed
    matrix is printed to outputFile. 
    """
    firstLine = True
    newRow = []
    rowsInBlock = 0
    skip = False
    count = 0
    outputFile.write(matrixHeader)
    for line in intermediateFile:
        splitLine = line.split()
        expCols = len(splitLine) # Lets verify this is actually a matrix while we make it.
        if skip: # Skip the first row (header names)
            skip = False
            outputFile.write(line)
            expCols = len(splitLine) # Lets verify this is actually a matrix while we make it.
            count += 1
            continue 
        if (len(splitLine) != expCols): 
            print ("The matrix has been corrupted on line %i. Aborting."%(count))
            exit(1)
        if firstLine: # This is the first line of the actual matrix
            firstLine = False
            newRow = splitLine  
            rowsInBlock = 1
            continue
        if (splitLine[0]==newRow[0]): #Inside the same block, keep a summation of the items in this block. 
            itemCount = 1
            rowsInBlock += 1
            for item in splitLine[1:]: 
                newRow[itemCount] = float(newRow[itemCount]) + float(splitLine[itemCount])
                itemCount +=1 
        else:  # This is a flush statement, the current 'block' is complete and the collapsing occurs. 
            itemCount = 1
            for item in newRow[1:]:
                newRow[itemCount] = str(float(newRow[itemCount]))
                itemCount += 1
            outputFile.write("\t".join(newRow) + "\n")
            newRow = splitLine 
            rowsInBlock = 1
        count +=1

def main(args):
    """
    Initialized options and calls other functions. Read in a hash table that maps the transcripts
    to genes. Go through the matrix and convert every transcript name to gene names. Then call the
    unix sort (fast) to sort the new matrix.  This generates 'blocks' of transcripts where all the 
    transcripts for a given gene are next to each other.  Go through the blocked transcript  matrix
    and compress it one block at a time, printing the compressed gene matrix one row at a time.  
    """
    options = parseArgs(args)
    if (options.verbose): print ("Start condensing the expression matrix from transcripts to genes (condenseMatrix).")
    
    # The environment variable CIRM must be set to the base cirm directory, on hgwdev this is /hive/groups/cirm 
    humanMap = common.dictFromTwoTabFile(os.environ['CIRM']+"annotation/ensToHugo/encode79HumanTrscrptToGene.tab", False)
    mouseMap = common.dictFromTwoTabFile(os.environ['CIRM']+"annotation/ensToHugo/encode79MouseTrscrptToGene.tab", False)
    
    # Making use of python's temporary files. 
    intermediateFile = tempfile.NamedTemporaryFile(mode="w+",bufsize=1)
    finalFile = tempfile.NamedTemporaryFile(mode="w+",bufsize=1)
    
    # Convert names
    if (options.verbose): print ("Converting names...")
    convertNames(options.inputFile, intermediateFile, humanMap, mouseMap, options.verbose)
    
    
    # Sort
    if (options.verbose): print ("Sorting...")
    os.system("sed '1D' %s | sort -k 1 -r > %s"%(intermediateFile.name, finalFile.name))
    
    # Collapse
    if (options.verbose): print ("Collapsing matrix...")
    
    collapseRows(finalFile,options.outputFile)
    
    intermediateFile.close()
    finalFile.close()

    if (options.verbose): print ("Completed condensing the expression matrix from transcripts to genes (condenseMatrix).")
    
if __name__ == "__main__" : 
    sys.exit(main(sys.argv))
