{ "cells": [ { "cell_type": "markdown", "id": "15766a1d", "metadata": { "papermill": { "duration": 0.022668, "end_time": "2021-09-23T21:20:08.536683", "exception": false, "start_time": "2021-09-23T21:20:08.514015", "status": "completed" }, "tags": [] }, "source": [ "# Deep Learning Starter : Simple LSTM\n", "\n", "This notebook leverages the time series structure of the data.\n", "\n", "I expect sequential Deep Learning models to dominate in this competition, so here's a simple LSTM architecture.\n", "\n", "Parameters were not really tweaked so the baseline is easily improvable.\n", "\n", "Code is taken from previous work, some functions are documented but the doc may be outdated.\n", "\n", "\n", "**Don't fork without upvoting ^^**" ] }, { "cell_type": "code", "execution_count": 2, "id": "d20019e6", "metadata": { "_kg_hide-input": true, "execution": { "iopub.execute_input": "2021-09-23T21:20:08.584490Z", "iopub.status.busy": "2021-09-23T21:20:08.583029Z", "iopub.status.idle": "2021-09-23T21:20:09.357402Z", "shell.execute_reply": "2021-09-23T21:20:09.356832Z", "shell.execute_reply.started": "2021-09-23T15:48:46.478686Z" }, "papermill": { "duration": 0.799234, "end_time": "2021-09-23T21:20:09.357559", "exception": false, "start_time": "2021-09-23T21:20:08.558325", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import os\n", "import warnings\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "from tqdm.notebook import tqdm\n", "from collections import Counter\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "NUM_WORKERS = 16" ] }, { "cell_type": "markdown", "id": "b0315ed0", "metadata": { "papermill": { "duration": 0.021801, "end_time": "2021-09-23T21:20:09.402779", "exception": false, "start_time": "2021-09-23T21:20:09.380978", "status": "completed" }, "tags": [] }, "source": [ "## Data" ] }, { "cell_type": "markdown", "id": "7cb25d05", "metadata": { "papermill": { "duration": 0.021737, "end_time": "2021-09-23T21:20:09.446272", "exception": false, "start_time": "2021-09-23T21:20:09.424535", "status": "completed" }, "tags": [] }, "source": [ "### Load" ] }, { "cell_type": "code", "execution_count": 3, "id": "fd520830", "metadata": { "execution": { "iopub.execute_input": "2021-09-23T21:20:09.495043Z", "iopub.status.busy": "2021-09-23T21:20:09.494517Z", "iopub.status.idle": "2021-09-23T21:20:23.057883Z", "shell.execute_reply": "2021-09-23T21:20:23.058339Z", "shell.execute_reply.started": "2021-09-23T15:48:46.489787Z" }, "papermill": { "duration": 13.590445, "end_time": "2021-09-23T21:20:23.058540", "exception": false, "start_time": "2021-09-23T21:20:09.468095", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "DATA_PATH = \"../ventilator-pressure-prediction-data/\"\n", "\n", "sub = pd.read_csv(DATA_PATH + 'sample_submission.csv')\n", "df_train = pd.read_csv(DATA_PATH + 'train.csv')\n", "df_test = pd.read_csv(DATA_PATH + 'test.csv')\n", "\n", "\n", "df = df_train[df_train['breath_id'] < 5].reset_index(drop=True)" ] }, { "cell_type": "code", "execution_count": 4, "id": "fec7619a", "metadata": { "execution": { "iopub.execute_input": "2021-09-23T21:20:23.108914Z", "iopub.status.busy": "2021-09-23T21:20:23.108256Z", "iopub.status.idle": "2021-09-23T21:20:23.122181Z", "shell.execute_reply": "2021-09-23T21:20:23.122585Z", "shell.execute_reply.started": "2021-09-23T15:48:53.636373Z" }, "papermill": { "duration": 0.042028, "end_time": "2021-09-23T21:20:23.122704", "exception": false, "start_time": "2021-09-23T21:20:23.080676", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
| \n", " | id | \n", "breath_id | \n", "R | \n", "C | \n", "time_step | \n", "u_in | \n", "u_out | \n", "pressure | \n", "
|---|---|---|---|---|---|---|---|---|
| 0 | \n", "1 | \n", "1 | \n", "20 | \n", "50 | \n", "0.000000 | \n", "0.083334 | \n", "0 | \n", "5.837492 | \n", "
| 1 | \n", "2 | \n", "1 | \n", "20 | \n", "50 | \n", "0.033652 | \n", "18.383041 | \n", "0 | \n", "5.907794 | \n", "
| 2 | \n", "3 | \n", "1 | \n", "20 | \n", "50 | \n", "0.067514 | \n", "22.509278 | \n", "0 | \n", "7.876254 | \n", "
| 3 | \n", "4 | \n", "1 | \n", "20 | \n", "50 | \n", "0.101542 | \n", "22.808822 | \n", "0 | \n", "11.742872 | \n", "
| 4 | \n", "5 | \n", "1 | \n", "20 | \n", "50 | \n", "0.135756 | \n", "25.355850 | \n", "0 | \n", "12.234987 | \n", "