diff --git a/models/training-tuning-scripts/fraud-detection-models/gnn-fraud-detection-training.ipynb b/models/training-tuning-scripts/fraud-detection-models/gnn-fraud-detection-training.ipynb index d66234974d..7decd59636 100644 --- a/models/training-tuning-scripts/fraud-detection-models/gnn-fraud-detection-training.ipynb +++ b/models/training-tuning-scripts/fraud-detection-models/gnn-fraud-detection-training.ipynb @@ -50,16 +50,16 @@ "source": [ "%load_ext autoreload\n", "%autoreload 2\n", - "import pandas as pd\n", - "import numpy as np\n", "import os\n", + "\n", "import dgl\n", + "import matplotlib.pylab as plt\n", "import numpy as np\n", - "import pandas as pd\n", "import torch\n", "import torch.nn as nn\n", "from model import HeteroRGCN\n", "from model import HinSAGE\n", + "from model import prepare_data\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.metrics import auc\n", "from sklearn.metrics import average_precision_score\n", @@ -68,9 +68,15 @@ "from sklearn.metrics import roc_curve\n", "from torchmetrics.functional import accuracy\n", "from tqdm import trange\n", + "from training import build_fsi_graph\n", + "from training import evaluate\n", + "from training import get_metrics\n", + "from training import init_loaders\n", + "from training import save_model\n", + "from training import train\n", "from xgboost import XGBClassifier\n", - "from training import (get_metrics, evaluate, init_loaders, build_fsi_graph,\n", - " map_node_id, prepare_data, save_model, train)\n" + "\n", + "import cudf" ] }, { @@ -85,26 +91,6 @@ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" ] }, - { - "cell_type": "code", - "execution_count": 73, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "device(type='cuda', index=0)" - ] - }, - "execution_count": 73, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#device " - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -122,8 +108,8 @@ "# Replace training-data.csv and validation-data.csv with training & validation csv in dataset file.\n", "TRAINING_DATA ='../../datasets/training-data/fraud-detection-training-data.csv'\n", "VALIDATION_DATA = '../../datasets/validation-data/fraud-detection-validation-data.csv'\n", - "train_data = pd.read_csv(TRAINING_DATA)\n", - "inductive_data = pd.read_csv(VALIDATION_DATA)" + "train_data = cudf.read_csv(TRAINING_DATA)\n", + "inductive_data = cudf.read_csv(VALIDATION_DATA)" ] }, { @@ -141,16 +127,15 @@ "outputs": [], "source": [ "# Increase number of samples.\n", - "def augement_data(train_data=train_data, n=20):\n", - " max_id = inductive_data.index.max()\n", + "def augment_data(train_data=train_data, n=20):\n", + " train_data.drop(columns=['index'], inplace=True, axis=1)\n", " non_fraud = train_data[train_data['fraud_label'] == 0]\n", - " \n", - " non_fraud = non_fraud.drop(['index'], axis=1)\n", - " df_fraud = pd.concat([non_fraud for i in range(n)])\n", - " df_fraud.index = np.arange(1076, 1076 + df_fraud.shape[0])\n", - " df_fraud['index'] = df_fraud.index\n", - " \n", - " return pd.concat((train_data, df_fraud))" + " df_fraud = cudf.concat([non_fraud for _ in range(n)])\n", + " df_train = cudf.concat([train_data, df_fraud])\n", + " df_train.reset_index(inplace=True)\n", + " df_train['index'] = df_train.index\n", + "\n", + " return df_train" ] }, { @@ -159,7 +144,19 @@ "metadata": {}, "outputs": [], "source": [ - "train_data = augement_data(train_data, n=20)" + "train_data = augment_data(train_data, n=20)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# re-arange test data index\n", + "last_train_index = train_data.index.max()+1\n", + "inductive_data.index = np.arange(last_train_index, last_train_index + inductive_data.shape[0])\n", + "inductive_data['index'] = inductive_data.index" ] }, { @@ -173,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -183,11 +180,11 @@ "The distribution of fraud for the train data is:\n", " 0 11865\n", "1 188\n", - "Name: fraud_label, dtype: int64\n", + "Name: fraud_label, dtype: int32\n", "The distribution of fraud for the inductive data is:\n", " 0 244\n", "1 21\n", - "Name: fraud_label, dtype: int64\n" + "Name: fraud_label, dtype: int32\n" ] } ], @@ -196,38 +193,13 @@ "print('The distribution of fraud for the inductive data is:\\n', inductive_data['fraud_label'].value_counts())" ] }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# split train, test and create nodes index\n", - "def prepare_data(df_train, df_test):\n", - " \n", - " train_idx_ = df_train.shape[0]\n", - " df = pd.concat([df_train, df_test], axis=0)\n", - " df['tran_id'] = df['index']\n", - "\n", - " meta_cols = ['tran_id', 'client_node', 'merchant_node']\n", - " for col in meta_cols:\n", - " map_node_id(df, col)\n", - "\n", - " train_idx = df['tran_id'][:train_idx_]\n", - " test_idx = df['tran_id'][train_idx_:]\n", - "\n", - " df['index'] = df['tran_id']\n", - " df.index = df['index']\n", - "\n", - " return (df.iloc[train_idx, :], df.iloc[test_idx, :], train_idx, test_idx, df['fraud_label'].values, df)" - ] - }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ + "# Split into training, testing datasets\n", "train_data, test_data, train_idx, inductive_idx, labels, df = prepare_data(train_data, inductive_data)" ] }, @@ -236,7 +208,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### 3. Construct transasction graph network" + "### 3. Construct transaction graph network" ] }, { @@ -253,45 +225,17 @@ "metadata": {}, "outputs": [], "source": [ - "meta_cols = [\"client_node\", \"merchant_node\", \"fraud_label\", \"index\", \"tran_id\"]\n", + "meta_cols = [\"client_node\", \"merchant_node\", \"index\"]\n", "\n", "# Build graph\n", "whole_graph, feature_tensors = build_fsi_graph(df, meta_cols)\n", "train_graph, _ = build_fsi_graph(train_data, meta_cols)\n", - "whole_graph = whole_graph.to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "# Dataset to tensors\n", - "feature_tensors = feature_tensors.to(device)\n", - "train_idx = torch.from_numpy(train_idx.values).to(device)\n", - "inductive_idx = torch.from_numpy(inductive_idx.values).to(device)\n", - "labels = torch.LongTensor(labels).to(device)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Graph(num_nodes={'client': 623, 'merchant': 388, 'transaction': 12053},\n", - " num_edges={('client', 'buy', 'transaction'): 12053, ('merchant', 'sell', 'transaction'): 12053, ('transaction', 'bought', 'client'): 12053, ('transaction', 'issued', 'merchant'): 12053},\n", - " metagraph=[('client', 'transaction', 'buy'), ('transaction', 'client', 'bought'), ('transaction', 'merchant', 'issued'), ('merchant', 'transaction', 'sell')])\n" - ] - } - ], - "source": [ - "# Show structure of training graph.\n", - "print(train_graph)" + "\n", + "# Dataset\n", + "feature_tensors = feature_tensors.float()\n", + "train_idx = torch.from_dlpack(train_idx.values.toDlpack()).long()\n", + "inductive_idx = torch.from_dlpack(inductive_idx.values.toDlpack()).long()\n", + "labels = torch.from_dlpack(labels.toDlpack()).long()" ] }, { @@ -312,31 +256,34 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Hyperparameters\n", "target_node = \"transaction\"\n", - "epochs = 20\n", + "epochs = 25\n", "in_size, hidden_size, out_size, n_layers,\\\n", " embedding_size = 111, 64, 2, 2, 1\n", - "batch_size = 100\n", - "hyperparameters = {\"in_size\": in_size, \"hidden_size\": hidden_size,\n", - " \"out_size\": out_size, \"n_layers\": n_layers,\n", - " \"embedding_size\": embedding_size,\n", - " \"target_node\": target_node,\n", - " \"epoch\": epochs}\n", - "\n", + "batch_size = 256\n", + "in_size, hidden_size, out_size, n_layers, embedding_size = 111, 64, 2, 2, 1\n", + "hyperparameters = {\n", + " \"in_size\": in_size,\n", + " \"hidden_size\": hidden_size,\n", + " \"out_size\": out_size,\n", + " \"n_layers\": n_layers,\n", + " \"embedding_size\": embedding_size,\n", + " \"target_node\": target_node,\n", + " \"epoch\": epochs\n", + "}\n", "\n", - "scale_pos_weight = train_data['fraud_label'].sum() / train_data.shape[0]\n", - "scale_pos_weight = torch.tensor(\n", - " [scale_pos_weight, 1-scale_pos_weight]).to(device)" + "scale_pos_weight = (labels[train_idx].sum() / train_data.shape[0]).item()\n", + "scale_pos_weight = torch.FloatTensor([scale_pos_weight, 1 - scale_pos_weight]).to(device)" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -345,7 +292,6 @@ " device), train_idx, test_idx=inductive_idx,\n", " val_idx=inductive_idx, g_test=whole_graph, batch_size=batch_size)\n", "\n", - "\n", "# Set model variables\n", "model = HinSAGE(train_graph, in_size, hidden_size, out_size, n_layers, embedding_size).to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n", @@ -354,314 +300,384 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/20 [00:00, ?it/s]" + " 0%| | 0/25 [00:00, ?it/s]/home/efajardo/miniconda3/envs/morpheus/lib/python3.10/site-packages/dgl/backend/pytorch/tensor.py:445: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " assert input.numel() == input.storage().size(), (\n", + " 4%|▍ | 1/25 [00:01<00:31, 1.30s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0/25 | Train Accuracy: 1.0 | Train Loss: 7.589520640056491\n", + "Validation Accuracy: 0.9207547307014465 auc 0.20833333333333334\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 8%|▊ | 2/25 [00:02<00:26, 1.15s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/25 | Train Accuracy: 1.0 | Train Loss: 118.13089790023514\n", + "Validation Accuracy: 0.9207547307014465 auc 0.7851288056206087\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 5%|▌ | 1/20 [00:02<00:47, 2.51s/it]" + " 12%|█▏ | 3/25 [00:03<00:24, 1.10s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0/20 | Train Accuracy: 1.0 | Train Loss: 4.077046836914391\n", - "Validation Accuracy: 0.9207547307014465 auc 0.13992974238875877\n" + "Epoch 2/25 | Train Accuracy: 1.0 | Train Loss: 38.45385842246469\n", + "Validation Accuracy: 0.9132075309753418 auc 0.8367486338797815\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 10%|█ | 2/20 [00:04<00:43, 2.40s/it]" + " 16%|█▌ | 4/25 [00:04<00:22, 1.09s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1/20 | Train Accuracy: 1.0 | Train Loss: 110.9858230000423\n", - "Validation Accuracy: 0.9207547307014465 auc 0.5852849336455894\n" + "Epoch 3/25 | Train Accuracy: 1.0 | Train Loss: 14.068548373878002\n", + "Validation Accuracy: 0.9245283007621765 auc 0.8592896174863389\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 15%|█▌ | 3/20 [00:07<00:40, 2.37s/it]" + " 20%|██ | 5/25 [00:05<00:21, 1.07s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 2/20 | Train Accuracy: 1.0 | Train Loss: 419.0077720507543\n", - "Validation Accuracy: 0.9207547307014465 auc 0.6083138173302107\n" + "Epoch 4/25 | Train Accuracy: 1.0 | Train Loss: 7.06401611212641\n", + "Validation Accuracy: 0.9358490705490112 auc 0.87743950039032\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 20%|██ | 4/20 [00:09<00:37, 2.34s/it]" + " 24%|██▍ | 6/25 [00:06<00:20, 1.06s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 3/20 | Train Accuracy: 1.0 | Train Loss: 176.4732639742433\n", - "Validation Accuracy: 0.9169811606407166 auc 0.7976190476190476\n" + "Epoch 5/25 | Train Accuracy: 1.0 | Train Loss: 6.148922558873892\n", + "Validation Accuracy: 0.9358490705490112 auc 0.882903981264637\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 25%|██▌ | 5/20 [00:11<00:34, 2.31s/it]" + " 28%|██▊ | 7/25 [00:07<00:19, 1.07s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 4/20 | Train Accuracy: 1.0 | Train Loss: 49.66766470632865\n", - "Validation Accuracy: 0.9245283007621765 auc 0.8080601092896175\n" + "Epoch 6/25 | Train Accuracy: 1.0 | Train Loss: 6.028049402870238\n", + "Validation Accuracy: 0.9358490705490112 auc 0.8889539422326307\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 30%|███ | 6/20 [00:14<00:33, 2.36s/it]" + " 32%|███▏ | 8/25 [00:08<00:17, 1.05s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 5/20 | Train Accuracy: 1.0 | Train Loss: 31.406425931840204\n", - "Validation Accuracy: 0.9283018708229065 auc 0.858216237314598\n" + "Epoch 7/25 | Train Accuracy: 1.0 | Train Loss: 5.655132191255689\n", + "Validation Accuracy: 0.9358490705490112 auc 0.8924668227946916\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 35%|███▌ | 7/20 [00:16<00:30, 2.35s/it]" + " 36%|███▌ | 9/25 [00:09<00:16, 1.05s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 6/20 | Train Accuracy: 1.0 | Train Loss: 24.368114110082388\n", - "Validation Accuracy: 0.9283018708229065 auc 0.8635831381733021\n" + "Epoch 8/25 | Train Accuracy: 1.0 | Train Loss: 5.50519098713994\n", + "Validation Accuracy: 0.9433962106704712 auc 0.8950039032006245\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 40%|████ | 8/20 [00:18<00:27, 2.31s/it]" + " 40%|████ | 10/25 [00:10<00:15, 1.05s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 7/20 | Train Accuracy: 1.0 | Train Loss: 17.363841364858672\n", - "Validation Accuracy: 0.9283018708229065 auc 0.8762685402029665\n" + "Epoch 9/25 | Train Accuracy: 1.0 | Train Loss: 5.286113580223173\n", + "Validation Accuracy: 0.9396226406097412 auc 0.8955893832943013\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 45%|████▌ | 9/20 [00:21<00:25, 2.33s/it]" + " 44%|████▍ | 11/25 [00:11<00:14, 1.05s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 8/20 | Train Accuracy: 1.0 | Train Loss: 16.201855568680912\n", - "Validation Accuracy: 0.9320755004882812 auc 0.8788056206088993\n" + "Epoch 10/25 | Train Accuracy: 1.0 | Train Loss: 5.151115204207599\n", + "Validation Accuracy: 0.9396226406097412 auc 0.9028103044496486\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 50%|█████ | 10/20 [00:23<00:23, 2.38s/it]" + " 48%|████▊ | 12/25 [00:12<00:13, 1.05s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 9/20 | Train Accuracy: 1.0 | Train Loss: 15.001215729862452\n", - "Validation Accuracy: 0.9320755004882812 auc 0.8873926619828258\n" + "Epoch 11/25 | Train Accuracy: 1.0 | Train Loss: 4.6040293434634805\n", + "Validation Accuracy: 0.9320755004882812 auc 0.9106167056986728\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 55%|█████▌ | 11/20 [00:25<00:21, 2.37s/it]" + " 52%|█████▏ | 13/25 [00:13<00:12, 1.05s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 10/20 | Train Accuracy: 1.0 | Train Loss: 14.861962082330137\n", - "Validation Accuracy: 0.9358490705490112 auc 0.8791959406713505\n" + "Epoch 12/25 | Train Accuracy: 1.0 | Train Loss: 4.592546273488551\n", + "Validation Accuracy: 0.947169840335846 auc 0.9080796252927401\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 60%|██████ | 12/20 [00:28<00:19, 2.40s/it]" + " 56%|█████▌ | 14/25 [00:14<00:11, 1.05s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 11/20 | Train Accuracy: 1.0 | Train Loss: 13.089418702758849\n", - "Validation Accuracy: 0.9320755004882812 auc 0.8858313817330211\n" + "Epoch 13/25 | Train Accuracy: 1.0 | Train Loss: 4.154761636629701\n", + "Validation Accuracy: 0.9358490705490112 auc 0.9067135050741607\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 65%|██████▌ | 13/20 [00:30<00:16, 2.40s/it]" + " 60%|██████ | 15/25 [00:15<00:10, 1.05s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 12/20 | Train Accuracy: 1.0 | Train Loss: 12.216756469802931\n", - "Validation Accuracy: 0.9320755004882812 auc 0.9127634660421545\n" + "Epoch 14/25 | Train Accuracy: 1.0 | Train Loss: 3.5454123290255666\n", + "Validation Accuracy: 0.9358490705490112 auc 0.9192037470725994\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 70%|███████ | 14/20 [00:33<00:14, 2.46s/it]" + " 64%|██████▍ | 16/25 [00:17<00:09, 1.05s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 13/20 | Train Accuracy: 1.0 | Train Loss: 12.858742844546214\n", - "Validation Accuracy: 0.9433962106704712 auc 0.9182279469164715\n" + "Epoch 15/25 | Train Accuracy: 1.0 | Train Loss: 3.1295745647512376\n", + "Validation Accuracy: 0.9358490705490112 auc 0.9311085089773614\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 75%|███████▌ | 15/20 [00:35<00:12, 2.43s/it]" + " 68%|██████▊ | 17/25 [00:18<00:08, 1.04s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 14/20 | Train Accuracy: 1.0 | Train Loss: 11.10123936785385\n", - "Validation Accuracy: 0.9320755004882812 auc 0.911592505854801\n" + "Epoch 16/25 | Train Accuracy: 1.0 | Train Loss: 3.140789811965078\n", + "Validation Accuracy: 0.9358490705490112 auc 0.9449648711943794\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 80%|████████ | 16/20 [00:38<00:09, 2.44s/it]" + " 72%|███████▏ | 18/25 [00:19<00:07, 1.04s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 15/20 | Train Accuracy: 1.0 | Train Loss: 15.444379360007588\n", - "Validation Accuracy: 0.9207547307014465 auc 0.8721701795472288\n" + "Epoch 17/25 | Train Accuracy: 1.0 | Train Loss: 2.7704373160377145\n", + "Validation Accuracy: 0.9358490705490112 auc 0.9533567525370804\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 85%|████████▌ | 17/20 [00:40<00:07, 2.39s/it]" + " 76%|███████▌ | 19/25 [00:20<00:06, 1.04s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 16/20 | Train Accuracy: 1.0 | Train Loss: 15.353719354665373\n", - "Validation Accuracy: 0.9169811606407166 auc 0.822599531615925\n" + "Epoch 18/25 | Train Accuracy: 1.0 | Train Loss: 3.2044623312540352\n", + "Validation Accuracy: 0.9320755004882812 auc 0.948087431693989\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 90%|█████████ | 18/20 [00:42<00:04, 2.35s/it]" + " 80%|████████ | 20/25 [00:21<00:05, 1.04s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 17/20 | Train Accuracy: 1.0 | Train Loss: 15.88208947563544\n", - "Validation Accuracy: 0.9283018708229065 auc 0.8528493364558939\n" + "Epoch 19/25 | Train Accuracy: 1.0 | Train Loss: 2.732395632308908\n", + "Validation Accuracy: 0.9433962106704712 auc 0.9510148321623731\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 95%|█████████▌| 19/20 [00:45<00:02, 2.33s/it]" + " 84%|████████▍ | 21/25 [00:22<00:04, 1.05s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 18/20 | Train Accuracy: 1.0 | Train Loss: 12.539632054162212\n", - "Validation Accuracy: 0.9283018708229065 auc 0.9022248243559718\n" + "Epoch 20/25 | Train Accuracy: 1.0 | Train Loss: 2.5043671822641045\n", + "Validation Accuracy: 0.9433962106704712 auc 0.9535519125683061\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 20/20 [00:47<00:00, 2.37s/it]" + " 88%|████████▊ | 22/25 [00:23<00:03, 1.05s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 19/20 | Train Accuracy: 1.0 | Train Loss: 13.172684742690763\n", - "Validation Accuracy: 0.9433962106704712 auc 0.9342310694769711\n" + "Epoch 21/25 | Train Accuracy: 1.0 | Train Loss: 2.1203417778451694\n", + "Validation Accuracy: 0.9132075309753418 auc 0.961943793911007\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 92%|█████████▏| 23/25 [00:24<00:02, 1.06s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22/25 | Train Accuracy: 1.0 | Train Loss: 10.550110493495595\n", + "Validation Accuracy: 0.9169811606407166 auc 0.8948087431693988\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 96%|█████████▌| 24/25 [00:25<00:01, 1.06s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23/25 | Train Accuracy: 1.0 | Train Loss: 12.623157457801426\n", + "Validation Accuracy: 0.9207547307014465 auc 0.9137392661982825\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 25/25 [00:26<00:00, 1.06s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24/25 | Train Accuracy: 1.0 | Train Loss: 14.581157505754874\n", + "Validation Accuracy: 0.9396226406097412 auc 0.8467993754879001\n" ] }, { @@ -673,20 +689,19 @@ } ], "source": [ - "\n", "for epoch in trange(epochs):\n", "\n", " train_acc, loss = train(\n", " model, loss_func, train_loader, labels, optimizer, feature_tensors,\n", - " target_node, device=device)\n", + " target_node)\n", " print(f\"Epoch {epoch}/{epochs} | Train Accuracy: {train_acc} | Train Loss: {loss}\")\n", - " val_logits, val_seed, _ = evaluate(model, val_loader, feature_tensors, target_node, device=device)\n", + " val_logits, val_seed, _ = evaluate(model, val_loader, feature_tensors, target_node)\n", " val_accuracy = accuracy(val_logits.argmax(1), labels.long()[val_seed].cpu(), \"binary\").item()\n", " val_auc = roc_auc_score(\n", " labels.long()[val_seed].cpu().numpy(),\n", " val_logits[:, 1].numpy(),\n", " )\n", - " print(f\"Validation Accuracy: {val_accuracy} auc {val_auc}\")\n" + " print(f\"Validation Accuracy: {val_accuracy} auc {val_auc}\")" ] }, { @@ -707,7 +722,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -734,27 +749,26 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Final Test Accuracy: 0.9320755004882812 auc 0.889344262295082\n" + "Final Test Accuracy: 0.9245283007621765 auc 0.8380171740827479\n" ] } ], "source": [ "# Create embeddings\n", - "_, train_seeds, train_embedding = evaluate(model, train_loader, feature_tensors, target_node, device=device)\n", - "test_logits, test_seeds, test_embedding = evaluate(model, test_loader, feature_tensors, target_node, device=device)\n", + "_, train_seeds, train_embedding = evaluate(model, train_loader, feature_tensors, target_node)\n", + "test_logits, test_seeds, test_embedding = evaluate(model, test_loader, feature_tensors, target_node)\n", "\n", "# compute metrics\n", "test_acc = accuracy(test_logits.argmax(dim=1), labels.long()[test_seeds].cpu(), \"binary\").item()\n", "test_auc = roc_auc_score(labels.long()[test_seeds].cpu().numpy(), test_logits[:, 1].numpy())\n", "\n", - "metrics_result = pd.DataFrame()\n", "print(f\"Final Test Accuracy: {test_acc} auc {test_auc}\")\n", "\n", "#acc, f_1, precision, recall, roc_auc, pr_auc, average_precision, _, _ = get_metrics(\n", @@ -781,22 +795,22 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ - "from xgboost import XGBClassifier\n" + "from xgboost import XGBClassifier" ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
XGBClassifier(base_score=None, booster=None, callbacks=None,\n", + "XGBClassifier(base_score=None, booster=None, callbacks=None,\n", " colsample_bylevel=None, colsample_bynode=None,\n", " colsample_bytree=None, early_stopping_rounds=None,\n", " enable_categorical=False, eval_metric=None, feature_types=None,\n", @@ -806,7 +820,7 @@ " max_delta_step=None, max_depth=None, max_leaves=None,\n", " min_child_weight=None, missing=nan, monotone_constraints=None,\n", " n_estimators=100, n_jobs=None, num_parallel_tree=None,\n", - " predictor=None, random_state=None, ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.XGBClassifier(base_score=None, booster=None, callbacks=None,\n", + " predictor=None, random_state=None, ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.XGBClassifier(base_score=None, booster=None, callbacks=None,\n", " colsample_bylevel=None, colsample_bynode=None,\n", " colsample_bytree=None, early_stopping_rounds=None,\n", " enable_categorical=False, eval_metric=None, feature_types=None,\n", @@ -832,7 +846,7 @@ " predictor=None, random_state=None, ...)" ] }, - "execution_count": 28, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -840,7 +854,7 @@ "source": [ "# Train XGBoost classifier on embedding vector\n", "classifier = XGBClassifier(n_estimators=100)\n", - "classifier.fit(train_embedding.cpu().numpy(), labels[train_seeds].cpu().numpy())\n" + "classifier.fit(train_embedding.cpu().numpy(), labels[train_seeds].cpu().numpy())" ] }, { @@ -853,11 +867,11 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ - "xgb_pred = classifier.predict_proba(test_embedding.cpu().numpy())\n" + "xgb_pred = classifier.predict_proba(test_embedding.cpu().numpy())" ] }, { @@ -878,14 +892,14 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Final Test Accuracy: 0.9245283018867925 auc 0.9055425448868072\n" + "Final Test Accuracy: 0.9320754716981132 auc 0.9040788446526152\n" ] } ], @@ -921,18 +935,18 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "model_dir= \"modelpath/\"\n", "\n", - "save_model(train_graph, model, hyperparameters, classifier, model_dir)\n" + "save_model(train_graph, model, hyperparameters, classifier, model_dir)" ] }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -949,14 +963,15 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ - "## For inference we can load from file as follows. \n", + "## For inference we can load from file as follows.\n", "from training import load_model\n", + "\n", "# do inference on loaded model, as follows\n", - "# hinsage_model, hyperparam, g = load_model(model_dir, device)" + "hinsage_model, hyperparam, g = load_model(model_dir)" ] }, { @@ -985,13 +1000,6 @@ "2.https://stellargraph.readthedocs.io/en/stable/hinsage.html?highlight=hinsage\n", "3.https://github.com/rapidsai/clx/blob/branch-0.20/examples/forest_inference/xgboost_training.ipynb\"" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {