#!/usr/bin/perl -w

# File    : evalloss.pl
# Date    : 7 Feb 2006
# Version : 1.05
# Author  : Nicola L.C. Talbot
# Purpose : Perl script to evaluate MSE and NLPD for predictive uncertainty competition
# Syntax  : perl evalloss.pl <target file> <output file>

die "Syntax: evalloss.pl <target file> <output file>\n" if ($#ARGV != 1);

# Find out the dataset from the name of the targets file
if ($ARGV[0]=~m/([^\\\/]+)_(test|valid|train)\.targets/)
{
   $dataset=lc($1);
}
else
{
   die "Target file name has invalid format (<dataset>_<type>.targets expected)\n";
}

# minimum bin sizes 
# (quantile widths and Gaussian mixture standard deviations
# can't be smaller than these values)
$minbinsize{'precip'} = 0.0393/2;
$minbinsize{'so2'}    = 2.79/2;
$minbinsize{'temp'}   = 0.0153/2;
# if you want to add another dataset, add the appropriate line:
# $minbinsize{<dataset>}=<min val>;

open(TARGET, $ARGV[0]) or die "Can't open $ARGV[0]: $!\n";
open(OUTPUT, $ARGV[1]) or die "Can't open $ARGV[1]: $!\n";

$eps=1e-323;

$mse  = 0;
$nlpd = 0;
$n    = 0;

while (<TARGET>)
{
   if (m/^\s*(\+|\-)?(\d+\.?\d+)([eE][\+\-]\d+)?\s*$/)
   {
      local $sign=(defined($1)?$1:'+');
      local $exponent=(defined($3)?$3:'');
      $y[$n] = "$sign$2$exponent";
   }
   else
   {
      die "Invalid format for '$ARGV[0]' on line $. - a single number expected\n";
   }
   
   $_ = <OUTPUT>;
   die "Insufficent lines in $ARGV[1]\n" unless (defined($_));

   if (m/^\s*([\deE\.\+-])\s+([\deE\.\+-]+)\s+([\d\.eE\+-]+)\s+/)
   {
      if ($1 == 1)
      {
         $mean = $2;
         $prob_y = &normal($y[$n],$mean,$3);
      }
      elsif ($1 == 2) # Gaussian mixture
      {
         s/^\s*2+(\.0+)?\s+//;
         my @data = split /\s+/;

         if ($#data%3 == 0)
         {
            close OUTPUTS;
            close TARGETS;
            die "$ARGV[1]: Line $. Gaussian mixtures must be expressed as a list of triplets:\n";
         }

         ($mean, $prob_y) = &gaussian_mixture($y[$n],@data);
      }
      elsif ($1 == 0)
      {
         s/^\s*0+(\.0+)?\s+//;
         my @data = split /\s+/;

         if ($#data%2 == 0)
         {
            close OUTPUTS;
            close TARGETS;
            die "$ARGV[1]: Line $. quantiles must be expressed as a list of pairs:\n";
         }
               
         $mean = &quantile_mean(@data);
         $prob_y = &quantile_prob($y[$n],@data);
      }
      else
      {
         close OUTPUTS;
         close TARGETS;
         die "$ARGV[1]: Line $. unknown identifier $1: $_";
      }

      $mse += (($y[$n] - $mean)*($y[$n]-$mean));
      
      # can't have the log of anything smaller than $eps
      $prob_y = $eps if ($prob_y < $eps);
      
      $nlpd -= log($prob_y);
      $n++;
   }
   else
   {
      close OUTPUTS;
      close TARGETS;
      die "$ARGV[1]: Line $. has invalid format: $_";
   }
}

close TARGET;

die "Too many lines in '$ARGV[1]'\n" if (<OUTPUT>);

close OUTPUT;

$var = &variance(@y);

die "Can't divide by zero!\n" if ($n==0);

$mse  /= ($n*$var);
$nlpd /= $n;

print "MSE  = $mse\n";
print "NLPD = $nlpd\n";


sub normal{
   local($x,$mean,$var)=@_;
   local($root2pi)=2.50662827463100050241576528481105;
   
   return(exp(-0.5*($x-$mean)*($x-$mean)/$var)/(sqrt($var)*$root2pi));
}

sub mean{
   local(@y)=@_;
   local($sum) = 0.0;
   local($n);
   
   for ($n = 0; $n <= $#y; $n++)
   {
      $sum += $y[$n];
   }
   
   $sum/$n;
}

sub variance{
   local(@y)=@_;
   local($sum)=0.0;
   local($n);
   local($m) = &mean(@y);

   for ($n = 0; $n <= $#y; $n++)
   {
     $sum += ($y[$n] - $m)*($y[$n] - $m);
   }
   
   $sum/($n - 1);
}

sub quantile_mean{
   local(@data)=@_;
   my (@alpha,@q,$i,$j);
   local($m_q)=0;
   local($m_ut)=0;
   local($m_lt)=0;
               
   for ($i = 0,$j=0; $i <= $#data; $i+=2,$j++)
   {
      $alpha[$j] = $data[$i];
      $q[$j] = $data[$i+1];
   }
   
   if (defined($minbinsize{$dataset}))
   {
      &check_min_bin_size($dataset, $minbinsize{$dataset},@q);
   }
   
   for ($i = 0; $i < $#q; $i++)
   {
      $m_q += 0.5*($q[$i] + $q[$i+1])*($alpha[$i+1] - $alpha[$i]);
   }
   
   my $z1 = ($alpha[1]-$alpha[0])/($q[1]-$q[0]);
   
   $m_lt = $alpha[0]*($q[0]-$alpha[0]*$alpha[0]/$z1);
   
   my $zN = ($alpha[$#alpha]-$alpha[$#alpha-1])/($q[$#q]-$q[$#q-1]);
   
   $m_ut = (1-$alpha[$#alpha])*($q[$#q] + (1-$alpha[$#alpha])*(1-$alpha[$#alpha])/$zN);
   
   $m_q + $m_lt + $m_ut;
}

sub quantile_prob{
   local($y, @data)=@_;
   my (@alpha,@q,$i,$j);
   local $prob_y;
               
   for ($i = 0,$j=0; $i <= $#data; $i+=2,$j++)
   {
      $alpha[$j] = $data[$i];
      $q[$j] = $data[$i+1];
   }

   my $z1 = ($alpha[1]-$alpha[0])/($q[1]-$q[0]);
   my $zN = ($alpha[$#alpha]-$alpha[$#alpha-1])/($q[$#q]-$q[$#q-1]);

   if ($y < $q[0])
   {
      $prob_y = $z1*exp(-$z1*abs($y-$q[0])/$alpha[0]);
   }
   elsif ($y >= $q[$#q])
   {
      $prob_y = $zN*exp(-$zN*abs($y-$q[$#q])/(1-$alpha[$#alpha]));
   }
   else
   {
      for ($i = 0; $i < $#q; $i++)
      {
         if ($y >= $q[$i] and $y < $q[$i+1])
         {
            $prob_y = ($alpha[$i+1]-$alpha[$i])/($q[$i+1]-$q[$i]);
            last;
         }
      }
   }
     
   $prob_y;   
}

sub gaussian_mixture{
   local($y,@data) = @_;
   local($mean, $prob_y,$sum_w);

   $mean=0;
   $prob_y = 0;
   $sum_w = 0;

   for (my $i = 0; $i < @data; $i+=3)
   {
      if ($data[$i+2] < 0 or $data[$i+2] > 1)
      {
         die "$ARGV[1] Line $.: Gaussian mixture weights must lie between 0 and 1\n";
      }

      # mean = sum(m_i * w_i)
      $mean   += $data[$i]*$data[$i+2];
      # prob_y = sum(prob_yi * w_i);
      $prob_y += &normal($y,$data[$i],$data[$i+1])*$data[$i+2];
      $sum_w  += $data[$i+2];

      # check standard deviation is not smaller than minbinsize
      if (defined($minbinsize{$dataset}))
      {
         if ($data[$i+1] < 4*$minbinsize{$dataset}*$minbinsize{$dataset})
         {
            die "$ARGV[1] Line $.: Gaussian mixture standard deviations must be greater than $minbinsize{$dataset} for $dataset dataset\n";
         }
      }
   }

   if ($sum_w < 1-1e-5 or $sum_w > 1+1e-5)
   {
      die "$ARGV[1] Line $.: Gaussian mixture weights must sum to 1 (sum(w) = $sum_w)\n";
   }

   return ($mean, $prob_y);
}

sub check_min_bin_size{
   local($dataset,$minsize, @q)=@_;

   for (my $i = 1; $i <= $#q; $i++)
   {
      if ($q[$i]-$q[$i-1] < $minsize)
      {
         die "'$dataset' dataset can't have quantile width less than $minsize\n",
             "Line $.: q[$i]-q[",$i-1,"]=",$q[$i]-$q[$i-1],")\n";
      }
   }   
}

1;
