import sys, os, subprocess, string, math, time
import seq_io, tools

# Command line interface
class Main:
    def __init__(self):
        self.path = ""
        self.IO = seq_io.IO()
        self.oMatrix = matrix()
        self.cwd = ""
        if __name__ == "__main__":
            self.cwd = ".."

    # Execute selected program
    def execute(self,path,tmp_path=''): # path to FASTA file
        # open fasta file
        fasta,seqlist = self.IO.openFasta(os.path.join(self.cwd,path))
        # align fasta file by muscle
        aln,seqlist = self.align(fasta,seqlist,tmp_path)
        # calculate distance
        return self.calculate_distance_matrix(aln,seqlist)
            
    def align(self,fasta,seqlist,tmp_path=''):
        if not tmp_path:
            tmp_path = os.path.join(self.cwd,"lib","bin","tmp","tmp.fas")
        else:
            tmp_path = os.path.join(tmp_path,"tmp.fas")
        self.IO.saveFasta(tmp_path,fasta,seqlist)
        if sys.platform == "win32":
           cline = "%s -in %s -out %s -quiet" % (os.path.join(self.cwd,"lib","bin","muscle"),tmp_path,tmp_path[:-3]+"aln")
        elif sys.platform == "linux" or sys.platform == "linux2" and 8 * struct.calcsize("P") == 32:
            cline = "%s -in %s -out %s -quiet" % (os.path.join(self.cwd,"lib","bin","muscle3.8.31_i86linux32"),tmp_path,tmp_path[:-3]+"aln")
        elif sys.platform == "linux" or sys.platform == "linux2" and 8 * struct.calcsize("P") == 64:
            cline = "%s -in %s -out %s -quiet" % (os.path.join(self.cwd,"lib","bin","muscle3.8.31_i86linux64"),tmp_path,tmp_path[:-3]+"aln")
        elif sys.platform == "darwin" and 8 * struct.calcsize("P") == 32:
            cline = "%s -in %s -out %s -quiet" % (os.path.join(self.cwd,"lib","bin","muscle3.8.31_i86darwin32"),tmp_path,tmp_path[:-3]+"aln")
        elif sys.platform == "darwin" and 8 * struct.calcsize("P") == 64:
            cline = "%s -in %s -out %s -quiet" % (os.path.join(self.cwd,"lib","bin","muscle3.8.31_i86darwin64"),tmp_path,tmp_path[:-3]+"aln")
        else:
            print()
            print("The program does not work under this OS %s" % sys.platform)
            return [],[]
        process = subprocess.Popen(cline,stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            shell=(sys.platform!="win32"))
        process.stdin.close()
        process.stdout.close()
        while not os.path.exists(tmp_path[:-3]+"aln"):
            time.sleep(0.5)
        return self.IO.openFasta(tmp_path[:-3]+"aln")
    
    def calculate_distance_matrix(self,aln,seqlist=[]):
        if not seqlist:
            seqlist = [item.strip() for item in list(aln.keys())]
        matrix = tools.matrix([len(seqlist),len(seqlist)])
        for i in range(len(seqlist)):
            for j in range(len(seqlist)):
                if i == j:
                    matrix[i][j] = 0
                    continue
                matrix[i][j] = self.oMatrix("%s\n%s" % (aln[seqlist[i]],aln[seqlist[j]]))
        return matrix,seqlist

class matrix:
    def __init__(self,table="blosum62"):
        self.aln = ["",""] 
        self.tables = ["blosum62"]
        self.oTable = None
        if table == "blosum62":
            self.oTable = blosum62()
    
    def ls_tables(self):
        return self.tables
    
    def _format(self,aln):
        aln = [l.strip().upper() for l in aln.split("\n")]
        if len(aln) > 2:
            aln = aln[:2]
        if len(aln[0]) != len(aln[1]):
            aln[0] = aln[0][:min(len(aln[0]),len(aln[1]))]
            aln[1] = aln[1][:min(len(aln[0]),len(aln[1]))]
        while aln[0][0] == "-" or aln[1][0] == "-":
            aln[0] = aln[0][1:]
            aln[1] = aln[1][1:]
        while aln[0][-1] == "-" or aln[1][-1] == "-":
            aln[0] = aln[0][:-1]
            aln[1] = aln[1][:-1]
        return aln
        
    def __call__(self,aln): # pairwise alignment of 2 sequences in 2 strings of the same length filled with "-"
        self.aln = self._format(aln)
        expectation = self.oTable.calculate_expectation(self.aln)
        dN = float(self.oTable.calculate_distance(self.aln) - expectation)
        #dU = float(self.oTable.calculate_distance([self.aln[0],self.aln[0]])+self.oTable.calculate_distance([self.aln[1],self.aln[1]]))/2.0
        dU = min(self.oTable.calculate_distance([self.aln[0],self.aln[0]]),self.oTable.calculate_distance([self.aln[1],self.aln[1]]))-expectation
        return -math.log(dN/dU)*self.oTable.calibration
        
class amc_substitutions:
    def __init__(self):
        self.table = {}
        self.calibration = 1.0
        self.amc_list = ['C', 'S', 'T', 'P', 'A', 'G', 'N', 'D', 'E', 'Q', 'H', 'R', 'K', 'M', 'I', 'L', 'V', 'F', 'Y', 'W', "-"]
        
    def odd_number(self,pair):
        pair = pair.strip("-")
        c = 1.0
        if not pair:
            return 0
        if len(pair)==1:
            '''
            pair *= 2
            c = -1.0
            '''
            return 0
        if pair not in self.table:
            return 0
        return c*self.table[pair]
        #return c*math.exp(self.table[pair])
        
    def calculate_expectation(self,aln):
        expectation = 0
        for amc1 in self.amc_list:
            for amc2 in self.amc_list:
                if amc1 == "-" and amc2 == "-":
                    continue
                expectation += float(self.odd_number(amc1+amc2))*len(aln[0])*float(aln[0].count(amc1))*float(aln[1].count(amc2))/len(aln[0])/len(aln[1])
        return expectation
    
    def calculate_distance(self,aln):
        distance = 0
        for i in range(len(aln[0])):
            distance += self.odd_number(aln[0][i]+aln[1][i])
        return distance

class blosum62(amc_substitutions):
    def __init__(self):
        amc_substitutions.__init__(self)
        self.calibration = 0.9835
        self.table = {'GW': -2, 'GV': -3, 'GT': 1, 'GR': -2, 'GQ': -2, 'GP': -2, 'GY': -3, 'GG': 6, 'GF': -3, 'GE': -2, 'GD': -1, 'GC': -3, 
            'GN': -2, 'GM': -3, 'GL': -4, 'GK': -2, 'GI': -4, 'GH': -2, 'ME': -2, 'MD': -3, 'MG': -3, 'MA': -1, 'MC': -1, 'MM': 5, 'ML': 2, 
            'MN': -2, 'MI': 1, 'MH': -2, 'MK': -1, 'MT': -1, 'MW': -1, 'MV': 1, 'MP': -2, 'MS': -1, 'MR': -1, 'MY': -1, 'FP': -4, 'FQ': -3, 
            'FR': -3, 'FS': -2, 'FT': -2, 'FV': -1, 'FW': 1, 'FY': 3, 'FA': -2, 'FC': -2, 'FD': -3, 'FE': -3, 'FF': 6, 'FG': -3, 'FH': -1, 
            'FK': -3, 'FN': -3, 'SY': -2, 'SS': 4, 'SR': -1, 'SP': -1, 'SW': -3, 'SV': -2, 'ST': 1, 'SI': -2, 'SH': -1, 'SN': 1, 'SM': -1, 
            'SL': -2, 'SC': -1, 'SA': 1, 'SF': -2, 'YI': -1, 'YH': 2, 'YK': -2, 'YM': -1, 'YL': -1, 'YN': -2, 'YA': -2, 'YC': -2, 'YE': -2, 
            'YD': -3, 'YG': -3, 'YF': 3, 'YY': 7, 'YQ': -1, 'YP': -3, 'YS': -2, 'YR': -2, 'YT': -2, 'YW': 2, 'YV': -1, 'LG': -4, 'LD': -4, 
            'LE': -3, 'LC': -1, 'LA': -1, 'LN': -3, 'LL': 4, 'LM': 2, 'LK': -2, 'LH': -3, 'LI': 2, 'LV': 1, 'LW': -2, 'LT': -2, 'LR': -2, 
            'LS': -2, 'LP': -3, 'LQ': -2, 'LY': -1, 'RT': -1, 'RV': -3, 'RW': -3, 'RP': -2, 'RQ': 1, 'RR': 5, 'RS': -1, 'RY': -2, 'RD': -2, 
            'RF': -3, 'RG': -2, 'RA': -1, 'RC': -3, 'RL': -2, 'RM': -1, 'RI': -3, 'RK': 2, 'VH': -3, 'IP': -3, 'EM': -2, 'EL': -3, 'IR': -3, 
            'EI': -3, 'EK': 1, 'EE': 5, 'ED': 2, 'EG': -2, 'EF': -3, 'EA': -1, 'EC': -4, 'IT': -2, 'EY': -2, 'VN': -3, 'EW': -3, 'EV': -2, 
            'EQ': 2, 'EP': -1, 'II': 4, 'VQ': -2, 'VR': -3, 'VT': -2, 'IN': -3, 'KC': -3, 'KA': -1, 'KG': -2, 'KF': -3, 'KE': 1, 'KD': -1, 
            'KK': 5, 'KI': -3, 'KH': -1, 'KM': -1, 'KL': -2, 'KR': 2, 'KQ': 1, 'KP': -1, 'KW': -3, 'KV': -2, 'KY': -2, 'DN': 1, 'DL': -4, 
            'DM': -3, 'DK': -1, 'DH': 1, 'DI': -3, 'DF': -3, 'DG': -1, 'DD': 6, 'DE': 2, 'DC': -3, 'DA': -2, 'DY': -3, 'DV': -3, 'DW': -4, 
            'DT': 1, 'DR': -2, 'DP': -1, 'QQ': 5, 'QP': -1, 'QR': 1, 'QW': -2, 'QV': -2, 'QY': -1, 'QA': -1, 'QC': -3, 'QE': 2, 'QG': -2, 
            'QF': -3, 'QI': -3, 'QK': 1, 'QL': -2, 'WG': -2, 'WF': 1, 'WE': -3, 'WD': -4, 'WC': -2, 'WA': -3, 'WN': -4, 'WM': -1, 'WL': -2, 
            'WK': -3, 'WI': -3, 'WH': -2, 'WW': 11, 'WV': -3, 'WT': -3, 'WS': -3, 'WR': -3, 'WQ': -2, 'WP': -4, 'WY': 2, 'PR': -2, 'PS': -1, 
            'PP': 7, 'PQ': -1, 'PV': -2, 'PW': -4, 'PT': 1, 'PY': -3, 'PC': -3, 'PA': -1, 'PF': -4, 'PG': -2, 'PD': -1, 'PE': -1, 'PK': -1, 
            'PH': -2, 'PI': -3, 'PN': -2, 'PL': -3, 'PM': -2, 'CK': -3, 'CI': -1, 'CH': -3, 'CN': -3, 'CM': -1, 'CL': -1, 'CC': 9, 'CG': -3, 
            'CF': -2, 'CE': -4, 'CD': -3, 'CY': -2, 'CS': -1, 'CR': -3, 'CQ': -3, 'CP': -3, 'CW': -2, 'CV': -1, 'CT': -1, 'IY': -1, 'VA': -2, 
            'VC': -1, 'VD': -3, 'VE': -2, 'VF': -1, 'VG': -3, 'IQ': -3, 'VI': 3, 'IS': -2, 'VK': -2, 'VL': 1, 'VM': 1, 'IW': -3, 'IV': 3, 
            'VP': -2, 'IH': -3, 'IK': -3, 'VS': -2, 'IM': 1, 'IL': 2, 'VV': 4, 'VW': -3, 'IA': -1, 'VY': -1, 'IC': -1, 'IE': -3, 'ID': -3, 
            'IG': -4, 'HY': 2, 'HS': -1, 'HP': -2, 'HV': -3, 'HW': -2, 'HK': -1, 'HH': 8, 'HI': -3, 'HN': 1, 'HL': -3, 'HM': -2, 'HC': -3, 
            'HA': -2, 'HF': -1, 'HG': -2, 'HD': 1, 'NH': 1, 'NI': -3, 'NL': -3, 'NM': -2, 'NN': 6, 'NA': -2, 'NC': -3, 'ND': 1, 'NF': -3, 
            'NG': -2, 'NY': -2, 'NP': -2, 'NS': 1, 'NV': -3, 'NW': -4, 'TY': -2, 'TV': -2, 'TW': -3, 'TT': 4, 'TR': -1, 'TS': 1, 'TP': 1, 
            'TL': -2, 'TM': -1, 'TI': -2, 'TF': -2, 'TG': 1, 'TD': 1, 'TC': -1, 'TA': -1, 'AA': 4, 'AE': -1, 'AD': -2, 'AF': -2, 'AI': -1, 
            'AH': -2, 'AK': -1, 'AM': -1, 'AL': -1, 'AN': -2, 'AQ': -1, 'AP': -1, 'AS': 1, 'AR': -1, 'AT': -1, 'AW': -3, 'AV': -2, 'AY': -2}
