{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "import warnings\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "from tqdm.notebook import tqdm\n", "from collections import Counter" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "df_train = pd.read_csv('../ventilator-pressure-prediction-data/train.csv')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idbreath_idRCtime_stepu_inu_outpressure
01120500.0000000.08333405.837492
12120500.03365218.38304105.907794
23120500.06751422.50927807.876254
34120500.10154222.808822011.742872
45120500.13575625.355850012.234987
\n", "
" ], "text/plain": [ " id breath_id R C time_step u_in u_out pressure\n", "0 1 1 20 50 0.000000 0.083334 0 5.837492\n", "1 2 1 20 50 0.033652 18.383041 0 5.907794\n", "2 3 1 20 50 0.067514 22.509278 0 7.876254\n", "3 4 1 20 50 0.101542 22.808822 0 11.742872\n", "4 5 1 20 50 0.135756 25.355850 0 12.234987" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.head()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
breath_ididRCtime_stepu_inu_outpressure
01[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...[20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 2...[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5...[0.0, 0.0336523056030273, 0.067514419555664, 0...[0.0833340056346443, 18.38304147263472, 22.509...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...[5.837491705069121, 5.907793850520346, 7.87625...
12[81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 9...[20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 2...[20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 2...[0.0, 0.0339975357055664, 0.0681509971618652, ...[12.184337517135212, 13.980205443281347, 12.57...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...[6.048398141422782, 7.524743195898315, 9.28229...
23[161, 162, 163, 164, 165, 166, 167, 168, 169, ...[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5...[20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 2...[0.0, 0.03196382522583, 0.0639522075653076, 0....[0.0, 7.18724187099842, 13.338780645925038, 17...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...[5.064168105105749, 5.064168105105749, 6.75141...
34[241, 242, 243, 244, 245, 246, 247, 248, 249, ...[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5...[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5...[0.0, 0.0318536758422851, 0.0637614727020263, ...[0.0, 1.2625385852839184, 4.001352088243387, 6...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...[6.259304577776412, 5.767189559617911, 6.54051...
45[321, 322, 323, 324, 325, 326, 327, 328, 329, ...[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5...[0.0, 0.0339670181274414, 0.0687971115112304, ...[21.424374842054064, 28.504653017718358, 29.35...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...[6.329606723227637, 7.384138904995879, 8.36836...
...........................
75445125740[6035601, 6035602, 6035603, 6035604, 6035605, ...[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5...[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5...[0.0, 0.0348799228668212, 0.0689036846160888, ...[0.0, 0.0, 0.6739133329743916, 1.7007626873808...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...[6.470211014130073, 5.837491705069121, 6.25930...
75446125742[6035681, 6035682, 6035683, 6035684, 6035685, ...[20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 2...[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1...[0.0, 0.0338180065155029, 0.0677299499511718, ...[85.63023054349601, 100.0, 81.56681300270868, ...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...[6.329606723227637, 11.391361195715188, 20.952...
75447125743[6035761, 6035762, 6035763, 6035764, 6035765, ...[20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 2...[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1...[0.0, 0.0336830615997314, 0.0673832893371582, ...[0.0, 0.0, 0.0, 0.9501772243146738, 3.25671236...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...[6.259304577776412, 5.907793850520346, 6.04839...
75448125745[6035841, 6035842, 6035843, 6035844, 6035845, ...[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5...[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5...[0.0, 0.031679630279541, 0.0633506774902343, 0...[15.564236227541224, 23.58883636738182, 20.298...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...[6.962326032288589, 8.790181814020203, 12.4458...
75449125749[6035921, 6035922, 6035923, 6035924, 6035925, ...[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5...[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1...[0.0, 0.0331871509552001, 0.0663647651672363, ...[6.030572044220927, 25.50419568083585, 21.6147...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...[3.939333777886297, 5.345376686910621, 9.56350...
\n", "

75450 rows × 8 columns

\n", "
" ], "text/plain": [ " breath_id id \\\n", "0 1 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14... \n", "1 2 [81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 9... \n", "2 3 [161, 162, 163, 164, 165, 166, 167, 168, 169, ... \n", "3 4 [241, 242, 243, 244, 245, 246, 247, 248, 249, ... \n", "4 5 [321, 322, 323, 324, 325, 326, 327, 328, 329, ... \n", "... ... ... \n", "75445 125740 [6035601, 6035602, 6035603, 6035604, 6035605, ... \n", "75446 125742 [6035681, 6035682, 6035683, 6035684, 6035685, ... \n", "75447 125743 [6035761, 6035762, 6035763, 6035764, 6035765, ... \n", "75448 125745 [6035841, 6035842, 6035843, 6035844, 6035845, ... \n", "75449 125749 [6035921, 6035922, 6035923, 6035924, 6035925, ... \n", "\n", " R \\\n", "0 [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 2... \n", "1 [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 2... \n", "2 [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5... \n", "3 [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5... \n", "4 [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ... \n", "... ... \n", "75445 [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5... \n", "75446 [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 2... \n", "75447 [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 2... \n", "75448 [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5... \n", "75449 [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5... \n", "\n", " C \\\n", "0 [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5... \n", "1 [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 2... \n", "2 [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 2... \n", "3 [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5... \n", "4 [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5... \n", "... ... \n", "75445 [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5... \n", "75446 [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1... \n", "75447 [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1... \n", "75448 [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5... \n", "75449 [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1... \n", "\n", " time_step \\\n", "0 [0.0, 0.0336523056030273, 0.067514419555664, 0... \n", "1 [0.0, 0.0339975357055664, 0.0681509971618652, ... \n", "2 [0.0, 0.03196382522583, 0.0639522075653076, 0.... \n", "3 [0.0, 0.0318536758422851, 0.0637614727020263, ... \n", "4 [0.0, 0.0339670181274414, 0.0687971115112304, ... \n", "... ... \n", "75445 [0.0, 0.0348799228668212, 0.0689036846160888, ... \n", "75446 [0.0, 0.0338180065155029, 0.0677299499511718, ... \n", "75447 [0.0, 0.0336830615997314, 0.0673832893371582, ... \n", "75448 [0.0, 0.031679630279541, 0.0633506774902343, 0... \n", "75449 [0.0, 0.0331871509552001, 0.0663647651672363, ... \n", "\n", " u_in \\\n", "0 [0.0833340056346443, 18.38304147263472, 22.509... \n", "1 [12.184337517135212, 13.980205443281347, 12.57... \n", "2 [0.0, 7.18724187099842, 13.338780645925038, 17... \n", "3 [0.0, 1.2625385852839184, 4.001352088243387, 6... \n", "4 [21.424374842054064, 28.504653017718358, 29.35... \n", "... ... \n", "75445 [0.0, 0.0, 0.6739133329743916, 1.7007626873808... \n", "75446 [85.63023054349601, 100.0, 81.56681300270868, ... \n", "75447 [0.0, 0.0, 0.0, 0.9501772243146738, 3.25671236... \n", "75448 [15.564236227541224, 23.58883636738182, 20.298... \n", "75449 [6.030572044220927, 25.50419568083585, 21.6147... \n", "\n", " u_out \\\n", "0 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "1 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "2 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "3 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "4 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "... ... \n", "75445 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "75446 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "75447 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "75448 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "75449 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "\n", " pressure \n", "0 [5.837491705069121, 5.907793850520346, 7.87625... \n", "1 [6.048398141422782, 7.524743195898315, 9.28229... \n", "2 [5.064168105105749, 5.064168105105749, 6.75141... \n", "3 [6.259304577776412, 5.767189559617911, 6.54051... \n", "4 [6.329606723227637, 7.384138904995879, 8.36836... \n", "... ... \n", "75445 [6.470211014130073, 5.837491705069121, 6.25930... \n", "75446 [6.329606723227637, 11.391361195715188, 20.952... \n", "75447 [6.259304577776412, 5.907793850520346, 6.04839... \n", "75448 [6.962326032288589, 8.790181814020203, 12.4458... \n", "75449 [3.939333777886297, 5.345376686910621, 9.56350... \n", "\n", "[75450 rows x 8 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.groupby('breath_id').agg(list).reset_index()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import Dataset\n", "\n", "class VentilatorDataset(Dataset):\n", " def __init__(self, df):\n", " if \"pressure\" not in df.columns:\n", " df['pressure'] = 0\n", "\n", " # makes lists of features per column for each breath\n", " self.df = df.groupby('breath_id').agg(list).reset_index()\n", " \n", " self.prepare_data()\n", " \n", " def __len__(self):\n", " return self.df.shape[0]\n", " \n", " def prepare_data(self):\n", " self.pressures = np.array(self.df['pressure'].values.tolist())\n", " \n", " rs = np.array(self.df['R'].values.tolist())\n", " cs = np.array(self.df['C'].values.tolist())\n", " u_ins = np.array(self.df['u_in'].values.tolist())\n", " \n", " self.u_outs = np.array(self.df['u_out'].values.tolist())\n", " \n", " self.inputs = np.concatenate([\n", " # makes nx1 - Matrix instaed of a n - vektor\n", " rs[:, None], \n", " cs[:, None], \n", " u_ins[:, None], \n", " np.cumsum(u_ins, 1)[:, None],\n", " self.u_outs[:, None]\n", " ], 1).transpose(0, 2, 1)\n", "\n", " def __getitem__(self, idx):\n", " data = {\n", " \"input\": torch.tensor(self.inputs[idx], dtype=torch.float),\n", " \"u_out\": torch.tensor(self.u_outs[idx], dtype=torch.float),\n", " \"p\": torch.tensor(self.pressures[idx], dtype=torch.float),\n", " }\n", " \n", " return data" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input': tensor([[2.0000e+01, 5.0000e+01, 8.3334e-02, 8.3334e-02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 1.8383e+01, 1.8466e+01, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.2509e+01, 4.0976e+01, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.2809e+01, 6.3784e+01, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.5356e+01, 8.9140e+01, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.7260e+01, 1.1640e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.7127e+01, 1.4353e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.6808e+01, 1.7034e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.7865e+01, 1.9820e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.8313e+01, 2.2651e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.6867e+01, 2.5338e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.6763e+01, 2.8014e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.7993e+01, 3.0814e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.6790e+01, 3.3493e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.5634e+01, 3.6056e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.6280e+01, 3.8684e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.4727e+01, 4.1157e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.3468e+01, 4.3503e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.3858e+01, 4.5889e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.1883e+01, 4.8078e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.0701e+01, 5.0148e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.0844e+01, 5.2232e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.0269e+01, 5.4259e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 1.9694e+01, 5.6228e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 1.8927e+01, 5.8121e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 1.8094e+01, 5.9931e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 1.7194e+01, 6.1650e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 1.6419e+01, 6.3292e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 1.5745e+01, 6.4866e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 1.4932e+01, 6.6359e+02, 0.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 0.0000e+00, 6.6359e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 7.7922e-01, 6.6437e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 1.4390e+00, 6.6581e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 1.9942e+00, 6.6781e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.4672e+00, 6.7027e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 2.8634e+00, 6.7314e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 3.1978e+00, 6.7634e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 3.4784e+00, 6.7981e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 3.7164e+00, 6.8353e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 3.9168e+00, 6.8745e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.0863e+00, 6.9153e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.2287e+00, 6.9576e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.3494e+00, 7.0011e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.4512e+00, 7.0456e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.5370e+00, 7.0910e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.6095e+00, 7.1371e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.6708e+00, 7.1838e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.7227e+00, 7.2310e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.7661e+00, 7.2787e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.8029e+00, 7.3267e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.8337e+00, 7.3751e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.8598e+00, 7.4237e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.8818e+00, 7.4725e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9004e+00, 7.5215e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9160e+00, 7.5706e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9292e+00, 7.6199e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9402e+00, 7.6693e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9496e+00, 7.7188e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9575e+00, 7.7684e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9641e+00, 7.8180e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9697e+00, 7.8677e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9745e+00, 7.9175e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9785e+00, 7.9673e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9818e+00, 8.0171e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9847e+00, 8.0669e+02, 1.0000e+00],\n", " [2.0000e+01, 5.0000e+01, 4.9871e+00, 8.1168e+02, 1.0000e+00]]),\n", " 'u_out': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.,\n", " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", " 1., 1., 1., 1., 1., 1., 1., 1.]),\n", " 'p': tensor([ 5.8375, 5.9078, 7.8763, 11.7429, 12.2350, 12.8677, 14.6956, 15.8907,\n", " 15.5392, 15.7501, 17.2967, 17.2264, 16.1719, 17.3670, 18.0701, 17.1561,\n", " 18.2810, 18.7731, 17.8592, 19.1246, 19.3355, 18.4919, 18.5622, 18.6325,\n", " 18.8434, 19.0543, 19.2652, 19.3355, 19.3355, 19.4761, 19.5464, 17.0155,\n", " 9.5635, 7.8763, 8.6496, 7.5950, 7.6653, 8.2981, 7.2435, 7.9466,\n", " 7.5950, 7.4544, 8.1575, 6.8217, 7.1732, 7.5247, 6.9623, 7.4544,\n", " 7.5950, 7.1732, 7.7356, 7.2435, 7.5950, 7.3841, 7.2435, 7.7356,\n", " 7.5247, 6.9623, 7.1029, 7.3138, 6.1187, 7.0326, 6.8217, 6.5405,\n", " 6.9623, 6.8217, 6.5405, 6.8217, 6.8217, 6.3999, 6.7514, 6.5405,\n", " 6.3999, 6.7514, 6.4702, 6.3999, 6.6108, 6.3296, 6.5405, 6.4702])}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = VentilatorDataset(df_train)\n", "dataset[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lets make the Model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "\n", "class RNNModel(nn.Module):\n", " def __init__(\n", " self,\n", " input_dim=4,\n", " lstm_dim=256,\n", " dense_dim=256,\n", " logit_dim=256,\n", " num_classes=1,\n", " ):\n", " super().__init__()\n", "\n", " # multilayer perceprton model (feedforward)\n", " self.mlp = nn.Sequential(\n", " nn.Linear(input_dim, dense_dim // 2),\n", " nn.ReLU(),\n", " nn.Linear(dense_dim // 2, dense_dim),\n", " nn.ReLU(),\n", " )\n", "\n", " self.lstm = nn.LSTM(dense_dim, lstm_dim, batch_first=True, bidirectional=True)\n", "\n", " self.logits = nn.Sequential(\n", " nn.Linear(lstm_dim * 2, logit_dim),\n", " nn.ReLU(),\n", " nn.Linear(logit_dim, num_classes),\n", " )\n", "\n", " def forward(self, x):\n", " features = self.mlp(x)\n", " features, _ = self.lstm(features)\n", " pred = self.logits(features)\n", " return pred" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Eine kluge Variante die Metrik der Challenge zu implementieren und direkt darauf zu optimieren" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def compute_metric(df, preds):\n", " \"\"\"\n", " Metric for the problem.\n", " \"\"\"\n", " \n", " # u_out = 1 is not used\n", " y = np.array(df['pressure'].values.tolist())\n", " w = 1 - np.array(df['u_out'].values.tolist())\n", " \n", " assert y.shape == preds.shape and w.shape == y.shape, (y.shape, preds.shape, w.shape)\n", " \n", " mae = w * np.abs(y - preds)\n", " mae = mae.sum() / w.sum()\n", " \n", " return mae\n", "\n", "\n", "class VentilatorLoss(nn.Module):\n", " \"\"\"\n", " Directly optimizes the competition metric\n", " \"\"\"\n", " def __call__(self, preds, y, u_out):\n", " w = 1 - u_out\n", " mae = w * (y - preds).abs()\n", " mae = mae.sum(-1) / w.sum(-1)\n", "\n", " return mae" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import gc\n", "import time\n", "import torch\n", "import numpy as np\n", "from torch.utils.data import DataLoader\n", "from transformers import get_linear_schedule_with_warmup\n", "\n", "\n", "def fit(\n", " model,\n", " train_dataset,\n", " val_dataset,\n", " loss_name=\"L1Loss\",\n", " optimizer=\"Adam\",\n", " epochs=50,\n", " batch_size=32,\n", " val_bs=32,\n", " warmup_prop=0.1,\n", " lr=1e-3,\n", " num_classes=1,\n", " verbose=1,\n", " first_epoch_eval=0,\n", " #device=\"cuda\"\n", "):\n", " avg_val_loss = 0.\n", "\n", " # Optimizer\n", " optimizer = getattr(torch.optim, optimizer)(model.parameters(), lr=lr)\n", "\n", " # Data loaders\n", " train_loader = DataLoader(\n", " train_dataset,\n", " batch_size=batch_size,\n", " shuffle=True,\n", " drop_last=True,\n", " #num_workers=NUM_WORKERS,\n", " pin_memory=True,\n", " #worker_init_fn=worker_init_fn\n", " )\n", "\n", " val_loader = DataLoader(\n", " val_dataset,\n", " batch_size=val_bs,\n", " shuffle=False,\n", " #num_workers=NUM_WORKERS,\n", " pin_memory=True,\n", " )\n", "\n", " # own lossfunction for challenge\n", " loss_fct = VentilatorLoss()\n", "\n", " # Scheduler, to reduce learning-rate in epochs\n", " num_warmup_steps = int(warmup_prop * epochs * len(train_loader))\n", " num_training_steps = int(epochs * len(train_loader))\n", " scheduler = get_linear_schedule_with_warmup(\n", " optimizer, num_warmup_steps, num_training_steps\n", " )\n", "\n", " for epoch in range(epochs):\n", " model.train()\n", "\n", " # to clear bevor each instance\n", " model.zero_grad()\n", " start_time = time.time()\n", "\n", " avg_loss = 0\n", " for data in train_loader:\n", " pred = model(data['input']).squeeze(-1)\n", "\n", " loss = loss_fct(\n", " pred,\n", " data['p'],\n", " data['u_out'],\n", " ).mean()\n", " loss.backward()\n", " avg_loss += loss.item() / len(train_loader)\n", "\n", " optimizer.step()\n", " scheduler.step()\n", "\n", " for param in model.parameters():\n", " param.grad = None\n", "\n", " model.eval()\n", " mae, avg_val_loss = 0, 0\n", " preds = []\n", "\n", " # does not need to be trained\n", " with torch.no_grad():\n", " for data in val_loader:\n", " pred = model(data['input']).squeeze(-1)\n", "\n", " loss = loss_fct(\n", " pred.detach(), \n", " data['p'],\n", " data['u_out'],\n", " ).mean()\n", " avg_val_loss += loss.item() / len(val_loader)\n", "\n", " preds.append(pred.detach().cpu().numpy())\n", " \n", " preds = np.concatenate(preds, 0)\n", " mae = compute_metric(val_dataset.df, preds)\n", "\n", " elapsed_time = time.time() - start_time\n", " if (epoch + 1) % verbose == 0:\n", " elapsed_time = elapsed_time * verbose\n", " lr = scheduler.get_last_lr()[0]\n", " print(\n", " f\"Epoch {epoch + 1:02d}/{epochs:02d} \\t lr={lr:.1e}\\t t={elapsed_time:.0f}s \\t\"\n", " f\"loss={avg_loss:.3f}\",\n", " end=\"\\t\",\n", " )\n", "\n", " if (epoch + 1 >= first_epoch_eval) or (epoch + 1 == epochs):\n", " print(f\"val_loss={avg_val_loss:.3f}\\tmae={mae:.3f}\")\n", " else:\n", " print(\"\")\n", "\n", " del (val_loader, train_loader, loss, data, pred)\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", "\n", " return preds" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def predict(\n", " model,\n", " dataset,\n", " batch_size=64,\n", " #device=\"cuda\"\n", "):\n", " \"\"\"\n", " Usual torch predict function. Supports sigmoid and softmax activations.\n", " Args:\n", " model (torch model): Model to predict with.\n", " dataset (PathologyDataset): Dataset to predict on.\n", " batch_size (int, optional): Batch size. Defaults to 64.\n", " device (str, optional): Device for torch. Defaults to \"cuda\".\n", "\n", " Returns:\n", " numpy array [len(dataset) x num_classes]: Predictions.\n", " \"\"\"\n", " model.eval()\n", "\n", " loader = DataLoader(\n", " dataset, batch_size=batch_size, shuffle=False\n", " )\n", " \n", " preds = []\n", " with torch.no_grad():\n", " for data in loader:\n", " pred = model(data['input']).squeeze(-1)\n", " preds.append(pred.detach().cpu().numpy())\n", "\n", " preds = np.concatenate(preds, 0)\n", " return preds" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def train(config, df_train, df_val, df_test, fold):\n", " \"\"\"\n", " Trains and validate a model.\n", "\n", " Args:\n", " config (Config): Parameters.\n", " df_train (pandas dataframe): Training metadata.\n", " df_val (pandas dataframe): Validation metadata.\n", " df_test (pandas dataframe): Test metadata.\n", " fold (int): Selected fold.\n", "\n", " Returns:\n", " np array: Study validation predictions.\n", " \"\"\"\n", "\n", " #seed_everything(config.seed)\n", "\n", " model = RNNModel(\n", " input_dim=config.input_dim,\n", " lstm_dim=config.lstm_dim,\n", " dense_dim=config.dense_dim,\n", " logit_dim=config.logit_dim,\n", " num_classes=config.num_classes,\n", " )\n", " model.zero_grad()\n", "\n", " train_dataset = VentilatorDataset(df_train)\n", " val_dataset = VentilatorDataset(df_val)\n", " test_dataset = VentilatorDataset(df_test)\n", "\n", " #n_parameters = count_parameters(model)\n", "\n", " print(f\" -> {len(train_dataset)} training breathes\")\n", " print(f\" -> {len(val_dataset)} validation breathes\")\n", " #print(f\" -> {n_parameters} trainable parameters\\n\")\n", "\n", " pred_val = fit(\n", " model,\n", " train_dataset,\n", " val_dataset,\n", " loss_name=config.loss,\n", " optimizer=config.optimizer,\n", " epochs=config.epochs,\n", " batch_size=config.batch_size,\n", " val_bs=config.val_bs,\n", " lr=config.lr,\n", " warmup_prop=config.warmup_prop,\n", " verbose=config.verbose,\n", " first_epoch_eval=config.first_epoch_eval,\n", " #device=config.device,\n", " )\n", " \n", " pred_test = predict(\n", " model, \n", " test_dataset, \n", " batch_size=config.val_bs, \n", " #device=config.device\n", " )\n", "\n", " # if config.save_weights:\n", " # save_model_weights(\n", " # model,\n", " # f\"{config.selected_model}_{fold}.pt\",\n", " # cp_folder=\"\",\n", " # )\n", "\n", " del (model, train_dataset, val_dataset, test_dataset)\n", " gc.collect()\n", " #torch.cuda.empty_cache()\n", "\n", " return pred_val, pred_test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "k-fold cross validation" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import GroupKFold\n", "\n", "def k_fold(config, df, df_test):\n", " \"\"\"\n", " Performs a patient grouped k-fold cross validation.\n", " \"\"\"\n", "\n", " pred_oof = np.zeros(len(df))\n", " preds_test = []\n", " \n", " gkf = GroupKFold(n_splits=config.k)\n", " splits = list(gkf.split(X=df, y=df, groups=df[\"breath_id\"]))\n", "\n", " for i, (train_idx, val_idx) in enumerate(splits):\n", " if i in config.selected_folds:\n", " print(f\"\\n------------- Fold {i + 1} / {config.k} -------------\\n\")\n", "\n", " df_train = df.iloc[train_idx].copy().reset_index(drop=True)\n", " df_val = df.iloc[val_idx].copy().reset_index(drop=True)\n", "\n", " pred_val, pred_test = train(config, df_train, df_val, df_test, i)\n", " \n", " pred_oof[val_idx] = pred_val.flatten()\n", " preds_test.append(pred_test.flatten())\n", "\n", " print(f'\\n -> CV MAE : {compute_metric(df, pred_oof) :.3f}')\n", "\n", " return pred_oof, np.mean(preds_test, 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start the fun." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "class Config:\n", " \"\"\"\n", " Parameters used for training\n", " \"\"\"\n", " # General\n", " seed = 42\n", " verbose = 1\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " #save_weights = True\n", "\n", " # k-fold\n", " k = 5\n", " selected_folds = [0, 1, 2, 3, 4]\n", " \n", " # Model\n", " selected_model = 'rnn'\n", " input_dim = 5\n", "\n", " dense_dim = 512\n", " lstm_dim = 512\n", " logit_dim = 512\n", " num_classes = 1\n", "\n", " # Training\n", " loss = \"L1Loss\" # not used\n", " optimizer = \"Adam\"\n", " batch_size = 128\n", " epochs = 50\n", "\n", " lr = 1e-3\n", " warmup_prop = 0\n", "\n", " val_bs = 256\n", " first_epoch_eval = 0" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "df_test = pd.read_csv('../ventilator-pressure-prediction-data/test.csv')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "------------- Fold 1 / 5 -------------\n", "\n", " -> 60360 training breathes\n", " -> 15090 validation breathes\n" ] }, { "ename": "TypeError", "evalue": "fit() got an unexpected keyword argument 'loss_name'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m pred_oof, pred_test = k_fold(\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mConfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mdf_train\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mdf_test\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m )\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mk_fold\u001b[0;34m(config, df, df_test)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mdf_val\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mval_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdrop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0mpred_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpred_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdf_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdf_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdf_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0mpred_oof\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mval_idx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpred_val\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(config, df_train, df_val, df_test, fold)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;31m#print(f\" -> {n_parameters} trainable parameters\\n\")\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m pred_val = fit(\n\u001b[0m\u001b[1;32m 38\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mtrain_dataset\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mTypeError\u001b[0m: fit() got an unexpected keyword argument 'loss_name'" ] } ], "source": [ "pred_oof, pred_test = k_fold(\n", " Config, \n", " df_train,\n", " df_test,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" }, "kernelspec": { "display_name": "Python 3.9.10 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.10" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }