#!/usr/bin/env fspython

import os
import math
import numpy as np
import freesurfer as fs
import nibabel as nib

from scipy import ndimage
from scipy.optimize import minimize
from freesurfer.labelfusion import maxflow, performFrontPropagation3D


# ------ Utility Functions ------


def expandCropping(cropping, limits, padding):
    expanded = []
    for dim in range(len(cropping)):
        start = cropping[dim].start - padding
        if start < 0: start = 0
        stop = cropping[dim].stop + padding
        if stop > limits[dim]: stop = limits[dim]
        expanded.append(slice(start, stop))
    return tuple(expanded)


def renderDistanceMap(dists, inds):
    p = np.zeros(((num_labels,) + dists.shape))
    for l in range(num_labels):
        mask = inds == l
        p[l, mask] = dists[mask]
    return np.sum(p, axis=1)


def computeLabPost():
    p = np.swapaxes(dist_map * pil_theta, 0, 1)
    p /= np.sum(p, axis=0)
    p[np.isnan(p)] = 1 / num_labels
    return np.sum(p * np.swapaxes(QM, 0, 1), axis=1)


def gaussianPDFsmooth(x):
    d = (x - mus[:, np.newaxis]) ** 2
    z = -1 / (2 * sigmas * sigmas)
    return (1 / (sigmas[:, np.newaxis] * np.sqrt(2 * np.pi))) * np.exp(z[:, np.newaxis] * d) + 1e-30


def prepBiasFieldBase(insize, bias_field_order):
    orders = []
    for o in range(bias_field_order + 1):
        for x in range(o + 1):
            for y in range(o + 1):
                for z in range(o + 1):
                    if x + y + z == o:
                        orders.append([x, y, z])
    grid = np.meshgrid(range(insize[1]), range(insize[0]), range(insize[2]))
    grid = [(d - np.min(d)).astype(float) for d in grid]
    grid = [d / np.max(d) for d in grid]
    grid = [2 * d - 1 for d in grid]
    psi = np.zeros(insize + (len(orders),))
    for d in range(len(orders)):
        tmp = np.ones(insize);
        for dim, dd in enumerate(grid):
            for x in range(orders[d][dim]):
                tmp *= dd
        psi[:, :, :, d] = tmp
    return psi


def singleChannelCostGrad(theta):
    bf = np.exp(np.sum(theta * psiv_down, axis=1))
    ibf = it_down * bf
    pdf = gaussianPDFsmooth(ibf) * bf
    iv = 1 - 1/variances[:, np.newaxis] * ibf * (ibf - mus[:, np.newaxis])
    cost = np.zeros(num_nonzeros_down)
    aux_grad = np.zeros(num_nonzeros_down)
    for n in range(num_opinions):
        product = dist_map_down[n] * pdf
        lsum = np.sum(product, axis=0)
        lnum = np.sum(product * iv, axis=0)
        cost += QM[downsample_ind, n] * np.log(lsum)
        aux_grad += QM[downsample_ind, n] * lnum / lsum
    mean_cost = -np.sum(cost) / num_nonzeros_down
    grad = np.sum(aux_grad[:, np.newaxis] * psiv_down, axis=0) / -num_nonzeros_down
    grad[0] = 0  # first coefficient is zero by design
    return mean_cost, grad


def compareArrays(a, b, tol=1e-12):
    diff = np.abs(a - b)
    maximum = np.max(diff)
    if maximum > tol:
        fs.errorExit('arrays differ - max: ' + str(maximum))


def computeDistanceTransform(labelmask):
    # Crop and pad to a cube
    crop_slices = expandCropping(ndimage.find_objects(labelmask)[0], labelmask.shape, 20)
    cropped = labelmask[crop_slices] - 0.5
    n = np.max(cropped.shape)
    padding = [(math.ceil(i/2), math.floor(i/2)) for i in n - cropped.shape]
    D = np.pad(cropped, padding, 'constant', constant_values=np.min(cropped))

    # Compute label outline starting points
    v = np.zeros((n, n, n))
    for dim in (0, 1, 2):
        p1 = np.take(D, range(n-1),  axis=dim)
        p2 = np.take(D, range(1, n), axis=dim)
        z = np.expand_dims(np.zeros((n, n)), axis=dim)
        p = (p1 * p2) <= 0
        i = (np.concatenate((p, z), axis=dim) + np.concatenate((z, p), axis=dim)) > 0
        d = abs(p1 - p2)
        d[d < np.finfo(float).eps] = 1
        va = np.maximum(np.concatenate((abs(p1)/d, z), axis=dim), np.concatenate((z, abs(p2)/d), axis=dim))
        v[i] = np.maximum(v[i], va[i])
    
    # Run fast marching
    mask = v != 0
    values = v[mask] / n
    start_points = np.asfortranarray(np.where(mask))
    iters = int(1.2 * n ** 3)
    dist = performFrontPropagation3D(np.ones((n, n, n)), start_points, iters, values)
    dist *= n
    dist[D < 0] = -dist[D < 0]

    # Unpad and uncrop
    for dim, pad in enumerate(padding):
        dist = np.take(dist, range(pad[0], dist.shape[dim] - pad[1]), axis=dim)
    dt = np.ones(labelmask.shape) * np.min(dist)
    dt[crop_slices] = dist
    return dt


def saveVolume(vol, fname):
    im = nib.Nifti1Image(vol, fixed_image.affine)
    im.header['xyzt_units'] = fixed_image.header['xyzt_units']
    nib.save(im, fname)


# ------ Parse Command Line Arguments ------

parser = fs.ArgParser()
# Required
parser.add_argument('-i','--image', metavar='file', required=True, help='Input image filename.')
parser.add_argument('-s','--segs', metavar='file', nargs='+', required=True, help='Aligned segmentations to fuse.')
parser.add_argument('-o','--out', metavar='file', required=True, help='Output segmentation filename.')
parser.add_argument('-r','--rho', type=float, required=True, help='Rho.')
# Optional
parser.add_argument('--smooth', action='store_true', help='Perform markov random field label-smoothing.')
parser.add_argument('-b','--beta', type=float, default=0.3, help='beta. Defaults to 0.3.')
parser.add_argument('--bias', metavar='file', help='Save biasfield volume to filename.')
parser.add_argument('--bf-order', type=int, default=4, help='Bias field order. Defaults to 4.')
parser.add_argument('--max-lab', type=int, default=3, help='Maximum number of lab . Defaults to 3.')
parser.add_argument('-e','--exclude', type=int, nargs=4, action='append', help='Exclude a set of labels.')
parser.add_argument('--unary-weight', type=int, default=5, help='Unary term weight. Defaults to 5.')
parser.add_argument('-v', '--verbose', action='store_true', help='Verbose output.')
args = parser.parse_args()

# ------ Setup ------

# Options
rho = args.rho
beta = args.beta
bias_field_order = args.bf_order
max_lab = args.max_lab
unary_term_weight = args.unary_weight

# Constants
downsample_factor = 10
min_param_change = 2

# Check the output folder
outdir = os.path.dirname(os.path.abspath(args.out))
if not os.path.isdir(outdir):
    fs.errorExit('output directory %s does not exist' % outdir)

# Load the fixed image
fixed_image = nib.load(args.image)
fixed_vol = fixed_image.get_data()
fixed_affine = fixed_image.affine

# Locate main component of the fixed input
clusters, numclusters = ndimage.measurements.label(fixed_vol > 0)
if numclusters == 0:
    fs.errorExit('cannot find main component in fixed volume')
largest_label = np.argmax(np.bincount(clusters.flat)[1:]) + 1
mask = ndimage.binary_fill_holes(clusters == largest_label)

# Crop the mask to the size of the primary component (with 2-voxel padding) to speed things up
cropping = ndimage.find_objects(mask)[0]
cropping = expandCropping(cropping, mask.shape, 2)
mask = mask[cropping]

# Flatten the brainmatter mask voxels to a 1D array
nonzeros = fixed_vol[cropping][mask]
num_nonzeros = len(nonzeros)
it = np.interp(nonzeros, (nonzeros.min(), nonzeros.max()), (1, 1000))

# Load and crop segmentations
moving_vols = [nib.load(fname).get_data()[cropping] for fname in args.segs]
num_opinions = len(moving_vols)

# Extract unique labels
labels = np.unique(moving_vols)
num_labels = len(labels)
print('found %d unique labels' % num_labels)

# Create mask for excluded labels
min_vox = 100
kill_mask = np.full((num_opinions, num_nonzeros), False)
for label in args.exclude:
    counts = np.sum([label[0] in vol and label[1] in vol for vol in moving_vols])
    if counts > 0 and counts < num_opinions and num_opinions > 1:
        for n, vol in enumerate(moving_vols):
            if np.count_nonzero(vol == label[0]) < min_vox and np.count_nonzero(vol == label[1]) < min_vox:
                kill_mask[n] = np.logical_or(kill_mask[n], np.logical_or.reduce([vol[mask] == l for l in label]))

# ------ Compute Label Distance Transforms and Probabilities ------

print('computing distance maps and prior label probabilities')
prob_inds = []
prob_vals = []
mv = np.zeros((num_labels, num_nonzeros))
dist_map = np.zeros((num_opinions, num_labels, num_nonzeros))
for n, vol in enumerate(moving_vols):
    print(' %d) processing %s' % (n+1, args.segs[n]))
    aux = np.zeros((num_labels, num_nonzeros))
    for l, label in enumerate(labels):
        labelmask = vol == label
        if np.any(labelmask):
            dmap = computeDistanceTransform(labelmask)[mask]
        else:
            print('    note: label %d not found in volume' % label)
            dmap = np.ones(num_nonzeros) * -200
        aux[l, :] = dmap * rho
    aux -= np.amax(aux, axis=0)
    aux  = np.exp(aux)
    aux /= np.sum(aux, axis=0)
    # Compute distance map
    inds = np.argsort(aux, axis=0)[::-1][:max_lab]
    vals = aux[inds, np.arange(aux.shape[1])[np.newaxis]]
    dist_map[n] = renderDistanceMap(vals / np.sum(vals, axis=0), inds)
    aux[:, kill_mask[n]] = 0
    mv += aux
mv /= np.sum(mv, axis=0)

# ------ Prepare for Optimization ------

# Initialize means and variances
mus = np.sum(it * mv, axis=1) / np.sum(mv, axis=1)
variances = (1 * 1 + np.sum(mv * (it - mus[:, np.newaxis]) ** 2, axis=1)) / (1 + np.sum(mv, axis=1))
sigmas = np.sqrt(variances)

# Prepare basis functions for the bias field and initialize weights
psi = prepBiasFieldBase(mask.shape, bias_field_order)
bfc = np.zeros(psi.shape[3])

# Masked PSI
psiv = psi[mask, :]

# Downsample for bias field correction
num_nonzeros_down = int(np.round(num_nonzeros/downsample_factor))
downsample_ind = np.random.permutation(num_nonzeros)[:num_nonzeros_down]
dist_map_down = dist_map[:, :, downsample_ind]
it_down = it[downsample_ind]
psiv_down = psiv[downsample_ind]

# Initialize the intensity likelihood and allocate posteriors
ibf = np.copy(it)
pil_theta = gaussianPDFsmooth(ibf)

# Initialize Q
QM = np.sum(dist_map * pil_theta, axis=1)
QM[kill_mask] = 0
with np.errstate(divide='ignore', invalid='ignore'):
    QM /= np.sum(QM, axis=0)
QM[np.isnan(QM)] = 1 / num_opinions
QM = np.swapaxes(QM, 0, 1)

# ------ Main Optimization Loop ------

print('optimizing for intensity parameters and Q (E-step)')
max_iter = 20
for iteration in range(max_iter):
    print('running iteration %d' % (iteration + 1))

    mus_old = np.copy(mus)
    variances_old = np.copy(variances)

    # ----- Step A: optimize for intensity parameters -----
    
    # Update means and variances
    labpost = computeLabPost()
    labpost[~labpost.any(axis=1)] = 1
    mus = np.sum(ibf * labpost, axis=1) / np.sum(labpost, axis=1)
    variances = (1 * 1 + np.sum(labpost * (ibf - mus[:, np.newaxis]) ** 2, axis=1)) / (1 + np.sum(labpost, axis=1))
    sigmas = np.sqrt(variances)

    # Optimize bias field correction
    bfc = minimize(singleChannelCostGrad, bfc, method='BFGS', jac=True, tol=1e-4, options={'maxiter':max_iter+3-iteration}).x

    # Apply correction
    ibf = it * np.exp(np.sum(bfc * psiv, axis=1))

    # Update image intensity likelihood term
    pil_theta = gaussianPDFsmooth(ibf)

    # ----- Step B: optimize for Q (E-step) -----

    ct = np.sum(dist_map * pil_theta, axis=1)
    ct = np.swapaxes(ct, 0, 1)

    for q_iteration in range(10):
        QM_old = np.zeros(mask.shape + (num_opinions,))
        QM_old[mask, :] = QM

        if beta > 0:
            s = np.zeros(mask.shape + (num_opinions,))
            for dim in range(3):
                s += np.roll(QM_old,  1, axis=dim)
                s += np.roll(QM_old, -1, axis=dim)
        QM = np.ones((num_opinions, num_nonzeros))
        QM[kill_mask] = 0
        if beta > 0:
            QM *= np.exp(beta * np.swapaxes(s[mask, :], 0, 1))
        QM = np.swapaxes(QM, 0, 1)
        QM *= ct
        QM /= (np.finfo(float).eps + np.sum(QM, axis=1)[:, np.newaxis])

        # Termination condition
        inc = np.mean(np.abs(QM_old[mask, :] - QM))
        print('    mean Q increment: %.6f' % inc)
        if inc < 1e-4:
            break

    max_change = np.max((np.max(np.abs(mus_old - mus)), np.sqrt(np.max(np.abs(variances_old - variances)))))
    print('    max change: %.2f' % max_change)
    if iteration > 1 and max_change < min_param_change:
        break

# Once we have the parameters, we only have to optimize the labels
print('collecting label contributions from each atlas')

# Compute final labpost
labpost = computeLabPost()

# Create optimal segmentation
seg = np.zeros(fixed_vol.shape)
seg[cropping][mask] = labels[np.argmax(labpost, axis=0)]

# Markov random field label smoothing
if args.smooth:
    print('performing markov random field label-smoothing')
    ct = np.ones(mask.shape + (num_labels,))
    for l in range(num_labels):
        ct[mask, l] = -np.log(1e-12 + labpost[l] * (1 - 1e-12)) * unary_term_weight
    u = maxflow(np.asfortranarray(ct), 100, 0, 0.25, 0.11)
    seg = np.zeros(fixed_vol.shape)
    seg[cropping][mask] = labels[np.argmax(u, axis=3)][mask]

saveVolume(seg, args.out)

# Save bias field
if args.bias:
    bf = np.exp(np.sum(bfc * psi, axis=3))
    vol = np.zeros(fixed_vol.shape)
    vol[cropping] = bf
    saveVolume(vol, args.bias)
