#Figure 4 : Digits classification


include("src/socg.jl")


py"""

import matplotlib.pyplot as plt

import numpy as np
from sklearn.datasets import load_digits
from sklearn import preprocessing 
from sklearn.model_selection import train_test_split

digits = load_digits()

sample_index = 45
plt.figure(figsize=(3, 3))
plt.imshow(digits.images[sample_index], cmap=plt.cm.gray_r,
           interpolation='nearest')
plt.title("label: %d" % digits.target[sample_index]);

digits.data_nor = preprocessing.scale(digits.data)

"""


M = 5 #2
vect_t = [0.5] #[0.1, 0.2, 0.5, 1.0, 5.0, 10.0, 100.0]

Nrep = 50
errors = Vector{Any}(undef, Nrep)

for nrep in 1:Nrep

  print(nrep)

  py"""
  X_train_nor, X_test_nor, y_train, y_test = train_test_split(
      digits.data_nor, digits.target, test_size=0.2)
  """

  train_sample_size = py"y_train.shape[0]"

  X_train = collect(transpose(py"X_train_nor"))
  y_train = py"y_train"

  trainset = [ (Float32.(reshape(X_train[:,i],8,8,1,1)) , y_train[i] ) for i in 1:train_sample_size];

  test_sample_size = py"y_test.shape[0]"

  X_test = collect(transpose(py"X_test_nor"))
  y_test = py"y_test"

  testset = [ (Float32.(reshape(X_test[:,i],8,8,1,1)) , y_test[i] ) for i in 1:test_sample_size];

  model_CNN = Chain(
      # First convolutional block
      Conv((3, 3), 1=>6, relu),   # 8x8x1 -> 6x6x6 # 28x28x1 -> 24x24x6 (assuming no padding) (5,5)
      MaxPool((2, 2)),          # 6x6x6 -> 3x3x6 # 24x24x6 -> 12x12x6

      # Flatten output and feed to dense layers
      Flux.flatten,             # 3x3x6  -> 54
      Dense(54, 20, relu),    # 256     -> 120
      Dense(20, 10),             # 84      -> 10 (logits for 10 classes)
      softmax
  )

  optim_CNN = Flux.setup(Adam(), model_CNN)
  for epoch in 1:20
    Flux.train!((m,x,y) -> 1*(argmax(vec(m(x))) != y), model_CNN, trainset, optim_CNN)
  end;

  model_NN = Chain(
    Flux.flatten,
    Dense(64, 128, relu),    # 64     -> 128
    Dense(128, 10),             # 84      -> 10 (logits for 10 classes)
    softmax
  )

  optim_NN = Flux.setup(Adam(), model_NN)
  for epoch in 1:20
    Flux.train!((m,x,y) -> 1*(argmax(vec(m(x))) != y), model_NN, trainset, optim_NN)
  end;

  traindata_ = Vector{Any}(undef,length(y_train))

  for N in 1:length(trainset)
      img = Int.(floor.(trainset[N][1] .- minimum(hcat(X_train,X_test))))
      pairs = vcat(vec(hcat([[repeat([i,j] , img[i,j]) for i in 1:8] for j in 1:8]...))...)
      data = collect(transpose(hcat(pairs[1:2:length(pairs)], pairs[2:2:length(pairs)])))
      mat = Distances.pairwise(Distances.SqEuclidean(1e-12), data, dims = 2) #distance_matrix(data,0)
      traindata_[N] = ( Float32.(real.(vcat(characteristic_function_boot( mat , 1000, M, vect_t)...))), trainset[N][2])
  end

  testdata_ = Vector{Any}(undef,length(y_train))

  for N in 1:length(testset)
      img = Int.(floor.(testset[N][1] .- minimum(hcat(X_train,X_test))))
      pairs = vcat(vec(hcat([[repeat([i,j] , img[i,j]) for i in 1:8] for j in 1:8]...))...)
      data = collect(transpose(hcat(pairs[1:2:length(pairs)], pairs[2:2:length(pairs)])))
      mat = Distances.pairwise(Distances.SqEuclidean(1e-12), data, dims = 2) #distance_matrix(data,0)
      testdata_[N] = ( Float32.(real.(vcat(characteristic_function_boot( mat , 1000, M, vect_t)...))), testset[N][2])
  end

  model_char = Chain(
      Dense(4=>15,relu),
      Dense(15=>10, sigmoid),
      softmax
  );
  optim_char = Flux.setup(Adam(), model_char)
  for epoch in 1:20
    Flux.train!((m,x,y) ->  1*(argmax(vec(m(x))) != y), model_char, traindata_, optim_char)
  end;

  test_error_CNN = 0
  test_error_NN = 0
  test_error_char = 0

  for N in 1:length(testset)
    test_error_CNN += 1*(argmax(vec(model_CNN(testset[N][1]))) != testset[N][2])
    test_error_NN += 1*(argmax(vec(model_NN(testset[N][1]))) != testset[N][2])
    test_error_char += 1*(argmax(vec(model_char(testdata_[N][1]))) != testdata_[N][2])
  end

  errors[nrep] = [test_error_CNN, test_error_NN, test_error_char] ./(test_sample_size + train_sample_size)

end






color_ = reshape(vcat(repeat(["gray80"],50),repeat(["gray50"],50),repeat(["gray20"],50)),50,3)
label_ = hcat(["method CNN images", "method NN images", "method charact. function"]...)
linewidth_ = reshape(vcat(repeat([1],100),repeat([2],50)),50,3)

plot((1:50)./ 50,collect(transpose(sort(hcat(errors...), dims = 2))), color = color_, label = label_ , linewidth = linewidth_, title = "MNIST digits recognition,\n 0-1 error distribution and median")

color_2 = reshape(vcat(repeat(["gray80"],2),repeat(["gray50"],2),repeat(["gray20"],2)),2,3)
#label_2 = hcat(["method CNN images", "method NN images", "method charact. function"]...)
linewidth_2 = reshape(vcat(repeat([1],4),repeat([2],2)),2,3)

plot!([0, 1],collect(transpose(hcat([median(hcat(errors...), dims = 2),median(hcat(errors...), dims = 2)]...))), color = color_2, label = false, linewidth = linewidth_2, xlab = "x", ylab = "0-1 mean error, empirical quantile (x)")