{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook plots loss functions for binary classification.\n", "\n", "Kevin Gimpel\n", "\n", "6/26/2020" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "matplotlib.rcParams['text.usetex'] = 'False'\n", "# Use Computer Modern as the math font\n", "matplotlib.rcParams['mathtext.fontset'] = 'cm'\n", "# Use Palatino as the text font (by default, setting Palatino as the text font causes Palatino to be the math font too)\n", "matplotlib.rcParams['font.family'] = 'Palatino'\n", "matplotlib.rcParams['pdf.fonttype'] = 42" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# Set the figure size and font sizes\n", "global_params = {'legend.fontsize': 22,\n", " 'figure.figsize': (12, 8),\n", " 'axes.labelsize': 24,\n", " 'axes.titlesize': 24,\n", " 'xtick.labelsize': 20,\n", " 'ytick.labelsize': 20,\n", " 'text.latex.preamble' : [r'\\usepackage{amsmath,amsthm,amssymb}']}\n", "plt.rcParams.update(global_params)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# variables for all plots:\n", "range_min = -4\n", "range_max = 4\n", "transparency = 1\n", "my01width = 3\n", "mywidth = 4\n", "stepsize = 0.1\n", "# set the value of m in the labels below, i.e., a constant representing the cost of an incorrect classification\n", "m = 1.0\n", "\n", "myred = '#FF2B00'\n", "mygreen = '#49BA00'\n", "myblue = '#06B2F4'\n", "mypurple = '#D973EF'\n", "myorange = '#FA9D00'\n", "mybrown = '#CE9716'\n", "\n", "zero_one_loss_label = r'0-1: $m\\,I\\,[z\\leq 0]$'\n", "perceptron_loss_label = r'perceptron: $\\max(0, -\\!z)$'\n", "hinge_loss_label = r'hinge: $\\max(0,m-z)$'\n", "log_loss_label = r'log: $\\log(1+\\exp(-z))$'\n", "exp_loss_label = r'exp: $\\exp(-z)$'\n", "softmax_margin_label = r'softmax-margin: $\\log(1+\\exp(m-z))$'\n", "risk_label = r'risk: $\\frac{m\\,\\exp(-z)}{1+\\exp(-z)}$'\n", "ramp_label = r'ramp: $\\max(0,m-z) - \\max(0,-\\!z)$'\n", "my_xlabel = r'$z = y^\\ast\\!\\!\\times f(x)$'\n", "my_ylabel = r'Loss'\n", "\n", "my_ymin = -0.1\n", "my_ymax = 4.1\n", "my_ytick_max = 5\n", "\n", "# compute losses\n", "\n", "# 0-1 loss (actually it's \"0-m\" loss):\n", "x_01 = [range_min, 0, 0, range_max]\n", "y_01 = [m, m, 0, 0]\n", "\n", "# set the x range\n", "x = np.arange(range_min, range_max + stepsize, stepsize)\n", "# the perceptron loss:\n", "y_perc = np.maximum(0,-x)\n", "# hinge loss:\n", "y_hinge = np.maximum(0,m-x)\n", "# log loss (note: some people define log loss using base 2 which will make it intersect with the point (0,m)):\n", "y_log = np.log(1 + np.exp(-x))\n", "# exp loss:\n", "y_exp = np.exp(-x)\n", "# softmax-margin:\n", "y_smm = np.log(1 + np.exp(m-x))\n", "# Bayes risk:\n", "y_risk = (m*np.exp(-x)) / (1 + np.exp(-x))\n", "# ramp loss:\n", "y_ramp1 = y_hinge - y_perc\n", "\n", "# other losses that are not currently being plotted in the cells below:\n", "# second form of ramp loss:\n", "#y_lazyhinge = np.maximum(0,-m-x)\n", "#y_ramp2 = y_perc - y_lazyhinge\n", "# soft version of ramp1:\n", "#y_sramp1 = np.log((1 + np.exp(m-x))/(1 + np.exp(-x)))\n", "# soft version of ramp2:\n", "#y_sramp2 = np.log((1 + np.exp(-x))/(1 + np.exp(-m-x)))\n", "# Savage loss:\n", "#y_savage = np.power(np.square(1 + np.exp(x)), -1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "\n", "use_smm = False\n", "\n", "ax.plot(x_01, y_01, c='black', linewidth=my01width, label=zero_one_loss_label, alpha=0.6)\n", "ax.plot(x, y_perc, c=myblue, linewidth=mywidth, linestyle='dashed', label=perceptron_loss_label, alpha=transparency)\n", "ax.plot(x, y_hinge, c=myred, linewidth=mywidth, linestyle=(0,(1,1)), label=hinge_loss_label, alpha=transparency)\n", "ax.plot(x, y_log, c=mygreen, linewidth=mywidth, linestyle='dashdot', label=log_loss_label, alpha=transparency)\n", "\n", "if use_smm:\n", " ax.plot(x, y_smm, c=mypurple, linewidth=mywidth, linestyle=(0,(6,2)), label=softmax_margin_label, alpha=transparency)\n", "\n", "plt.xlabel(my_xlabel)\n", "plt.ylabel(my_ylabel)\n", "plt.yticks(np.arange(0, my_ytick_max, step=1))\n", "plt.ylim(my_ymin, my_ymax)\n", "\n", "if use_smm:\n", " ax.legend(prop={'size': 20})\n", " plt.savefig('loss-perc-hinge-log-smm.pdf', format='pdf')\n", " plt.savefig('loss-perc-hinge-log-smm.png', format='png')\n", "else:\n", " ax.legend()\n", " plt.savefig('loss-perc-hinge-log.pdf', format='pdf')\n", " plt.savefig('loss-perc-hinge-log.png', format='png')\n", " \n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig1, ax1 = plt.subplots()\n", "\n", "ax1.plot(x_01, y_01, c='black', linewidth=my01width, label=zero_one_loss_label, alpha=0.6)\n", "ax1.plot(x, y_log, c=mygreen, linewidth=mywidth, linestyle='dashdot', label=log_loss_label, alpha=transparency)\n", "ax1.plot(x, y_exp, c=myorange, linewidth=mywidth, linestyle=(0,(4,1,4,1,1,1)), label=exp_loss_label, alpha=transparency)\n", "\n", "ax1.legend()\n", "plt.xlabel(my_xlabel)\n", "plt.ylabel(my_ylabel)\n", "plt.yticks(np.arange(0, my_ytick_max, step=1))\n", "plt.ylim(my_ymin, my_ymax)\n", "plt.savefig('loss-log-exp.pdf', format='pdf')\n", "plt.savefig('loss-log-exp.png', format='png')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig2, ax2 = plt.subplots()\n", "\n", "ax2.plot(x_01, y_01, c='black', linewidth=my01width, label=zero_one_loss_label, alpha=0.6)\n", "ax2.plot(x, y_hinge, c=myred, linewidth=mywidth, linestyle=(0,(1,1)), label=hinge_loss_label, alpha=transparency)\n", "ax2.plot(x, y_smm, c=mypurple, linewidth=mywidth, linestyle=(0,(6,2)), label=softmax_margin_label, alpha=transparency)\n", "\n", "ax2.legend()\n", "plt.xlabel(my_xlabel)\n", "plt.ylabel(my_ylabel)\n", "plt.yticks(np.arange(0, my_ytick_max, step=1))\n", "plt.ylim(my_ymin, my_ymax)\n", "plt.savefig('loss-hinge-smm.pdf', format='pdf')\n", "plt.savefig('loss-hinge-smm.png', format='png')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig3, ax3 = plt.subplots()\n", "\n", "ax3.plot(x_01, y_01, c='black', linewidth=my01width, label=zero_one_loss_label, alpha=0.6)\n", "ax3.plot(x, y_risk, c=myblue, linewidth=mywidth, linestyle='dashed', label=risk_label, alpha=transparency)\n", "ax3.plot(x, y_ramp1, c=myred, linewidth=mywidth, linestyle=(0,(4,1,4,1,1,1)), label=ramp_label, alpha=transparency)\n", "\n", "#ax3.plot(x, y_ramp2, c=myorange, linewidth=4, linestyle=(0,(4,1,4,1,1,1)), label=r'ramp2: $\\max(0,-\\!z) - \\max(0,-m-z)$', alpha=transparency)\n", "#ax3.plot(x, y_sramp1, c=mybrown, linewidth=4, linestyle=(0,(4,1,4,1,1,1)), label=r'soft ramp: $\\log\\left(\\frac{1 + \\exp(m-z)}{1 + \\exp(-z)}\\right)$', alpha=transparency)\n", "#ax3.plot(x, y_sramp2, c=mypurple, linewidth=4, linestyle=(0,(4,1,4,1,1,1)), label=r'sramp2: $\\log\\left(\\frac{1 + \\exp(-z)}{1 + \\exp(-m-z)}\\right)$', alpha=transparency)\n", "#ax3.plot(x, y_savage, c=mygreen, linewidth=4, linestyle=(0,(4,1,4,1,1,1)), label=r'savage: $\\frac{1}{(1+\\exp(z))^2}$', alpha=transparency)\n", "\n", "ax3.legend()\n", "plt.xlabel(my_xlabel)\n", "plt.ylabel(my_ylabel)\n", "plt.yticks(np.arange(0, my_ytick_max, step=1))\n", "plt.ylim(my_ymin, my_ymax)\n", "plt.savefig('loss-risk-ramp.pdf', format='pdf')\n", "plt.savefig('loss-risk-ramp.png', format='png')\n", "plt.show()" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python [conda env:anaconda]", "language": "python", "name": "conda-env-anaconda-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.6" } }, "nbformat": 4, "nbformat_minor": 2 }