#!/usr/bin/env python3

import sys
import os
import argparse
myBinDir = os.path.normpath(os.path.dirname(sys.argv[0]))
sys.path.append(os.path.join(myBinDir, "../lib"))
sys.path.append(os.path.expanduser("/hive/groups/browser/pycbio/lib"))
from collections import defaultdict
try:
    from dataclasses import dataclass, field
except ModuleNotFoundError as ex:
    raise Exception(f"requires at least Python 3.7 for dataclasses: {sys.version}") from ex
from pycbio.sys import fileOps
from pycbio.tsv import TsvReader
from gencode import gencodeGxfParserFactory
from gencode.gencodeGxfParser import ParseException, GxfValidateException, GxfRecordValidator
from gencode.biotypes import BioType, getTranscriptFunction

def parseArgs():
    desc = """Extract attributes from a GENCODE GTF or GFF3 into a tab-separate file for
use by other programs."""
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument("--debug", action="store_true",
                        help="""print stack trace on all error""")
    parser.add_argument("--transcriptRanks",
                        help="""the gencode.v??.transcript_rankings.txt.gz file""")
    parser.add_argument("--drop", action="append", default=[],
                        help="drop genes or transcripts with this id, maybe repeated; if dropping a gene, must specify all transcripts")
    parser.add_argument("gencodeGxf",
                        help="""GENCODE GTF or GFF3 file""")
    parser.add_argument("attrsTsv",
                        help="""write attributes to this TSV file""")
    return parser.parse_args()


def emptyIfNone(v):
    return v if v is not None else ""

def loadTranscriptRanks(transcriptRanksTab):
    "load transcript ranks"
    cols = ("chrom", "geneId", "geneType", "rank", "transcriptId", "transcriptType", "carsScore", "extraInfo")
    typeMap = {
        "rank": int,
        "carsScore": float
    }
    return {rec.transcriptId: rec for rec in TsvReader(transcriptRanksTab,
                                                       columns=cols, typeMap=typeMap)}

@dataclass
class Transcript:
    """attributes collect from one transcript"""
    transcriptId: str
    transcriptName: str
    transcriptType: str
    geneId: str
    geneName: str
    geneType: str
    source: str
    havanaGeneId: str
    havanaTranscriptId: str
    ccdsId: str
    level: int
    proteinId: str
    transcriptSupportLevel: str
    transcriptRank: int
    geneTags: set = field(default_factory=set, init=False)
    transcriptTags: set = field(default_factory=set, init=False)

class GencodeGxfToAttrs(object):
    def __init__(self, gxfParser, dropIds, transcriptRanks=None):
        self.gxfParser = gxfParser
        self.dropIds = dropIds
        self.recValidator = GxfRecordValidator(sys.stderr)
        self.transcriptRanks = transcriptRanks
        self.errorCnt = 0
        self.unknownGeneTypes = set()
        self.unknownTranscriptTypes = set()
        self.geneTags = defaultdict(set)
        self.transcripts = {}

    def _processGeneRec(self, rec):
        geneTags = self.geneTags[rec.attributes.gene_id]
        for tag in rec.attributes.tags:
            geneTags.add(tag)

    def _getTranscriptRank(self, transcriptId):
        if self.transcriptRanks is None:
            return 0
        rankRec = self.transcriptRanks.get(transcriptId)
        if rankRec is None:
            raise ParseException(f"no transcript rank found for {transcriptId}")
        return rankRec.rank

    def _processTransRec(self, rec):
        attrs = rec.attributes
        trans = self.transcripts.get(attrs.transcript_id)
        # older GENCODEs had transcript id twice because of PAR,
        # so check for existing
        if trans is None:
            trans = Transcript(transcriptId=attrs.transcript_id,
                               transcriptName=attrs.transcript_name,
                               transcriptType=attrs.transcript_type,
                               geneId=attrs.gene_id,
                               geneName=attrs.gene_name,
                               geneType=attrs.gene_type,
                               source=rec.source,
                               havanaGeneId=attrs.havana_gene,
                               havanaTranscriptId=attrs.havana_transcript,
                               ccdsId=attrs.ccdsid,
                               level=attrs.level,
                               transcriptSupportLevel=attrs.transcript_support_level,
                               proteinId=attrs.protein_id,
                               transcriptRank=self._getTranscriptRank(attrs.transcript_id))
            self.transcripts[attrs.transcript_id] = trans
        # due to PAR naming on older GENCODES, we have to handle multiple
        # tags for the same transcript_id, which might be different (mainly
        # the PAR tag), so always add they now have modified the ids to remove
        # the _PAR_Y extensions.  In retrospect, GENCODE having the _PAR_Y extensions
        # and us remove it were both bad ideas.
        for tag in attrs.tags:
            trans.transcriptTags.add(tag)

    def _processRec(self, rec):
        try:
            self.recValidator.validate(rec)
            if rec.feature == "gene":
                self._processGeneRec(rec)
            elif rec.feature == "transcript":
                if rec.attributes.transcript_id not in self.dropIds:
                    self._processTransRec(rec)
        except ParseException as ex:
            sys.stderr.write("Error: {}: {}\n".format(str(ex), rec))
            self.errorCnt += 1

    def _mergeGeneTags(self):
        for trans in self.transcripts.values():
            trans.geneTags = self.geneTags[trans.geneId]

    def _assignFunction(self, trans):
        trans.transcriptClass = getTranscriptFunction(BioType(trans.geneType),
                                                      BioType(trans.transcriptType))

    def _assignFunctions(self):
        for trans in self.transcripts.values():
            self._assignFunction(trans)

    def parse(self):
        for rec in self.gxfParser.reader():
            self._processRec(rec)
            if self.errorCnt >= 100:
                break
        if self.errorCnt > 0:
            raise Exception("{} parse errors in {}".format(self.errorCnt, self.gxfParser.getFileName()))
        self.recValidator.reportErrors()
        self._mergeGeneTags()
        self._assignFunctions()

    @staticmethod
    def _writeAttrsTsvRow(tr, fh):
        fileOps.prRowv(fh, tr.geneId, tr.geneName, tr.geneType,
                       tr.transcriptId, tr.transcriptName, tr.transcriptType,
                       tr.source, tr.havanaGeneId, tr.havanaTranscriptId,
                       emptyIfNone(tr.ccdsId), tr.level,
                       tr.transcriptClass, emptyIfNone(tr.proteinId),
                       tr.transcriptRank, tr.transcriptSupportLevel,
                       ",".join(map(str, sorted(tr.geneTags))),
                       ",".join(map(str, sorted(tr.transcriptTags))))

    def writeAttrsTsv(self, fh):
        fileOps.prRowv(fh, "geneId", "geneName", "geneType",
                       "transcriptId", "transcriptName", "transcriptType",
                       "source", "havanaGeneId", "havanaTranscriptId",
                       "ccdsId", "level", "transcriptClass", "proteinId",
                       "transcriptRank", "tsl", "geneTags", "transcriptTags")
        for transcript in sorted(self.transcripts.values(),
                                 key=lambda t: (t.geneId, t.transcriptId)):
            self._writeAttrsTsvRow(transcript, fh)

def gencodeGxfToAttrs(opts):
    try:
        transcriptRanks = None
        if opts.transcriptRanks is not None:
            transcriptRanks = loadTranscriptRanks(opts.transcriptRanks)
        parser = GencodeGxfToAttrs(gencodeGxfParserFactory(opts.gencodeGxf, ignoreUnknownAttrs=True),
                                   dropIds=frozenset(opts.drop), transcriptRanks=transcriptRanks)
        parser.parse()
        with open(opts.attrsTsv, "w") as fh:
            parser.writeAttrsTsv(fh)
    except GxfValidateException as ex:
        if opts.debug:
            raise
        print(str(ex), file=sys.stderr)
        exit(1)


gencodeGxfToAttrs(parseArgs())
