#!/usr/bin/env python

import logging, sys, optparse
from collections import defaultdict
from os.path import join, basename, dirname, isfile

# ==== functions =====
    
def parseArgs():
    " setup logging, parse command line arguments and options. -h shows auto-generated help page "
    parser = optparse.OptionParser("usage: %prog [options] file1 file2 ... outFname - merge expression matrices with one line per gene into a big matrix. Matrices have to be of identical size and sorted by geneId.")

    parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages")
    #parser.add_option("-f", "--file", dest="file", action="store", help="run on file") 
    #parser.add_option("", "--test", dest="test", action="store_true", help="do something") 
    (options, args) = parser.parse_args()

    if args==[]:
        parser.print_help()
        exit(1)

    if options.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)
    return args, options

def getAllFields(ifhs):
    " give a list of file handles, get all non-gene headers and return as a list of names "
    fieldToFname = {}
    allFields = []
    colCounts = []
    for ifh in ifhs:
        fields = ifh.readline().rstrip("\n").split("\t")
        colCounts.append(len(fields))
        assert(fields[0]=="#gene")
        for i, field in enumerate(fields):
            # make sure we have no field name overlaps
            if field in fieldToFname and i!=0:
                raise Exception("field %s seen twice, in %s and %s" % (field, ifh.name, fieldToFname[field]))
            fieldToFname[field] = ifh.name

            # only use first field name from first file
            if i==0:
                if len(allFields)==0:
                    allFields.append(field)
            else:
                allFields.append(field)

    return allFields, colCounts
# ----------- main --------------
def main():
    args, options = parseArgs()

    inFnames = args[:-1]
    outFname = args[-1]
    ofh = open(outFname, "w")

    ifhs = []
    for fn in inFnames:
        ifhs.append( open(fn) )

    headers, colCounts = getAllFields(ifhs)
    ofh.write("\t".join(headers))
    ofh.write("\n")

    for i, ifh in enumerate(ifhs):
        logging.info("File %s: %d columns with values" % (ifhs[i].name, colCounts[i]-1)) 

    fileCount = len(inFnames)
    progressStep = 1000

    doProcess = True
    lineCount = 0

    while doProcess:
        geneId = None

        lineCount += 1
        if lineCount % progressStep == 0:
            logging.info("Processed %d rows" % (lineCount))

        for i, ifh in enumerate(ifhs):
            lineStr = ifh.readline()
            # check if we're at EOF
            if lineStr=='':
                doProcess = False
                break

            fields = lineStr.rstrip("\n").split("\t")
            if (len(fields)!=colCounts[i]): # check number of columns against header count
                raise Exception("Illegal number of fields: file %s, line %d, field count: %d, expected %d" % 
                    (ifh.name, lineCount, len(fields), colCounts[i]))

            # check the gene ID
            if i==0: # get the gene ID from the first file
                geneId = fields[0]
                allVals = [geneId]
            else:
                assert(geneId == fields[0]) # if this fails, then the rows are not in the same order

            allVals.extend(fields[1:])

        ofh.write("\t".join(allVals))
        ofh.write("\n")

    # make sure that we've read through all files
    for ifh in ifhs:
        #print ifh.name
        assert(ifh.readline()=='') # a file has still lines left to read?

    ofh.close()
    logging.info("Wrote %d lines (not counting header)" % lineCount)

main()
