{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\nPruning Unpromising Trials\n==========================\n\nThis feature automatically stops unpromising trials at the early stages of the training (a.k.a., automated early-stopping).\nOptuna provides interfaces to concisely implement the pruning mechanism in iterative training algorithms.\n\n\nActivating Pruners\n------------------\nTo turn on the pruning feature, you need to call :func:`~optuna.trial.Trial.report` and :func:`~optuna.trial.Trial.should_prune` after each step of the iterative training.\n:func:`~optuna.trial.Trial.report` periodically monitors the intermediate objective values.\n:func:`~optuna.trial.Trial.should_prune` decides termination of the trial that does not meet a predefined condition.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import sklearn.datasets\nimport sklearn.linear_model\nimport sklearn.model_selection\n\nimport optuna\n\ndef objective(trial):\n    iris = sklearn.datasets.load_iris()\n    classes = list(set(iris.target))\n    train_x, valid_x, train_y, valid_y = \\\n        sklearn.model_selection.train_test_split(iris.data, iris.target, test_size=0.25, random_state=0)\n\n    alpha = trial.suggest_loguniform('alpha', 1e-5, 1e-1)\n    clf = sklearn.linear_model.SGDClassifier(alpha=alpha)\n\n    for step in range(100):\n        clf.partial_fit(train_x, train_y, classes=classes)\n\n        # Report intermediate objective value.\n        intermediate_value = 1.0 - clf.score(valid_x, valid_y)\n        trial.report(intermediate_value, step)\n\n        # Handle pruning based on the intermediate value.\n        if trial.should_prune():\n            raise optuna.TrialPruned()\n\n    return 1.0 - clf.score(valid_x, valid_y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Set up the median stopping rule as the pruning condition.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "study = optuna.create_study(pruner=optuna.pruners.MedianPruner())\nstudy.optimize(objective, n_trials=20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Executing the script above:\n\n.. code-block:: console\n\n         $ python prune.py\n         [I 2020-06-12 16:54:23,876] Trial 0 finished with value: 0.3157894736842105 and parameters: {'alpha': 0.00181467547181131}. Best is trial 0 with value: 0.3157894736842105.\n         [I 2020-06-12 16:54:23,981] Trial 1 finished with value: 0.07894736842105265 and parameters: {'alpha': 0.015378744419287613}. Best is trial 1 with value: 0.07894736842105265.\n         [I 2020-06-12 16:54:24,083] Trial 2 finished with value: 0.21052631578947367 and parameters: {'alpha': 0.04089428832878595}. Best is trial 1 with value: 0.07894736842105265.\n         [I 2020-06-12 16:54:24,185] Trial 3 finished with value: 0.052631578947368474 and parameters: {'alpha': 0.004018735937374473}. Best is trial 3 with value: 0.052631578947368474.\n         [I 2020-06-12 16:54:24,303] Trial 4 finished with value: 0.07894736842105265 and parameters: {'alpha': 2.805688697062864e-05}. Best is trial 3 with value: 0.052631578947368474.\n         [I 2020-06-12 16:54:24,315] Trial 5 pruned.\n         [I 2020-06-12 16:54:24,355] Trial 6 pruned.\n         [I 2020-06-12 16:54:24,511] Trial 7 finished with value: 0.052631578947368474 and parameters: {'alpha': 2.243775785299103e-05}. Best is trial 3 with value: 0.052631578947368474.\n         [I 2020-06-12 16:54:24,625] Trial 8 finished with value: 0.1842105263157895 and parameters: {'alpha': 0.007021209286214553}. Best is trial 3 with value: 0.052631578947368474.\n         [I 2020-06-12 16:54:24,629] Trial 9 pruned.\n         ...\n\n``Trial 5 pruned.``, etc. in the log messages means several trials were stopped\nbefore they finished all of the iterations.\n\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Integration Modules for Pruning\n-------------------------------\nTo implement pruning mechanism in much simpler forms, Optuna provides integration modules for the following libraries.\n\nFor the complete list of Optuna's integration modules, see `integration_list`.\n\nFor example, :class:`~optuna.integration.XGBoostPruningCallback` introduces pruning without directly changing the logic of training iteration.\n(See also `example <https://github.com/optuna/optuna/blob/master/examples/pruning/xgboost_integration.py>`_ for the entire script.)\n\n.. code-block:: python\n\n        pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'validation-error')\n        bst = xgb.train(param, dtrain, evals=[(dvalid, 'validation')], callbacks=[pruning_callback])\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}