#!/usr/bin/env python

import logging, sys, optparse, gzip, subprocess
from collections import defaultdict
from os.path import join, basename, dirname, isfile

# ==== functions =====

# Colors by rule
RULE_COLORS = {
    1: "255,0,0",       # red - last 50bp of last coding junction
    2: "255,140,0",     # orange - intronless transcript
    3: "139,0,0",       # dark red - first 100bp of coding nucleotides
}

RULE_DESCRIPTIONS = {
    1: "Rule 1 - last 50bp of last coding exon junction",
    2: "Rule 2 - intronless transcript",
    3: "Rule 3 - first 100bp of coding nucleotides",
}

def parseArgs():
    " setup logging, parse command line arguments and options. -h shows auto-generated help page "
    parser = optparse.OptionParser("usage: %prog [options] inFname outDecoFname outCollapsedFname - "
        "Output BEDs with NMD escape regions. First output is decorator format, "
        "second is collapsed bigGenePred with gene symbols and transcript lists.")

    parser.add_option("-d", "--debug", dest="debug", action="store_true",
        help="show debug messages")
    parser.add_option("-f", "--format", dest="format", action="store", default="genePredExt",
        help="Input format: 'genePredExt' (with bin column, e.g. ncbiRefSeq) or "
             "'bigGenePred' (e.g. gencode .bb file, will use bigBedToBed). Default: genePredExt")
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.print_help()
        exit(1)

    if options.debug:
        logging.basicConfig(level=logging.DEBUG)
        logging.getLogger().setLevel(logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)
        logging.getLogger().setLevel(logging.INFO)

    return args, options

def openInput(fname, fmt):
    """Open input file and yield parsed transcript dicts.
    Both formats yield dicts with keys: name, chrom, strand, txStart, txEnd,
    cdsStart, cdsEnd, exonCount, exonStarts, exonEnds, geneSym, cdsStartStat, cdsEndStat, exonFrames
    """
    if fmt == "bigGenePred":
        # pipe through bigBedToBed
        proc = subprocess.Popen(["bigBedToBed", fname, "stdout"],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        fh = proc.stdout
    elif fname.endswith(".gz"):
        fh = gzip.open(fname, "rt")
    else:
        fh = open(fname)

    for line in fh:
        line = line.rstrip("\n\r")
        if not line:
            continue
        fields = line.split("\t")

        if fmt == "genePredExt":
            # genePredExt with bin column:
            # bin, name, chrom, strand, txStart, txEnd, cdsStart, cdsEnd,
            # exonCount, exonStarts, exonEnds, score, name2, cdsStartStat, cdsEndStat, exonFrames
            rec = {
                "name": fields[1],
                "chrom": fields[2],
                "strand": fields[3],
                "txStart": int(fields[4]),
                "txEnd": int(fields[5]),
                "cdsStart": int(fields[6]),
                "cdsEnd": int(fields[7]),
                "exonCount": int(fields[8]),
                "exonStarts": [int(x) for x in fields[9].strip(",").split(",") if x],
                "exonEnds": [int(x) for x in fields[10].strip(",").split(",") if x],
                "geneSym": fields[12] if len(fields) > 12 else "",
                "cdsStartStat": fields[13] if len(fields) > 13 else "none",
                "cdsEndStat": fields[14] if len(fields) > 14 else "none",
                "exonFrames": fields[15].strip(",") if len(fields) > 15 else "",
            }
        elif fmt == "bigGenePred":
            # bigGenePred:
            # chrom, chromStart, chromEnd, name, score, strand, thickStart, thickEnd,
            # color, blockCount, blockSizes, chromStarts, name2, cdsStartStat, cdsEndStat,
            # exonFrames, type, geneName, geneName2, geneType
            chromStart = int(fields[1])
            blockSizes = [int(x) for x in fields[10].strip(",").split(",") if x]
            blockStarts = [int(x) for x in fields[11].strip(",").split(",") if x]
            rec = {
                "name": fields[3],
                "chrom": fields[0],
                "strand": fields[5],
                "txStart": chromStart,
                "txEnd": int(fields[2]),
                "cdsStart": int(fields[6]),
                "cdsEnd": int(fields[7]),
                "exonCount": int(fields[9]),
                "exonStarts": [chromStart + s for s in blockStarts],
                "exonEnds": [chromStart + s + sz for s, sz in zip(blockStarts, blockSizes)],
                "geneSym": fields[17] if len(fields) > 17 else "",
                "cdsStartStat": fields[13] if len(fields) > 13 else "none",
                "cdsEndStat": fields[14] if len(fields) > 14 else "none",
                "exonFrames": fields[15].strip(",") if len(fields) > 15 else "",
            }
        else:
            raise ValueError("Unknown format: " + fmt)

        yield rec

    if fmt == "bigGenePred":
        proc.wait()

def bedOut(row, txStart, txEnd, ofh, rule):
    "write a decorator BED line"
    row = [str(x) for x in row]
    chrom, start, end, name = row
    decItem = chrom+":"+str(txStart)+"-"+str(txEnd)+":"+name
    color = RULE_COLORS[rule]
    mouseover = RULE_DESCRIPTIONS[rule]
    row = [chrom, start, end, name, "0", ".", start, end, color, "1",
           str(int(end)-int(start)), "0", decItem, "block", color, "", mouseover]
    ofh.write("\t".join(row))
    ofh.write("\n")


def outputExonsUpTo(from3Prime, cdsExons, chrom, txStart, txEnd, name, n, ofh, rule):
    """ given a list of (start, end), output start-end BEDs to ofh that cover n nucleotides.
        Returns list of (chrom, start, end) regions output. """
    doneNs = 0
    doStop = False
    regions = []
    if from3Prime:
        cdsExons = list(reversed(cdsExons))

    # -50 means "-50 from the last junction" so take length of last exon + 50
    if n < 0:
        ex1Start = cdsExons[0][0]
        ex1End = cdsExons[0][1]
        n = (ex1End-ex1Start)+50

    for start, end in cdsExons:
        if doneNs >= n:
            return regions
        exLen = end-start
        missBps = n-doneNs
        if doneNs+exLen > n:
            if from3Prime:
                start = end-missBps
            else:
                end = start+missBps
            doStop = True
        bed = [chrom, str(start), str(end), name]
        doneNs += exLen
        bedOut(bed, txStart, txEnd, ofh, rule)
        regions.append( (chrom, start, end) )
        if doStop:
            return regions
    return regions

# ----------- main --------------
def main():
    args, options = parseArgs()

    inFname, outDecoFname, outCollapsedFname = args

    decoOfh = open(outDecoFname, "w")

    # collect regions for the collapsed output:
    # key = (chrom, start, end, rule) -> {"transcripts": [...], "strand": set(), "geneSym": str}
    regionData = defaultdict(lambda: {"transcripts": [], "strands": set(), "geneSym": ""})

    for rec in openInput(inFname, options.format):
        name = rec["name"]
        chrom = rec["chrom"]
        strand = rec["strand"]
        txStart = rec["txStart"]
        txEnd = rec["txEnd"]
        cdsStart = rec["cdsStart"]
        cdsEnd = rec["cdsEnd"]

        # skip non-coding transcripts (cdsStart == cdsEnd)
        if cdsStart >= cdsEnd:
            continue

        # gene symbol from record, fall back to transcript name
        geneSym = rec["geneSym"]

        # only keep exons that have CDS and cut around CDS
        cdsExons = []
        for exStart, exEnd in zip(rec["exonStarts"], rec["exonEnds"]):
            # 5' UTR
            if cdsStart > exEnd:
                continue
            # 3' UTR
            if exStart > cdsEnd:
                continue
            if (exStart <= cdsStart and cdsStart <= exEnd):
                exStart = cdsStart
            if (exStart <= cdsEnd and cdsEnd <= exEnd):
                exEnd = cdsEnd
            # skip degenerate cdsExons (CDS boundary lands exactly on exon boundary)
            if exStart >= exEnd:
                continue
            cdsExons.append( (exStart, exEnd) )

        def addRegions(regions, rule):
            for r in regions:
                key = (r[0], r[1], r[2], rule)
                regionData[key]["transcripts"].append(name)
                regionData[key]["strands"].add(strand)
                if geneSym and not regionData[key]["geneSym"]:
                    regionData[key]["geneSym"] = geneSym

        if len(cdsExons)==1:
            # rule 2: intronless transcript
            bed = [chrom, str(cdsStart), str(cdsEnd), name]
            bedOut(bed, txStart, txEnd, decoOfh, 2)
            addRegions([(chrom, cdsStart, cdsEnd)], 2)
        else:
            if strand=="+":
                # rule 3: first 100bp of coding nucleotides
                regions = outputExonsUpTo(False, cdsExons, chrom, txStart, txEnd, name, 100, decoOfh, 3)
                addRegions(regions, 3)
                # rule 1: last 50bp of last coding junction
                regions = outputExonsUpTo(True, cdsExons, chrom, txStart, txEnd, name, -50, decoOfh, 1)
                addRegions(regions, 1)
            else:
                # rule 3: first 100bp of coding nucleotides
                regions = outputExonsUpTo(True, cdsExons, chrom, txStart, txEnd, name, 100, decoOfh, 3)
                addRegions(regions, 3)
                # rule 1: last 50bp of last coding junction
                regions = outputExonsUpTo(False, cdsExons, chrom, txStart, txEnd, name, -50, decoOfh, 1)
                addRegions(regions, 1)

    decoOfh.close()

    # write collapsed output as bed 9 + mouseover + transcripts
    collOfh = open(outCollapsedFname, "w")
    for (chrom, start, end, rule), data in sorted(regionData.items()):
        txList = sorted(set(data["transcripts"]))
        geneSym = data["geneSym"]
        if not geneSym:
            geneSym = txList[0]

        # pick strand: use the strand if all transcripts agree, else "."
        strands = data["strands"]
        strand = list(strands)[0] if len(strands) == 1 else "."

        color = RULE_COLORS[rule]
        txListStr = ",".join(txList)
        mouseover = RULE_DESCRIPTIONS[rule] + " (" + str(len(txList)) + " transcripts)"

        row = [chrom, str(start), str(end), geneSym, "0", strand,
               str(start), str(end), color, mouseover, txListStr]
        collOfh.write("\t".join(row))
        collOfh.write("\n")

    collOfh.close()
    logging.info("Wrote %d decorator regions to %s" % (sum(len(d["transcripts"]) for d in regionData.values()), outDecoFname))
    logging.info("Wrote %d collapsed regions to %s" % (len(regionData), outCollapsedFname))

main()
