#!/usr/bin/env /hive/data/outside/genbank/bin/gbPython
#
# $Id: gbAlignSetup,v 1.24 2009/04/19 01:36:13 markd Exp $
#

usage="""gbAlignSetup [options] database ...

Do setup for the alignment step.

Arguments:
  - database - databases to align

Obtains other parameters from etc/genbank.conf, see doc/config.html for
details.

Does:
   - Find updates that are missing alignments in the latest releases for
     each specified database. 
   - Extract tmp fasta file to align into the alignment work directory.
   - Produce a single parasol jobs file to do all of the alignments.
   - create a list of all expected PSL files."""


# Note: jobs-specific paths are created in a multi-level tree
# to avoid large number off files in a single directory. 
#
# An alignment of .../mrna.0.fa and  .../chr10.nib will end up in:
#
#   refseq.12/hg17/full/mrna.0/chr10/chr10.psl
#
# The way the directory names are derived depends on the the genome file
# name:
#
#    chr22.nib -> chr22/chr22.psl
#    chr1.nib:207000000-217000000 -> chr1/2/207000000-217000000.psl
#    
# IMPORTANT NOTE: gbAlignFinish takes advantage of a one-to-one mapping between
# query file and the first directory level to make sorting more efficient, so
# don't change this with understanding what is going on.
#

import sys, os, os.path, re

myBinDir = os.path.normpath(os.path.dirname(sys.argv[0]))
sys.path.append(myBinDir + "/../lib/py")
from optparse import OptionParser
from genbank import procOps, fileOps, gbAbsBinDir, gbSetupPath
from genbank.Config import Config
from genbank.GbIndex import GbIndex, GenBank, RefSeq, MRNA, EST, Native, Xeno
from genbank.GenomeSeqs import GenomeSeqs
from genbank.GenomePartition import GenomePartition, UnplacedChroms
from genbank.fileOps import prLine, prRow, prRowv
import glob

# collect profile output from gbAlignGet runs
collectGmonOut = False
gmonOutDir = "profiles"

gbSetupPath()
#gbBlatPath = gbAbsBinDir() + "/gbBlat"
gbBlatPath =  "/hive/data/outside/genbank/bin/gbBlat"

cmdOpts = None

# global table to make sure we don't generate an output file more than
# once.
definedPsls = set()

class GenomeTarget(GenomePartition):
    "object to store info about genome and partition it into windows"

    def __init__(self, db, conf):
        # get nib dir or 2bit file on cluster
        self.clusterGenome = conf.getDbStr(db, "serverGenome")
        if self.clusterGenome.endswith(".nib"):
            self.clusterGenome = os.path.dirname(self.clusterGenome)

        liftFile = conf.getDbStrNo(db, "lift")
        unplacedChroms = None
        unSpecs = conf.getDbWordsNone(db, "align.unplacedChroms")
        if unSpecs != None:
            if liftFile == None:
                raise Exception(db + ".align.unplacedChroms requires " + db + ".lift")
            unplacedChroms = UnplacedChroms(unSpecs)
        
        GenomePartition.__init__(self, db, conf.getDbStr(db, "serverGenome"),
                                 conf.getDbInt(db, "align.window"),
                                 conf.getDbInt(db, "align.overlap"),
                                 conf.getDbInt(db, "align.maxGap"),
                                 conf.getDbIntDefault(db, "align.minUnplacedSize", 0),
                                 liftFile, unplacedChroms)
        if cmdOpts.verbose >= 3:
            self.dump(sys.stderr)


class CDnaQuery(object):
    "object to store info on one of the extracted cDNA fasta files"

    def __init__(self, procPart, cdnaType, orgCat, numSeqs, numBases, faPath):
        self.procPart = procPart
        self.cdnaType = cdnaType
        self.orgCat = orgCat
        self.numSeqs = numSeqs  
        self.numBases = numBases 
        self.faPath = faPath
        self.name = os.path.splitext(os.path.basename(faPath))[0]

    def __str__(self):
        return str(self.procPart) + ": nseqs=" + str(self.numSeqs) + ", nbases=" + str(self.numBases) + ": " + self.faPath
        
                        
class CDnaUnaligned(object):
    """object used to find unaligned cDNA partitions and extract fasta files
    to query against genome.  Does all specified orgCats at once, to avoid
    multiple passes throught the cDNA fasta files files"""

    def __init__(self, db, rel, cdnaType, orgCats, querySplitSizeBytes, workDir, partitionsFh):
        self.db = db
        self.rel = rel
        self.cdnaType = cdnaType
        self.orgCats = orgCats
        self.querySplitSizeBytes = querySplitSizeBytes
        self.workDir = os.path.normpath(workDir)
        self.partitionsFh = partitionsFh
        self.numAlign = 0L
        self.unaligned = [] # list of GbProcessedPart objs
        self.queries = {} # list of CDnaQuery objs, by orgCat
        for orgCat in orgCats:
            self.queries[orgCat] = []

        rel.loadAlignedParts(db)
        self.__findUnaligned()
        if cmdOpts.verbose >= 3:
            sys.stderr.write("=========== GENBANK INDEX ===========\n")
            self.rel.dump(sys.stderr)
            sys.stderr.flush()

    def __needsAligned(self, procPart):
        "determine if procPart is all or partially unaligned"
        if (procPart.cdnaType != self.cdnaType):
            return False  # not interested
        alnPart = procPart.getAlignedPart()
        if alnPart == None:
            return True
        # check if all requested orgCats are there
        if not alnPart.orgCats.issuperset(self.orgCats):
            return True
        return False

    def __findUnaligned(self):
        "find unaligned processed partitions and add them to the object"
        for upd in self.rel.updates:
            for procPart in upd.procParts:
                if self.__needsAligned(procPart):
                    self.unaligned.append(procPart)

    def __parseAlignGetRow(self, procPart, row):
        "parse on row of output from gbAlignGet"
        if row[0] == "alignFa":
            # alignFa: path orgCat numSeqs numBases
            cdnas = CDnaQuery(procPart, self.cdnaType, row[2], int(row[3]), long(row[4]), row[1])
            self.queries[cdnas.orgCat].append(cdnas)
        elif row[0] == "partitions":
            self.partitionsFh.write("\t".join(row[1:]) + "\n")
        elif row[0] == "alignCnt":
            self.numAlign += long(row[1])
        else:
            raise Exception("unexpected output from gbAlignGet: " + "\t".join(row))

    def __parseAlignGet(self, procPart, lines):
        "parse output from gbAlignGet"
        for line in lines:
            self.__parseAlignGetRow(procPart, line.split("\t"))

    def __saveGmonOut(self, procPart):
        saveName = gmonOutDir + "/gbAlignGet/" + self.db + "/" + self.rel + "/" + procPart.update + "." + procPart.cdnaType
        if procPart.accPrefix != None:
            saveName += "." + procPart.accPrefix
        saveName += ".gmon.out"
        fileOps.ensureFileDir(saveName)
        os.rename("gmon.out", saveName)
        
    def __buildPart(self, cmdBase, procPart):
        "run gbAlignGet to get fasta for one partition"
        cmd = list(cmdBase)
        cmd.append(self.rel)
        cmd.append(procPart.update)
        if procPart.accPrefix == None:
            cmd.append(self.cdnaType)
        else:
            cmd.append(self.cdnaType + "." + procPart.accPrefix)
        cmd.append(self.db)

        if cmdOpts.verbose >= 2:
            prLine(sys.stderr, procOps.fmtCmd(cmd))
        lines = procOps.callProcLines(cmd, inheritStderr=True)
        if collectGmonOut:
            self.__saveGmonOut(procPart)
        # process results
        self.__parseAlignGet(procPart, lines)

    def buildQueries(self):
        """build cDNA fasta files for unaligned partitions; return count of
        unaligned sequences"""

        # build first part of gbAlignGet command with options
        cmdBase = ["gbAlignGet",
                   "-workdir=" + self.workDir,
                   "-orgCats=" + ",".join(self.orgCats),
                   "-polyASizes"];
        if self.querySplitSizeBytes > 0:
            cmdBase.append("-fasize=" + str(self.querySplitSizeBytes))
        if cmdOpts.verbose:
            cmdBase.append("-verbose=" + str(cmdOpts.verbose))
        if cmdOpts.noMigrate:
            cmdBase.append("-noMigrate")
        for procPart in self.unaligned:
            self.__buildPart(cmdBase, procPart)
        return self.numAlign

class Job(list):
    """Information about a job.  This contains a single query sequence, and
    multiple target windows."""
    # Note: can't contain multiple queries due to gbAlignFinsh assumptions

    __slots__ = ("cdnaQuery", "targetBases")

    def __init__(self, cdnaQuery):
        self.cdnaQuery = cdnaQuery
        self.targetBases = 0L

    def addTargetWin(self, win):
        self.append(win)
        self.targetBases += win.size

    def __getJobPath(self, jobGenDb):
        "generate the path relative to the work directory that is unique to this job"
        # follow the old conventions of a psl directory under the directory
        # with the fasta file.  The job file goes in the psl dir too!
        win = self[0]  # user first window in names
        update = self.cdnaQuery.procPart.update
        path = [update.rel, jobGenDb.db, update, "psl", self.cdnaQuery.name, win.seq.id]
        if win.isFullSeq:
            path.append(win.seq.id)
        else:
            # create subdir from start of range
            s = str(win.start)
            path.append(s[0])
            path.append(s + "-" + str(win.end))
        return "/".join(path)

    def __makeWorkDirRel(self, jobGen, path):
        "removing the workdir directories from path"
        assert(path.startswith(jobGen.workDir))
        return path[len(jobGen.workDir)+1:]
        
    def write(self, jobGenDb, preFilterOpts):
        "write this job spec"
        jobGen = jobGenDb.jobGen
        jobPath = self.__getJobPath(jobGenDb)
        
        # create job spec file
        relJobFile = jobPath + ".job"
        jobFile = jobGen.workDir + "/" + relJobFile
        fileOps.ensureFileDir(jobFile)
        fh = open(jobFile, "w")
        prRowv(fh, "type", self.cdnaQuery.orgCat, self.cdnaQuery.cdnaType)
        if jobGenDb.ooc != None:
            prRowv(fh, "ooc", jobGenDb.ooc)

        prRowv(fh, "tdb", jobGenDb.genome.clusterGenome)
        fh.write("tseq")
        for win in self:
            fh.write("\t")
            fh.write(win.getSpec())
        fh.write("\n")
        prRowv(fh, "qdb", self.__makeWorkDirRel(jobGen, self.cdnaQuery.faPath))
        if preFilterOpts != None:
            prRowv(fh, "preFilterOpts", preFilterOpts)

        # add job
        relPslFile = jobPath + ".psl"
        if relPslFile in definedPsls:
            raise Exception("BUG: psl file already used: " + relPslFile)
        definedPsls.add(relPslFile)
        fileOps.prLine(jobGen.jobFh, gbBlatPath, " ", relJobFile, " {check out exists ", relPslFile, "}")
        fileOps.prLine(jobGen.expectFh, relPslFile)

class Jobs(list):
    """job to to partition target sequence windows into jobs.  This builds
    a list of Job objects for a given update and query fasta and packs
    multiple targets smaller than the window size into jobs."""

    def __init__(self, jobGenDb, cdnaQuery):
        self.jobGenDb = jobGenDb
        self.winSize = jobGenDb.genome.windows.winSize
        self.overlap = jobGenDb.genome.windows.overlap
        self.maxJobTargetSize = jobGenDb.maxJobTargetSize
        self.cdnaQuery = cdnaQuery
        self.pending = []

        for win in jobGenDb.genome.windows:
            self.__addWindow(win)

    def __findJobWithRoom(self, win):
        """find a job with room for this window, alowing jobs to be overflow
        by the overlap amount. Returns None if none found, or index in
        pending list """
        need = win.size - self.overlap
        for i in xrange(len(self.pending)):
            if self.pending[i].targetBases + need < self.winSize:
                return i
        return None

    def __newJob(self, win):
        """add a new job for the window"""
        # add to pending unless it's already full
        job = Job(self.cdnaQuery)
        job.addTargetWin(win)
        if job.targetBases >= self.maxJobTargetSize:
            self.append(job)
        else:
            self.pending.append(job)

    def __moveFromPending(self, i, job):
        """move a full job from the pending list to being ready"""
        # shuffle last pending job it's position in array
        iLast = len(self.pending)-1
        self.pending[i] = self.pending[iLast]
        del(self.pending[iLast])
        self.append(job)

    def __addWindow(self, win):
        """add a window to the object.  Performance sensitive due to
        pseudo-chromosomes with unplaced sequence"""
        i = self.__findJobWithRoom(win)
        if i == None:
            self.__newJob(win)
        else:
            # pack into existing job
            job = self.pending[i]
            job.addTargetWin(win)
            if job.targetBases >= self.maxJobTargetSize:
                self.__moveFromPending(i, job)

    def write(self, preFilterOpts):
        "write jobs"
        for job in self:
            job.write(self.jobGenDb, preFilterOpts)
        for job in self.pending:
            job.write(self.jobGenDb, preFilterOpts)

    def prStats(self):
        "print stats about jobs for debugging purposes"
        allJobs = self + self.pending
        job = allJobs[0]
        minBases = job.targetBases
        maxBases = job.targetBases
        minSeqs = len(job)
        maxSeqs = len(job)
        for job in allJobs:
            minBases = min(minBases, job.targetBases)
            maxBases = max(maxBases, job.targetBases)
            minSeqs = min(minSeqs, len(job)) 
            maxSeqs = max(maxSeqs, len(job))
            
        pp = self.cdnaQuery.procPart
        prLine(sys.stderr, "Jobs: ",pp.update.rel, " ", pp.update, " ", pp, " ", self.jobGenDb.db, " jobs:", len(allJobs), " minBases:", minBases, " maxBases:", maxBases, " minSeqs:", minSeqs, " maxSeqs:", maxSeqs)

class JobGenDb(object):
    "object to generate alignment jobs for a database"
    def __init__(self, jobGen, db, partitionsFh):
        self.jobGen = jobGen
        self.conf = jobGen.conf
        self.db = db
        self.partitionsFh = partitionsFh
        self.querySplitSizeBytes = 1024*1024*self.conf.getDbInt(db, "align.querySplitSize")
        self.ooc = self.conf.getDbStrNo(db, "ooc")
        self.maxJobTargetSize = self.conf.getDbInt(db, "align.maxJobTargetSize")
        self.genome = None
        self.numAlign = 0L

    # matches pre-filter options
    preFilterOptsRe = re.compile("(-minId=)|(-minCover=)|(-maxRepMatch=)|(-minQSize=)|(-polyASizes=)")

    def __getPreFilterOpts(self, rel, cdnaType, orgCat):
        "select non-comparitive pslCDnaFilter options for blat job filtering"

        # return None or True if enabled!
        if self.conf.getDbBoolNone(self.db, "align.prefilter") == False:
            return None

        filterOpts = self.conf.getvDbStrNo(self.db, rel.srcDb, cdnaType, orgCat, "pslCDnaFilter")
        if filterOpts == None:
            return None
        preOpts = []
        for opt in filterOpts.split():
            if JobGenDb.preFilterOptsRe.match(opt):
                preOpts.append(opt)
        if len(preOpts) == 0:
            return None
        else:
            return " ".join(preOpts)

    def __addOrgCatJobs(self, rel, cdnaType, orgCat, queries):
        "make query jobs for a given orgCat"
        if self.genome == None:
            self.genome = GenomeTarget(self.db, self.conf)
        preFilterOpts = self.__getPreFilterOpts(rel, cdnaType, orgCat)
        for cdnaQuery in queries:
            jobs = Jobs(self, cdnaQuery)
            jobs.write(preFilterOpts)
            if cmdOpts.verbose >= 4:
                jobs.prStats()
        
    def __addJobSet(self, rel, cdnaType, orgCats):
        cdnaQueries = CDnaUnaligned(self.db, rel, cdnaType, orgCats,
                                    self.querySplitSizeBytes,
                                    self.jobGen.workDir, self.partitionsFh)
        if cdnaQueries.buildQueries() > 0:
            for orgCat in orgCats:
                self.__addOrgCatJobs(rel, cdnaType, orgCat, cdnaQueries.queries[orgCat])
            self.numAlign += cdnaQueries.numAlign

    def __getOrgCats(self, rel, cdnaType):
        "determine orgCats that are to be loaded"
        orgCats = set()
        for orgCat in cmdOpts.orgCats:
            if self.conf.getvDbBoolNone(self.db, rel.srcDb, cdnaType, orgCat, "load") or self.conf.getvDbBoolNone(self.db, rel.srcDb, cdnaType, orgCat, "align"):
                orgCats.add(orgCat)
        return orgCats

    def addJobs(self):
        "add jobs, returning the total number of sequences needing align"
        for srcDb in cmdOpts.srcDbs:
            rel = self.jobGen.gbIndex[srcDb].getLatestRel()
            if rel != None:
                for cdnaType in cmdOpts.cdnaTypes:
                    if cdnaType in rel.cdnaTypes:
                        orgCats = self.__getOrgCats(rel, cdnaType)
                        if len(orgCats) > 0:
                            self.__addJobSet(rel, cdnaType, orgCats)
        
class JobGen(object):
    "object to generate alignment jobs"
    def __init__(self, conf, gbIndex, workDir):
        self.conf = conf
	self.gbIndex = gbIndex
        self.workDir = os.path.normpath(workDir)
        self.jobsFile = workDir + "/align.jobs"
        self.jobsFileTmp = self.jobsFile + ".tmp"
        self.expectFile = workDir + "/align.expected"
        self.expectFileTmp = self.expectFile + ".tmp"
        self.partitionsFile = workDir + "/partitions.lst"
        self.partitionsFileTmp = self.partitionsFile + ".tmp"
        self.numAlign = 0L

        fileOps.ensureDir(workDir)
        if os.path.exists(self.jobsFile):
            raise Exception("job file exists, not overwriting: " + self.jobsFile);
        if os.path.exists(self.jobsFileTmp):
            raise Exception("tmp job file exists, ensure no other alignment is running, then remove working directory: " + self.workDir);

        self.jobFh = open(self.jobsFileTmp, "w")
        self.expectFh = open(self.expectFileTmp, "w")
        self.partitionsFh = open(self.partitionsFileTmp, "w")

    def addJobs(self, db):
        "add jobs for unaligned sequences in a database"
        j = JobGenDb(self, db, self.partitionsFh)
        j.addJobs()
        self.numAlign += j.numAlign

    def finish(self):
        "finish jobs file creation"
        self.expectFh.close()
        self.jobFh.close()

        if self.numAlign == 0:
            # leave flag indicating nothing to align
            open(self.workDir + "/align.none", "w").close()
            os.remove(self.expectFileTmp)
            os.remove(self.partitionsFileTmp)
            os.remove(self.jobsFileTmp)
        else:
            os.rename(self.expectFileTmp, self.expectFile)
            os.rename(self.partitionsFileTmp, self.partitionsFile)
            os.rename(self.jobsFileTmp, self.jobsFile) # must be last, flag file check in __init__

def convertOptToSet(optName, optValue, valid):
    """convert one of the restriction options to a set, validating it's values
    and defaulting to all valid if none are specified"""
    if optValue == None:
        opts = set(valid)
    else:
        opts = set(optValue)
        for o in optValue:
            if not o in valid:
                raise Exception("\"" + o + "\" is not a valid value for " + optName
                                + ", expected one of: " + ", ".join(valid))
            opts.add(o)
    return opts

def main():
    "entry point"
    parser = OptionParser(usage=usage)
    parser.add_option("--workdir", action="store", dest="workdir", default="work/align",
                      help="directory where alignment is built. See gbAlignStep for details. Defaults to work/align")
    parser.add_option("--clusterWorkDir",  action="store", dest="clusterWorkDir", default=None,
                      help="location of work directory on cluster.")
    parser.add_option("--verbose", action="store", dest="verbose", type="int", default=0,
                      help="print details")
    parser.add_option("--srcDb", action="append", dest="srcDbs", default=None,
                      help="Restrict the source database to either \"genbank\" or \"refseq\". Maybe repeated.")
    parser.add_option("--type", action="append", dest="cdnaTypes", default=None,
                      help="Restrict the type of sequence processeed to either \"mrna\" or \"est\". Maybe repeated.")
    parser.add_option("--orgCat", action="append", dest="orgCats", default=None,
                      help="""Restrict type of sequences to process to \"native\" or \"xeno\". If not speicifed, categories that are to be loaded are used from genbank.conf. Maybe repeated.""")
    parser.add_option("--noMigrate", action="store_true", dest="noMigrate", default=False,
                      help="Don't migrate alignments.")
    global cmdOpts
    (cmdOpts, databases) = parser.parse_args()
    if len(databases) == 0:
        parser.error("wrong number of arguments")

    # convert restriction options to sets
    cmdOpts.srcDbs = convertOptToSet("--srcDb", cmdOpts.srcDbs, (GenBank, RefSeq))
    cmdOpts.cdnaTypes = convertOptToSet("--type", cmdOpts.cdnaTypes, (MRNA, EST))
    cmdOpts.orgCats = convertOptToSet("--orgCat", cmdOpts.orgCats, (Native, Xeno))

    confFile = "etc/genbank.conf"
    if not os.path.exists(confFile):
        raise Exception(confFile + " not found, must be run from genbank root directory")
    conf = Config(confFile)

    gbIndex = GbIndex()
    jobGen = JobGen(conf, gbIndex, cmdOpts.workdir)
    for db in databases:
        jobGen.addJobs(db)
    jobGen.finish()

main()

# Local Variables:
# mode: python
# End:
