# Blanche Buet code


 py"""
import numpy as np
import matplotlib.pyplot as plt
#from sklearn.metrics.pairwise import euclidean_distances
#from sklearn.neighbors import KDTree
from scipy import spatial
from mpl_toolkits import mplot3d

class PointCloudVarifold:
    
    def __init__(self, Ntot=10, d=2, n=3):
        self.varifoldDim = d
        self.ambientDim = n
        self.size = Ntot
        
    def loadFromArray(self, X):
        Ntot, n = X.shape
        d = self.varifoldDim
        self.size = Ntot
        self.ambientDim = n
        self.pts = X
        self.mass = np.ones(Ntot)
        self.tgtProj = np.zeros((Ntot, n, n))
        self.tgtBasis = np.zeros((Ntot, n, d))
        self.normal = np.zeros((Ntot, n))
        self.curvature = np.zeros((Ntot, n))
        
    def computeKDTree(self):
        self.tree = spatial.KDTree(X)
    
    def regressionKnn(self, kT, kernel, massWeights = True, truncateToProjector = True):
        X = self.pts
        Ntot = self.size
        n = self.ambientDim
        if massWeights:
            m = self.mass
        else:
            m = np.ones(Ntot)
        tree = self.tree     
        for i in range(Ntot):
            x = X[i, :]
            dist, ind = tree.query(x, kT)
            r = dist[-1] # radius of the ball containing points
            M = 0
            for j in range(kT):
                z = X[i, :] - X[ind[j], :]
                M = M + m[ind[j]]*kernel(np.linalg.norm(z)/r)*np.outer(z,z)
            if truncateToProjector: # assume codimension 1
                w, v = np.linalg.eigh(M)
                self.tgtProj[i] = np.identity(n) - np.outer(v[:,0],v[:,0])
                self.normal[i] = v[:,0]
                self.tgtBasis[i] = v[:,1:3]
            else:
                self.tgtProj[i] = M/kT
            if i%10000 == 0:
                print(str(i)+'-')
                
    def computeMassKnn(self, kM, kernelOpt = False):
        X = self.pts
        Ntot = self.size
        d = self.varifoldDim
        tree = self.tree 
        for i in range(Ntot):
            x = X[i, :]
            dist, ind = tree.query(x, kM)
            r = dist[-1] # radius of the ball containing points
            self.mass[i] = r**d/kM # up to dimensional constant d-volume of unit ball
            
    def regressionRadius(self, r, kernel, massWeights = True):
        X = self.pts
        Ntot = self.size
        if massWeights:
            m = self.mass
        else:
            m = np.ones(Ntot)
        tree = self.tree
        for i in range(Ntot):
            x = X[i, :]
            ind = tree.query_ball_point(x, r)
            k = len(ind)
            M = 0
            for j in range(k): 
                z = X[i, :] - X[ind[j], :]
                M = M + m[ind[j]]*kernel(np.linalg.norm(z)/r)*np.outer(z,z)
            self.tgtProj[i] = M/k

    def computeSFFKnn(self, kT, kernelPrime, kernelXi):
        X = self.pts
        Ntot = self.size
        n = self.ambientDim
        #self.varifoldDim = n-1 # codim 1
        d = n-1
        self.sff = np.zeros((Ntot,d,d))
        self.gaussCurvature = np.zeros(Ntot)
        self.meanCurvature = np.zeros((Ntot,n))
        tree = self.tree  
        Bijk = np.zeros((n,n,n))
        for i in range(Ntot):
            x0 = X[i, :]
            T0 = self.tgtProj[i]
            dist, ind = tree.query(x0, kT)
            eps = dist[-1] # radius of the ball containing points
            Bijk = np.zeros((n,n,n))
            Bij = np.zeros((n,n))
            denom = 0
            for jneigh in range(1,kT):
                z = x0 - X[ind[jneigh], :]
                deltaZ = dist[jneigh]
                znormalized = z/ deltaZ
                w = self.varifoldDim*kernelPrime(deltaZ/eps) # multp by mass m_j
                denom += kernelXi(deltaZ/eps) # multp by mass m_j, Xi(0) = 0 for pairs otherwise add j=0
                T = self.tgtProj[ind[jneigh]]
                deltaT = T - T0
                for ii in range(n):
                    for jj in range(n):
                        for kk in range(n):
                            Bijk[ii,jj,kk] += w*np.vdot(znormalized, 0.5*(deltaT[jj,kk]*T[:,ii] + deltaT[ii,kk]*T[:,jj] - deltaT[ii,jj]*T[:,kk] ))
                        Bij[ii,jj] = np.dot(Bijk[ii,jj,:], self.normal[i])
            Bij /= eps*denom
            self.sff[i] = np.transpose(self.tgtBasis[i]).dot(Bij.dot(self.tgtBasis[i]))
            self.gaussCurvature[i] = np.linalg.det(self.sff[i])
            self.meanCurvature[i] = np.trace(self.sff[i])*self.normal[i]




def standard(x):
    y = (x<1)*x
    return (x<1)*np.exp(-1./(1-y**2))

def standardPrime(x):
    y = (x<1)*x
    return (x<1)*(-2*y)/(1-y**2)**2*np.exp(-1./(1-y**2))

def standardPair(x): # it remains to divide by n when used with its pair standardPrime or adjust constant in computSFFKnn
    y = (x<1)*x
    return (x<1)*(2)*y**2/(1-y**2)**2*np.exp(-1./(1-y**2))

def constant(x):
    return (x<1)*1


 """


function get_mass_normal_vector(
    sample
)

sample2 = collect(transpose(sample))

py"""
V = PointCloudVarifold()
X = $sample2
V.loadFromArray(X)
V.computeKDTree()
V.computeMassKnn(10)
V.regressionKnn(20, standard)
#V.computeSFFKnn(20, standardPrime, standardPair)
"""
mass = py"V.mass"
normal_vectors = collect(transpose(py"V.normal"))

return([ mass , normal_vectors ])
end
#py"V.meanCurvature"
#py"V.gaussCurvature"


