#Figure 5: Shape detection


include("src/socg.jl")



sample_size = 50
data = vcat([ ( generate_uniform_segment(sample_size, [0,0], [1,0]) , 0 ) for nrep in 1: length(1:0.1:2)] , [ ( generate_lp_ball(sample_size,p) , 1) for p in 1:0.1:2] );
data1_ = [(vec(data[i][1]) , data[i][2]) for i in 1:length(data) ];
data2 = [ (Distances.pairwise(Distances.SqEuclidean(1e-12), data[i][1], dims = 2),data[i][2]) for i in 1:length(data)];
data2_ = [ (vec(data2[i][1]), data[i][2] ) for i in 1:length(data)];
data3 = [ (real.(vcat(characteristic_function_boot( data2[i][1] , 1000, 5, [0.5])...)) , data2[i][2]) for i in 1:length(data)];

list_images_ = Vector{Any}(undef,length(data))

for (idx,dat_) in enumerate(data)
    dt = dat_[1]
    indices = [[sum([dt[1,i]>=x for x in range(-1,1,28)]),sum([dt[2,i]>=x for x in range(-1,1,28)])] for i in 1:size(dt)[2]]
    image_ = zeros(28,28,1,1)
    for i in 1:size(dt)[2]
        image_[indices[i][1], indices[i][2],1,1] += 1
    end
    list_images_[idx] = (image_ , data[idx][2])
end




model1 = Chain(
    Dense(100=>15,relu),
    Dense(15=>1, sigmoid),
    only
);
optim1 = Flux.setup(Adam(), model1)
for epoch in 1:1000
  Flux.train!((m,x,y) -> (m(x) - y)^2, model1, data1_, optim1)
end;


model2 = Chain(
    Dense(2500=>15,relu),
    Dense(15=>1, sigmoid),
    only
);
optim2 = Flux.setup(Adam(), model2)
for epoch in 1:1000
  Flux.train!((m,x,y) -> (m(x) - y)^2, model2, data2_, optim2)
end;


model3 = Chain(
    Dense(4=>15,relu),
    Dense(15=>1, sigmoid),
    only
);
optim3 = Flux.setup(Adam(), model3)
for epoch in 1:1000
  Flux.train!((m,x,y) -> (m(x) - y)^2, model3, data3, optim3)
end;


model_CNN = Chain(
    # First convolutional block
    Conv((5, 5), 1=>6, relu),   # 28x28x1 -> 24x24x6 (assuming no padding)
    MaxPool((2, 2)),          # 24x24x6 -> 12x12x6

    # Second convolutional block
    Conv((5, 5), 6=>16, relu),  # 12x12x6 -> 8x8x16
    MaxPool((2, 2)),          # 8x8x16  -> 4x4x16

    # Flatten output and feed to dense layers
    Flux.flatten,             # 4x4x16  -> 256
    Dense(256, 120, relu),    # 256     -> 120
    Dense(120, 84, relu),     # 120     -> 84
    Dense(84, 1, sigmoid),             # 84      -> 10 (logits for 10 classes)
    only
)
optim_CNN = Flux.setup(Adam(), model_CNN)
for epoch in 1:1000
  Flux.train!((m,x,y) -> (m(x) .- y)^2, model_CNN, list_images_, optim_CNN)
end;


vec_p = [p for p in 1:0.1:2]

mse1 = zeros(length(vec_p))
mse2 = zeros(length(vec_p))
mse3 = zeros(length(vec_p))
mse_CNN = zeros(length(vec_p))

MC = 100

for (idx, p) in enumerate(vec_p)
    for mc in 1:MC
        x1 = generate_lp_ball(sample_size,p) #.+ rand(d, 2, sample_size)
        x1_ = vec(x1)
        x2 = Distances.pairwise(Distances.SqEuclidean(1e-12), x1, dims = 2)
        x2_ = vec(x2)
        x3 = real.(vcat(characteristic_function_boot( x2 , 1000, 5, [0.5])...))

        dt = x1
        indices = [[sum([dt[1,i]>=x for x in range(-1,1,28)]),sum([dt[2,i]>=x for x in range(-1,1,28)])] for i in 1:size(dt)[2]]
        image_ = zeros(28,28,1,1)
        for i in 1:size(dt)[2]
            image_[max(1,indices[i][1]), max(1,indices[i][2]),1,1] += 1
        end

        x_CNN = image_

        mse1[idx] += (1 - model1(x1_))^2
        mse2[idx] += (1 - model2(x2_))^2
        mse3[idx] += (1 - model3(x3))^2
        mse_CNN[idx] += (1 - model_CNN(x_CNN))^2
    end
end

mse1 = mse1./MC
mse2 = mse2./MC
mse3 = mse3./MC
mse_CNN = mse_CNN./MC

plot(vec_p, mse1)
plot!(vec_p, mse2)
plot!(vec_p, mse3)
plot!(vec_p, mse_CNN)

MSE = hcat(mse3,mse1,mse2,mse_CNN)

np = length(vec_p)

ls_ = hcat(repeat([:solid],np),repeat([:dashdot],np),repeat([:dot],np),repeat([:dash],np))
label_ = hcat(["characteristic functions","data coordinates", "distance matrices", "CNN"]...)

plot(vec_p,MSE,xlab = "p",ylab = "MSE",label = label_, ls = ls_, title = "MSE for classifier, ball vs segment", ylims = [-0.1,1.1])








spl = generate_several_shapes([3000,1500,500,1000,500], [[1.7, 4.5, [0,0]],[2, 2.5, [4,0]],[1.2, 1, [-2,2]],[1.8, 2, [1,-1]],[1.3, 0.8, [-3,-3]]])
clust = tomato_clustering(spl)

clusters = [  spl[:,clust .== i] for i in 0:10];

I = 0

labels = [1*(model3(real.(vcat(characteristic_function_boot( Distances.pairwise(Distances.SqEuclidean(1e-12), clusters[ I + 1], dims = 2), 1000, 5, [0.5])...))) >= 0.5) for I in 0:10]

assignment = [labels[Int(clust[i] + 1)] for i in 1:length(clust)];

spl = collect(transpose(spl))

py"""
plt.scatter($spl[:,0],$spl[:,1],marker='.',s=1,c=$assignment)
plt.show()
"""
