import string, os, io, re, random, time, math
import seq_io, protdist, oup, blast, verhulst, tree, tools

# Command line interface
class Interface:
    def __init__(self,options={}):
        self.reference_genes = {"gyra":1.91}
        self.oValidator = Validator(list(self.reference_genes.keys()))
        self.IO = seq_io.IO()
        self.oup = None
        self.cwd = ""
        if __name__ == "__main__":
            self.cwd = ".."
        self.options = {}
        if options:
            self.options.update(options)
            if not self.options["-x"]:
                if self.options["-n"] == "Yes":
                    self.options["-x"] = "n1_4mer"
                else:
                    self.options["-x"] = "n0_4mer"
            valid = self.oValidator.validate(self.options)
            if valid:
                self.execute()
            else:
                self.main_menu()
        else:
            self.options = {
                       "-i":"zip",          # input folder
                       "-o":"output",       # output folder
                       "-x":"",             # pattern
                       "-f":"",     	    # folder
                       "-n":"No",           # GC-normalization
                       "-g":"",             # reference gene
                       "-r":0,	            # clustering number
                       "-c":1,              # N Protein Distance Contribution
                       "-e":"No",           # default parameter
                       "-d":"Yes",          # save distance table
                       "-m":"Yes",          # save clustering matrix
                       "-t":"Yes",          # save phylogenetic tree
                       "-l":"Yes",          # save cladogram
                       "-p":"Yes",          # save Verhulst plot
                    }
            self.main_menu()

    # Execute selected program
    def execute(self):
        # Create temporary folder
        tmp_path = os.path.join(self.cwd,"lib","bin","tmp",self.options["-f"])
        if not os.path.exists(tmp_path):
            #self._clean(tmp_path)
            os.mkdir(tmp_path)
        # Create WGS table
        path = ""
        fnames = [f.upper() for f in os.listdir(os.path.join(self.cwd,self.options["-i"],self.options["-f"]))]
        for ext in ('FA','FAS','FST','FSA','FASTA','FAA','FNN','FNA','ALN'):
            fname = ("%s.%s" % (self.options["-g"],ext)).upper()
            if fname in fnames:
                i = fnames.index(fname)
                path = os.path.join(self.cwd,self.options["-i"],self.options["-f"],fnames[i].lower())
                break
        #### If 'path' is empty, calculate OUP tree only 
        oProtdist = None
        prot_matrix = None
        prot_seqlist = None
        if path:
            oProtdist = protdist.Main()
            prot_matrix,prot_seqlist = oProtdist.execute(path,tmp_path)
        # Create OUP table
        path = ''
        path = self._prepare_fasta_input(tmp_path)
        oOUP = oup.OUP(self.options["-x"],"d-matrix",os.path.dirname(path),os.path.basename(path))
        oOUP.execute()
        oup_matrix,oup_seqlist = oOUP.get_distance_matrix()
        variances,oup_seqlist = oOUP.get_OUV()
        oup_seqlist = [item.strip() for item in oup_seqlist]
        OUP_WGS_table = []
        x_axis = []
        y_axis = []
        d_matrix = tools.matrix([len(oup_seqlist),len(oup_seqlist)])
        m = sum(variances)/len(variances)
        varRes = sum([(xi - m)**2 for xi in variances]) / len(variances)
        oFitting = None
        oCladogram = None
        oPhyloTree = None
        for i in range(len(oup_seqlist)-1):
            query = oup_seqlist[i].strip()
            for j in range(i+1,len(oup_seqlist),1):
                sbjct = oup_seqlist[j].strip()
                if self.options["-g"] != "":
                    #### Joined OUP + protdist tree
                    l = prot_seqlist.index(query)
                    m = prot_seqlist.index(sbjct)
                    OUP_WGS_table.append([query,sbjct,oup_matrix[i][j],prot_matrix[l][m]*self.reference_genes[self.options["-g"].lower()]])
                    x_axis.append(oup_matrix[i][j])
                    y_axis.append(prot_matrix[l][m]*self.reference_genes[self.options["-g"].lower()])
                else:
                    #### OUP tree only
                    d_matrix[i][j] = self._calculate_distance(oup_matrix[i][j],-1,variances,oup_seqlist)
        if OUP_WGS_table:
            # Cluster OTUs
            oFitting = verhulst.Verhulst(OUP_WGS_table,self.options["-r"])
            if self.options["-e"] == "Yes":
                oFitting.assign_default(0.0775227,1.33786626)
                for i in range(len(oup_seqlist)-1):
                    query = oup_seqlist[i].strip()
                    for j in range(i+1,len(oup_seqlist),1):
                        sbjct = oup_seqlist[j].strip()
                        l = prot_seqlist.index(query)
                        m = prot_seqlist.index(sbjct)            
                        d_matrix[i][j] = self._calculate_distance(oup_matrix[i][j],prot_matrix[l][m],oup_seqlist,query,sbjct,oFitting.g_param,oFitting.k_param)
            elif self.options["-e"] == "No":
                oFitting.set_matrix()
                oFitting.verhulst_fitting()
                oFitting.chisqr_criterion()
                oFitting.fill_matrix()
                cluster_matrix = []
                for row in oFitting.data_matrix[1:]:
                    cluster_matrix.append(row[1:])
                oCladogram = tree.Cluster(self.options["-f"])
                oCladogram.execute(cluster_matrix,oup_seqlist)
                oCladogram._format_phylip_matrix()
                for i in range(len(oup_seqlist)-1):
                    query = oup_seqlist[i].strip()
                    for j in range(i+1,len(oup_seqlist),1):
                        sbjct = oup_seqlist[j].strip()
                        l = prot_seqlist.index(query)
                        m = prot_seqlist.index(sbjct)
                        cluster_number = cluster_matrix[i][j]
                        if oFitting.minchisqr == 0:
                            g = oFitting.result1.best_values['mut']
                            k = oFitting.result1.best_values['k']
                        elif oFitting.minchisqr == 1:
                            if cluster_number == 0:
                                g = oFitting.result21.best_values['mut']
                                k = oFitting.result21.best_values['k']
                            elif cluster_number == 1:
                                g = oFitting.result22.best_values['mut']
                                k = oFitting.result22.best_values['k']
                        elif oFitting.minchisqr == 2:
                            if cluster_number == 0:
                                g = oFitting.result31.best_values['mut']
                                k = oFitting.result31.best_values['k']
                            elif cluster_number == 1:
                                g = oFitting.result32.best_values['mut']
                                k = oFitting.result32.best_values['k']
                            elif cluster_number == 2:
                                g = oFitting.result33.best_values['mut']
                                k = oFitting.result33.best_values['k'] 
                        d_matrix[i][j] = self._calculate_distance(oup_matrix[i][j],prot_matrix[l][m],oup_seqlist,query,sbjct,g,k)
        else:
            d_matrix = oup_matrix
        # Create distance matrix
        # calculate trees
        oPhyloTree = tree.Cladogram(self.options["-f"])
        #oPhyloTree.execute(d_matrix,oup_seqlist,True)
        #### New call with OUV outliers
        oPhyloTree.execute(d_matrix,oup_seqlist,True,self._get_OUV_outliers(variances,oup_seqlist,math.sqrt(varRes)))
        oPhyloTree._format_phylip_matrix()
        self.IO.save(oPhyloTree.tree,os.path.join(self.cwd,self.options["-o"],self.options["-f"]+"_phylo.tre"))
        # save outputs
        self.save_output(oFitting,oCladogram,oPhyloTree)
        
    def save_output(self,oFitting,oCladogram,oPhyloTree):
        if self.options["-d"] == "Yes" and oPhyloTree.phylip_matrix:
            self.IO.save(oPhyloTree.phylip_matrix,os.path.join(self.cwd,self.options["-o"],self.options["-f"]+"_table.txt"))
        if self.options["-m"] == "Yes" and oCladogram != None:
            self.IO.save(oCladogram.phylip_matrix,os.path.join(self.cwd,self.options["-o"],self.options["-f"]+"_matrix.txt"))
        if self.options["-t"] == "Yes" and oPhyloTree.tree:
            self.IO.save(oPhyloTree.tree,os.path.join(self.cwd,self.options["-o"],self.options["-f"]+".tre"))
            self.IO.save(oPhyloTree.svg(),os.path.join(self.cwd,self.options["-o"],self.options["-f"]+".svg"))
        if self.options["-l"] == "Yes" and oCladogram != None:
            self.IO.save(oCladogram.tree,os.path.join(self.cwd,self.options["-o"],self.options["-f"]+"_cladogram.tre"))
            self.IO.save(oCladogram.svg(),os.path.join(self.cwd,self.options["-o"],self.options["-f"]+"_cladogram.svg"))
        if self.options["-p"] == "Yes" and oFitting != None:
            self.IO.save(oFitting.svg(oFitting),os.path.join(self.cwd,self.options["-o"],self.options["-f"]+"_plot.svg"))
    
    def _prepare_fasta_input(self,tmp_path):
        fnames = [f for f in os.listdir(os.path.join(self.cwd,self.options["-i"],self.options["-f"])) if f[f.rfind(".")-len(list(f)):].upper() in [".GB",".GBK"]]
        if fnames != []:
            fasta = []
            for fname in fnames:
                dataset,seq,path = self.IO.openGBK(os.path.join(self.cwd,self.options["-i"],self.options["-f"],fname),"ALL")
                name = dataset['Accession']
                if not name:
                    name = dataset['Sequence name']
                if not name:
                    name = dataset['Sequence description']
                if not name:
                    name = fname[:fname.rfind(".")]
                fasta.append(">%s\n%s" % (name,seq))
            self.IO.save("\n".join(fasta),os.path.join(tmp_path,"tmp.fasta"))
            return os.path.join(tmp_path,"tmp.fasta")
        elif fnames == []:
            fasta_file = os.listdir(os.path.join(self.cwd,self.options["-i"],self.options["-f"]))
            with open(os.path.join(self.cwd,self.options["-i"],self.options["-f"],fasta_file[0])) as f:
                content = f.readlines()
            self.IO.save("".join(content),os.path.join(tmp_path,"tmp.fasta"))
            return os.path.join(tmp_path,"tmp.fasta")

    def _calculate_distance(self,D_oup,D_prot,seqlist,query="",sbjct="",g=0,k=0):
        n = self.options["-c"]
        if D_prot >= 0: 
            emp_gyra = -math.log(abs(2*k/(D_prot+k)-1))/g
            return ((D_oup+float(n)*emp_gyra)/(float(n)+1))
        return D_oup
    
    #### Get OUV outliers
    def _get_OUV_outliers(self,variances,seqlist,cutoff = 3):
        def get_values(v,avr,cutoff):
            if v - avr > cutoff:
                return "*"
            if v - avr < -cutoff:
                return "**"
            return ""

        avr_OUV = float(sum(variances))/len(variances)
        return dict(list(zip(seqlist,[get_values(variances[seqlist.index(seqname)],avr_OUV,cutoff) for seqname in seqlist])))

    # show command prompt interface
    def main_menu(self):
        response = ''
        while response != "Q":
            print("\nSeqWord Phylogenomics 2017/04/01")
            print()
            print("Settings for this run:\n")
            print(("  F    Folder to process\t: " + self.options["-f"]))
            print(("  G    Reference gene\t\t: " + self.options["-g"]))
            print(("  R    Cluster number\t\t: " + str(self.options["-r"])))
            print(("  C    N-Prot Contribution\t: " + str(self.options["-c"])))
            print(("  E    Default parameter\t: " + self.options["-e"]))
            print(("  D    Save distance table\t: %s" % self.options["-d"]))
            print(("  N    GC-normalization\t\t: %s" % self.options["-n"]))
            print(("  M    Save cluster matrix\t: %s" % self.options["-m"]))
            print(("  T    Save phylogenetic tree\t: %s" % self.options["-t"]))
            print(("  L    Save cladogram\t\t: %s" % self.options["-l"]))
            print(("  P    Save Verhulst plot\t: %s" % self.options["-p"]))
            print("  H    for help;")
            print("  Q    to quit;")
            print()
            print("Y to accept these settings, type the letter for one to change or Q to quit")
            print()
            try:
                response = input("?").upper()
                print()
            except:
                continue
            if response not in ("F","G","D","R","C","E","M","T","L","H","Y"):
                continue
            elif response == "Y":
                if not self.options["-x"]:
                    if self.options["-n"] == "Yes":
                        self.options["-x"] = "n1_4mer"
                    else:
                        self.options["-x"] = "n0_4mer"
                valid = self.oValidator.validate(self.options)
                if valid:
                    self.execute()
                continue
            elif response == "F":
                self.options['-f'] = input("Enter folder name? ")
                valid = self.oValidator.validate(self.options,"-f")
                if not valid:
                    self.options['-f'] = ""
                continue
            elif response == "G":
                self.options['-g'] = input("Enter reference gene? ")
                valid = self.oValidator.validate(self.options,"-g")
                if not valid:
                    self.options['-g'] = ""
                continue
            elif response == "R":
                self.options['-r'] = eval(input("Enter cluster number between 1 and 3? "))
                valid = self.oValidator.validate(self.options,"-r")
                if not valid:
                    self.options['-r'] = ""
                continue
            elif response == "C":
                self.options['-c'] = eval(input("Enter N-Protein Contribution between 1 and 3? "))
                valid = self.oValidator.validate(self.options,"-c")
                if not valid:
                    self.options['-c'] = ""
                continue
            elif response == "E":
                if self.options['-e'] == "Yes":
                    self.options['-e'] = "No"
                else:
                    self.options['-e'] = "Yes"
            elif response == "N":
                if self.options['-n'] == "Yes":
                    self.options['-n'] = "No"
                else:
                    self.options['-n'] = "Yes"
            elif response == "D":
                if self.options['-d'] == "Yes":
                    self.options['-d'] = "No"
                else:
                    self.options['-d'] = "Yes"
            elif response == "M":
                if self.options['-m'] == "Yes":
                    self.options['-m'] = "No"
                else:
                    self.options['-m'] = "Yes"
            elif response == "T":
                if self.options['-t'] == "Yes":
                    self.options['-t'] = "No"
                else:
                    self.options['-t'] = "Yes"
            elif response == "L":
                if self.options['-l'] == "Yes":
                    self.options['-l'] = "No"
                else:
                    self.options['-l'] = "Yes"
            elif response == "P":
                if self.options['-p'] == "Yes":
                    self.options['-p'] = "No"
                else:
                    self.options['-p'] = "Yes"
            elif response == "H":
                self.show_help()
                continue
            
    def show_help(self):
        os.chdir("/var/www/html/swphylo/swphylo")
        f = open("readme.txt","r")
        for lines in f:
            print (lines)
        f.close()
        response = ''
    
# Validator
class Validator:
    def __init__(self,reference_genes):
        self.reference_genes = reference_genes
        self.cwd = ""
        if __name__ == "__main__":
            self.cwd = ".."
        
    def validate(self,options,field=""):
        if not field:
            return self.validate_all(options)
        elif field in ("-i","-o"):
            return self.validate_path(options[field])
        elif field == "-x":
            return self.validate_pattern(options)
        elif field == "-g":
            return self.validate_gene(options)
        elif field == "-r":
            return self.validate_cluster(options)
        elif field == "-c":
            return self.validate_N(options)
        elif field == "-f":
            return self.validate_path(options[field],options["-i"])
        else:
            return True
        
    def validate_all(self,options):
        for p in ("-i","-o","-g","-p","-f","-e","c","r","-d","-m","-t","-l","-n","-x"):
            valid = self.validate(options,p)
            if not valid:
                print (p)
                return False
        return True
    
    def validate_pattern(self,options):
        if "-x" not in options:
            return
        if re.match("^n[0-6]_[1-7]mer",options["-x"]):
            return True
        return
    
    def validate_gene(self,options):
        if options["-g"].lower() in self.reference_genes:
            return True
        return False

    def validate_cluster(self,options):
        if int(options["-r"]) in [1,2,3]:
            return True
        return False 

    def validate_N(self,options):
        if float(options["-c"]) <= 3 and float(options["-c"]) >= 1:
            return True
        return False 
    
    def validate_coefficient(self,v):
        try:
            v = float(v)
            return True
        except:
            return False
    
    def validate_path(self,path,basepath=""):
        if os.path.exists(os.path.join(basepath,path)):
            return True
        return False
        
        
###############################################################################

if __name__ == "__main__":
    oInterface = Interface()
    
