"""

Solving the metabolic system in a reactive astrocytic domain

"""

import numpy as np
from dolfin import *
from cutfem import *
from timeit import default_timer as timer
from Dimensionless_parameters import *

parameters["form_compiler"]["no-evaluate_basis_derivatives"] = False
parameters['allow_extrapolation'] = True
# Start timer
startime = timer()

date = time.strftime("%d%m%Y")

############################

#   READ FROM HDF5 File  #

############################

f = HDF5File(mpi_comm_world(),  './AD_results_'+ date + '/hdfFile', 'r')
bg_mesh = Mesh()
f.read(bg_mesh, "bg_mesh", False)
W = FunctionSpace(bg_mesh, "CG", 1)
level_set = Function(W)
f.read(level_set, 'level_set')
f.close()


# read the volume from part 1
volume = np.load('./AD_results_' + date + '/Volume.npy')


mesh = CutFEMTools_fictitious_domain_mesh(bg_mesh,level_set,0,0)

L = np.load('./AD_results_' + date + '/L.npy')

# Define the sum of ATP+ADP inside the cell
sum_atp_adp = 3.2
# define the dimensionless parameter alpha
alpha = 0.16


# define the dimensionless diffusion concentration for each species
D_a_bar, D_b_bar, D_c_bar, D_d_bar, D_e_bar, D_f_bar = dimensionless_diffusion_coeff()

# Dimensionless parameters
t_c, f_bar, lac_bar, beta_a, beta_b, beta_c, beta_d, gamma_b, gamma_c, gamma_d, gamma_e, mu_e, mu_g, csi_b, csi_c, csi_e, tau_b, tau_c = dim_param(L, alpha, sum_atp_adp, 1.0)

# store the dimensionless time
np.save('./AD_results_' + date + '/t_c.npy', t_c)

# Define final dimensionless time
T = 90/t_c # final time
print('final time', T)
num_step = 250 # number of time step
dt = T / num_step
k = Constant(dt)

# cutfem parameters
beta = 10.0
gamma =0.1

# compute normals
n = FacetNormal(mesh)
h = CellSize(mesh)

# Compute Mesh -Levelset intersection and corresponding marker
mesh_cutter = MeshCutter(mesh,level_set)

# Diminishing quadrature points
q_degree = 3

# Define new measures associated with the interior domains and facets
dx = Measure("dx")[mesh_cutter.domain_marker()]
dS = Measure("dS")[mesh_cutter.interior_facet_marker(0)]
dxq = dc(0, metadata={"num_cells": 1, 'quadrature_degree': q_degree})
dsq = dc(1, metadata={"num_cells": 1})

dxc = dx(0, metadata={'quadrature_degree': q_degree}) + dxq

# Finite Element space for the concentration
T = FunctionSpace(mesh,"P",1)
V0 = MixedFunctionSpace([T, T, T, T, T, T])

# Define test functions
(v_1, v_2, v_3, v_4, v_5, v_6) = TestFunctions(V0)

# Define Trial functions which must be Functions instead of Trial Functions cause the pb is non linear
u = Function (V0)

# Define the initial condition of concetrations
a_0 = 0.0
b_0 = 1.6 / sum_atp_adp
c_0 = 1.6 / sum_atp_adp
d_0 = 0.0
e_0 = 0.0
f_0 = 0.0


u_0 = Expression(('a_0', 'b_0', 'c_0','d_0', 'e_0', 'f_0'), a_0=a_0, b_0=b_0, c_0=c_0, d_0=d_0, e_0=e_0, f_0=f_0, degree=1)

u_n = interpolate(u_0, V0)

a, b, c, d, e, f = split(u)
a_n, b_n, c_n, d_n, e_n, f_n = split(u_n)

# Reaction sites 

from Enzymes import Gauss_expression_3d


# From Part 2 read the coordinates of the reaction sites
list_of_enzymes = np.load('./AD_results_' + date + '/enzymes_coordinates.npy')

M = len(list_of_enzymes[0])


coordinate_enzymes_hxk = list_of_enzymes[0:3]
coordinate_enzymes_pyrk = list_of_enzymes[3:6]
coordinate_enzymes_ldh = list_of_enzymes[6:9]

# Read coordinates of mitochondria
mito = np.load('Images/AD_mito4cutfem_scaled_sigma.npy')
mito_points = mito.T

# Read the adaptive parameter eta
eta_hxk, eta_pyrk, eta_ldh, eta_mito = np.load('./AD_results_' + date + '/eta.npy')

# Variance dimensionless
sigma = 1./L

def Sum_Gaussian_Mito(M, coordinate_enzymes_test, sigma):
    Gaussian = 0
    for i in range(M):
        Gaussian += Gauss_expression_3d(coordinate_enzymes_test[0,i], coordinate_enzymes_test[1,i],coordinate_enzymes_test[2,i], sigma[i])
    return(Gaussian)

def Sum_Gaussian(M, coordinate_enzymes_test, sigma):
    if M == 1:
        Gaussian = Gauss_expression_3d(coordinate_enzymes_test[0], coordinate_enzymes_test[1],coordinate_enzymes_test[2], sigma)
    else:
        Gaussian = 0
        for i in range(M):
            Gaussian += Gauss_expression_3d(coordinate_enzymes_test[0,i], coordinate_enzymes_test[1,i],coordinate_enzymes_test[2,i], sigma)
    return(Gaussian)

Gaussian_hxk = Sum_Gaussian(M, coordinate_enzymes_hxk, sigma)
Gaussian_pyrk = Sum_Gaussian(M, coordinate_enzymes_pyrk, sigma)
Gaussian_ldh = Sum_Gaussian(M, coordinate_enzymes_ldh, sigma)
Gaussian_mito = Sum_Gaussian_Mito(len(mito.T[0]), mito_points[:-1], mito_points[-1])



# Reaction rates 
k_hxk = Constant(0.0619)
k_pyrk = Constant(1.92)
k_ldh = Constant(0.719)
k_mito = Constant(0.0813)

K_act = Constant(0.169)

# Spatial reaction rates
K_hxk = Gaussian_hxk/eta_hxk  * k_hxk *  Constant(volume)
K_pyrk = Gaussian_pyrk/eta_pyrk  * k_pyrk *  Constant(volume)
K_ldh = Gaussian_ldh/eta_ldh  * k_ldh *  Constant(volume)
K_mito = Gaussian_mito/eta_mito  * k_mito *  Constant(volume)


################################################################

# Define the source term

# read the subdomain volume computed in Part 2
subdomain_volume = np.load('./AD_results_' + date + '/subdomain_volume_glc.npy')

influx =  f_bar * volume /subdomain_volume

radius_influx = 0.0307

x_infl, y_infl, z_infl =  0.313278, 0.520833, 0.0302059
x_infl2, y_infl2, z_infl2 = 0.104167, 0.0, 0.0815185
x_infl3, y_infl3, z_infl3 = 0.364583, 0.572917, 0.0402746

f_1 = Expression('(pow(x[0] - x_0,2) + pow(x[1] - y_0,2) + pow(x[2]- z_0,2)) < (r * r) ? influx : 0', influx=influx, r=radius_influx, x_0 = x_infl , y_0 = y_infl, z_0 = z_infl, degree=1)
f_1 += Expression('(pow(x[0] - x_0,2) + pow(x[1] - y_0,2) + pow(x[2]- z_0,2)) < (r * r) ? influx : 0', influx=influx, r=radius_influx, x_0 = x_infl2 , y_0 = y_infl2, z_0 = z_infl2, degree=1)
f_1 += Expression('(pow(x[0] - x_0,2) + pow(x[1] - y_0,2) + pow(x[2]- z_0,2)) < (r * r) ? influx : 0', influx=influx, r=radius_influx, x_0 = x_infl3, y_0 = y_infl3, z_0 = z_infl3, degree=1)

# Define Lac

subdomain_outflux_volume = np.load('./AD_results_' + date + '/subdomain_volume_lac.npy')

outflux = lac_bar * volume/ subdomain_outflux_volume

radius_outflux = 0.035

x_out, y_out, z_out = 0.923611, 1., 0.0203796
x_out2, y_out2, z_out2 =  1., 0.909722, 0.0135864
x_out3, y_out3, z_out3 = 0.506944, 0.993056, 0.0543457
x_out4, y_out4, z_out4 =0.989583, 0.510142, 0.0402746

eta_f = Expression('(pow(x[0] - x_0,2) + pow(x[1] - y_0,2) + pow(x[2]- z_0,2)) < (r * r) ? outflux : 0', outflux = outflux, r=radius_outflux, x_0 = x_out , y_0 = y_out, z_0 = z_out, degree=1)
eta_f += Expression('(pow(x[0] - x_0,2) + pow(x[1] - y_0,2) + pow(x[2]- z_0,2)) < (r * r) ? outflux : 0', outflux = outflux, r=radius_outflux, x_0 = x_out2 , y_0 = y_out2, z_0 = z_out2, degree=1)
eta_f += Expression('(pow(x[0] - x_0,2) + pow(x[1] - y_0,2) + pow(x[2]- z_0,2)) < (r * r) ? outflux : 0', outflux = outflux, r=radius_outflux, x_0 = x_out3, y_0 = y_out3, z_0 = z_out3, degree=1)
eta_f += Expression('(pow(x[0] - x_0,2) + pow(x[1] - y_0,2) + pow(x[2]- z_0,2)) < (r * r) ? outflux : 0', outflux = outflux, r=radius_outflux, x_0 = x_out4, y_0 = y_out4, z_0 = z_out4, degree=1)

#  Weak formulation dimensionless system

F = ((a - a_n) / k) * v_1 * dxc \
    + D_a_bar * dot(grad(a), grad(v_1)) * dxc + K_hxk * beta_a * a * b**2 * v_1 * dxc \
    + ((b - b_n) / k) * v_2 * dxc  \
    + D_b_bar * dot(grad(b), grad(v_2)) * dxc + 2 * K_hxk * beta_b * a * b**2 * v_2 * dxc - 2 * K_pyrk * gamma_b * d * c**2 * v_2 * dxc - 28 * K_mito * csi_b * e * c**28 * v_2 * dxc + K_act * tau_b * b * v_2 * dxc\
    + ((c - c_n) / k)*v_3 * dxc \
    + D_c_bar * dot(grad(c), grad(v_3)) * dxc - 2 * K_hxk * beta_c * a * b**2 * v_3 * dxc  + 2 * K_pyrk * gamma_c * d * c**2 * v_3 * dxc - K_act *tau_c * b * v_3 * dxc + 28 * K_mito * csi_c * e * c**28 * v_3 * dxc\
    + ((d - d_n) / k)*v_4 * dxc\
    + D_d_bar * dot(grad(d),grad(v_4)) * dxc - 2 * K_hxk * beta_d * a * b**2 * v_4 * dxc + K_pyrk * gamma_d * d * c**2 * v_4 * dxc\
    + ((e - e_n) / k)*v_5 * dxc\
    + D_e_bar * dot(grad(e),grad(v_5)) * dxc - K_pyrk * gamma_e * d * c**2 * v_5 * dxc + K_ldh * mu_e * e * v_5 * dxc + K_mito * csi_e * e * c**28 * v_5 * dxc\
    + ((f - f_n) / k)*v_6 * dxc\
    + D_f_bar * dot(grad(f),grad(v_6)) * dxc - K_ldh * mu_g * e * v_6 * dxc + eta_f * f * v_6 * dxc\
    - f_1 * v_1 * dxc


F += avg(gamma) * avg(h) * D_a_bar * dot(jump(grad(a), n), jump(grad(v_1), n)) * dS(1) + avg(gamma) * avg(h) * D_b_bar *dot(jump(grad(b), n), jump(grad(v_2), n)) * dS(1)\
    + avg(gamma) * avg(h) * D_c_bar * dot(jump(grad(c), n), jump(grad(v_3), n)) * dS(1) + avg(gamma) * avg(h) * D_d_bar * dot(jump(grad(d), n), jump(grad(v_4), n)) * dS(1)\
    + avg(gamma) * avg(h) * D_e_bar * dot(jump(grad(e), n), jump(grad(v_5), n)) * dS(1) + avg(gamma) * avg(h) * D_f_bar * dot(jump(grad(f), n), jump(grad(v_6), n)) * dS(1)



#Create VTK files for visualization output
vtkfile_a = File('AD_results_' + date + '/a/a.pvd')
vtkfile_b = File('AD_results_' + date + '/b/b.pvd')
vtkfile_c = File('AD_results_' + date + '/c/c.pvd')
vtkfile_d = File('AD_results_' + date + '/d/d.pvd')
vtkfile_e = File('AD_results_' + date + '/e/e.pvd')
vtkfile_f = File('AD_results_' + date + '/f/f.pvd')

# Compute Jacobian
J = derivative(F, u)

# Fictitious domain
composite_mesh = CompositeMesh()
composite_mesh.add(mesh)

V = CompositeFunctionSpace(composite_mesh)
V.add(V0);
V.build();

# Constrain dofs outside
FidoTools_compute_constrained_dofs(V, mesh_cutter)

a = FidoForm(V, V)
form_a = create_dolfin_form(J)
a.add(form_a, mesh_cutter)

L = FidoForm(V)
form_L = create_dolfin_form(F)
L.add(form_L, mesh_cutter);

# Space for the solution
cutmesh0 = CutFEMTools_physical_domain_mesh(mesh, mesh_cutter.cut_cells(0), mesh_cutter.domain_marker(), 0);
V0Phys = FunctionSpace(cutmesh0, "CG", 1);

a_inter = Function(V0Phys);
b_inter = Function(V0Phys);
c_inter = Function(V0Phys);
d_inter = Function(V0Phys);
e_inter = Function(V0Phys);
f_inter = Function(V0Phys);

# Time stepping
t = [0.0]

# Function that compute the average concentration inside the astrocyte
def compute_average_concentration(conc, V):

    int_conc = conc * dxc

    form_conc_mean = create_dolfin_form(int_conc)

    composite_form_conc_mean = CompositeForm(V)
    composite_form_conc_mean.add(form_conc_mean)

    cut_cells = mesh_cutter.cut_cells(0)

    quadrature = Quadrature(cut_cells.type().cell_type(),cut_cells.geometry().dim(),order=2)

    composite_form_conc_mean.cut_form(0).set_quadrature(0, quadrature);
    composite_form_conc_mean.cut_form(0).set_cut_mesh(0, cut_cells);
    composite_form_conc_mean.cut_form(0).add_single_parent_mesh_id(0, 0);

    mean_conc = composite_assemble(composite_form_conc_mean)

    return(mean_conc)

def compute_LAC_efflux(bg_mesh, level_set, subdom2comp):

    r = radius_outflux

    class Sub_Efflux(SubDomain):
        def inside(self, x, on_boundary):
            return ((pow(x[0] - x_out, 2) + pow(x[1] - y_out, 2) + pow(x[2] - z_out, 2)) <= (r * r))

    class Sub_Efflux2(SubDomain):
        def inside(self, x, on_boundary):
                return ( (pow(x[0] - x_out2, 2) + pow(x[1] - y_out2, 2) + pow(x[2] - z_out2, 2)) < (r * r))

    class Sub_Efflux3(SubDomain):
        def inside(self, x, on_boundary):
                return ( (pow(x[0] - x_out3, 2) + pow(x[1] - y_out3, 2) + pow(x[2] - z_out3, 2)) < (r * r) )

    class Sub_Efflux4(SubDomain):
        def inside(self, x, on_boundary):
                return ( (pow(x[0] - x_out4, 2) + pow(x[1] - y_out4, 2) + pow(x[2] - z_out4, 2)) < (r * r) )

    # Create mesh

    mesh = CutFEMTools_fictitious_domain_mesh(bg_mesh, level_set, 0, 0)

    # Compute Mesh -Levelset intersection and corresponding marker
    mesh_cutter = MeshCutter(mesh, level_set)

    marker_test = mesh_cutter.domain_marker()

    Sub_Efflux().mark(marker_test, 50)
    Sub_Efflux2().mark(marker_test, 50)
    Sub_Efflux3().mark(marker_test, 50)
    Sub_Efflux4().mark(marker_test, 50)

    # file = File("./test_domain_marker.pvd")
    # file << marker_test

    # Define new measures associated with the interior domains and facets
    dx = Measure("dx", domain=mesh)[marker_test]
    dxq = dc(0, metadata={"num_cells": 1})  # cut cells inside the Circle

    dx_eff = dx(50) + dxq

    V0 = FunctionSpace(mesh, "Lagrange", 1)

    # Fictitious domain
    composite_mesh = CompositeMesh()
    composite_mesh.add(mesh)

    V = CompositeFunctionSpace(composite_mesh)
    V.add(V0);
    V.build();

    # Constrain dofs outside
    FidoTools_compute_constrained_dofs(V, mesh_cutter)

    u_e_V0 = project(subdom2comp, V0)

    psi0 = u_e_V0 * dx_eff

    form_psi0 = create_dolfin_form(psi0)

    composite_form_psi0 = CompositeForm(V)
    composite_form_psi0.add(form_psi0)

    area_ls_cutfem = composite_assemble(composite_form_psi0)

    return(area_ls_cutfem)

# Empty list to save the values

list_a =[]
list_b =[]
list_c =[]
list_d =[]
list_e = []
list_f = []

time_list = []


# add the time = 0
time_list.append(t[0])

mean_a_ = compute_average_concentration(a_n, V)
mean_b_ = compute_average_concentration(b_n, V)
mean_c_ = compute_average_concentration(c_n, V)
mean_d_ = compute_average_concentration(d_n, V)
mean_e_ = compute_average_concentration(e_n, V)
mean_f_ = compute_average_concentration(f_n, V)


list_a.append(mean_a_/ volume)
list_b.append(mean_b_/ volume)
list_c.append(mean_c_/ volume)
list_d.append(mean_d_/ volume)
list_e.append(mean_e_/ volume)
list_f.append(mean_f_/ volume)

efflux_lac = []

lac_eff = compute_LAC_efflux(bg_mesh, level_set, f_n)
efflux_lac.append(lac_eff)

# Parameter for solver
Nmax = 50
abs_tol = 1.0e-12
rel_tol = 1.0e-07


# Initial residual
initial_residual = composite_assemble(L)
absolute0 = initial_residual.norm('l2')

for i in range(num_step):
    print("Timestep", i)
    print("Time", t)

    n = 1
    while n < Nmax:

        A = composite_assemble(a)
        b = composite_assemble(L)

        uc = CompositeFunction(V)

        solve(A, uc.vector(), -b, 'mumps')

        u.vector().axpy(1.0, uc.part(0).vector())

        residual = composite_assemble(L)

        # absolute residual
        absolute = residual.norm('l2')

        # relative residual
        relative = absolute / absolute0

        if absolute < abs_tol or relative < rel_tol:
            break
        else:
            n += 1

    _a, _b, _c, _d, _e, _f = u.split()

    ## To save the vtk uncomment this section.
    # 
    ## save  GLC
    #a_inter.interpolate(_a);
    ## save  ATP
    #b_inter.interpolate(_b);
    ## save  ADP
    #c_inter.interpolate(_c);
    ## save  GLY
    #d_inter.interpolate(_d);
    ## save  PYR
    #e_inter.interpolate(_e);
    ## save  LAC
    #f_inter.interpolate(_f);
    #
    #vtkfile_a << (a_inter, t[0])
    #vtkfile_b << (b_inter, t[0])
    #vtkfile_c << (c_inter, t[0])
    #vtkfile_d << (d_inter, t[0])
    #vtkfile_e << (e_inter, t[0])
    #vtkfile_f << (f_inter, t[0])
    
    u_n._assign(u)

    t[0] = t[0] + dt

    time_list.append(t[0])
    
    #Compute average concentration inside astrocyte of each species
    mean_a_ = compute_average_concentration(_a, V)
    mean_b_ = compute_average_concentration(_b, V)
    mean_c_ = compute_average_concentration(_c, V)
    mean_d_ = compute_average_concentration(_d, V)
    mean_e_ = compute_average_concentration(_e, V)
    mean_f_ = compute_average_concentration(_f, V)


    list_a.append(mean_a_/ volume)
    list_b.append(mean_b_/ volume)
    list_c.append(mean_c_/ volume)
    list_d.append(mean_d_/ volume)
    list_e.append(mean_e_/ volume)
    list_f.append(mean_f_/ volume)

    # Compute the LAC concentration in the subregions where the efflux is defined
    lac_eff = compute_LAC_efflux(bg_mesh, level_set, _f)
    efflux_lac.append(lac_eff)

# stop time
aftersolve = timer()
tottime = aftersolve-startime
print('Final time', tottime)

# Create a single list with all the solutions
list_of_list = [list_a, list_b, list_c, list_d, list_e, list_f, time_list]

# save using numpy
print(list_a, list_b, list_c, list_d, list_e, list_f, time_list)
np.save('./AD_results_' + date + '/AD_3Dastro_list' + pathology, np.asarray(list_of_list))

np.save('./AD_results_' + date + '/lac_efflux' + pathology, np.asarray(efflux_lac))