import sys,os.path
import pysal as psl
from pysal.cg.shapes import Point, LineSegment
import scipy as sp
import copy
import networkx as nx
import pickle as pk
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.axes_grid.anchored_artists import AnchoredText
import csv

#We start by defining various functions, all changeable parameters/variables come after

#This adds a text box to give a key of specified site names
def add_at(ax, t, loc=2):
    fp = dict(size=12)
    _at = AnchoredText(t, loc=loc, prop=fp)
    ax.add_artist(_at)
    return _at

#Read the data from a given shapefile
def read_data(filepath):
    #Initialise some lists for the various types of information
    #The costs is going to be a list of lists
    #Each list gives the distances from that site to all others
    costs=[]
    points=[]
    names=[]
    final_sizes=[]
    #Now open file reader objects for both shapefile and database file
    shp_reader=psl.open(filepath+'.shp')
    dbf_reader=psl.open(filepath+'.dbf')
    #Cycle through each feature in the shapefile i.e. site
    for (shape,data) in zip(shp_reader,dbf_reader):
        #NOTE THE INDEXES USED BELOW ARE SPECIFIC TO THESE SHAPEFILES, MAY NEED TO BE CHANGED#
        #Read the distances to all other sites 
        costs.append(data[3:-1])
        #Read site name
        names.append(data[1])
        #Read position from the shapefile
        points.append(shape)
        #Read the final estimated size given in the shapefile
        final_sizes.append(data[2])
    #Convert the list of lists we have made into an array (i.e. matrix)
    costs=sp.array(costs)
    #Return everything we have found
    return points, names, costs, final_sizes
    
#This calculates a distance matrix based on the exact locations of the points
#This is not needed if we are going to use the costs provided in the matrix
def calc_distance_matrix(points):
    dij = []
    for p1 in points:
        row = []
        for p2 in points:
            row.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5)
        dij.append(row)         
    dij = sp.array(dij)
    return dij

#This counts the number of the true top 10 sites that came out in the top 10 sites of a simulation    
def top_10_overlap(true_top_10,sim_final):
    #true_top_10 is a list of the indices of the true top 10
    #We get the indices of all sites we simulated
    number_sites=len(sim_final[0])
    #We sort them according to simulated size, and take the top 10
    sim_top_10=sorted(range(number_sites),key=lambda i: sim_final[0,i])[-10:]
    #Maintain a counter
    hit_count=0
    #Go through each index in the true top 10
    for i in true_top_10:
        #Check whether it is in the simulated top 10
        if i in sim_top_10:
            #If so, increment the counter
            hit_count+=1
    return hit_count

#Exactly the same as above, but for the top 19 true sites    
def top_19_overlap(true_top_19,sim_final):
    number_sites=len(sim_final[0])
    sim_top_19=sorted(range(number_sites),key=lambda i: sim_final[0,i])[-19:]
    hit_count=0
    for i in true_top_19:
        if i in sim_top_19:
            hit_count+=1
    return hit_count
    
#Make a series of snapshots of an individual run, which could be used to make a video
#We get the sizes of sites at each time step, and what the time was in each case    
def make_video(Zj_snapshots,t_snapshots,points,rivers,file_location):
    #Do some calculations at the start just to get the aspect ratio correct
    points_x=[i[0] for i in points]
    points_y=[i[1] for i in points]
    
    min_x=min(points_x)
    max_x=max(points_x)
    min_y=min(points_y)
    max_y=max(points_y)
    
    x_extent=max_x-min_x
    y_extent=max_y-min_y
    
    #Create a figure of the right dimensions so that perspective is correct
    fig=plt.figure(figsize=((8*x_extent/float(y_extent),8)))
    
    #Make sure we are plotting (almost) to the edge
    fig.subplots_adjust(bottom=0,left=0,top=0.97,right=1.0)
    
    #Create the axes we will use
    ax=fig.add_subplot(111)
    
    #Counter for the frame number
    shot_number=0
    
    #Loop through all the snapshots we have
    for Zj in Zj_snapshots:
        
        #Clear anything already there
        ax.cla()
        
        #Just take a simple vector for what we want to plot
        to_plot=Zj[0]
        
        #Produce a scatter plot of this data
        ax.scatter(points_x,points_y,s=120,c=to_plot,alpha=0.9,norm=None, vmin = 0, vmax = max(to_plot), edgecolor = 'none')
        
        #Plot the rivers as well for context
        for river in rivers:
            river_x_points=[i[0] for i in river]
            river_y_points=[i[1] for i in river]
            ax.plot(river_x_points,river_y_points,'b')
        
        #Set the title to be the time value at this time step 
        ax.set_title('t = %.3f' %(t_snapshots[shot_number]))
        
        #Add an extra buffer around the edge of the plot
        ax.set_xlim(min_x-1000,max_x+1000)
        ax.set_ylim(min_y-1000,max_y+1000)
        
        #Save this frame in the desired location        
        plt.savefig(file_location+'/'+str(shot_number))
        
        #Increment the frame number
        shot_number+=1
    
    #Close plotting windows
    plt.close('all')
    
#This matters
#Build a Nystuen-Dacey, given a certain flow matrix 
def build_graph(Sij,points):
    #Create a NetworkX graph with one node for each site
    N=len(points)
    g=nx.DiGraph()
    g.add_nodes_from(range(N))
    
    #Take a copy of the flow matrix so we do not tamper with the true one
    Sij_copy=copy.copy(Sij)
    
    #Set all diagonals to zero
    for i in range(N):
        Sij_copy[i,i]=0
    
    #For each site (row), get the other site to which it sends most flow (column with highest value)    
    max_dests=sp.argmax(Sij_copy,axis=1)
    #Get the total in-flow to each site
    col_sums=sp.sum(Sij_copy,axis=0)
    
    #Loop through each site
    for i in range(N):
        #Look at its dominant destination
        j=max_dests[i]
        
        #If the destination is smaller than it, do nothing
        if col_sums[j]<col_sums[i]:
            pass
        else:
            #If it is bigger, add an edge pointing to the destination
            g.add_edge(i,j)
    
    #Return the graph object        
    return g
    
def plot_graph(g,points,rivers,Zj,top_sites,alpha,betafactor,iival,river_cost_weight,file_location):
    
    #Create a dictionary of labels for all points - all of them blank for now
    rank_labels=dict.fromkeys(range(len(points)),'')
    
    #Initialise some text that will go in the legend box
    legend_text=''
    
    #Loop through the top sites we want to identify
    for i in range(len(top_sites)):
        #Set the label for that site to be its rank
        rank_labels[top_sites[-(i+1)]]=str(i+1)
        #Add a new line of text to the legend with the rank and the name of the site
        legend_text=legend_text+str(i+1)+' - '+names[top_sites[-(i+1)]]+'\n'
        
    #Do some calculations at the start just to get the aspect ratio correct
    points_x=[i[0] for i in points]
    points_y=[i[1] for i in points]
    
    min_x=min(points_x)
    max_x=max(points_x)
    min_y=min(points_y)
    max_y=max(points_y)
    
    x_extent=max_x-min_x
    y_extent=max_y-min_y
    
    #Create a figure of the right dimensions so that perspective is correct
    fig=plt.figure(figsize=((8*x_extent/float(y_extent),8)))
    
    #Make sure we are plotting (almost) to the edge
    fig.subplots_adjust(bottom=0,left=0,top=0.97,right=1.0)
    
    #Create the axes we will use
    ax=fig.add_subplot(111)
    
    #Work out what value should be indicated by the 'maximum' colour             
    redval=max(Zj[0])
    
    #Make a dictionary giving the location of every node
    loc={}
    for i in range(N):
        loc[i]=points[i]
    
    #Draw the nodes of the graph (i.e. sites)   
    nx.draw_networkx_nodes(g, node_color = Zj, pos = loc, ax=ax, vmin = 0 , vmax = redval, 
            cmap = 'jet', linewidths = 0, font_size = 10, with_labels=True,
            node_size=50, labels=rank_labels, font_color='r')
    
    #Add the labels we found above
    nx.draw_networkx_labels(g, node_color = Zj, pos = loc, ax=ax, vmin = 0 , vmax = redval, 
            cmap = 'jet', linewidths = 0, font_size = 25, with_labels=True,
            node_size=50, labels=rank_labels, font_color='r')
    
    #Add a fancy arrow for each edge in the graph 
    for e in g.edges():
        line=FancyArrowPatch(loc[e[0]],loc[e[1]],arrowstyle='->',connectionstyle='arc3,rad=0.4',mutation_scale=20,color='k',shrinkA=3,shrinkB=3)
        ax.add_patch(line)
        
    #Plot the rivers as well for context
    for river in rivers:
        river_x_points=[i[0] for i in river]
        river_y_points=[i[1] for i in river]
        ax.plot(river_x_points,river_y_points,'b')
    
    #Put a buffer round
    ax.set_xlim(min_x-1000,max_x+1000)
    ax.set_ylim(min_y-1000,max_y+1000)
    
    #Add the legend for site names
    add_at(ax, legend_text, loc=3)
    
    #Add a title to indicate parameter values
    plt.title('Alpha='+str(round(alpha,2))+', Beta='+str(betafactor)+', Internal='+str(iival)+', River weight='+str(river_cost_weight))
    #Give a file name which indicates the parameter setup used to generate
    savename=file_location+'/dt'+str(dt_ord_mag)+'_a'+str(int(1000*alpha))+'_beta'+str(int(betafactor))+'_iival'+str(int(iival))+'_riverweight'+str(int(river_cost_weight))
    #Save figure and close
    plt.savefig(savename)
    plt.close('all')
    
#This finds the average distance, weighted by flow
#We multiply the distance matrix by flow matrix and take the average value
def average_distance(dij,flow_array):
    product=dij*flow_array
    avg=sum(product)/float(len(product)**2)
    return avg

#In this case we look at the average length of all lengths in a Nystuen-Dacey graph    
def average_closest_distance(g,dij):
    #Set up a list for these lengths
    edgelens=[]
    #Loop the edges of the graph
    for e in g.edges():
        #For each edge, lookup the distance and add to the list
        edgelens.append(dij[e[0],e[1]])
    #Take the average
    avg=sum(edgelens)/float(len(edgelens))
    return avg

#Find the shortest distance from every site to its closest river    
def river_proximity(points,rivers):
    #We need each individual line segment for the rivers, not just the polyline
    river_segs=[]
    #So loop the rivers and build this
    for river in rivers:
        for i in range(len(river)-1):
            river_segs.append(LineSegment(river[i],river[i+1]))
    
    #Now set up a list which will be filled for each site         
    river_dists=[]
    #Loop the sites
    for point in points:
        test_point=Point(point)
        #Get a list which gives the distance from this point to every segment in river_segs
        dists_this_point=[psl.cg.standalone.get_segment_point_dist(seg, test_point)[0] for seg in river_segs]
        #Take the minimum of these distances and add to the list
        river_dists.append(min(dists_this_point))
    
    return river_dists

#Calculate the number of distinct rivers crossed by every i-j path
def river_intersection(points,rivers):
    #We want a different list of segments for each inidividual river
    rivers_split=[]
    #Loop the rivers
    for river in rivers:
        #Build a list of segments for each river
        river_segs=[]
        for i in range(len(river)-1):
            river_segs.append(LineSegment(river[i],river[i+1]))
        rivers_split.append(river_segs)
    
    #Initialise the intersection matrix as all zeros
    N=len(points)
    river_intersections=sp.zeros((N,N))
    
    #Go through each i-j pairing
    for i in range(N):
        for j in range(i):
            #Get a segment representing this simple path
            trip_segment=LineSegment(points[i],points[j])
            #Look at each distinct river
            for river_segs in rivers_split:
                #Get a list which gives True if the trip intersects with each line segment, False otherwise
                intersections=[psl.cg.standalone.get_segments_intersect(trip_segment, river_seg)!=None for river_seg in river_segs]
                #Count the number of True values, but take modulo 2 because we would not cross the same river more than once
                number_intersections=intersections.count(True)%2
                #Add the value found to the correct position in the intersection matrix
                river_intersections[i,j]+=number_intersections
                river_intersections[j,i]+=number_intersections
                
    return river_intersections


#The full model
def model(alpha, betafactor, gamma, iival, real_dij, river_cost_matrix, river_cost_weight, dt, max_time, snapshot_freq, ic):
    
    #Work out how many sites we have
    N=len(real_dij)
    
    #Take a copy of the distance matrix so we do no tamper with the true value
    dij=copy.copy(real_dij)
    
    #Initialise all the other variables
    Zj=ic
    Xi=Zj.T
    Xi=Xi*N/sum(Xi)
    
    #Find our value for beta, which we do based on the mean distance
    beta=betafactor/sp.mean(real_dij)
    
    #Set the diagonal values of distance matrix - representing internal flow - to the parameter value    
    for i in xrange(N):
        dij[i,i]=iival
    
    #Add river_cost_weight to the relevant entry for every river that is crossed    
    dij=dij+river_cost_weight*river_cost_matrix
    
    #Initialise the simulation lists
    Zj_snapshots, t_snapshots, Sij_snapshots = [], [], []

    #Set up the time array
    time=sp.arange(0,max_time+dt,dt)
    counter=0
    
    for t in time:
        
        #Work out the Sijs - this formula works
        Sij = sp.dot(Xi/(sp.dot(sp.exp(-beta*dij),(Zj**alpha).T)+ 10**-8), Zj**alpha)*sp.exp(-beta*dij)
        
        #Find Dj values by summing over rows
        Dj=sp.sum(Sij,axis=0)
        
        #Can then calculate new Zjs
        Zj_new = dt*(Dj - gamma*Zj)*Zj + Zj
        
        #Switch new and current Zj
        Zj, Zj_new = Zj_new, Zj
        
        #Work out the new Xi, simply the transpose of the Zj we just found
        Xi=Zj.T
        
        #Re-normalise Xi so it sums to N
        Xi=Xi*N/sum(Xi)
        
        #Every snapshot_freq timesteps, add the current values to our records
        if counter%snapshot_freq==0:
            Zj_snapshots.append(Zj)
            t_snapshots.append(t)
            Sij_snapshots.append(Sij)
        
        #Increment the counter                
        counter+=1
    
    #Return the final lists of the snapshot-by-snapshot lists
    return Zj_snapshots, t_snapshots, Sij_snapshots


### THIS IS WHERE THE CHANGEABLE PART OF THE CODE BEGINS ###

print 'start'

#The filename for the data we are going to use, relative to the directory of this script
filename='data/MB-sites_Jazira/MB_costmatrix'
#This is where the river data is located
river_filename='data/Waterway_Data/rivers.shp'

#These are settings for the time used in the simulations
#The total running time
max_time=30
#The size of timestep in terms of order of magnitude (larger number -> smaller timestep)
dt_ord_mag=2
dt=10**-dt_ord_mag
snapshot_freq=5*10**(dt_ord_mag-1)

#This is a switch for whether to use uniform initial conditions or bias towards true final values
ic_sml=True

### Here we have a choice of how to set up the simulations ###
# Either just give a set series of values
alpha=0.9
betafactor=35
gamma=1
iival=1000
river_cost_weight=2500

#Or give lists of several values so that we will loop over all combinations 
alphaall=sp.arange(0.8,1.31,0.1)
betafactorall=sp.arange(15.0,45.1,10)
gammaall = [1]
iivalall = [100,500,1000,2500,5000]
rcwall=[100,500,1000,2500,5000]


#Now we actually run the simulations


#Read the data from the chosen file
points,names,real_dij,final_sizes=read_data(filename)


#Read in the river data
shp_reader=psl.open(river_filename)
water=shp_reader.next()
rivers=water.parts


#Get the number of rivers intersected by all i-j journeys
#Here for reasons of time we just read it from stored file
mat_file=open('river_intersection_matrix_MB','rb')
river_cost_matrix=pk.load(mat_file)
mat_file.close()

#But the following line would allow it to be calculated for any points and rivers
#river_cost_matrix=river_intersection(points,rivers)

##Calculate the proximity to rivers - at the moment this is not used further
#river_prox=river_proximity(points,rivers)


#Take the number of sites we have and set up a list for these indices
N=len(points)
all_sites=range(N)

#Find the indices of the top 10 and top 19 sites by sorting the indices on the basis of final size
top_10_sites=sorted(all_sites,key=lambda i: final_sizes[i])[-10:]
top_19_sites=sorted(all_sites,key=lambda i: final_sizes[i])[-19:]

#Find the indices of sites of various orders of magnitude - small, medium, large
s_sites=[i for i in all_sites if final_sizes[i]<=1]
m_sites=[i for i in all_sites if final_sizes[i]>1 and final_sizes[i]<=10]
l_sites=[i for i in all_sites if final_sizes[i]>10]

#Find the average final size of our small/medium/large sites
s_val=sum([final_sizes[i] for i in s_sites])/float(len(s_sites))
m_val=sum([final_sizes[i] for i in m_sites])/float(len(m_sites))
l_val=sum([final_sizes[i] for i in l_sites])/float(len(l_sites))

#Set up a vector representing the initial conditions for the model
ic=sp.ones((1,N))

#Make changes if the 'modified initial conditions' switch is set to True
if ic_sml:
    
    #We simply give each site the average final value of its category
    for i in s_sites:
        ic[0,i]=s_val
        
    for i in m_sites:
        ic[0,i]=m_val
    
    for i in l_sites:
        ic[0,i]=l_val

#Normalise so that the sum is N
ic=ic*N/sum(ic)

#Here we can set up a CSV file to take the results of a series of simulations
result_file=open('results_series.csv','wb')
result_writer=csv.writer(result_file)
header=['alpha','beta','iival','river_weight','SML','top_10','top_19','avg_dist','avg_closest']
result_writer.writerow(header)

#Now we loop through the various combinations of our parameter values and run a model each time
for alpha in alphaall:
    for betafactor in betafactorall:
        for iival in iivalall:
            for river_cost_weight in rcwall:
                #Print an update saying which model we are running at the moment
                print 'Starting run with parameter array %.2f, %.2f, %.2f, %.2f, %.2f' %(alpha,betafactor,gamma,iival,river_cost_weight)
                
                #Run the model and take the results
                Zj_snapshots, t_snapshots, Sij_snapshots=model(alpha, betafactor, gamma, iival, real_dij, river_cost_matrix, river_cost_weight, dt, max_time, snapshot_freq, ic)

                #Could use make_video here to get a video of the simulation we just did, like this (commented out now):
#                make_video(Zj_snapshots,t_snapshots,points,rivers,'result_video')
                
                #Build a graph based on the final flow matrix                
                g=build_graph(Sij_snapshots[-1],points)
                
                #Plot the final graph in a specified place
                plot_graph(g,points,rivers,Zj_snapshots[-1],top_19_sites,alpha,betafactor,iival,river_cost_weight,'result_graphs')
                
                #Work out how many of the top 10/19 we picked out in simulation                
                top_10=top_10_overlap(top_10_sites,Zj_snapshots[-1])
                top_19=top_19_overlap(top_19_sites,Zj_snapshots[-1])
                
                #Work out a couple of our average distance outputs
                avg_dist=average_distance(real_dij,Sij_snapshots[-1])
                avg_closest=average_closest_distance(g,real_dij)
                
                #Set up a row of relevant results and write this to our output CSV
                row=[alpha,betafactor,iival,river_cost_weight,ic_sml,top_10,top_19,avg_dist,avg_closest]
                result_writer.writerow(row)
                
                #Also save the final state somewhere in case we want to carry out more analysis later
                #Give it a filename based on the parameter setup used
                output_filename='outputs/a%d_b%d_ii%d_rc%d' %(alpha*1000,betafactor,iival,river_cost_weight)
                f=open(output_filename,'wb')
                pk.dump([Zj_snapshots[-1], t_snapshots[-1], Sij_snapshots[-1]],f)
                f.close()

#Close the result CSV
result_file.close()
