#!/usr/bin/env python2.7
# Summarize all patent documents to a single bigBed line.

# ==== functions =====
def parseArgs():
    " setup logging, parse command line arguments and options. -h shows auto-generated help page "
    parser = optparse.OptionParser("usage: %prog [options] filename - reduce patLens doc information to just one patent info row per unique sequence (=fingerPrint). Input needs to be sorted on first field")
    parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages")
    #parser.add_option("-f", "--file", dest="file", action="store", help="run on file") 
    #parser.add_option("", "--test", dest="test", action="store_true", help="do something") 
    (options, args) = parser.parse_args()

    if args==[]:
        parser.print_help()
        exit(1)

    if options.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)
    return args, options

# ----------- main --------------
import logging, sys, optparse
import datetime, re
from collections import defaultdict, Counter
import operator

stripRes = [
    re.compile("^METHOD OF ", re.I),
    re.compile("^method and composition for", re.I),
    re.compile("^methods for", re.I),
    re.compile("^methods of", re.I),
    re.compile("^methods and compositions for", re.I),
    re.compile("^methods of and analysis of", re.I),
    re.compile("METHOD OF DETECTION AND USE THEREOF", re.I),
    re.compile("METHOD OF DETECTION AND USES THEREOF", re.I),
    re.compile("METHODs OF DETECTION AND USES THEREOF", re.I),
]

def titleHash(sent):
    """ given a string, reduce it to something that is the same for similar sentences 
    """
    # remove common prefixes
    for sre in stripRes:
        sent = sre.sub("", sent)
    # remove trailing s and special chars
    sent = ''.join(e.strip("'s") for e in sent if e.isalnum())
    sent = sent.lower()
    return sent

def summarize(counter):
    """ summarize a counter object, return a string text (count), text2 (count), etc 
    never summarize more than 10 values
    """
    lines = []
    if len(counter)==1:
        return counter.most_common()[0][0]
        
    for text, count in counter.most_common(10):
        lines.append("%s (%d)" % (text, count))
    return ", ".join(lines)
        
def summarizeTitles(titleClusters):
    " summarize a dict with hash -> list of titles to a single title string "
    # take one random title per cluster and create list (count, title)
    titleCounts = []
    for th, titles in titleClusters.iteritems():
        mainTitle = titles[0]
        titleCounts.append( (len(titles), mainTitle) )
    titleCounts.sort(reverse=True)

    # if we have only a single title, return it
    if len(titleCounts)==1 and titleCounts[0][0]==1:
        return titleCounts[0][1]

    # take top 10 and rewrite to list of strings
    subCounts = titleCounts[:10]
    titleStrings = []
    for titleCount, titleString in titleCounts:
        titleStrings.append(titleString + " (%d)" % (titleCount))

    # merge into a single string
    titleStr = "; ".join(titleStrings)
    if len(subCounts)!=len(titleCounts):
        titleStr += "; (only 10 titles are shown)"
    return titleStr

def summarizeRows(rows):
    " summarize the documents linked to a sequence "
    counters = defaultdict(Counter)
    claimCount = 0
    grantCount = 0
    docDates = []
    titleClusters = defaultdict(list)
    intDocIds = []
    extDocIds = []
    seqIdNos = []
    isGrantList = []
    isInClaimsList = []

    # clean fields and sort by date
    datedRows = []
    for row in rows:
        row = [x.strip() for x in row]
        docDate = row[9]
        pd = datetime.datetime.strptime( docDate, "%Y%m%d" )
        datedRows.append ( (pd, row) )
    datedRows.sort()

    for docDateObj, row in datedRows:
        docDates.append(docDateObj)
        fingPrint, seqIdNo, seqType, refInClaims, org, taxId, intDocId, extDocId, docType, docDate, docTitle = row
        if org=="":
            org = "no Organism"
        counters["seqType"][seqType]+=1

        if refInClaims=="1":
            claimCount += 1
            isInClaimsList.append("1")
        else:
            isInClaimsList.append("0")

        if docType.endswith("atent"):
            grantCount += 1
            isGrantList.append("1")
        else:
            isGrantList.append("0")

        counters["org"][org]+=1
        counters["taxId"][taxId]+=1
        intDocIds.append(intDocId)
        extDocIds.append(extDocId)
        seqIdNos.append(seqIdNo)
        if docType=="":
            docType = "No type"
        counters["docType"][docType]+=1
        docTitle = docTitle.strip()
        if docTitle=="":
            docTitle = "no title"
        titleClusters[titleHash(docTitle)].append(docTitle)

    summRow = [fingPrint]
    # number of documents
    summRow.append( len(rows) )
    # number of documents with claims
    summRow.append( claimCount )
    # number of docs with granted patents
    summRow.append( grantCount )
    # summary of organisms
    summRow.append( summarize(counters["org"]) )

    # date range of docs
    docDates.sort()
    if len(docDates)>1:
        dateStr = docDates[0].strftime('%d. %b %Y') + " - " + docDates[-1].strftime('%d. %b %Y')
    else:
        dateStr = docDates[0].strftime('%d. %b %Y')
    summRow.append( dateStr )

    # int and ext doc IDs
    summRow.append(",".join(intDocIds))
    summRow.append(",".join(extDocIds))
    summRow.append(",".join(seqIdNos))
    # summary of titles
    summRow.append( summarizeTitles(titleClusters) )

    summRow.append(",".join(isGrantList))
    summRow.append(",".join(isInClaimsList))

    summRow = [str(x) for x in summRow]
    return summRow
    
def summDocs(inFname):
    # example input line:
    #fingerprint            seqIdNo seqType inClaim    orgName         taxId   intDocId                extDocId                docType                 date            docTitle                            
    #----3JvtXDgvMmPpPhgKRg	203735	1	0	Homo sapiens	9606	US_A1_2006057564	US_2006_0057564_A1	Patent Application	20060316	Identification and mapping of single nucleotide polymorphisms in the human genome
    patInfo = {}
    lastId = None
    rows = []
    i = 0
    seqTypes = ["DNA", "RNA", "Amino acid", "Unknown"]
    rows = []
    #for line in open("seqAndPatentJoined.s1.tab"):
    for line in open(inFname):
        row = line.rstrip("\n").split('\t')
        if row[0]!=lastId and lastId is not None:
            outRow = summarizeRows(rows)
            print "\t".join(outRow)
            rows = []
        rows.append(row)
        lastId = row[0]

    outRow = summarizeRows(rows)
    print "\t".join(outRow)

# ----------- main --------------
def main():
    args, options = parseArgs()
    summDocs(args[0])

main()
