This notebook plots loss functions for binary classification.

Kevin Gimpel

6/26/2020

In [None]:
import numpy as np
import matplotlib
%matplotlib inline

In [None]:
matplotlib.rcParams['text.usetex'] = 'False'
# Use Computer Modern as the math font
matplotlib.rcParams['mathtext.fontset'] = 'cm'
# Use Palatino as the text font (by default, setting Palatino as the text font causes Palatino to be the math font too)
matplotlib.rcParams['font.family'] = 'Palatino'
matplotlib.rcParams['pdf.fonttype'] = 42

In [None]:
import matplotlib.pyplot as plt

# Set the figure size and font sizes
global_params = {'legend.fontsize': 22,
 'figure.figsize': (12, 8),
 'axes.labelsize': 24,
 'axes.titlesize': 24,
 'xtick.labelsize': 20,
 'ytick.labelsize': 20,
 'text.latex.preamble' : [r'\usepackage{amsmath,amsthm,amssymb}']}
plt.rcParams.update(global_params)

In [None]:
# variables for all plots:
range_min = -4
range_max = 4
transparency = 1
my01width = 3
mywidth = 4
stepsize = 0.1
# set the value of m in the labels below, i.e., a constant representing the cost of an incorrect classification
m = 1.0

myred = '#FF2B00'
mygreen = '#49BA00'
myblue = '#06B2F4'
mypurple = '#D973EF'
myorange = '#FA9D00'
mybrown = '#CE9716'

zero_one_loss_label = r'0-1: $m\,I\,[z\leq 0]$'
perceptron_loss_label = r'perceptron: $\max(0, -\!z)$'
hinge_loss_label = r'hinge: $\max(0,m-z)$'
log_loss_label = r'log: $\log(1+\exp(-z))$'
exp_loss_label = r'exp: $\exp(-z)$'
softmax_margin_label = r'softmax-margin: $\log(1+\exp(m-z))$'
risk_label = r'risk: $\frac{m\,\exp(-z)}{1+\exp(-z)}$'
ramp_label = r'ramp: $\max(0,m-z) - \max(0,-\!z)$'
my_xlabel = r'$z = y^\ast\!\!\times f(x)$'
my_ylabel = r'Loss'

my_ymin = -0.1
my_ymax = 4.1
my_ytick_max = 5

# compute losses

# 0-1 loss (actually it's "0-m" loss):
x_01 = [range_min, 0, 0, range_max]
y_01 = [m, m, 0, 0]

# set the x range
x = np.arange(range_min, range_max + stepsize, stepsize)
# the perceptron loss:
y_perc = np.maximum(0,-x)
# hinge loss:
y_hinge = np.maximum(0,m-x)
# log loss (note: some people define log loss using base 2 which will make it intersect with the point (0,m)):
y_log = np.log(1 + np.exp(-x))
# exp loss:
y_exp = np.exp(-x)
# softmax-margin:
y_smm = np.log(1 + np.exp(m-x))
# Bayes risk:
y_risk = (m*np.exp(-x)) / (1 + np.exp(-x))
# ramp loss:
y_ramp1 = y_hinge - y_perc

# other losses that are not currently being plotted in the cells below:
# second form of ramp loss:
#y_lazyhinge = np.maximum(0,-m-x)
#y_ramp2 = y_perc - y_lazyhinge
# soft version of ramp1:
#y_sramp1 = np.log((1 + np.exp(m-x))/(1 + np.exp(-x)))
# soft version of ramp2:
#y_sramp2 = np.log((1 + np.exp(-x))/(1 + np.exp(-m-x)))
# Savage loss:
#y_savage = np.power(np.square(1 + np.exp(x)), -1)

In [None]:
fig, ax = plt.subplots()

use_smm = False

ax.plot(x_01, y_01, c='black', linewidth=my01width, label=zero_one_loss_label, alpha=0.6)
ax.plot(x, y_perc, c=myblue, linewidth=mywidth, linestyle='dashed', label=perceptron_loss_label, alpha=transparency)
ax.plot(x, y_hinge, c=myred, linewidth=mywidth, linestyle=(0,(1,1)), label=hinge_loss_label, alpha=transparency)
ax.plot(x, y_log, c=mygreen, linewidth=mywidth, linestyle='dashdot', label=log_loss_label, alpha=transparency)

if use_smm:
 ax.plot(x, y_smm, c=mypurple, linewidth=mywidth, linestyle=(0,(6,2)), label=softmax_margin_label, alpha=transparency)

plt.xlabel(my_xlabel)
plt.ylabel(my_ylabel)
plt.yticks(np.arange(0, my_ytick_max, step=1))
plt.ylim(my_ymin, my_ymax)

if use_smm:
 ax.legend(prop={'size': 20})
 plt.savefig('loss-perc-hinge-log-smm.pdf', format='pdf')
 plt.savefig('loss-perc-hinge-log-smm.png', format='png')
else:
 ax.legend()
 plt.savefig('loss-perc-hinge-log.pdf', format='pdf')
 plt.savefig('loss-perc-hinge-log.png', format='png')
 

plt.show()

In [None]:
fig1, ax1 = plt.subplots()

ax1.plot(x_01, y_01, c='black', linewidth=my01width, label=zero_one_loss_label, alpha=0.6)
ax1.plot(x, y_log, c=mygreen, linewidth=mywidth, linestyle='dashdot', label=log_loss_label, alpha=transparency)
ax1.plot(x, y_exp, c=myorange, linewidth=mywidth, linestyle=(0,(4,1,4,1,1,1)), label=exp_loss_label, alpha=transparency)

ax1.legend()
plt.xlabel(my_xlabel)
plt.ylabel(my_ylabel)
plt.yticks(np.arange(0, my_ytick_max, step=1))
plt.ylim(my_ymin, my_ymax)
plt.savefig('loss-log-exp.pdf', format='pdf')
plt.savefig('loss-log-exp.png', format='png')
plt.show()

In [None]:
fig2, ax2 = plt.subplots()

ax2.plot(x_01, y_01, c='black', linewidth=my01width, label=zero_one_loss_label, alpha=0.6)
ax2.plot(x, y_hinge, c=myred, linewidth=mywidth, linestyle=(0,(1,1)), label=hinge_loss_label, alpha=transparency)
ax2.plot(x, y_smm, c=mypurple, linewidth=mywidth, linestyle=(0,(6,2)), label=softmax_margin_label, alpha=transparency)

ax2.legend()
plt.xlabel(my_xlabel)
plt.ylabel(my_ylabel)
plt.yticks(np.arange(0, my_ytick_max, step=1))
plt.ylim(my_ymin, my_ymax)
plt.savefig('loss-hinge-smm.pdf', format='pdf')
plt.savefig('loss-hinge-smm.png', format='png')
plt.show()

In [None]:
fig3, ax3 = plt.subplots()

ax3.plot(x_01, y_01, c='black', linewidth=my01width, label=zero_one_loss_label, alpha=0.6)
ax3.plot(x, y_risk, c=myblue, linewidth=mywidth, linestyle='dashed', label=risk_label, alpha=transparency)
ax3.plot(x, y_ramp1, c=myred, linewidth=mywidth, linestyle=(0,(4,1,4,1,1,1)), label=ramp_label, alpha=transparency)

#ax3.plot(x, y_ramp2, c=myorange, linewidth=4, linestyle=(0,(4,1,4,1,1,1)), label=r'ramp2: $\max(0,-\!z) - \max(0,-m-z)$', alpha=transparency)
#ax3.plot(x, y_sramp1, c=mybrown, linewidth=4, linestyle=(0,(4,1,4,1,1,1)), label=r'soft ramp: $\log\left(\frac{1 + \exp(m-z)}{1 + \exp(-z)}\right)$', alpha=transparency)
#ax3.plot(x, y_sramp2, c=mypurple, linewidth=4, linestyle=(0,(4,1,4,1,1,1)), label=r'sramp2: $\log\left(\frac{1 + \exp(-z)}{1 + \exp(-m-z)}\right)$', alpha=transparency)
#ax3.plot(x, y_savage, c=mygreen, linewidth=4, linestyle=(0,(4,1,4,1,1,1)), label=r'savage: $\frac{1}{(1+\exp(z))^2}$', alpha=transparency)

ax3.legend()
plt.xlabel(my_xlabel)
plt.ylabel(my_ylabel)
plt.yticks(np.arange(0, my_ytick_max, step=1))
plt.ylim(my_ymin, my_ymax)
plt.savefig('loss-risk-ramp.pdf', format='pdf')
plt.savefig('loss-risk-ramp.png', format='png')
plt.show()