{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fit Model\n", "This notebook fits a 3-state classification model on a training set and calculates metrics on a test set." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "**Rule 9: Design Your Notebooks to Be Read, Run, and Explored.** We use ipywidgets to present the user with a pull-down menu to select a machine learning model.\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import mlutils\n", "from sklearn import svm, metrics\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.neural_network import MLPClassifier\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.externals import joblib\n", "from ipywidgets import widgets" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# column names\n", "feature_col = \"features\" # feature vector\n", "value_col = \"foldClass\" # fold class to be predicted" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read data set with fold type classifications and feature vectors" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "df = pd.read_json(\"./intermediate_data/features.json\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of data: 5370 \n", "\n" ] }, { "data": { "text/html": [ "
\n", " | Exptl. | \n", "FreeRvalue | \n", "R-factor | \n", "alpha | \n", "beta | \n", "coil | \n", "features | \n", "foldClass | \n", "length | \n", "ngram | \n", "pdbChainId | \n", "resolution | \n", "secondary_structure | \n", "sequence | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | \n", "XRAY | \n", "0.26 | \n", "0.19 | \n", "0.469945 | \n", "0.046448 | \n", "0.483607 | \n", "[-2.6183412084, -0.37215537190000003, 0.140630... | \n", "alpha | \n", "366 | \n", "[SRM, RMP, MPS, PSP, SPP, PPM, PMP, MPV, PVP, ... | \n", "16VP.A | \n", "2.1 | \n", "CCSCCCCCCCCHHHHHHHHHHHHTCTTHHHHHHHHHHCCCCCSTTS... | \n", "SRMPSPPMPVPPAALFNRLLDDLGFSAGPALCTMLDTWNEDLFSAL... | \n", "
1000 | \n", "XRAY | \n", "0.23 | \n", "0.18 | \n", "0.504630 | \n", "0.004630 | \n", "0.490741 | \n", "[-2.4130836608, -0.5122827316, 0.1969318015, -... | \n", "alpha | \n", "216 | \n", "[MEA, EAD, ADV, DVE, VEQ, EQQ, QQA, QAL, ALT, ... | \n", "1PBW.B | \n", "2.0 | \n", "CCCCCCCCCCCCCCHHHHCCTTSCSCHHHHHHHHHHHHHHTTCTTT... | \n", "MEADVEQQALTLPDLAEQFAPPDIAPPLLIKLVEAIEKKGLECSTL... | \n", "