# This file analyzes SWC structure. 

from numpy import *
from pylab import *
import os, os.path, glob, re, fnmatch, sys

class LinkedPoint:
	def __init__(self,x,y,z,r,ID,Type,zSig):
		self.x = x; self.y = y; self.z = z; self.r = r; self.ID = ID; self.Type = Type; self.zSig = zSig; self.conn = set()
	def addConn(self,ID):
		self.conn.add(ID)
	def delConn(self,ID):
		self.conn.discard(ID)	
	def numConn(self):
		return len(self.conn)

def getLinkedPointsFromSWC(swcfilename,nx=0,ny=0,bw=None,bwz=None,bws=None,idOffset=0,xOffset=0,yOffset=0,zOffset=0):
	# reload the swc file
	f = open(swcfilename)
	lines = f.readlines()
	f.close()

	# parse the swc file
	LinkedPoints = []
	pIDs = []
	IDs = []
	for line in lines:
		ll = line.split(' ')
		if ll[0] == '#':
			continue
		nn = float(ll[0]) + idOffset	# label of the point
		tp = float(ll[1])  	# point type
		py = float(ll[2]) + yOffset	# note the inversion of x, y.
		px = float(ll[3]) + xOffset
		r = float(ll[5])	# radius of the sphere. 
		zsig = 1
		z = float(ll[4])
		if bw!= None and bwz != None and bws !=None:
			# find z. 
			zAveSizeMin = 2		# minimum size of pixels for averaging z around each point. 
			zz = []
			zs = []
			zAveSize = min(zAveSizeMin,int(round(r * 0.3)))
			for ii in range(max(0,int(px)-zAveSize),min(nx,int(px)+zAveSize+1)):
				for jj in range(max(0,int(py)-zAveSize),min(ny,int(py)+zAveSize+1)):
					if bw[ii,jj] == False:
						continue;
					zz.append(bwz[ii,jj])
					zs.append(bws[ii,jj])
			if len(zz) > 0:
				z = mean(zz)
				zsig = sum(zs) == len(zs)
		z  += zOffset	# z
		np = float(ll[6]) 
		if np != -1:
			np += idOffset	# parent point id. 			
		pIDs.append(np)
		IDs.append(nn)
		LinkedPoints.append(LinkedPoint(px,py,z,r,nn,tp,zsig))

	for ii in range(len(pIDs)):
		pid = pIDs[ii]
		if pid == -1:
			continue
		iid=IDs.index(pid)
		LinkedPoints[ii].addConn(pid)
		LinkedPoints[iid].addConn(IDs[ii])	
	return LinkedPoints

def analyzeStructure(swcFilename):
		 
	print 'Analyzing bursting cell ', swcFilename
	apicalBranches, basalBranches, axonBranches = getBranchedStructure(swcFilename)

	apicalArea, apicalLen, apicalNumBranch, apicalBranchOrders, apicalBranchRadii \
	= getBranchStats(apicalBranches)	
	print 'Apical totArea: ',apicalArea,' totLen: ',apicalLen,' numBranch: ',apicalNumBranch

	basalArea, basalLen, basalNumBranch, basalBranchOrders, basalBranchRadii \
	= getBranchStats(basalBranches)	
	print 'Basal totArea: ',basalArea,' totLen: ',basalLen,' numBranch: ',basalNumBranch

	axonArea, axonLen, axonNumBranch, axonBranchOrders, axonBranchRadii \
	= getBranchStats(axonBranches)	
	print 'Axon totArea: ',axonArea,' totLen: ',axonLen,' numBranch: ',axonNumBranch
	
	areaDen = apicalArea + basalArea
	lenDen = apicalLen + basalLen
	numBranchDen = apicalNumBranch + basalNumBranch

	print 'Total dendrite area: ',areaDen
	print 'Total dendrite length:', lenDen
	print 'Total dendrite number of branches : ',numBranchDen 

def getBranchedStructure(filename):
	# from a swc scaled file create the branch structure of the neuron. 
	# the swc should contain annotations of the points, 1, soma, 2, axon, 3, basal dendrite, 4 apical dendrite
	
	if filename.find("scaled") == -1:
		print "The swc file "+filename+" is not scaled. Need a swc file with .scaled.swc"
		return 
	linkedPoints = getLinkedPointsFromSWC(filename)
	nP = len(linkedPoints)

	types = []
	numCs = []
	IDInds = {}		
	for ii in range(nP):
		pp = linkedPoints[ii]
		types.append(pp.Type)
		numCs.append(pp.numConn())
		IDInds[pp.ID] = ii
	# check if the types are set 
	utypes = unique(types)
	if len(find((utypes < 1) & (utypes > 4))) != 0:
		print 'The branch types are not set correctly. Reminder, soma 1, axon 2, basal, 3, apical 4'
		return	
		
	# find the starting points for the segments
	print 'Finding the starting points of axon, apical dendrite and basal dendrite...'
	axonStartInds = []
	apicalStartInds = []
	basalStartInds = []
	flagUsed = zeros(nP)
	for ii in range(nP):
		pp = linkedPoints[ii]
		if pp.Type == 1:	# soma point
			flagUsed[ii] = 1
			continue
		for id in pp.conn:
			iid = IDInds[id]
			if linkedPoints[iid].Type == 1:
				if pp.Type == 2:
					axonStartInds.append(ii)
				elif pp.Type == 3:
					basalStartInds.append(ii)
				elif pp.Type == 4:
					apicalStartInds.append(ii)
	print 'Number of starting points: axon ',len(axonStartInds), \
			' apical dendrite ',len(apicalStartInds), 	\
			' basal dendrite ',len(	basalStartInds)	

	print 'Getting apical dendritic branches ...'
	apicalBranches  = []
	for startPointInd in apicalStartInds:
		branches = []; connections = []
		getSegment(startPointInd,linkedPoints,IDInds,flagUsed, branches, connections)
		branchOrders = zeros(len(branches))
		branchOrders[0] = 1
		branchParents = zeros(len(branches))
		branchParents[0] = -1
		getBranchOrders(0,connections,branchOrders,branchParents)
		for ii in range(len(branches)):
			branches[ii] += (branchOrders[ii],branchParents[ii])
		apicalBranches.append(branches)

	print 'Getting basal dendritic branches ...'
	basalBranches  = []
	for startPointInd in basalStartInds:
		branches = []; connections = []
		getSegment(startPointInd,linkedPoints,IDInds,flagUsed, branches, connections)
		branchOrders = zeros(len(branches))
		branchOrders[0] = 1
		branchParents = zeros(len(branches))
		branchParents[0] = -1
		getBranchOrders(0,connections,branchOrders,branchParents)
		for ii in range(len(branches)):
			branches[ii] += (branchOrders[ii],branchParents[ii])
		basalBranches.append(branches)
		
	print 'Getting axon bracnhes ....'
	axonBranches  = []
	for startPointInd in axonStartInds:
		branches = []; connections = []
		getSegment(startPointInd,linkedPoints,IDInds,flagUsed, branches, connections)
		branchOrders = zeros(len(branches))
		branchOrders[0] = 1
		branchParents = zeros(len(branches))
		branchParents[0] = -1
		getBranchOrders(0,connections,branchOrders,branchParents)
		for ii in range(len(branches)):
			branches[ii] += (branchOrders[ii],branchParents[ii])
		axonBranches.append(branches)

	return apicalBranches, basalBranches, axonBranches

def getBranchOrders(startSegNum,connections,branchOrders,branchParents):
	for cnn in connections:
		if len(cnn) == 1 or cnn[0] != startSegNum:
			continue
		else:
			for i in range(1,len(cnn)):
				segNum = cnn[i]
				branchOrders[segNum] = branchOrders[startSegNum] + 1
				branchParents[segNum] = startSegNum
				getBranchOrders(segNum,connections,branchOrders,branchParents)	
	return

def getBranchStats(Branches):
	areaTot = 0
	lenTot = 0
	numBranches = 0
	branchOrders = []
	branchRadii = []
	for br in Branches:
		for (aa, ll, rr, bo, pr) in br:	# area, length, mean radius, branch order, branch parent index. 
			numBranches += 1
			areaTot += aa
			lenTot += ll
			branchOrders.append(bo)
			branchRadii.append(rr)
	# average radius for the same order
	branchRadii = array(branchRadii)
	brOrder = unique(branchOrders)
	brRadMean = zeros(len(brOrder))
	for ii in range(len(brOrder)):
		ind = find(branchOrders == brOrder[ii])
		brRadMean[ii] = mean(branchRadii[ind])		
	return areaTot, lenTot, numBranches, brOrder, brRadMean		
					
def getSegment(startPointInd,linkedPoints,IDInds,flagUsed, branches, connections):	# recursive function for getting the branch segments. 
	br = []
	id = startPointInd
	while 1:
		PP = linkedPoints[id]
		br.append(id)
		flagUsed[id] = 1
		if len(PP.conn) == 1:	# this is the end point.
			break
		elif len(PP.conn) == 2: # this is an interior point. 	
			for kk in PP.conn:
				kid = IDInds[kk]
				if flagUsed[kid] == 1:
					continue
				else:
					id = kid
					break
		else:	# start of new branches
			break
	#compute branch geometry.
	area = 0
	length = 0
	meanRadius = linkedPoints[br[0]].r
	for ii in range(1,len(br)):
		i1 = br[ii-1]; 
		i2 = br[ii];   
		x1 = linkedPoints[i1].x; y1 = linkedPoints[i1].y; z1 = linkedPoints[i1].z; r1 = linkedPoints[i1].r;
		x2 = linkedPoints[i2].x; y2 = linkedPoints[i2].y; z2 = linkedPoints[i2].z; r2 = linkedPoints[i2].r;	
		ll2 = (x2 - x1)*(x2 - x1) + (y2 - y1)*(y2 - y1) + (z2 - z1)*(z2 - z1)
		length += sqrt(ll2)
		area += pi * (r1 + r2) * sqrt(ll2 + (r2 - r1)*(r2 - r1))
		meanRadius += r2
	meanRadius /= len(br)	
				
	brID = len(branches)
	branches.append((area,length,meanRadius))
	connIDs = [brID]
	for kk in PP.conn:
		kid = IDInds[kk]
		if flagUsed[kid] == 1:
				continue
		bbID = getSegment(kid,linkedPoints,IDInds,flagUsed,branches, connections)	
		connIDs.append(bbID)
	connections.append(connIDs)	
	return brID


if len(sys.argv) > 1:
	swcFilename = sys.argv[1]
else:
	print 'Error: please supply swcFilename.'
	exit()

analyzeStructure(swcFilename)
