function [net,P,g] = train(net, x, y, varargin)

% TRAIN
%
%    [net,Q,g] = train(net, x, y, ...)

%
% File        : @lssvc/train.m
%
% Date        : Friday 31st October 2003
%
% Author      : Dr Gavin C. Cawley
%
% Description :
%
% References  :
%
% History     : 31/10/2003 - v1.00
%
% Copyright   : (c) Dr Gavin C. Cawley, October 2003.
%

% get kernel matrix (and perhaps partial derivatives w.r.t. kernel parameters)

if nargout == 3

   [K,dK] = evaluate(net.kernel, x, x);

else

   K = evaluate(net.kernel, x, x);

end

% train least-squares support vector machine

ntp   = size(x,1);
one   = ones(ntp,1);
zero  = zeros(ntp,1);
H     = K + (net.lambda+1e-9)*eye(ntp);
net.x = x;
[R,p] = chol(H);

if p == 0

   xi            = R\(R'\[y one]);
   eta           = xi(:,2);
   nu            = xi(:,1);
   oneoversumeta = 1/sum(eta);
   net.b         = oneoversumeta*sum(nu);
   net.alpha     = nu - eta*net.b;

   % optionally perform leave-one-out cross-validation

   if nargout > 1

      Ri       = inv(R);
      s        = sum(Ri.^2,2) - oneoversumeta*eta.^2;
      r        = net.alpha./s;
      [Q,dPdr] = evaluate(net.criterion, r, y);
      omicron  = get(net.kernel, 'eta');
      d        = length(omicron);
      omega    = 0.5*sum(omicron.^2); 
      P        = 0.5*ntp*log(Q) + 0.5*d*log(omega+eps);
 
      % optionally compute gradient information

      if nargout > 2

         % set up a few useful scaling factors etc.

         A    = Ri*Ri' - (oneoversumeta*eta)*eta';
         xi   = dPdr./s; 
         zeta = xi.*r;

         % derivative w.r.t. regularisation parameter

         dCi    = diag(A'*A);
         dalpha = (A*net.alpha);
         g(1)   = 0.5*ntp*log(2)*net.lambda*sum(xi.*dalpha - zeta.*dCi)/Q;

         % derivative w.r.t. kernel parameter(s)

         for i=1:size(dK,1)

            deekay = squeeze(dK(i,:,:));
            dCi    = diag((A*deekay)*A);
            dalpha = (A*(deekay*net.alpha));
            g(i+1) = 0.5*ntp*sum(xi.*dalpha - zeta.*dCi)/Q ...
                   + 0.5*d*log(2)*(omicron(i)^2)/omega;

         end

      end

   end

else

   status = warning('off');
   lastwarn('');
   Ci = inv([H one ; one' 0]);

   if ~isempty(lastwarn)

      Ci = pinv([H one ; one' 0]);

   end

   warning(status);

   w         = Ci*[y ; 0];
   net.alpha = w(1:end-1);
   net.b     = w(end);

   if nargout > 1

      cii       = diag(Ci);
      cii       = cii(1:ntp);
      r         = net.alpha./max(cii,eps);
      [P,dPdr]  = evaluate(net.criterion, r, y);

      if nargout > 2

         % derivative w.r.t. regularisation parameter

         dk     = [eye(ntp) zero ; zero' 0];
         dalpha = -Ci*[net.alpha ; 0];
         dalpha = dalpha(1:ntp);
         dcii   = -diag((Ci*dk)*Ci);
         dcii   = dcii(1:ntp);
         dr     = (dalpha - r.*dcii)./cii;  
         g(1)   = -log(2)*net.lambda*(dr'*dPdr)/ntp;

         % derivative w.r.t. kernel parameter(s)

         for i=1:size(dK,1)

            dk     = squeeze(dK(i,:,:));
            dk     = [dk zero ; zero' 0];
            dcii   = diag((Ci*dk)*Ci);
            dcii   = dcii(1:ntp);
            dw     = (Ci*(dk*w));
            dalpha = dw(1:ntp);
            dr     = (dalpha - r.*dcii)./cii;
            g(i+1) = (dr'*dPdr)/ntp;

         end

      end

   end

end

% bye bye...

