#!/usr/bin/env python
# compare two genepred files and output some stats about the differences

import sys, gzip, operator, logging
from collections import defaultdict

CLOSE = 10

STRIPVERSION = True

def main():
    if len(sys.argv)==1:
        print("genePredComp file1 file2: compare two gene pred files and output some stats")
        sys.exit(0)

    logging.basicConfig(level=logging.INFO)
    fname1, fname2 = sys.argv[1], sys.argv[2]
    gp1 = parseGp(fname1)
    gp2 = parseGp(fname2)
    genePredComp(fname1, fname2, gp1, gp2)

def parseGp(inFname):
    " return as dict id -> tuple "
    logging.info("Parsing %s..." % inFname)
    if inFname.endswith(".gz"):
        ofunc = gzip.open
    else:
        ofunc = open
    ifh = ofunc(inFname)

    data = defaultdict(list)
    for line in ifh:
        if line.startswith("name"):
            continue
        row = line.rstrip("\n").split("\t")
        row = [r.strip(",") for r in row]
        row = row[:10]
        #name = row[0].split(".")[0]
        name = row[0]
        if STRIPVERSION:
            name = name.split(".")[0]

        if row[1].isdigit() or row[1]=="Y" or row[1]=="X":
            row[1] = "chr"+row[1]
        if row[1]=="MT":
            row[1]="chrM"

        data[name].append(tuple(row[1:]))
    return data

#def countVerDiff(ids1, ids2):
    #for id in ids1:
        #base, ver = id.split(".")

def coordOverlap(start1, end1, start2, end2):
    """ returns true if two Features overlap """
    start1 = int(start1)
    end1 = int(end1)
    start2 = int(start2)
    end2 = int(end2)

    result = (( start2 <= start1 and end2 > start1) or \
            (start2 < end1 and end2 >= end1) or \
            (start1 >= start2 and end1 <= end2) or \
            (start2 >= start1 and end2 <= end1))
    return result
    # modified: end2 > end1

def genePredComp(fname1, fname2, gp1, gp2):
    keys1 = gp1.keys()
    keys2 = gp2.keys()
    print("%d IDs in %s" % (len(keys1), fname1))
    print("%d IDs in %s" % (len(keys2), fname2))
    commonIds = set(keys1).intersection(keys2)
    print("%d common IDs" % (len(commonIds)))

    print("For the common IDs:")
    diffTransCount = 0
    diffTransList = []
    exactMatch = 0
    overlapCount = 0
    noOverlap = 0
    txMatch = 0
    txMatchList = []
    txNonMatch = 0
    txClose = 0
    txStartClose = 0
    txEndClose = 0
    cdsMatch = 0

    ofh = open("noOverlap.gp", "w")
    diffOfh = open("different.gp", "w")

    for id in commonIds:
        tList1 = gp1[id]
        tList2 = gp2[id]
        if len(tList1)!=1 or len(tList2)!=1:
            diffTransCount +=1
            diffTransList.append(id)
        t1 = tList1[0]
        t2 = tList2[0]
        
        #print id
        #print t1
        #print t2
        #print
        if t1==t2:
            exactMatch += 1
        else:
            txStart1, txEnd1 = int(t1[2]), int(t1[3])
            txStart2, txEnd2 = int(t2[2]), int(t2[3])
            if txStart1==txStart2 and txEnd1==txEnd2:
                txMatch += 1
                txMatchList.append(id)
            else:
                txNonMatch +=1
                diffOfh.write(id+"\t"+"\t".join(t1)+"\n")
                diffOfh.write(id+"\t"+"\t".join(t2)+"\n")

                cdsStart1, cdsEnd1 = int(t1[2]), int(t1[3])
                cdsStart2, cdsEnd2 = int(t2[2]), int(t2[3])
                if cdsStart1==cdsStart2 and cdsEnd1==cdsEnd2:
                    cdsMatch +=1

                if abs(txStart1-txStart2)<CLOSE and abs(txEnd1-txEnd2)<CLOSE:
                    txClose += 1
                if abs(txStart1-txStart2)<CLOSE and txEnd1==txEnd2:
                    txStartClose += 1
                if txStart1==txStart2 and abs(txEnd1-txEnd2)<CLOSE:
                    txEndClose += 1
                if coordOverlap(t1[2], t1[3], t2[2], t2[3]):
                    overlapCount += 1
                else:
                    noOverlap += 1
                    ofh.write("%s\t%s\n" % (id, "\t".join(t1)))
                    ofh.write("%s\t%s\n" % (id, "\t".join(t2)))

    print("  - %d IDs ignored because number of transcripts/id is > 1, e.g. %s" % (diffTransCount, diffTransList[:3]))
    #print "  - %d IDs ignored because different version" % diffVerCount
    if len(commonIds)==0:
        print("Error: Nothing in common. ")
        print("First set of keys: %s " % list(keys1)[:3])
        print("Second set of keys: %s " % list(keys2)[:3])
        sys.exit(1)

    exactMatchShare = 100*float(exactMatch)/float(len(commonIds))
    print("  - %d transcripts (%0.00f%%) are exactly identical" % (exactMatch, exactMatchShare))
    print("For those %d that are not exactly identical:" % (len(commonIds)-exactMatch))
    print("  - %d transcripts have identical txStart,txEnd (e.g. %s)" % (txMatch, txMatchList[:3]))
    print("  - %d transcripts have a different txStart,txEnd (written to different.gp)" % txNonMatch)
    print("For those %d that don't have identical txStart,txEnd:" % txNonMatch)
    print("  - %d transcripts have identical cdsStart,cdsEnd" % cdsMatch)
    print("  - %d transcripts have txStart or txEnd off by <%dbp" % (txClose, CLOSE))
    print("  - %d transcripts have only txStart off by <%d bp" % (txStartClose, CLOSE))
    print("  - %d transcripts have only txEnd off by <%d bp" % (txEndClose, CLOSE))
    print("  - %d transcripts overlap" % overlapCount)
    print("  - %d transcripts have no overlap at all (written to noOverlap.gp)" % noOverlap)

def printData(exonSpans):
    # for each transcript, 
    for ctgTrans, exonList in exonSpans.iteritems():
        contig, transId = ctgTrans
        # sort on start pos
        exonList.sort(key=operator.itemgetter(1))
        # find cdsStart, cdsEnd
        cdsExons = [e for e in exonList if e[4] == "CDS"]
        if len(cdsExons)==0:
            cdsStart, cdsEnd = exonList[0][1], exonList[0][1]
        else:
            cdsStart = cdsExons[0][1]
            cdsEnd = cdsExons[-1][2]

        # convert to (start, end) tuples
        spans = []
        for chrom, chromStart, chromEnd, strand, typeGroup in exonList:
            # if directly adjacent, fuse
            if len(spans)!=0 and chromStart==spans[-1][1]:
                spans[-1][1] = chromEnd
            # otherwise append
            else:
                spans.append( [chromStart, chromEnd] )

        exonStarts = [str(s[0]) for s in spans]
        exonEnds = [str(s[1]) for s in spans]

        txStart = exonStarts[0]
        txEnd = exonEnds[-1]

        row = [ \
            transId, "chr"+chrom, strand, \
            txStart, txEnd, \
            cdsStart, cdsEnd, \
            len(exonStarts), \
            ",".join(exonStarts), \
            ",".join(exonEnds) 
            ]
        row = [str(s) for s in row]
        print("\t".join(row))

main()
