#!/usr/bin/env python2

import logging, sys, optparse, itertools, json, codecs
from collections import defaultdict, namedtuple
from os.path import join, basename, dirname, isfile
#import xml.etree.ElementTree as et
import xml.etree.cElementTree as et

HGNCFILE="/hive/data/outside/hgnc/current/hgnc_complete_set.txt"

# output file headers
headers = "eventId,causeType,causeName,causeGenes,themeType,themeName,themeGenes,relType,relSubtype,sourceDb,sourceId,sourceDesc,pmids,evidence".split(",")

# ID of event in output file, goes across input files, so global
eventId = 0

# === COMMAND LINE INTERFACE, OPTIONS AND HELP ===
parser = optparse.OptionParser("""usage: %prog [options] filename - convert .BEL files to tab-sep pathway format
""")

parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages")
#parser.add_option("-k", "--keggDir", dest="keggDir", action="store", help="the KEGG ftp mirror directory on disk, default %default", default="/hive/data/outside/kegg/06032011/ftp.genome.jp/pub/kegg") 
parser.add_option("-s", "--hgncFname", dest="hgncFname", action="store", help="the HGNC tab file on disk, default %default", default=HGNCFILE)
#parser.add_option("-u", "--uniprotFname", dest="uniprotFname", action="store", help="the uniprot file from the pubs parser, default %default", default="/hive/data/inside/pubs/parsedDbs/uniprot.9606.tab")
#parser.add_option("-f", "--file", dest="file", action="store", help="run on file") 
#parser.add_option("", "--test", dest="test", action="store_true", help="do something") 
(options, args) = parser.parse_args()

if options.debug:
    logging.basicConfig(level=logging.DEBUG)
else:
    logging.basicConfig(level=logging.INFO)
# ==== FUNCTIONs =====
def lineFileNext(fh):
    """ 
        parses tab-sep file with headers as field names 
        yields collection.namedtuples
        strips "#"-prefix from header line
    """
    line1 = fh.readline()
    line1 = line1.strip("\n").strip("#")
    headers = line1.split("\t")
    headers = [h.replace(" ", "_") for h in headers]
    headers = [h.replace("(", "") for h in headers]
    headers = [h.replace(")", "") for h in headers]
    headers = [h.replace(".", "") for h in headers]
    headers = [h.replace("-", "_") for h in headers]
    Record = namedtuple('tsvRec', headers)

    for line in fh:
        line = line.rstrip("\n")
        fields = line.split("\t")
        try:
            rec = Record(*fields)
        except Exception, msg:
            logging.error("Exception occured while parsing line, %s" % msg)
            logging.error("Filename %s" % fh.name)
            logging.error("Line was: %s" % repr(line))
            logging.error("Does number of fields match headers?")
            logging.error("Headers are: %s" % headers)
            #raise Exception("wrong field count in line %s" % line)
            continue
        # convert fields to correct data type
        yield rec

def pipeSplitAddAll(string, dict, key, toLower=False):
    " split on pipe and add all values to dict with key "
    for val in string.split("|"):
        if toLower:
            val = val.lower()
        dict[val]=key

def parseHgnc(fname, addEntrez=False):
    " return two dicts: uniprot -> symbol, set of validSyms and alias -> symbol from the HGNC tab-sep file "
    upToSym = {}
    skipSyms = set()
    aliasToSym = defaultdict(set)
    clashList = []
    validSyms = set()
    for row in lineFileNext(open(fname)):
        sym = row.symbol
        if "withdrawn" in sym:
            continue
        validSyms.add(sym)

        aliasList = [sym]
        #aliasList.append(row.Approved_Name)
        #aliasList.extend(splitQuote(row.Previous_Names))
        #aliasList.extend(splitQuote(row.Name_Synonyms))
        aliasList.extend(splitQuote(row.prev_symbol, isSym=True))
        aliasList.extend(splitQuote(row.alias_symbol, isSym=True))


        for n in aliasList:
            if n in aliasToSym:
                oldSym = aliasToSym[n]
                if oldSym!=n:
                    clashList.append("%s->%s/%s" % (n, aliasToSym[n], sym))
                #print "clash: %s with %s" % (n, aliasToSym[n])
            else:
                aliasToSym[n].add(sym)

        upAcc = row.uniprot_ids
        if upAcc=="" or upAcc=="-":
            continue
        if upAcc in upToSym:
            #logging.debug("uniprot accession %s assigned to %s, but already assigned to symbol %s" % (upAcc, sym, upToSym[upAcc]))
            skipSyms.add(sym)
            continue
        upToSym[upAcc] = sym
    logging.info("Skipped these symbols due to duplicate uniprot IDs: %s" % ",".join(skipSyms))
    logging.info("%d symbol clashes: %s" % (len(clashList),  clashList))
    return validSyms, upToSym, aliasToSym

def flattenCsv(member, validSyms, aliasToSym):
    " convert pybel csv member to type and list of genes, as in our format "
    mType = member[0]

    if member[1]=="HGNC":
        #sDb = member[1]
        #if sDb!="HGNC":
            #return None
        sym = member[2]
        if sym in validSyms:
            newSyms = [sym]
        else:
            newSyms = aliasToSym[sym]

        label = sym
        # only output a label is we did make a change to the gene names
        newSymStr = "|".join(newSyms)
        if label==newSymStr:
            label = ""
        return "gene", label, newSymStr

    if mType=="Complex":
        parts = member[1:]
        partSyms = []
        origNames = []
        for p in parts:
            sDb = p[1]
            sym = p[2]
            origNames.append(sym)
            if sDb!="HGNC":
                continue

            if sym in validSyms:
                partSyms.append(sym)
            else:
                partSyms.extend(aliasToSym[sym])

        # only show a label if we changed something
        label = "/".join(origNames)
        geneStr = "|".join(partSyms)
        if set(origNames)==set(partSyms):
            label = ""
        return "complex", label, geneStr

    if mType=="Abundance":
        # ('Abundance', 'CHEBI', 'metformin')
        name = member[2]
        return "compound", name, ""

def parseBelCsv(fname, validSyms, aliasToSym):
    " parse BEL csv format and return as rows "
    skipCount = 0
    doneRows = set()
    #('Protein', 'HGNC', 'HOXD3')	('RNA', 'HGNC', 'TSPAN7')	{'subject_effect_namespace': 'bel', 'citation_type': 'PubMed', 'citation_name': 'Oncogene 2002 Jan 24 21(5) 798-808', 'evidence': 'Table 1 - Gene regulated by overexpressing HOXD3 in A549 lung cancer cell lines by more than 1.6 fold.', 'subject_modifier': 'Activity', 'relation': 'increases', 'line': 62002, 'subject_effect_name': 'tscript', 'citation_reference': '11850808'}
    #('Abundance', 'CHEBI', 'palmitic acid') ('Complex', ('Protein', 'MGI', 'Myd88'), ('Protein', 'MGI', 'Tlr2')) {'citation_type': 'PubMed', 'citation_name': 'J Biol Chem 2006 Sep 15 281(37) 26865-75', 'evidence': 'Treatment with palmitate rapidly induced the association of myeloid differentiation factor 88 (MyD88) with the TLR2 receptor, activated the stress-linked kinases p38, JNK, and protein kinase C, induced degradation of IkappaBalpha, and increased NF-kappaB DNA binding...TLR2 mediates the initial events of fatty acid-induced insulin resistance in muscle', 'relation': 'increases', 'line': 205920, 'citation_reference': '16798732'}

    i = 0
    for line in codecs.open(fname, "r", encoding="utf8"):
        #print line, len(line)
        line = line.rstrip("\n").rstrip(" ")
        fs = line.split("\t")
        fs = [x.encode("latin1") for x in fs]

        # some lines have invalid syntac, wrong quote characters 
        # opened github issue, but skipping for now, only 18 lines
        try:
            source = eval(fs[0])
        except:
            skipCount += 1
            continue
        sType = source[0]

        target = eval(fs[1])
        tType = target[0]

        try:
            attrs = eval(fs[2])
        except:
            skipCount += 1

        #print sType, tType

        # useless relationship
        if attrs["relation"]=="hasComponent":
            continue

        g1 = flattenCsv(source, validSyms, aliasToSym)
        g2 = flattenCsv(target, validSyms, aliasToSym)
        if g1 is None or g2 is None:
            continue
        g1Type, g1Name, g1Genes = g1
        g2Type, g2Name, g2Genes = g2
        pmid = attrs.get("citation_reference", "")
        evid = attrs.get("evidence", "")
        if len(evid)>500:
            evid = evid[:500]+"..."

        relType = attrs.get("relation", "")

        # conv modifiers to a single string
        parts = []
        # our format is theme <- cause but every other database 
        # is cause -> theme. Too late to change our format now.
        mod1 = attrs.get("object_modifier")
        mod2 = attrs.get("subject_modifier")
        if mod1!=None:
            g1Label = g1Genes.replace("|", "/") # '/' is a bit more readable
            parts.append("%s %s" % (g1Label, mod1))
        if mod2!=None:
            g2Label = g2Genes.replace("|", "/")
            parts.append("%s %s" % (g2Label, mod2))
        relSubtype = ", ".join(parts)

        # skip duplicates
        mainData = (g2Type, g2Name, g2Genes, g1Type, g1Name, g1Genes, relType, relSubtype, pmid, evid)
        if mainData in doneRows:
            continue
        doneRows.add(mainData)

        # fields are:
        #"eventId,causeType,causeName,causeGenes,themeType,themeName,themeGenes,relType,relSubtype,sourceDb,sourceId,sourceDesc,pmids,evidence
        # openBel is source -> target
        # but our format is target <- source, because I copied Hoifung's format

        row = [
                "belLarge"+str(i),g2Type, g2Name, g2Genes, g1Type, g1Name, g1Genes,
                relType, relSubtype,
                "belLarge","9ea3c170-01ad-11e5-ac0f-000c29cb28fb", "Selventa BEL large corpus",
                pmid, evid
              ]
        yield row
        i += 1

    logging.info("Could not parse %d lines" % skipCount)

def splitQuote(name, isSym=False):
    """ try to split quoted names on , """
    if '"' in name:
        # first split quoted names
        names = name.split('", "')
    else:
        # if there are no quotes, also try to split on just commas
        names = name.split(",")
    names = [n.strip('"') for n in names]
    names = [n.strip() for n in names]
    names = [n for n in names if n!=""]
    #names = [unidecode.unidecode(n) for n in names]
    # make sure there are no commas left, if symbol
    if isSym:
        for n in names:
            assert("," not in n)
    return names

# ----------- MAIN --------------
if args==[]:
    parser.print_help()
    exit(1)

filename = args[0]

#uniprotTable = "/hive/data/inside/pubs/parsedDbs/uniprot.9606.tab"
validSyms, upToSym, aliasToSym = parseHgnc(options.hgncFname)
#accToSym = parseUniprot(options.uniprotFname, accToSym)

print "#"+"\t".join(headers)

logging.debug(filename)
rows = parseBelCsv(filename, validSyms, aliasToSym)
for row in rows:
    l = "\t".join(row)
    print l
