#' @title CoLSTIM Learner
#'
#' @include Learner.R
#'
#' @description
#' This Learner specializes [Learner] to match the CoLSTIM algorithm
#'
COLSTIM = R6Class("COLSTIM", inherit = Learner,
                 public = list(
                   
                   #' @field M_mat (`numeric(d,d)`)\cr
                   #' Stores the Gram matrix.
                   M_mat = NULL,
                   
                   #' @field theta_hat (`numeric(d)`)\cr
                   #' Stores the current weight vector estimate.
                   theta_hat = NULL,
                   
                   #' @field z_t (`numeric(d)`)\cr
                   #' Stores the contrast vectors over the time.
                   z_t = NULL,
                                    
                   #' @field o_t (`integer()`)\cr
                   #' Stores the feedback obtained by the learner over the time.
                   o_t = NULL,
                   
                   #' @field tau_0 (`integer()`)\cr
                   #' the initial exploration length.
                   tau_0 = NULL,
                   
                   #' @field d (`integer()`)\cr
                   #' dimensionality of the problem.
                   d = NULL,
                   
                   #' @field n (`integer()`)\cr
                   #' number of arms.
                   n = NULL,
                   
                   #' @field cdf (`function()`)\cr
                   #' comparison functions used.
                   cdf = NULL,
                   
                   #' @field df (`function()`)\cr
                   #' derivative of the comparison functions used.
                   df = NULL,
                   
                   #' @field pert_dis (`integer()`)\cr
                   #' perturbation distribution used.
                   pert_dis = NULL,
                   
                   #' @field fullMLE (`Boolean()`)\cr
                   #' specifies whether the full MLE should be used.
                   fullMLE = TRUE,
                   
                   #' @field eta (`numeric()`)\cr
                   #' learning rate of the SGD variant.
                   eta = NULL,
				   
				   #' @field threshold (`numeric()`)\cr
                   #' threshold parameter.
                   threshold = NULL,
				   
				   #' @field p_t (`function()`)\cr
                   #' coupling probability.
                   p_t = NULL,
                   
                   
                   #' @description
                   #' Creates a new instance of this [R6][R6::R6Class] class.
                   initialize = function(data_model_specs = list(num_arms = 2, dim = 1), aggregation, tau_0, cdf , df,  pert_dis, fullMLE = TRUE, eta =NULL, threshold = NULL, p_t = function(x) 0){
                     super$initialize(aggregation = aggregation, action_size = 2)
                     
                     self$n            = data_model_specs$num_arms
                     self$d            = data_model_specs$dim
                     self$M_mat        = matrix(rep(0,self$d^2),ncol=self$d)
                     self$theta_hat    = rep(0,self$d)
                     self$z_t          = list()
                     self$o_t          = c()
                     self$cdf          = cdf
                     self$df           = df
                     self$pert_dis     = pert_dis
                     self$tau_0        = tau_0
                     self$fullMLE      = fullMLE
                     self$eta          = eta
                     self$threshold    = threshold
					 self$p_t 		   = p_t
                     
                   },
                   
                   action = function(data_model) {
                     
                     selection = c()
                     if (self$timestep < self$tau_0){
                       selection = sample(1:self$n,2,replace=FALSE)
                     }
                     else{
                       
                       X_t            = data_model$getContext(self$timestep)
                       
                       if(self$fullMLE){
                         self$theta_hat = MLE_estimate(feedback = self$o_t,covariates = self$z_t,comparison = self$cdf, comparison_der = self$df, theta_start = self$theta_hat)
                       }
                       else{
                         self$theta_hat = self$theta_hat + self$eta/sqrt(self$timestep)*gradLikelihood(theta=self$theta_hat, covariates = self$z_t[[self$timestep-1]], comparison = self$cdf, feedback = self$o_t[self$timestep-1] )
                       }

                       theta_tilde    = rep(0,self$d)
                       
					   
					   tilde_epsilon = rep(self$pert_dis(1),self$n)
					   B_t 			 = rbinom(1,1,p_t(self$timestep))
					   if (B_t == 1){
						 tilde_epsilon = self$pert_dis(self$n)
					   }

                       theta_tilde    = t(X_t)%*%self$theta_hat + pmax(-self$threshold,pmin(self$threshold,tilde_epsilon))* diag(sqrt(  t(X_t)%*% solve(self$M_mat) %*% X_t   ))

                       temp_max       = -Inf

					   
					   i_t			  = which(theta_tilde==max(theta_tilde), arr.ind = TRUE)[1]
					   
					   

					   for (j in 1:self$n){
						   tup_diff = X_t[,j] - X_t[,i_t]
						   tup_norm = sqrt(  t(tup_diff)%*% solve(self$M_mat) %*% tup_diff   )
                           temp         = tup_diff%*%self$theta_hat + self$threshold*tup_norm
                           if (temp > temp_max){
                             temp_max   = temp
                             selection  = c(i_t,j)
                           }
                        }	
                     }
                      
                     return (selection)
                     
                   },
                   
                   
                   update = function(chosen_arms,data_model) { 
                     
                     temp 							= 	data_model$getFeedback(c(chosen_arms[1],chosen_arms[2]),self$timestep)
                     
                     X_t 							= 	data_model$getContext(self$timestep)
                     self$z_t[[self$timestep]]		=	X_t[,chosen_arms[1]] - X_t[,chosen_arms[2]]
                     self$o_t[self$timestep]      	= 	temp
                     
                     # update estimates
                     self$M_mat      				= 	self$M_mat + outer(X_t[,chosen_arms[1]] - X_t[,chosen_arms[2]],X_t[,chosen_arms[1]] - X_t[,chosen_arms[2]])
                   }
                 )
)