#!/usr/bin/python
import string, math
import tools

class DendrogramBasics:
    def __init__(self,width=500,row_height=20,border=20,legend=35,taxa_field=50):
        self.taxaList = []
        self.title = ""
        self.tree_width = 1.0
        self.matrix = []
        self.oTree = [] # ['A', ['B', 'C']]
        self.flg_clustering = False
        self.outliers = []
        self.width = width
        self.top_border = legend
        self.left_border = border
        self.right_border = width-taxa_field
        self.bottom_border = legend
        self.row_height = row_height
        self.height = 0
        self.font_size = 14
        self.flg_print_indices=False
        self.flg_print_branch_lengths=True
        self.outlier_list = []
        
    def get_distance(self,ls):
        if len(ls) < 2:
            return [0]
        output = []
        for i in range(len(ls)-1):
            for j in range(i+1,len(ls),1):
                total = count = 0
                query_ls = self.get_members(ls[i])
                sbjct_ls = self.get_members(ls[j])
                for q in query_ls:
                    for s in sbjct_ls:
                        total += self.get_value(q,s)
                        count += 1
                output.append(total/count)
        return sum(output)/len(output)

    
    def get_value(self,query,sbjct):
        if query not in self.taxaList or sbjct not in self.taxaList:
            return None
        i = self.taxaList.index(query)
        j = self.taxaList.index(sbjct)
        return float(self.matrix[i][j])

    def size(self,ls,size=0):
        if type(ls)!=type([]):
            return size+1
        size = 0
        for item in ls:
            if type(item)==type([]):
                size += self.size(item,size)
            else:
                size += 1
        return size
    
    def get_members(self,ls,members=[]):
        if type(ls) != type([]):
            return [ls]
        members = []
        for item in ls:
            if type(item)==type([]):
                members += self.get_members(item,members)
            else:
                members.append(item)
        return members
    
    def index(self,taxon):
        if taxon not in self.taxaList:
            return 0
        return self.taxaList.index(taxon)+1

    def _is_end_node(self,ls):
        for item in ls:
            if type(item)==type([]):
                return False
        return True
    
    def _format_num(self,n,d=0):
        n *= float(10**d)
        num = str(n)
        n = int(n)
        if int(num[num.find(".")+1]) >= 5:
            n += 1
        return str(float(n)/10**d)
    
    def _format_str(self,title,n=25):
        if len(title) > n:
            return title[:n-3]+"..."
        return title
    
class CladogramGraph(DendrogramBasics):
    def __init__(self):
        DendrogramBasics.__init__(self)

    def set(self,matrix=[],taxaList=[],flg_clustering = False,outliers=[]):
        self.outliers = outliers
        if matrix:
            self.matrix = []
            self.matrix.extend(matrix)
            self.tree_width = max([max(matrix[i]) for i in range(len(matrix))])
            self.phylip_matrix = self._format_phylip_matrix()
        if taxaList:
            self.taxaList = []
            self.taxaList.extend(taxaList)
            self.height = self.row_height*len(self.taxaList)+self.top_border+self.bottom_border
        self.flg_clustering = flg_clustering
        
    def setTree(self,ls_tree,taxaList,flg_clustering = False):
        self.oTree = tools.copy_multilevel_ls(ls_tree)
        self.taxaList = []
        self.taxaList.extend(taxaList)
        self.flg_clustering = flg_clustering
    
    def depth(self,ls,level=1):
        depth = []
        for item in ls:
            if type(item)==type([]):
                depth.append(self.depth(item,level+1))
        if depth:
            return max(depth)
        else:
            return level+1
        
    def cluster_depth(self,ls):
        return (self.tree_width+4)/(self.get_distance(ls)+1)

    def svg(self,flg_clustering=False):
        self.flg_clustering = flg_clustering
        plot_width = self.right_border-self.left_border
        plot_height = self.row_height*len(self.taxaList)
        plot_left_indent = self.left_border
        plot_top_indent = self.top_border
        svg = ["<svg xmlns=\"http://www.w3.org/2000/svg\" viewbox=\"0 0 %d %d\">" % (self.width,self.height)]
        # White dot
        svg.append("<rect x=\"1\" y=\"1\" width=\"1\" height=\"1\" style=\"fill:wight;stroke-width:0;fill-opacity:0.0\" />")
        n_rows = 0
        coords = []
        for item in self.oTree:
            if type(item) == type([]):
                split,y = self.draw_split(item,plot_width,plot_height,plot_top_indent+n_rows*self.row_height,plot_left_indent)
                svg += split
                n_rows += self.size(item)
                coords.append(y)
            else:
                y = plot_top_indent+n_rows*self.row_height
                svg += self.draw_terminal_branch(item,plot_width,plot_left_indent,y)
                n_rows += 1
                coords.append(y)
        y1 = min(coords)
        y2 = max(coords)
        # Horizontal branch
        svg.append(("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" %
            (0,(y1+y2)/2,self.left_border,(y1+y2)/2)))
        if y1 != y2:
        # Verical branch
            svg.append(("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" %
                (self.left_border,y1,self.left_border,y2)))
        svg.append("</svg>")
        return "\n".join(svg)
    
    def draw_split(self,item,plot_width,plot_height,plot_top_indent,plot_left_indent):
        svg = []
        compartments = [0]*len(item)
        n_rows = 0
        if self.flg_clustering:
            width = plot_width/self.cluster_depth(item)
        else:
            width = plot_width/self.depth(item)
        coords = []

        if self.flg_clustering and not self.get_distance(item) and len(item)>1:
            item = [self.get_members(item)]
        if self.flg_clustering and self._is_end_node(item) and type(item[0])==type([]):
            n = self.size(item[0])
            compartments = [n]
            y = plot_top_indent+(n_rows+n/2)*self.row_height
            svg += self.draw_terminal_cluster(item,plot_width-width,plot_left_indent+width,y)
            coords.append(y)
            n_rows += n
        else:
            for i in range(len(item)):
                if self.flg_clustering and self._is_end_node(item[i]) and type(item[0])==type([]):
                    n = self.size(item[i])
                    compartments[i] = n
                    y = plot_top_indent+(n_rows+n/2)*self.row_height
                    svg += self.draw_terminal_cluster(item[i],plot_width-width,plot_left_indent+width,y)
                    coords.append(y)
                    n_rows += n

                elif type(item[i])==type([]):
                    n = self.size(item[i])
                    compartments[i] = n
                    split,y = self.draw_split(item[i],plot_width-width,n*self.row_height,plot_top_indent+n_rows*self.row_height,plot_left_indent+width)
                    svg += split
                    coords.append(y)
                    n_rows += n

                else:
                    compartments[i] = 1
                    y = plot_top_indent+n_rows*self.row_height+self.row_height/3
                    svg += self.draw_terminal_branch(item[i],plot_width-width,plot_left_indent+width,y)
                    coords.append(y)
                    n_rows += 1

        height = plot_height - self.row_height*math.ceil(compartments[0]/2) - self.row_height*math.ceil(compartments[-1]/2)
        top_border = plot_top_indent + self.row_height*math.ceil(compartments[0]/2)
        y1 = min(coords)
        y2 = max(coords)
        # Horizontal branch
        svg.append(("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" %
            (plot_left_indent,(y1+y2)/2,plot_left_indent+width,(y1+y2)/2)))
        # Verical branch
        x = plot_left_indent+width
        svg.append(("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" %
            (x,y1,x,y2)))
        if self.flg_clustering:
            svg.append(("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-size=\"%d\" style=\"text-anchor:start\">%s</text>" %
                (x+5,(y1+y2)/2-5,self.font_size,str(self.get_distance(item))[:5])))

        return svg,(y1+y2)/2
    
    def draw_terminal_branch(self,title,width,x,y):
        # Terminal horizontal branch
        svg = [("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" %
            (x,y,x+width,y))]
        # OTU title
        if self.outliers and title in self.outliers:
            title += self.outliers[title]
        svg.append(("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-size=\"%d\" style=\"text-anchor:start\">%s</text>" %
            (self.right_border+5,y+self.font_size/3,self.font_size,title)))
        return svg
    
    def draw_terminal_cluster(self,ls,width,x,y):
        indent = 2
        # Terminal horizontal branch
        svg = [("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" %
            (x,y,x+width/2,y))]
        # Cluster
        svg.append("<polygon points=\"%d,%d %d,%d %d,%d\" fill=\"black\" stroke=\"black\" stroke-linejoin=\"round\" />" %
            (x+width/2,y,x+width,y-len(ls)*self.row_height/2+indent,x+width,y+len(ls)*self.row_height/2-indent))
        for i in range(len(ls)):
            # OTU titles
            svg.append(("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-size=\"%d\" style=\"text-anchor:start\">%s</text>" %
                (self.right_border+5,y-len(ls)*self.row_height/2+indent+i*self.row_height+self.row_height/2+indent,self.font_size,ls[i])))
        return svg
    
class TreeAndClusterBasics(DendrogramBasics):
    def __init__(self):
        DendrogramBasics.__init__(self)
        self.split_widths = [0,0,0] # length of the left branch, root branch and right branch
        self.branch_length_correction = 1.0
        self.scale = 1.0
        self.cluster_width = 50
        self.plot_width = self.right_border - self.left_border
        self.plot_height = 0
        self.set_svg_parameters()
        self.Y_coords = []
    
    def set(self,matrix=[],taxaList=[],flg_clustering = True,outliers=[]):
        if matrix:
            self.matrix = []
            self.matrix.extend(matrix)
            self.tree_width = max([max(matrix[i]) for i in range(len(matrix))])
            self.phylip_matrix = self._format_phylip_matrix()
        if taxaList:
            self.taxaList = []
            self.taxaList.extend(taxaList)
            self.height = self.row_height*len(self.taxaList)+self.top_border+self.bottom_border
        self.flg_clustering = flg_clustering
        self.outliers = outliers

    def setTree(self,oTree,taxaList,flg_clustering = True,clustering_cutoff=0):
        self.oTree = tools.copy_multilevel_ls(oTree)
        self.taxaList = []
        self.taxaList.extend(taxaList)
        self.flg_clustering = flg_clustering
        self.cluster_neighbours(self.oTree,clustering_cutoff)
        self.set_splits()
        self.scale = float(self.plot_width)/self.tree_width
        if self.oTree and len(self.oTree[0])==2:
            self.plot_height = max(self.size(self.oTree[0][0]),self.size(self.oTree[0][1]))*self.row_height
            self.height = self.plot_height + self.top_border + self.bottom_border + self.size(self.oTree[0][1])*self.row_height + 2*self.row_height
        
    def cluster_neighbours(self,ls,cutoff=0):
        if not ls or type(ls) != type([]):
            return
        if len(ls)==1:
            self.cluster_neighbours(ls[0],cutoff)
            return
        if self.get_distance(ls)==0:
            return self.get_members(ls)
        for i in range(len(ls)):
            subbranch = self.cluster_neighbours(ls[i],cutoff)
            if subbranch:
                ls[i] = subbranch
        
    def get_branch_length(self,ls,length=0):
        if not ls or type(ls) != type([]):
            return length
        length += self.get_distance(ls)/2.0
        series = []
        for i in range(len(ls)):
            series.append(self.get_branch_length(ls[i],length))
        series.sort()
        return series[-1]
    
    def draw_split(self,item,x,y,flg_branch="right"):
        svg = []
        if flg_branch == "left":
            k = -1
        elif flg_branch == "right":
            k = 1
        vline = []
        for i in range(len(item)):
            vline.append(self.size(item[i]))
        plot_height = sum(vline)*self.row_height
        y -= plot_height/2
        for i in range(len(vline)-1,0,-1):
            vline[i] = y + sum(vline[:i+1])*self.row_height - vline[i]*self.row_height/2
        vline[0] = y + vline[0]*self.row_height/2
        # Horizontal branches
        branch_length = self.branch_length_correction*self.scale*self.get_distance(item)/2.0
        x0 = x + k*branch_length
        for i in range(len(vline)):
            svg.append("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" % (x,vline[i],x0,vline[i]))
            if self.flg_print_branch_lengths:
                svg.append("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-size=\"%d\" style=\"text-anchor:middle\">%s</text>" % 
                    (x+k*self.font_size,vline[i]-3,self.font_size-3,self._format_num(float(abs(x0-x))/self.scale,2)))
            if type(item[i]) == type([]) and not self._is_end_node(item[i]):
                split,y = self.draw_split(item[i],x0,vline[i],flg_branch)
            else:
                split,y = self.draw_terminal_branch(item[i],x0,vline[i],abs(x0-x),flg_branch)
            svg += split
        # Vertical split line
        svg.append("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" % (x,vline[0],x,vline[-1]))
        return svg,(vline[0]+vline[-1])/2
        
    def draw_terminal_branch(self,item,x,y,branch_length=0,flg_branch="right"):
        svg = []
        if flg_branch == "left":
            k = -1
            x0 = min(x+self.cluster_width,x+branch_length/2)
            anchor = "end"
        elif flg_branch == "right":
            k = 1
            x0 = max(x-self.cluster_width,x-branch_length/2)
            anchor = "start"
        if self.flg_clustering and type(item) == type([]) and len(item) > 1:
            y1 = y-self.size(item)*self.row_height/2
            y2 = y+self.size(item)*self.row_height/2
            svg.append("<polygon points=\"%d,%d %d,%d %d,%d\" fill=\"black\" stroke=\"black\" stroke-linejoin=\"round\" />" %
                (x0,y,x,y1+2,x,y2-2))
            for i in range(len(item)):
                taxon = item[i]
                if self.outliers and taxon in self.outliers:
                    taxon += self.outliers[taxon]
                if self.flg_print_indices:
                    taxon = str(self.index(taxon))
                else:
                    taxon = self._format_str(taxon)
                y0 = y1 + self.row_height/2 + i*self.row_height + self.font_size/3
                svg.append("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-weight=\"bold\" font-size=\"%d\" style=\"text-anchor:%s\">%s</text>" %
                    (x+3*k,y0,self.font_size,anchor,taxon))
                self.Y_coords.append(y0)
        elif type(item) == type([]) and len(item) > 1:
            y1 = y-self.size(item)*self.row_height/2
            y2 = y+self.size(item)*self.row_height/2
            # Vertical cluster line
            svg.append("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" % (x,y1+2,x,y2-2))
            for i in range(len(item)):
                taxon = item[i]
                if self.flg_print_indices:
                    taxon = str(self.index(taxon))
                else:
                    taxon = self._format_str(taxon)
                y0 = y1 + self.row_height/2 + i*self.row_height + self.font_size/3
                svg.append("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-weight=\"bold\" font-size=\"%d\" style=\"text-anchor:%s\">%s</text>" %
                    (x+3*k,y0,self.font_size,anchor,taxon))
                self.Y_coords.append(y0)
        else:
            y1 = y2 = y
            if self.flg_print_indices:
                taxon = str(self.index(item))
            else:
                taxon = self._format_str(item)
            svg.append("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-weight=\"bold\" font-size=\"%d\" style=\"text-anchor:%s\">%s</text>" %
                (x+3*k,y+self.font_size/3,self.font_size,anchor,taxon))
            self.Y_coords.append(y)
        return svg,(y1+y2)/2
    
    def draw_title(self):
        svg = []
        svg.append("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-weight=\"bold\" font-size=\"%d\" style=\"text-anchor:%s\">%s</text>" %
            (self.left_border,self.top_border-self.font_size*1.5,self.font_size+2,"start",self.title))
        return svg
    
    def draw_legend(self):
        svg = []
        Y = self.plot_height+3*self.row_height
        indend = 30
        svg.append("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-weight=\"bold\" font-size=\"%d\" style=\"text-anchor:start\">LEGEND:</text>" %
            (self.left_border,Y,self.font_size+2))
        k = 0
        for i in range(len(self.taxaList)):
            Y += self.row_height
            svg.append("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-size=\"%d\" style=\"text-anchor:start\">%s</text>" %
                (self.left_border+indend+k*self.plot_width/2,Y,self.font_size,"%d: %s" % (i+1,self._format_str(self.taxaList[i]))))
            if i+1 == len(self.taxaList)/2:
                Y = self.plot_height+3*self.row_height
                k = 1
        return svg
    
    def draw_ruler(self):
        svg = []
        svg.append("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" % 
            (self.left_border,self.height-3,self.left_border+self.scale,self.height-3))
        svg.append("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" % 
            (self.left_border-1,self.height-6,self.left_border-1,self.height))
        svg.append("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" % 
            (self.left_border+self.scale+1,self.height-6,self.left_border+self.scale+1,self.height))
        svg.append("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-size=\"%d\" style=\"text-anchor:middle\">1.0</text>" %
            (self.left_border+self.scale/2,self.height-8,self.font_size-3))
        return svg
    
class TreeGraph(TreeAndClusterBasics):
    def __init__(self):
        TreeAndClusterBasics.__init__(self)

    def set_svg_parameters(self):
        if self.oTree and self.oTree[0]:
            self.plot_height = self.size(self.oTree[0])*self.row_height
            self.height = self.plot_height + self.top_border + 2*self.row_height
        
    def set_splits(self):
        self.split_widths = [0,0,[]]
        if not self.oTree or len(self.oTree[0])!=2:
            return
        self.split_widths[2] = self.get_branch_length(self.oTree[0][1])
        self.branch_length_correction = float(self.tree_width)/sum(self.split_widths)
        if self.branch_length_correction > 1:
            self.branch_length_correction = 1.0
        self.split_widths = [v*self.branch_length_correction for v in self.split_widths]
        
    def svg(self,flg_clustering=False):
        self.flg_clustering = flg_clustering
        svg = ["<svg xmlns=\"http://www.w3.org/2000/svg\" viewbox=\"0 0 %d %d\">" % (self.width,self.height)]
        # TITLE
        svg += self.draw_title()
        # Add branches
        y0 = self.top_border+max([self.size(self.oTree[0][i]) for i in range(len(self.oTree[0]))])*self.row_height
        if self.oTree and len(self.oTree[0])==2:
            if type(self.oTree[0]) == type([]) and not self._is_end_node(self.oTree[0]):
                split,y = self.draw_split(self.oTree[0],self.left_border,y0)
            else:
                split,y = self.draw_terminal_branch(self.oTree[0],self.left_border,y0)
            svg += split
        # Horizontal root branch
        svg.append("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" % (0,y0,self.left_border,y0))
        # Legend
        svg += self.draw_ruler()
        if self.flg_print_indices:
            svg += self.draw_legend()
        svg.append("</svg>")
        return "\n".join(svg)
    
class ClusterGraph(TreeAndClusterBasics):
    def __init__(self):
        TreeAndClusterBasics.__init__(self)
        self.flg_print_indices=True

    def set_svg_parameters(self):
        if self.oTree and len(self.oTree[0])==2:
            self.plot_height = max(self.size(self.oTree[0][0]),self.size(self.oTree[0][1]))*self.row_height
            self.height = self.plot_height + self.top_border + self.size(self.oTree[0])*2*self.row_height + 2*self.row_height
            #self.height = self.plot_height + self.top_border + self.bottom_border + self.size(self.oTree[0][1])*self.row_height + 2*self.row_height
        
    def set_splits(self):
        self.split_widths = [[],0,[]]
        if not self.oTree or len(self.oTree[0])!=2:
            return
        self.split_widths[1] = self.get_distance(self.oTree[0])
        self.split_widths[0] = self.get_branch_length(self.oTree[0][0])
        self.split_widths[2] = self.get_branch_length(self.oTree[0][1])
        self.branch_length_correction = float(self.tree_width)/sum(self.split_widths)
        self.split_widths = [v*self.branch_length_correction for v in self.split_widths]
        
    def svg(self,flg_clustering=True):
        self.flg_clustering = flg_clustering
        svg = ["<svg xmlns=\"http://www.w3.org/2000/svg\" viewbox=\"0 0 %d %d\">" % (self.width,self.height)]
        # White dot
        svg.append("<rect x=\"1\" y=\"1\" width=\"1\" height=\"1\" style=\"fill:wight;stroke-width:0;fill-opacity:0.0\" />")
        # TITLE
        svg += self.draw_title()
        # Cluster shading
        cluster_width = float(self.plot_width)/(self.tree_width+1)
        for i in range(0,int(self.tree_width)+1,1):
            # Zones
            svg.append(("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-size=\"%d\" font-weight=\"bold\" style=\"text-anchor:middle\">%s</text>" %
                (self.left_border+cluster_width*i+cluster_width/2,                      #x
                self.top_border-5,                                                      #y
                self.font_size,                                                         #font
                "Zone %d" % (i+1))))                                                    #text
            if i%2==1:
                # Shaded areas
                svg.append("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" style=\"fill:aliceblue;stroke-width:0;\" />" %
                    (self.left_border+cluster_width*i,self.top_border,cluster_width,self.plot_height))
                # Ruler segment
                svg.append("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"2\" stroke-linejoin=\"round\" />" % 
                    (self.left_border+cluster_width*i,self.top_border+self.plot_height+1,self.left_border+cluster_width+cluster_width*i,self.top_border+self.plot_height+1))
                svg.append("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"2\" stroke-linejoin=\"round\" />" % 
                    (self.left_border+cluster_width+cluster_width*i,self.top_border+self.plot_height+1,self.left_border+cluster_width+cluster_width*i,self.top_border+self.plot_height+11))
                svg.append("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-size=\"%d\" style=\"text-anchor:end\">%s</text>" % 
                    (self.left_border+cluster_width+cluster_width*i-3,self.top_border+self.plot_height+self.font_size,self.font_size-3,self._format_num(float(i+1)*self.tree_width/(self.tree_width+1),2)))
            else:
                # Ruler segment
                svg.append("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" % 
                    (self.left_border+cluster_width*i,self.top_border+self.plot_height+1,self.left_border+cluster_width+cluster_width*i,self.top_border+self.plot_height+1))
                svg.append("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" % 
                    (self.left_border+cluster_width+cluster_width*i,self.top_border+self.plot_height+1,self.left_border+cluster_width+cluster_width*i,self.top_border+self.plot_height+11))
                svg.append("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-size=\"%d\" style=\"text-anchor:end\">%s</text>" % 
                    (self.left_border+cluster_width+cluster_width*i-3,self.top_border+self.plot_height+self.font_size,self.font_size-3,self._format_num(float(i+1)*self.tree_width/(self.tree_width+1),2)))
                
        # Root node and root branch
        x1 = self.left_border+self.split_widths[0]*self.scale
        y0 = self.top_border+self.plot_height/2
        x2 = self.left_border+(self.split_widths[0]+self.split_widths[1])*self.scale
        # Horizontal root branch
        svg.append("<line x1=\"%d\" y1=\"%d\" x2=\"%d\" y2=\"%d\" stroke=\"black\" stroke-width=\"1\" stroke-linejoin=\"round\" />" % (x1,y0,x2,y0))
        svg.append("<text x=\"%d\" y=\"%d\" font-family=\"Times New Roman\" font-size=\"%d\" style=\"text-anchor:middle\">%s</text>" % 
            ((x1+x2)/2,y0-8,self.font_size-3,self._format_num(float(abs(x2-x1))/self.scale,2)))
        # Root node
        svg.append("<circle cx=\"%d\" cy=\"%d\" r=\"5\" stroke=\"black\" stroke-width=\"1\" fill=\"black\" />" % (float(x1+x2)/2,y0))
        # Add branches
        if self.oTree and len(self.oTree[0])==2:
            for i in range(2):
                flg_branch = "left"
                x = x1
                if i:
                    flg_branch = "right"
                    x = x2
                if type(self.oTree[0][i]) == type([]) and not self._is_end_node(self.oTree[0][i]):
                    split,y = self.draw_split(self.oTree[0][i],x,y0,flg_branch)
                else:
                    split,y = self.draw_terminal_branch(self.oTree[0][i],x,y0,abs(x2-x1)/2,flg_branch)
                svg += split
        
        # LEGEND
        if self.flg_print_indices:
            svg += self.draw_legend()
        svg.append("</svg>")
        return "\n".join(svg)
    
'''
This module was modified by O. Reva for the program SWPhylo in 2016/05/29
*******************************************************************************
tree.py

By: Jason Pell

An implementation of the neighbor-joining algorithm.

References:
N Saitou, M Nei. The neighbor joining method: a new method for reconstructing
phylogenetic trees. Molecular Biology and Evolution 4 (1987): no. 4, 406-425.
*******************************************************************************
'''
class NJ:
    def __init__(self):
        self.taxaList = []
        self.matrix = []
        self.tree = ""
        self.oTree = []
        self.phylip_matrix = ""
        
    def execute(self,matrix=[],taxaList=[],flg_clustering = False,outliers=[]):
        self.set(matrix,taxaList,flg_clustering,outliers)
        return self.doNeigJoin(matrix,taxaList)
        
    def calcDist(self, distMat, i, j):
       ''' 
       Returns the distance between two taxa.

       Receive: Distance matrix and the two taxa.
       Return: Distance between taxa.
       '''

       if i < j:
          i, j = j, i
       return distMat[i][j]

    def calcDistSum(self, distMat, i):
       '''
       Calculates the sum of distances for a taxa.

       Receive: Distance matrix and a taxa.
       Return: The sum of distances.
       '''

       sum = 0

       for k in range(len(distMat)):
          sum += distMat[i][k]

       for k in range(len(distMat)):
          sum += distMat[k][i]

       return sum

    def calcQ(self, distMat, i, j):
       '''
       Calculates the Q value for two taxa.

       Receive: Distance matrix and two taxa.
       Return: Q value between the two taxa.
       '''

       return (len(distMat)-2)*self.calcDist(distMat, i, j) - \
               self.calcDistSum(distMat, i) - \
               self.calcDistSum(distMat, j)

    def calcQMat(self, distMat):
       '''
       Calculates the entire Q matrix.

       Receive: Distance matrix.
       Return: Q matrix.
       '''

       q = tools.matrix((len(distMat),len(distMat)), int)

       for i in range(1, len(distMat)):
          for j in range(i):
             q[i][j] = self.calcQ(distMat, i, j)

       return q

    def calcDistOldNew(self, distMat, i, j):
       '''
       Calculates the distance from each of the old taxas (i,j) to the new combined taxa (ij).

       Receive: Distance matrix and the two old taxas.
       '''

       return (.5)*(self.calcDist(distMat, i, j)) + ((1./(2*(len(distMat)-2))) * \
              (self.calcDistSum(distMat,i) - self.calcDistSum(distMat, j)))

    def calcDistNewOther(self, distMat, k, f, g):
       '''
       Calculates the distance between the new combined taxa (fg) and another taxa (k).

       Receive: The distance matrix, old taxas (f,g) and the other taxa (k).
       Return: Distance between (fg) and k.
       '''

       return (.5)*(self.calcDist(distMat,f,k) - self.calcDistOldNew(distMat, f, g)) + \
              (.5)*(self.calcDist(distMat,g,k) - self.calcDistOldNew(distMat, g, f))
       
    def minQVal(self, q):
        '''
        Finds the minimum Q value, which will be the next combined taxa.

        Receive: Q matrix.
        Return: The minimum Q value and each of the taxa.
        '''

        iMin = 0
        jMin = 0
        qMin = 0

        for i in range(len(q)):
          for j in range(len(q)):
             if i==j:
                continue
             if min(qMin, q[i][j]) == q[i][j]:
                qMin = q[i][j]
                iMin = i
                jMin = j
        
        joined_taxa = [iMin, jMin]
        return qMin, joined_taxa

    def doNeigJoin(self,mat,taxaList):
        '''
        Recursively executes the neighbor-join algorithm.  Ends when the size of the distance matrix is 1.

        Receive: The distance matrix, and a list of the taxa names.
        Return: If finished, the completed tree.  Otherwise, the next distance matrix and taxa list.

        Algorithm (when the matrix is greater than size 1):
          1. Calculate Q matrix.
          2. Find lowest value in Q matrix and combine corresponding taxa.
          3. Calculate new distance matrix that is one size smaller than previous step.
          4. Recursively call function until size is 1.
        '''
        if len(mat) == 1:
            self.setTree(taxaList,self.taxaList,self.flg_clustering)
            self.tree = self._format_tree()
            return self.tree

        q = self.calcQMat(mat)

        #minQ, taxaA, taxaB = self.minQVal(q)
        minQ, joined_taxa = self.minQVal(q)
        
        # initialize our new distance matrix
        #newMat = tools.matrix((len(mat)-1, len(mat)-1), float)
        newMat = tools.matrix((len(mat)-len(joined_taxa)+1, len(mat)-len(joined_taxa)+1), float)

        # combine old taxa in taxa list to create new taxalist
        oldTaxaList = taxaList[:]
        #oldTaxaList.remove(taxaList[taxaA])
        #oldTaxaList.remove(taxaList[taxaB])
        for taxon in joined_taxa:
            oldTaxaList.remove(taxaList[taxon])
        newTaxaList = [[taxaList[taxon] for taxon in joined_taxa]] + oldTaxaList
        #newTaxaList = [[taxaList[taxaA], taxaList[taxaB]]] + oldTaxaList
        #newTaxaList = [[str(taxaList[taxaA])+":%f" % self._get_branch_length(minQ), str(taxaList[taxaB])+":%f" % self._get_branch_length(minQ)]] + oldTaxaList

        # calculate new distance matrix for new combined taxa values
        for i in range(1, len(newMat)):
          oldI = taxaList.index(newTaxaList[i])
          newMat[i][0] = self.calcDistNewOther(mat, oldI, joined_taxa[0], joined_taxa[-1])
          #newMat[i][0] = self.calcDistNewOther(mat, oldI, taxaB, taxaA)

        # copy over everything else from old distance matrix
        for i in range(2, len(newMat)):
          for j in range(1, len(newMat)-1):
             oldI = taxaList.index(newTaxaList[i])
             oldJ = taxaList.index(newTaxaList[j])
             newMat[i][j] = mat[oldI][oldJ]

        return self.doNeigJoin(newMat, newTaxaList)

    def getMaxInMatrix(self, mat):
       '''
       Finds the maximum value in a given matrix.

       Receive: A matrix.
       Recturn: The maximum value in the matrix.
       '''

       maxVal = 0

       for i in range(len(mat)):
          for j in range(len(mat)):
             if max(maxVal, mat[i][j]) == mat[i][j]:
                maxVal = mat[i][j]

       return maxVal

    def normalizeMatrix(self, mat):
       '''
       Takes the BLOSUM alignment scores and normalizes the matrix so that the most similar sequences will
       have a smaller distance.  The most closely related taxa will have a score of 0.

       Receive: A matrix of BLOSUM alignment scores.
       Return: A distance matrix.
       '''

       maxVal = self.getMaxInMatrix(mat)

       for i in range(1, len(mat)):
          for j in range(i):
             mat[i][j] = maxVal - mat[i][j]

       return mat

    def createTree(self, mat, taxaList):
       '''
       Calculates an unrooted tree based on a BLOSUM sequence alignment matrix.

       Receive: A matrix of BLOSUM scores and a taxa list.
       Return: A string representing the unrooted tree.
       '''

       mat = self.normalizeMatrix(mat)

       return self.doNeigJoin(mat, taxaList)
    
    def get_phylip_table(self):
        table = ["%d" % len(self.taxaList)]
        abbreviations = self._phylip_name_format(self.taxaList)
        for i in range(len(self.taxaList)):
            table.append("%s " % abbreviations[i])
            for j in range(len(self.matrix[i])):
                table[-1] += ("%f " % max(float(self.matrix[i][j]),float(self.matrix[j][i])))
        table.append("\nAbbreviations:")
        for i in range(len(self.taxaList)):
            table.append("\t%s\t%s" % (abbreviations[i],self.taxaList[i]))
        return "\n".join(table)
    
    def __repr__(self):
        return self.tree
    
    def _get_branch_length(self,minQ):
        return 1.0
    
    def _format_tree(self):
        #return str(self.oTree).replace(' ', '').replace('[','(').replace(']',')').replace('\'','').replace('\"','').replace('\\','') + ";"
        def get_clade(ls):
            clade = "("
            for item in ls:
                if type(item)==type([]):
                    clade += "%s," % get_clade(item)
                else:
                    clade += "%s," % item
            if clade and clade[-1]==",":
                clade = clade[:-1]
            return clade + ")"
        return "%s;" % get_clade(self.oTree)[1:-1]
    
    def _format_phylip_matrix(self):
        abbreviations = self._phylip_name_format(self.taxaList)
        table = ["%d" % len(self.taxaList)]
        for i in range(len(self.taxaList)):
            table.append(abbreviations[i])
            for j in range(len(self.taxaList)):
                table[-1] += (" %f" % float(self.matrix[i][j]+self.matrix[j][i]))
        table.append("\nAbbreviations:")
        for i in range(len(self.taxaList)):
            table.append("\t%s\t%s" % (abbreviations[i],self.taxaList[i]))
        self.phylip_matrix = "\n".join(table)
        
    def _phylip_name_format(self,seqlist):
        abbreviations = []
        for name in seqlist:
            if len(name) > 10:
                name = name[:10]
            elif len(name) < 10:
                name += (" "*(10-len(name)))
            n = 1
            while name in abbreviations:
                name = "%s_%d" % (name[:-(len(str(n))+1)],n)
                n += 1
            abbreviations.append(name)
        return abbreviations
    
class Cladogram(NJ,CladogramGraph):
    def __init__(self,title=""):
        NJ.__init__(self)
        CladogramGraph.__init__(self)
        self.title = "Cladogram: " + title
        self.flg_print_indices=False
        self.flg_print_branch_lengths=False
    
class Tree(NJ,TreeGraph):
    def __init__(self,title=""):
        NJ.__init__(self)
        TreeGraph.__init__(self)
        self.title = "Tree: " + title
        self.flg_print_indices=False
        self.flg_print_branch_lengths=False
    
class Cluster(NJ,ClusterGraph):
    def __init__(self,title=""):
        NJ.__init__(self)
        ClusterGraph.__init__(self)
        self.title = "Clustering: " + title
        self.flg_print_indices=True
        self.flg_print_branch_lengths=True

if __name__ == "__main__":
    
    def read_matrix(fname):
        if not os.path.exists(fname):
            print()
            print(("Wrong file name %s" % fname))
            return
        f = open(fname)
        n = int(f.readline())
        taxaList = []
        mat = [[]]*n
        for i in range(n):
            row = f.readline()
            taxaList.append(row[:10].strip())
            mat[i] = [float(v) for v in row[11:].split(" ")]
        f.close()
        # check consistency
        for i in range(len(taxaList)-1):
            for j in range(i+1,len(taxaList),1):
                if mat[i][j] != mat[j][i]:
                    v = max(mat[i][j],mat[j][i])
                    mat[i][j] = mat[j][i] = v
        return taxaList,mat
        
    
    import os, seq_io
    oIO = seq_io.IO()
    #fname = "input_matrix.txt"
    
    for name in ("Bacillus","Corynebacteria","Enterobacteria","Lactobacillus","Prochlorococcus2","Prochlorococcus3","Pseudomonas","Thermotoga"):
        print(name)
        taxaList,mat = read_matrix("%s_matrix.txt" % name)
        oCluster = Cluster(name)
        oCluster.execute(mat, taxaList)
        oIO.save(oCluster.svg(),"%s_cluster.svg" % name)
    '''    
    for name in ("Corynebacteria","Lactobacillus","Prochlorococcus","Pseudomonas","Thermotoga","Mycobacterium","Bacillus","Enterobacteria"):
        print name
        taxaList,mat = read_matrix("%s_table.txt" % name)
        oTree = Tree(name)
        oTree.execute(mat, taxaList)
        oIO.save(oTree.svg(),"%s_tree.svg" % name)
    '''  
    
    '''
    oTree = Tree()
    oTree.execute(mat, taxaList, False)
    oIO.save(oTree.svg(),"Mtb_tree.svg")
    '''
    
    '''
    taxaList,mat = read_matrix(fname)
    oTree = Tree()
    oTree.execute(mat, taxaList, True)
    oIO.save(oTree.svg(),"cluster.svg")
    
    
    taxaList = ['A', 'B', 'C', 'D', 'E', 'F']

    print taxaList

    mat = tools.matrix((6,6), float)
    mat[1][0] = 5
    mat[2][0] = 4
    mat[2][1] = 7
    mat[3][0] = 7
    mat[3][1] = 10
    mat[3][2] = 7
    mat[4][0] = 6
    mat[4][1] = 9
    mat[4][2] = 6
    mat[4][3] = 5
    mat[5][0] = 8
    mat[5][1] = 11
    mat[5][2] = 8
    mat[5][3] = 9
    mat[5][4] = 8

    oTree = Tree()
    oTree.execute(mat, taxaList, False)
    
    #oTree.setTree([['A','B','C'],[['D','E'],'F']],taxaList,True)
    #oTree.setTree([['A','B','C'],[['D','E'],'F']],taxaList,False)
    
    print oTree
    print oTree.oTree
    #print oGraph.calculate_size(oGraph.tree)
    oIO.save(oTree.tree,"tree.tre")
    #oIO.save(oTree.get_phylip_table(),"matrix.txt")
    oIO.save(oTree.svg(),"tree.svg")
    '''
