#!/usr/bin/env python2.7

import logging, sys, optparse, glob, itertools, os, tempfile, gzip, operator, re, cPickle, gc
import marshal
from collections import defaultdict, namedtuple, Counter
from itertools import chain, combinations
from os.path import join, basename, dirname, isfile, expanduser
import ujson

# TODO: change indexPairs so it corrects synonyms to the current gene symbol

# saves 20% of time when loading graph marshal
import gc
gc.disable()

#LTPMAXMEMBERS=5 # maximum number of proteins in a complex for an interaction to quality for low throughput
LTPMAX=5 # maximum number of interactions a PMID can have to be declared low throughput

# don't even write out links with less than this number of documents. 2 = weeds out many false positives.
# not using right now, because a few pathway databases do not annotate ANY PMID. Increasing this filter would remove
# all interactions from these pathways databases.
MINSUPP=0

# a cutoff on the number of documents required for text mined documents to show up in the UI
UI_MINSUPP=2

outFields = ["gene1", "gene2", "flags", "refs", "fwWeight", "revWeight", "snip"]

# directory with autoSql descriptions of output tables
autoSqlDir = expanduser("~/kent/src/hg/lib/")

# file with all of medline in short form
allArtFname = "textInfo.tab"
# file with just pmids and events
#pmidEventFname = "temp.pmidToEvent.tab"

# RE to split sentences
wordRe = re.compile("[a-zA-Z0-9]+")

# === COMMAND LINE INTERFACE, OPTIONS AND HELP ===
parser = optparse.OptionParser("""usage: %prog [options] build|load pathwayDir ppiDir textDir outDir - given various tab sep files with text-mining, gene interaction or pathway information, build the  table ggLink, ggDoc, ggDb and ggText

run it like this:

Reduce the big medline table to something smaller, only needed once:
    %prog medline

Slowest part: build the big table of interactions mysql/ggLink.tmp.tab
    %prog build pathways ppi text mysql

Create mysql/ggDocs.tab, very slow
    %prog docs mysql

Add the "context" (aka mesh terms) to mysql/ggLink.tmp.tab 
and write to mysql/ggLink.tab file
    %prog context mysql  
format is:
gene1, gene2, flags, forwDocCount, revDocCount, allDocCount, databaseList, minimalPairCountPaper, snippet

Load all tables in mysql/ into MySql:
    %prog load mysql hgFixed

Create the bigBed File
    %prog bigBed outDir bigBedFile db
""")

parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages") 
#parser.add_option("-t", "--test", dest="test", action="store_true", help="run tests") 
parser.add_option("-t", "--textDir", dest="textDir", action="store", help="directory with the parsed copy of medline, default %default", default="/hive/data/inside/pubs/text/medline")
parser.add_option("-m", "--meshFile", dest="meshFname", action="store", help="An mtrees<year>.bin file, default %default", default="/hive/data/outside/ncbi/mesh/mtrees2015.bin")
parser.add_option("-j", "--journalInfo", dest="journalInfo", action="store", help="tab-sep file with journal info from the NLM Catalog converted by 'pubPrepCrawl publishers'. Used to shorten the journal names. Optional and not used if file is not found. Default %default", default="/cluster/home/max/projects/pubs/tools/data/journals/journals.tab")
parser.add_option("-b", "--wordFname", dest="wordFname", action="store", help="a file with common English words", default="/hive/data/outside/pubs/wordFrequency/bnc/bnc.txt")
#parser.add_option("-f", "--file", dest="file", action="store", help="run on file") 
#parser.add_option("", "--test", dest="test", action="store_true", help="do something") 
(options, args) = parser.parse_args()

if options.debug:
    logging.basicConfig(level=logging.DEBUG)
else:
    logging.basicConfig(level=logging.INFO)
# ==== FUNCTIONs =====
    
def parseAutoSql(asFname):
    " parse auto sql file and return list of field names "
    headers = []
    for line in open(asFname):
        line = line.replace("(","").replace(")","").strip()
        if line.startswith('"') or line.startswith("\n"):
            continue
        if line.startswith("table"):
            continue
        if not ";" in line:
            continue
        parts = line.split(";")[0]
        parts = parts.split()
        assert(len(parts)>=2)
        headers.append(parts[1])
    return headers

def lineFileNext(fh, headers=None, asFname=None):
    """ parses tab-sep file with headers as field names , assumes that file starts with headers
        yields collection.namedtuples
    """
    line1 = fh.readline()
    line1 = line1.strip("\n").strip("#")
    if headers==None and asFname!=None:
        headers = parseAutoSql(asFname)
    elif headers==None:
        headers = line1.split("\t")
        headers = [h.replace(" ", "_") for h in headers]
        headers = [h.replace("(", "") for h in headers]
        headers = [h.replace(")", "") for h in headers]
    Record = namedtuple('tsvRec', headers)

    for line in fh:
        line = line.rstrip("\n")
        fields = line.split("\t")
        try:
            rec = Record(*fields)
        except Exception, msg:
            logging.error("Exception occured while parsing line, %s" % msg)
            logging.error("Filename %s" % fh.name)
            logging.error("Line was: %s" % repr(line))
            logging.error("Does number of fields match headers?")
            logging.error("Headers are: %s" % headers)
            #raise Exception("wrong field count in line %s" % line)
            continue
        # convert fields to correct data type
        yield rec

def loadFiles(inDir, prefix=None):
    """ load .tab files into dict fType -> list of rows and return tuple (ppiRows, textRows)
    """
    #typeRows = defaultdict(list)
    #jppiRows = list()
    #jtextRows = list()
    pairs = defaultdict(list)
    inFnames = glob.glob(join(inDir, "*.tab"))
    rows = list()
    for inFname in inFnames:
        logging.info("Loading %s" % inFname)
        #fType = None
        for row in lineFileNext(open(inFname)):
            if prefix!=None:
                row = row._replace(eventId=prefix+row.eventId)
            rows.append(row)
    return rows

def getResultCounts(pairs):
    """
    Input is a pair -> rows dictionary
    for each PMID, count how many pairs are assigned to it. This is
    something like the "resultCount" of a paper, the lower, the better.
    Then, for each pair, get the minimum resultCount and return as a dict
    pair -> resultCount
    """
    logging.info("Getting low-throughput studies in PPI and pathway data")

    # create dict doc -> set of gene pairs
    docToPairs = defaultdict(set)
    for pair, rows in pairs.iteritems():
        #print "pp", pair
        for row in rows:
            #members = row.themeGenes.split("|")
            #members.extend(row.causeGenes.split("|"))
            # complexes with more than 5 proteins are not low-throughput anyways
            # skip these right away
            #if len(members)>LTPMAXMEMBERS:
                #continue
            docIds = row.pmids.split("|")
            for docId in docIds:
                if docId=="":
                    continue
                docToPairs[docId].add(pair)

    #print "d2p", docToPairs

    pairMinResultCounts = {}
    for pair, rows in pairs.iteritems():
        #print "pp", pair
        resultCounts = []
        # get all docIds in rows
        docIds = []
        for row in rows:
            docIds.extend(row.pmids.split("|"))

        # get the minimal resultCount of all docIds 
        for docId in set(docIds):
            if docId=="":
                continue
            #print "di2", docId
            #print "d2p", docToPairs[docId]
            resCount = len(docToPairs[docId])
            #print "rc", resCount
            resultCounts.append(resCount)
        if len(resultCounts)!=0:
            minResCount = min(resultCounts)
        else:
            minResCount = 0
        #print "min", minResCount
        pairMinResultCounts[pair] = minResCount
        
    #ltPairs = set()
    #ltDocs = []
    #for pmid, pairList in docToPairs.iteritems():
        #if len(pairList) <= LTPMAX:
            #ltPairs.update(pairList)
            #ltDocs.append(pmid)

    #logging.info("Found %d low-throughput studies out of %d" % (len(ltDocs), len(pmidToPairs)))
    #logging.info("Found %d low-throughput interactions out of %d" % (len(ltPairs), len(pairs)))
    #return ltPairs, ltDocs
    return pairMinResultCounts, docToPairs
            
def allSubPairs(pair):
    """ given a pair of two strings, where each can be a _-separate list of genes (a family), 
    return all combinations of each member 
    """
    x, y = pair
    xs = x.split("_")
    ys = y.split("_")
    for subPair in [(a, b) for a in xs for b in ys]:
        a, b = subPair
        if a=="" or b=="" or a=="-" or b=="-" or \
                a.startswith("compound") or b.startswith("compound"):
            continue
        yield subPair
    
def iterAllPairs(row):
    """ yield all pairs of interacting genes for a given row. Handles families 
    >>> list(iterAllPairs("gene", ["TP1","TP2"], "complex", ["OMG1","OMG2"]))
    [('OMG1', 'OMG2'), ('TP1', 'OMG1'), ('TP1', 'OMG2'), ('TP2', 'OMG1'), ('TP2', 'OMG2')]
    >>> list(iterAllPairs("complex", ["TP1_TEST2","TP2"], "complex", ["OMG1","OMG2"]))
    """

    type1 = row.causeType
    type2 = row.themeType
    genes1 = set(row.causeGenes.split("|"))
    genes2 = set(row.themeGenes.split("|"))
    # all genes of complexes interact in some way
    if type1=="complex":
        for pair in itertools.combinations(genes1, 2):
            # a complex can contain families
            for subPair in allSubPairs(pair):
                yield tuple(subPair)
    if type2=="complex":
        for pair in itertools.combinations(genes2, 2):
            # a complex can contain families
            for subPair in allSubPairs(pair):
                yield tuple(subPair)

    # all genes from the left and the right side interact
    if type2!="":
        pairs = list([(aa, bb) for aa in genes1 for bb in genes2])
        for pair in pairs:
            for subPair in allSubPairs(pair):
                gene1, gene2 = subPair
                if gene1=="-" or gene1=="" or gene2=="" or gene2=="-":
                    #skipCount += 1
                    continue
                if gene1.startswith("compound") or gene2.startswith("compound"):
                    continue
                yield subPair
    
def indexPairs(ppiRows, desc):
    """ given rows with theme and cause genes, return 
    a dict with sorted (gene1, gene2) -> list of eventIds """
    logging.info("enumerating all interacting pairs: %s" % desc)
    pairs = defaultdict(list)
    for row in ppiRows:
        for pair in iterAllPairs(row):
            gene1, gene2 = pair
            if gene1.startswith("compound") or gene2.startswith("compound"):
                continue
            pairs[tuple(sorted(pair))].append(row)
    logging.info("got %d pairs" % len(pairs))
    return pairs

def mergePairs(dicts):
    " merge a list of defaultdict(list) into one defaultdict(set) "
    logging.info("Merging pairs")
    data = defaultdict(set)
    for defDict in dicts:
        for key, valList in defDict.iteritems():
            data[key].update(valList)
    return data

def directedPairToDocs(rows):
    """ get documents of text mining pairs. a dict with pair -> text rows
    create a dict pair -> set of document IDs. Note that these pairs are
    DIRECTED - so they can be used to infer the direction of the interaction
    """
    # create a dict with pair -> pmids
    pairPmids = defaultdict(set)
    for row in rows:
        genes1 = set(row.causeGenes.split("|"))
        genes2 = set(row.themeGenes.split("|"))
        pairs = list([(aa, bb) for aa in genes1 for bb in genes2])
        for cause, theme in pairs:
            pairPmids[(cause, theme)].add(row.pmid)
    
    return pairPmids

def writeGraphTable(allPairs, pairDocs, pairToDbs, pairMinResCounts, pwDirPairs, bestSentences, outFname, outFname2):
    " write the ggLink table "
    logging.info("writing merged graph to %s" % outFname)
    rows = []
    rows2 = []
    allSyms = set()
    for pair,pairRows in allPairs.iteritems():
        gene1, gene2 = pair

        dbs = set()
        flags = []
        if pair in dbPairs:
            flags.append("ppi")
        if pair in pwPairs:
            flags.append("pwy")
        if pair in textPairs:
            flags.append("text")
        refs = [row.eventId for row in pairRows]
        #if pairMinResultCounts:
            #flags.append("low")
        # direction of interaction - only based on pathways
        if pair in pwDirPairs:
            flags.append("fwd")
        if tuple(reversed(pair)) in pwDirPairs:
            flags.append("rev")
            
        forwDocs = pairDocs.get(pair, [])
        revDocs = pairDocs.get(tuple(reversed(pair)), [])
        allDocs = set(forwDocs).union(set(revDocs))

        if len(allDocs)<MINSUPP and "pwy" not in flags and "ppi" not in flags:
            # if it's text-mining only and less than X documents, just skip it
            continue
        pairMinResCount = pairMinResCounts.get(pair, 0)

        pairDbs = "|".join(pairToDbs.get(pair, []))
        snippet = bestSentences.get(pair, "")
        row = [gene1, gene2, ",".join(flags), str(len(forwDocs)), str(len(revDocs)), \
            str(len(allDocs)), pairDbs, str(pairMinResCount), snippet]
        rows.append(row)

        allSyms.add(gene1)
        allSyms.add(gene2)

        refs = list(refs)
        refs.sort()
        for ref in refs:
            #row2 = [gene1, gene2, ",".join(refs)]
            row = [gene1, gene2, ref]
            rows2.append(row)

    ofh = open(outFname, "w")
    rows.sort()
    for row in rows:
        ofh.write("\t".join(row))
        ofh.write("\n")
    ofh.close()

    ofh2 = open(outFname2, "w")
    rows2.sort()
    for row in rows2:
        ofh2.write("\t".join(row))
        ofh2.write("\n")
    ofh2.close()

    return allSyms

def runCmd(cmd):
    """ run command in shell, exit if not successful """
    msg = "Running shell command: %s" % cmd
    logging.debug(msg)
    ret = os.system(cmd)
    if ret!=0:
        raise Exception("Could not run command (Exitcode %d): %s" % (ret, cmd))
    return ret

def asToSql(table, sqlDir):
    " given a table name, return the name of a .sql file with CREATE TABLE for it"

    asPath = join(autoSqlDir, table+".as")
    #tempBase = tempfile.mktemp()
    outBase = join(sqlDir, table)
    cmd = "autoSql %s %s" % (asPath, outBase)
    runCmd(cmd)
    #sql = open("%s.sql" % sqlFname).read()

    # delete the files that are not needed
    #assert(len(tempBase)>5) # paranoia check
    #cmd = "rm -f %s.h %s.c"  % (tempBase, tempBase)
    #runCmd(cmd)

    return outBase+".sql"

def loadTable(db, tableDir, table):
    " load table into mysql, using autoSql "
    #sqlFname = join(tableDir, table+".sql")
    tmpSqlFname = asToSql(table, tableDir)
    tabFname = join(tableDir, table+".tab")

    cmd = "hgLoadSqlTab %s %s %s %s" % (db, table, tmpSqlFname, tabFname)
    try:
        runCmd(cmd)
    except:
        # make sure that the temp file gets deleted
        os.remove(tmpSqlFname)
        raise

    os.remove(tmpSqlFname)

def hgsql(db, query):
    assert('"' not in query)
    cmd = "hgsql %s -NBe '%s'" % (db, query)

def addIndexes(db):
    " add the indexes for mysql "
    query = "ALTER TABLE ggLinkEvent ADD INDEX gene12Idx (gene1, gene2);"
    hgsql(db, query)

    query = "ALTER TABLE ggEventText ADD INDEX docIdIdx (docId);"
    hgsql(db, query)

    query = "alter table ggDocEvent add index eventIdIdx (eventId);"
    hgsql(db, query)

def loadTables(tableDir, db):
    " load graph tables into mysql "

    loadTable(db, tableDir, "ggSymbol")
    loadTable(db, tableDir, "ggDoc")
    loadTable(db, tableDir, "ggDocEvent")
    loadTable(db, tableDir, "ggEventDb")
    loadTable(db, tableDir, "ggEventText")
    loadTable(db, tableDir, "ggLink")
    loadTable(db, tableDir, "ggLinkEvent")

    addIndexes(db)

def indexPmids(rowList, textRows):
    " return dict pmid -> list of event Ids "
    pmidToIds = defaultdict(set) 
    for rows in rowList:
        for row in rows:
            pmidStr = row.pmids
            if pmidStr=="":
                continue
            pmids = pmidStr.split("|")
            rowId = row.eventId
            for pmid in pmids:
                if pmid=="":
                    continue
                pmidToIds[pmid].add(rowId)

    for row in textRows:
        if row.pmid=="":
            continue
        pmidToIds[row.pmid].add(row.eventId)

    return pmidToIds

def writeDocEvents(pmidToId, outFname):
    " write a table with PMID -> list of event Ids "
    logging.info("Writing docId-eventId %s" % outFname)
    ofh = open(outFname, "w")
    for docId, eventIds in pmidToId.iteritems():
        eventIds = sorted(list(eventIds))
        for eventId in eventIds:
            ofh.write("%s\t%s\n" % (docId, eventId))
    ofh.close()

def writeEventTable(rowList, outFname, colCount=None):
    " write the event table with event details "
    logging.info("Writing events to %s" % outFname)
    ofh = open(outFname, "w")
    for rows in rowList:
        for row in rows:
            if colCount:
                if len(row)+1 == colCount:
                    row = list(row)
                    row.append("")
                #row = row[:colCount]
                assert(len(row)==colCount)
            ofh.write("%s\n" % ("\t".join(row)))
    ofh.close()

def pairToDbs(pairs):
    """ given pairs and data rows, return a dict pair -> int 
    that indicates how many DBs a pair is referenced in 
    """
    # first make dict event -> source dbs
    #eventDbs = defaultdict(set)
    #for row in pwRows:
        #eventDbs[row.eventId].add(row.sourceDb)
    #for row in dbRows:
        #sourceDbs = row.sourceDbs.split("|")
        #eventDbs[row.eventId].update(sourceDbs)

    # construct a dict pair -> source dbs
    pairDbs = defaultdict(set)
    for pair, rows in pairs.iteritems():
        for row in rows:
            pairDbs[pair].add(row.sourceDb)

    return pairDbs

def parseMeshContext(fname):
    " given a medline trees file, return the list of disease and pathway names in it "
    # ex. filename is mtrees2013.bin (it's ascii)
    # WAGR Syndrome;C10.597.606.643.969
    terms = set()
    lines = open(fname)
    for line in lines:
        line = line.strip()
        term, code = line.split(";")
        term = term.strip()
        # all disease terms start with C
        if code.startswith("C"):
            terms.add(term)
        # all signalling pathways a below a very specific branch
        elif code.startswith("G02.149.115.800") and not code=="G02.149.115.800":
            terms.add(term)
    logging.info("Read %d disease/context MESH terms from %s" % (len(terms), fname))
    return terms


def getDirectedPairs(pwRows):
    " get the set of directed gene pairs from the rows, keep the direction "
    pairs = set()
    for row in pwRows:
        for pair in iterAllPairs(row):
            pairs.add(pair)
    return pairs
            
def writeAllDocInfo(textDir, outFname):
    " get all author/year/journal/title as tab-sep from a pubtools-style input directory, ~5GB big "
    mask = join(textDir, "*.articles.gz")

    ofh = open(outFname, "w")
    fnames = glob.glob(mask)
    doneDocs = set()
    for i, fname in enumerate(fnames):
        if i % 10 == 0:
            logging.info("%d out of %d files" % (i, len(fnames)))

        for row in lineFileNext(gzip.open(fname)):
            # skip duplicates
            if row.pmid in doneDocs:
                continue
            doneDocs.add(row.pmid)

            if row.year.isdigit() and int(row.year)>1975:
                newRow = (row.pmid, row.authors, row.year, row.journal, row.printIssn, \
                        row.title, row.abstract, row.keywords)
                ofh.write("\t".join(newRow))
                ofh.write("\n")
    ofh.close()
    logging.info("Article info written to %s" % outFname)

def parseShortNames(journalFname):
    # get dict ISSN -> short name
    shortNames = {}
    if isfile(journalFname):
        for row in lineFileNext(open(journalFname)):
            if row.medlineTA!="" and row.pIssn!="":
                shortNames[row.pIssn] = row.medlineTA
        logging.info("Read a short journal name for %d ISSNs from  %s" % (len(shortNames), journalFname))
    else:
        logging.info("%s not found, not shortening journal names" % journalFname)
    return shortNames

def writeDocsTable(pmidEventPath, medlinePath, shortNames, contextFilter, resCounts, outFname):
    """ join pmid-Event info and our shortened medline version 
    resCount is a set of docIds with low-throughput data (fewer than LTPMAX interactions per doc)
    """
    # parse the PMIDs to export
    docIds = set()
    for row in lineFileNext(open(pmidEventPath), headers=["docId", "eventId"]):
        docIds.add(row.docId)
    logging.info("read %d document IDs from %s" %  (len(docIds), pmidEventPath))

    docContexts = {}

    logging.info("Writing to %s" % outFname)
    ofh = open(outFname, "w")
    # fields are: docId, authors, year, journal, printIssn, title, abstract, keywords
    foundIds = set()
    for line in open(medlinePath):
        fields = line.rstrip("\n").split("\t")
        docId = fields[0]
        if docId in docIds:
            issn = fields[4]
            shortName = shortNames.get(issn)
            if shortName!=None:
                fields[3] = shortName

            newKeywords = []
            for kw in fields[7].split("/"):
                if kw in contextFilter:
                    newKeywords.append(kw)
            docContext = "|".join(newKeywords)
            fields[7] = docContext
            if docContext!="":
                docContexts[docId] = docContext

            # add a field: how many gene-pairs are associated to this paper
            fields.append(resCounts.get(docId, "0"))

            line = "\t".join(fields)+"\n"

            ofh.write(line)
            foundIds.add(docId)
    ofh.close()

    notFoundIds = docIds - foundIds
    logging.info("No info for %d documents" % len(notFoundIds))
    logging.debug("No info for these documents: %s" % notFoundIds)

    return docContexts

def sumBasic(sentences, commonWords):
    """ given probabilities of words, rank sentences by average prob
    (removing commonWords).
    Sentences is a list of list of words
    Algorithm is described in http://ijcai.org/papers07/Papers/IJCAI07-287.pdf

    Returns sentence with highest score and shortest length, if several have a highest score
    """

    if len(sentences)==0:
        return ""
    sentWordsList = [set(wordRe.findall(sentence)) for sentence in sentences]
    words = list(chain.from_iterable(sentWordsList))
    wordProbs = {word: float(count)/len(words) for word, count in Counter(words).items()}

    scoredSentences = []
    for sentWords, sentence in zip(sentWordsList, sentences):
        mainWords = sentWords - commonWords
        if len(mainWords)==0:
            continue
        avgProb = sum([wordProbs[word] for word in mainWords]) / len(mainWords)
        scoredSentences.append((avgProb, sentence, sentWords))

    # happens rarely: all words are common English words
    if len(scoredSentences)==0:
        return ""

    # get sentences with equally good top score
    scoredSentences.sort(key=operator.itemgetter(0), reverse=True)
    topScore = scoredSentences[0][0]
    topSents = [(sent, words) for score, sent, words in scoredSentences if score >= topScore]

    # sort these by length and pick shortest one
    topSentLens = [(len(s), s, w) for s, w in topSents]
    topSentLens.sort(key=operator.itemgetter(0))
    topLen, topSent, topWords = topSentLens[0]

    # update word frequencies
    for word in topWords:
        wordProbs[word] *= wordProbs[word]

    return topSent

def runSumBasic(textPairs, wordFname):
    """ Get all sentences for an interaction and use sumBasic to pick the best one 
    """
    # get list of very common English words
    bncWords = set([line.split()[0] for line in open(wordFname).read().splitlines()])
    logging.info("Loaded %d common English words from %s" % (len(bncWords), wordFname))

    logging.info("Running SumBasic on sentences")
    bestSentences = {}
    for pair, rows in textPairs.iteritems():
        sentSet = set()
        for row in rows:
            if row.sentence!="":
                sentSet.add(row.sentence)

        sentences = list(sentSet)
        bestSentences[pair] = sumBasic(sentences, bncWords)
    return bestSentences

def readDictList(fname, reverse=False):
    " read a key-value tab-sep table and return as a dict of key -> values"
    logging.info("reading %s" % fname)
    data = defaultdict(list)
    for line in open(fname):
        key, val = line.rstrip("\n").split("\t")
        if reverse:
            key, val = val, key
        data[key].append(val)
    return data

def readDict(fname, reverse=False):
    " read a key-value tab-sep table and return as a dict of key -> value"
    logging.info("reading %s" % fname)
    data = defaultdict(list)
    for line in open(fname):
        key, val = line.rstrip("\n").split("\t")
        if reverse:
            key, val = val, key
        data[key] = val
    return data

def readPairEvent(fname):
    """ read a tab-sep table in format (gene1, gene2, eventId) and return as
    dict (gene1,gene2) -> list of eventId""" 
    logging.info("reading %s" % fname)
    data = defaultdict(list)
    for line in open(fname):
        gene1, gene2, val = line.rstrip("\n").split("\t")
        data[(gene1, gene2)].append(val)
    return data

def addContext(ctFname, docEventFname, linkEventFname, linkFname):
    " read the data from the first three files, and put it into the last field of linkFname "
    docContext = readDict(ctFname)
    eventDocs = readDictList(docEventFname, reverse=True)
    pairEvents = readPairEvent(linkEventFname)
    
    logging.info("Reading %s" % linkFname)
    newLines = []
    for line in open(linkFname):
        #print line
        fields = line.rstrip("\n").split("\t")
        pair = (fields[0], fields[1])

        contextCounts = Counter()
        for eventId in pairEvents[pair]:
            #print "pair %s, event %s" % (pair, eventId)
            for docId in eventDocs.get(eventId, []):
                #print "doc %s" % docId
                contexts = docContext.get(docId, "")
                for context in contexts.split("|"):
                    #print "context %s" % context
                    if context=="" or context==" ":
                        continue
                    contextCounts[context]+=1
        # take best three contexts and reformat as a string
        suffix = ""
        if len(contextCounts)>3:
            suffix = "..."
        bestContexts = contextCounts.most_common()
        contextStrings = ["%s (%d)" % (ct, count) for ct, count in bestContexts]
        contextStr = ", ".join(contextStrings)
        fields.append(contextStr)
        newLines.append("\t".join(fields))
    return newLines

def convGraph(outDir):
    import networkx as nx
    fname = join(outDir, "ggLink.tab")
    G=nx.Graph()
    for line in open(fname):
        g1, g2 = line.split()[:2]
        G.add_edge(g1, g2)
    outFname = join(outDir, "graph.bin")
    #nx.write_adjlist(G, outFname)
    cPickle.dump(G, open(outFname, "w"), cPickle.HIGHEST_PROTOCOL)
    logging.info("Wrote graph to %s" % outFname)

def convGraph2(outDir):
    import igraph as ig
    fname = join(outDir, "ggLink.tab")
    G=ig.Graph()
    geneToId = {}
    nextId = 0
    edges = []
    for line in open(fname):
        g1, g2 = line.split()[:2]

        if g1 not in geneToId:
            id1 = geneToId[g1] = nextId
            nextId += 1
        else:
            id1 = geneToId[g1]

        if g2 not in geneToId:
            id2 = geneToId[g2] = nextId
            nextId += 1
        else:
            id2 = geneToId[g2]

        edges.append( (id1, id2) )
    #print nextId
    #print edges
    G.add_vertices(nextId+1)
    G.add_edges(edges)
    #outFname = join(outDir, "graph2.bin")
    #nx.write_adjlist(G, outFname)
    #cPickle.dump(G, open(outFname, "w"), cPickle.HIGHEST_PROTOCOL)
    #logging.info("Wrote graph to %s" % outFname)
    outFname = join(outDir, "graph.lgl")
    G.write(outFname, "lgl")
    logging.info("Wrote graph to %s" % outFname)

    outFname = join(outDir, "graph.genes.txt")
    ofh = open(outFname, "w")
    for gene, geneId in geneToId.iteritems():
        ofh.write("%s\t%s\n" % (gene, geneId))
    ofh.close()
    logging.info("Wrote nodeId-symbol mapping to %s" % outFname)
    #outFname = join(outDir, "graph.adj")
    #G.write(outFname, "adjacency")
    #outFname = join(outDir, "graph.leda")
    #G.write(outFname, "leda")
    #outFname = join(outDir, "graph.dot")
    #G.write(outFname, "dot")
    #outFname = join(outDir, "graph.pajek")
    #G.write(outFname, "pajek")
    #logging.info("Wrote graph to %s" % outFname)

def convGraph3(outDir):
    """ convert graph to a compact format gene -> list of connected genes
    """
    fname = join(outDir, "ggLink.tab")
    geneLinks = defaultdict(set)
    #genes = set()
    # format: gene1, gene2, flags, forwDocs, revDocs, allDocs, pairDbs, pairMinResCount, snippet
    for line in open(fname):
        fields = line.split("\t")

        g1, g2, flags = fields[:3]
        #flags = flags.split(",")
        pairMinResCount = int(fields[7])
        docCount = int(fields[5])

        # require many documents or manually curated database
        if docCount < 3 and not "ppi" in flags and not "pwy" in flags:
            continue
        #pairDbs = fields[6]

        # ignore all interactions that are derived from papers 
        # that reported more than 100 interactions, e.g. big complexes
        if pairMinResCount > 100:
            continue

        #geneLinks[g1].add((g2, flags))
        #geneLinks[g2].add((g1, flags))
        #geneLinks[g1].add(g2)
        #geneLinks[g2].add(g1)
        geneLinks[g1].add((g2, docCount))
        geneLinks[g2].add((g1, docCount))

        #genes.add(g1)
        #genes.add(g2)
    geneLinks = dict(geneLinks)

    # rewrite to dict str -> list
    #geneLinks = {k:list(v) for k,v in geneLinks.iteritems()}

    # map gene -> integer
    #geneToId = {}
    #for geneId, gene in enumerate(genes):
        #geneToId[gene] = geneId

    # this is only slightly slower - probably best format
    #outFname = join(outDir, "graph.txt")
    #ofh = open(outFname, "w")
    #for gene, neighbors in geneLinks.iteritems():
        #ofh.write("%s\t%s\n" % (gene, ",".join(neighbors)))
    #ofh.close()
    #logging.info("Wrote links to %s" % outFname)

    # this is not faster
    #intLinks = {}
    #for gene, neighbors in geneLinks.iteritems():
        #intLinks[geneToId[gene]] = [geneToId[n] for n in neighbors]

    # fastest when taking into account set-building time
    outFname = join(outDir, "graph.marshal")
    marshal.dump(geneLinks, open(outFname, "w"), 2) # prot 0, 1 not faster
    logging.info("Wrote links to %s" % outFname)

    #outFname = join(outDir, "graph.ujson")
    #ujson.dump(geneLinks, open(outFname, "w"))
    #logging.info("Wrote json strings links to %s" % outFname)

    #outFname = join(outDir, "graphInt.marshal")
    #idToGene = {v: k for k, v in geneToId.iteritems()}
    #data = (idToGene, intLinks)
    #marshal.dump(data, open(outFname, "w"), 2)
    #logging.info("Wrote integer links to %s" % outFname)

    #outFname = join(outDir, "graphInt.ujson")
    #ujson.dump(data, open(outFname, "w"))
    #logging.info("Wrote json integer links to %s" % outFname)

def loadGraph(outDir):
    import networkx as nx
    inFname = join(outDir, "graph.bin")
    gc.disable() # no need for GC here, saves 2 seconds
    G = cPickle.load(open(inFname))
    gc.enable()
    geneList = ["OTX2", "PITX2", "APOE", "TP53", "TNF", "SP1"]
    foundPairs = set()
    G2 = G.subgraph(geneList)
    print nx.nodes(G)
    #for g1, g2 in combinations(geneList, 2):
        #path = nx.shortest_path(G, g1, g2)
        #print path
        #for i in range(0, len(path)-1):
            #pair = tuple(sorted(path[i:i+2]))
            #foundPairs.add(pair)
    #print foundPairs
        
def parseGeneToId(inFname):
    geneToId = {}
    idToGene = {}
    for l in open(inFname):
        gene, geneId = l.rstrip("\n").split("\t")
        geneId = int(geneId)
        geneToId[gene] = geneId
        idToGene[geneId] = gene
    return geneToId, idToGene

def loadGraph2(outDir):
    import igraph as ig
    inFname = join(outDir, "graph.lgl")
    #gc.disable() # no need for GC here, saves 2 seconds
    #G = cPickle.load(open(inFname))
    #gc.enable()
    #G=ig.Graph()
    G = ig.load(inFname)
    #print G

    inFname = join(outDir, "graph.genes.txt")
    geneToId, idToGene = parseGeneToId(inFname)

    geneList = ["OTX2", "PITX2", "APOE", "TP53", "TNF", "SP1","ABCA1", "CD4", "BRCA2", "APP", "SRY", "GAST", "MYOD1"]
    idList = [geneToId[g] for g in geneList]
    allSyms = set()
    for i in range(0, len(idList)-1):
        fromId = idList[i]
        fromGene = geneList[i]
        toIds = idList[i+1:]
        toGenes = geneList[i+1:]
        print fromId, fromGene, toIds, toGenes
        paths = G.get_shortest_paths(fromId, toIds, mode=ig.ALL) 
        genePaths = []
        for idPath in paths:
            genePath = [idToGene[i] for i in idPath]
            genePaths.append(genePath)
            allSyms.update(genePath)
        print "paths", genePaths
        print ",".join(allSyms)
        #for i in range(0, len(path)-1):
            #pair = tuple(sorted(path[i:i+2]))
            #foundPairs.add(pair)
    #print foundPairs
    #geneList = ["OTX2", "PITX2", "APOE", "TP53", "TNF", "SP1"]
    #G2 = G.subgraph(geneList)

def loadGraph3(outDir):
    inFname = join(outDir, "graph.marshal")
    #idToGene, idLinks = marshal.load(open(inFname))
    #idToGene, idLinks = ujson.load(open(inFname))
    #geneLinks = ujson.load(open(inFname))
    geneLinks = marshal.load(open(inFname))
    return geneLinks

    #idToGene = {int(k):v for k,v in idToGene.iteritems()}

    #geneLinks = {}
    #for geneId, linkedIds in idLinks.iteritems():
        #geneLinks[idToGene[int(geneId)]] = [idToGene[i] for i in linkedIds]

    # reversing the list was 25% slower than reading it all from disk
    # 1.1 seconds, so 20% slower than the marshal version
    #inFname = join(outDir, "graph.links.txt")
    #ofh = open(outFname, "w")
    #graph = {}
    #for line in open(inFname):
        #gene, neighbors = line.rstrip("\n").split("\t")
        #neighbors = set(neighbors.split(","))
        #graph[gene] = neighbors
    ##logging.info("Wrote links to %s" % outFname)

def parseLinkTargets(outDir, validSyms):
    """ parse the ggLink table in outDir and return a dict gene -> Counter() of targetGenes -> count.
    Count is either the article count or, if there is no text mining hit, the count of databases 
    """
    errFh = open("ggLink.errors.tab", "w")

    inFname = join(outDir, "ggLink.tab")
    logging.info("Parsing %s" % inFname)

    asPath = join(autoSqlDir, "ggLink.as")
    targets = defaultdict(Counter)
    for row in lineFileNext(open(inFname), asFname=asPath):
        gene1, gene2 = row.gene1, row.gene2
        count = int(row.docCount)
        # text mining documents needs some minimum support
        if row.linkTypes=="text" and count<UI_MINSUPP:
            continue
        if count==0:
            count = len(row.dbList.split("|"))

        if gene1 not in validSyms:
            if gene2 not in validSyms:
                errFh.write("BothSymsInvalid\t"+"\t".join(row)+"\n")
            else:
                errFh.write("sym1Invalid\t"+"\t".join(row)+"\n")
        if gene2 not in validSyms:
            errFh.write("sym2Invalid\t"+"\t".join(row)+"\n")

        targets[gene1][gene2]=count
        targets[gene2][gene1]=count

    errFh.close()
    logging.info("Wrote rows from ggLink.tab with invalid symbols to ggLink.errors.tab")

    return targets

def makeBigBed(inDir, outDir, bedFname, db):
    " create a file geneInteractions.<db>.bb in outDir from bedFname "
    validSyms = set()
    for line in open(bedFname):
        sym = line.split("\t")[3].rstrip("\n")
        validSyms.add(sym)

    # get interactors from our ggLink table
    geneCounts = parseLinkTargets(inDir, validSyms)

    # get genes from knownGenes table and write to bed
    #bedFname = join(outDir, "genes.%s.bed" % db)
    #logging.info("Writing genes to %s" % bedFname)
    #cmd = "hgsql %s -NBe 'select chrom, chromStart, chromEnd, geneSymbol from knownCanonical JOIN kgXref ON kgId=transcript' > %s" % (db, bedFname)
    #runCmd(cmd)
    #bedFname = "geneModels/gencode19.median.bed"

    # rewriting bed file and fill with counts
    bedOutFname = join(outDir, "geneInteractions.%s.bed" % db)
    ofh = open(bedOutFname, "w")
    logging.info("Rewriting %s to %s" % (bedFname, bedOutFname))
    doneSymbols = set()
    for line in open(bedFname):
        row = line.rstrip("\n").split("\t")
        gene = row[3]
        counts = geneCounts.get(gene, None)
        if counts==None:
            # skip gene if not found
            continue

        # create the new name field
        docCount = 0
        strList = []
        geneCount = 0
        for targetGene, count in counts.most_common():
            #strList.append("%s:%d" % (targetGene, count))
            if geneCount < 10:
                strList.append("%s" % (targetGene))
            docCount += count
            geneCount += 1

        score = min(docCount, 1000)

        targetGenes = ",".join(strList)
        row[3] = gene+": "+targetGenes # why a space? see linkIdInName trackDb statement

        row.append( str(score) )
        row.append(".")
        row.append(row[1])
        row.append(row[2])

        if docCount > 100:
            color = "0,0,0"
        elif docCount > 10:
            color = "0,0,128"
        else:
            color = "173,216,230"

        row.append(color)

        ofh.write("\t".join(row))
        ofh.write("\n")
        doneSymbols.add(gene)
    ofh.close()

    missingSyms = set(geneCounts) - set(doneSymbols)
    logging.info("%d symbols in ggLink not found in BED file" % len(missingSyms))
    logging.info("missing symbols written to missSym.txt")

    ofh= open("missSym.txt", "w")
    ofh.write("\n".join(missingSyms))
    ofh.close()

    cmd = "bedSort %s %s" % (bedOutFname, bedOutFname)
    runCmd(cmd)

    bbFname = join(outDir, "geneInteractions.%s.bb" % db)
    chromSizeFname = "/hive/data/genomes/%s/chrom.sizes" % db
    cmd = "bedToBigBed -tab %s %s %s"  % (bedOutFname, chromSizeFname, bbFname)
    runCmd(cmd)
    logging.info("bigBed file written to %s" % bbFname)
    
def findBestPaths(genes, geneLinks):
    " find best paths of max length 2 between genes using geneLinks. return pairs. "
    pairs = set()
    links = defaultdict(list) # dict (from, to) -> list of (docCountSum, path)
    for gene1 in genes:
        # search at distance 1
        for gene2, docCount2 in geneLinks.get(gene1, []):
            if gene2 in genes and gene2!=gene1:
                # stop if found
                print "%s-%s" % (gene1, gene2)
                pairs.add( tuple(sorted((gene1, gene2))) )
                links[ (gene1, gene2) ].append( (docCount2, [gene1, gene2]) )
                continue

            # search at distance 2
            for gene3, docCount3 in geneLinks.get(gene2, []):
                if gene3 in genes and gene3!=gene2 and gene3!=gene1:
                    # distance = 2
                    print "%s-%s-%s" % (gene1, gene2, gene3)
                    pairs.add( tuple(sorted((gene1, gene2))) )
                    pairs.add( tuple(sorted((gene2, gene3))) )
                    links[ (gene1, gene3) ].append( ((docCount2+docCount3)/2, [gene1, gene2, gene3]) )

    for genePair, paths in links.iteritems():
        paths.sort(reverse=True)
        print genePair, paths
        
    return pairs

# ----------- MAIN --------------
#if options.test:
    #import doctest
    #doctest.testmod()
    #sys.exit(0)

if args==[]:
    parser.print_help()
    exit(1)

cmd = args[0]
if cmd == "build":
    wordFname = options.wordFname

    pathwayDir, dbDir, textDir, outDir = args[1:]
    # load the input files into memory
    dbRows = loadFiles(dbDir, prefix="ppi_") 
    pwRows = loadFiles(pathwayDir)
    textRows = loadFiles(textDir)

    # index and merge them
    dbPairs   = indexPairs(dbRows, "ppi databases")
    pwPairs   = indexPairs(pwRows, "pathways")
    textPairs = indexPairs(textRows, "text mining")
    pwDirPairs = getDirectedPairs(pwRows)

    curatedPairs = mergePairs([dbPairs, pwPairs])
    pairMinResultCounts, docToPairs = getResultCounts(curatedPairs)

    bestSentences = runSumBasic(textPairs, wordFname)
    allPairs = mergePairs([curatedPairs, textPairs])

    #ltPairs, ltDocs = getResultCounts(curatedPairs)
    # keep result counts for the "docs" step
    ofh = open(join(outDir, "resultCounts.tmp.txt"), "w")
    for docId, pairs in docToPairs.iteritems():
        ofh.write("%s\t%d\n" % (docId, len(pairs)))
    ofh.close()

    pairDirDocs = directedPairToDocs(textRows)
    pairDbs = pairToDbs(curatedPairs)

    outFname = join(outDir, "ggLink.tmp.txt") # needs the addContext step to complete it
    eventFname = join(outDir, "ggLinkEvent.tab")
    allSyms = writeGraphTable(allPairs, pairDirDocs, pairDbs, pairMinResultCounts, pwDirPairs, \
        bestSentences, outFname, eventFname)

    pmidToId = indexPmids([dbRows,pwRows], textRows)
    outFname = join(outDir, "ggDocEvent.tab")
    writeDocEvents(pmidToId, outFname)

    outFname = join(outDir, "ggEventDb.tab")
    writeEventTable([dbRows, pwRows], outFname, colCount=14)

    outFname = join(outDir, "ggEventText.tab")
    writeEventTable([textRows], outFname)

    # make sure we don't forget to update the link table with context
    linkFname = join(outDir, "ggLink.tab")
    if isfile(linkFname):
        os.remove(linkFname)

    # hgGene does not like it if the gene symbols are in two different
    # columns, so we create a very simple table with just the gene symbols
    symFname = join(outDir, "ggSymbol.tab")
    logging.info("Writing %s" % symFname)
    open(symFname, "w").write("\n".join(allSyms))

elif cmd == "medline":
    outDir = args[1]
    textDir = options.textDir
    medlineFname = join(outDir, allArtFname)
    writeAllDocInfo(textDir, medlineFname)

elif cmd == "docs":
    outDir = args[1]
    outFname = join(outDir, "ggDoc.tab")
    pmidEventPath = join(outDir, "ggDocEvent.tab")

    medlineFname = join(outDir, allArtFname)
    meshTerms = parseMeshContext(options.meshFname)

    shortNames = parseShortNames(options.journalInfo)

    resCountFname = join(outDir, "resultCounts.tmp.txt")
    resCounts = readDict(resCountFname)

    docContext = writeDocsTable(pmidEventPath, medlineFname, shortNames, meshTerms, resCounts, outFname)

    # write docContext to file
    ctFname = join(outDir, "docContext.txt")
    ofh = open(ctFname, "w")
    for docId, context in docContext.iteritems():
        ofh.write("%s\t%s\n" % (docId, context))
    ofh.close()
    logging.info("Written document contexts to %s for %d documents" % (ctFname, len(docContext)))

elif cmd == "context":
    outDir = args[1]
    ctFname = join(outDir, "docContext.txt")
    docEventFname = join(outDir, "ggDocEvent.tab")
    linkEventFname = join(outDir, "ggLinkEvent.tab")
    linkFname = join(outDir, "ggLink.tmp.txt")
    newLines = addContext(ctFname, docEventFname, linkEventFname, linkFname)

    outFname = join(outDir, "ggLink.tab")
    ofh = open(outFname, "w")
    for l in newLines:
        ofh.write("%s\n" % l)
    ofh.close()
    logging.info("appended document context to %s" % outFname)

elif cmd == "bigBed":
    inDir = args[1]
    outDir = args[2]
    geneBedFile = args[3]
    db = args[4]
    makeBigBed(inDir, outDir, geneBedFile, db)

elif cmd == "load":
    inDir = args[1]
    db = args[2]
    loadTables(inDir, db)

# --- DEBUGGING / TESTING ----

elif cmd=="sumBasic": # for debugging
    inFname = args[1]
    rows = []
    for row in lineFileNext(open(inFname)):
        rows.append(row)
    textPairs = indexPairs(rows, "text mining")
    for pair, sent in runSumBasic(textPairs, options.wordFname).iteritems():
        print pair, sent
    
elif cmd == "graph":
    outDir = args[1]
    convGraph(outDir)

elif cmd == "graph2":
    outDir = args[1]
    convGraph2(outDir)

elif cmd == "graph3":
    outDir = args[1]
    convGraph3(outDir)

elif cmd == "loadgraph":
    outDir = args[1]
    loadGraph(outDir)

elif cmd == "load2":
    outDir = args[1]
    loadGraph2(outDir)

elif cmd == "load3":
    outDir = args[1]
    loadGraph3(outDir)

elif cmd == "subnet":
    outDir, geneFile = args[1:]
    geneLinks = loadGraph3(outDir)
    genes = set(open(geneFile).read().splitlines())

    print len(findBestPaths(genes, geneLinks))

else:
    logging.error("unknown command %s" % cmd)

