Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
332 changes: 332 additions & 0 deletions nikki_exp_conv_temporal_static/demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import rasterio\n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"import lightgbm as lgbm\n",
"\n",
"from typing import Any, Dict, Optional, List\n",
"from tqdm import tqdm as tqdm\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import utils\n",
"import main\n",
"import sys"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Arguments to be passed"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"\n",
"data_dir = \"gs://earth-engine-seminar/urbanization/data/export_22122024\"\n",
"output_path=\"prediction_tiff.tiff\"\n",
"filter_size=5\n",
"block_coverage=0.3\n",
"total_blocks=100\n",
"test_size=0.3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading the input files "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"\n",
"labels = utils.files_in_dir(data_dir, \"label.tif\")\n",
"features = utils.files_in_dir(data_dir, \"feat.tif\")\n",
"train_labels = utils.files_in_dir(data_dir, \"label.tif\")\n",
"train_features = utils.files_in_dir(data_dir, \"feat.tif\")\n",
"test_labels = utils.load_tif_data(labels[0])\n",
"test_features = utils.load_tif_data(features[0])\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pre-processing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Removing dublicate features from train and test set of features"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"train_feature=np.concatenate((train_features.data[0,:,:].reshape(1,train_features.metadata['height'],train_features.metadata['width'] ), train_features.data[2:,:,:]), axis=0)\n",
"test_feature= np.concatenate((test_features.data[0,:,:].reshape(1,test_features.metadata['height'],test_features.metadata['width']),test_features.data[2:,:,:]),axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"block size 121\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/226 [00:00<?, ?it/s]/home/nikki/Downloads/urbanan_google_earth/main.py:186: RuntimeWarning: Mean of empty slice\n",
" new_data.append(np.concatenate(((np.nanmean(mean_data,axis=(1,2))),static_data),axis=0 ))\n",
"100%|██████████| 226/226 [07:41<00:00, 2.04s/it] \n",
"100%|██████████| 98/98 [00:08<00:00, 11.17it/s]\n",
" 2%|▏ | 41/2207 [00:04<03:46, 9.56it/s]/home/nikki/Downloads/urbanan_google_earth/main.py:92: RuntimeWarning: Mean of empty slice\n",
" new_data.append(np.concatenate(((np.nanmean(mean_data,axis=(1,2))),static_data),axis=0 ))\n",
"100%|██████████| 2207/2207 [04:13<00:00, 8.72it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(26442, 15) (26442,)\n",
"(4932645, 15) (4932645,)\n",
"(11466, 15) (11466,)\n",
"(18200, 15) (18200,)\n",
"(3304360, 15) (3304360,)\n",
"(3304360, 15) (3304360,)\n"
]
}
],
"source": [
"x_train,y_train,x_test,y_test,x_val,y_val,train_mask, test_mask, val_mask=main.pre_process(train_feature,train_labels.data,test_feature, test_labels.data, block_coverage, total_blocks,test_size)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Correlation-Matrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_train=y_train.reshape(len(y_train),1)\n",
"train=np.concatenate((x_train, y_train), axis=1)\n",
"\n",
"# Normalizing features \n",
"\n",
"col_means = np.mean(train,axis=0)\n",
"train_normalized = (train - col_means) / np.std(train, axis=0)\n",
"print(np.mean(train_normalized,axis=0))\n",
"print(np.std(train_normalized,axis=0))\n",
"coerr=np.corrcoef(train, rowvar=False)\n",
"\n",
"\n",
"\n",
"plt.figure(figsize=(12, 8))\n",
"sns.heatmap(coerr, annot=True, xticklabels=['Palmer Drought Severity Index','Precipitation accumulation',\n",
"'min temp','max temp','16-day NDVI avg','16-day EVI avg','LC_Type1','population_density','urban'], yticklabels=['Palmer Drought Severity Index','Precipitation accumulation',\n",
"'min temp','max temp','16-day NDVI avg','16-day EVI avg','LC_Type1','population_density','urban'], cmap='coolwarm', linewidths=0.5)\n",
"plt.title('Correlation Heatmap')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/nikki/Downloads/venv/lib/python3.12/site-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.000569 seconds.\n",
"You can set `force_row_wise=true` to remove the overhead.\n",
"And if memory is not enough, you can set `force_col_wise=true`.\n",
"[LightGBM] [Info] Total Bins 3483\n",
"[LightGBM] [Info] Number of data points in the train set: 18200, number of used features: 15\n",
"[LightGBM] [Info] Start training from score -6.551080\n",
"[LightGBM] [Info] Start training from score -7.611952\n",
"[LightGBM] [Info] Start training from score -6.198259\n",
"[LightGBM] [Info] Start training from score -0.003964\n",
"[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n",
"[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n"
]
}
],
"source": [
"\n",
"model = lgbm.LGBMClassifier(objective=\"multiclass\", num_class=4)\n",
"main.train(model,x_train,y_train,x_val,y_val)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Testing"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/nikki/Downloads/venv/lib/python3.12/site-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.\n",
" warnings.warn(\n"
]
}
],
"source": [
"test_pred,cm,report=main.predict(model,x_test,y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Results"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.01 0.04 0.02 8586\n",
" 1 0.01 0.06 0.01 2260\n",
" 2 0.02 0.11 0.04 6524\n",
" 3 1.00 0.98 0.99 3286990\n",
"\n",
" accuracy 0.98 3304360\n",
" macro avg 0.26 0.30 0.27 3304360\n",
"weighted avg 0.99 0.98 0.98 3304360\n",
"\n"
]
}
],
"source": [
"print(report)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 318 489 152 7627]\n",
" [ 127 143 358 1632]\n",
" [ 213 115 717 5479]\n",
" [ 20568 16523 28307 3221592]]\n"
]
}
],
"source": [
"# Confusion matrix\n",
"print(cm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Saving"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"main.save_prediction(test_labels,filter_size,test_mask,test_pred,output_path)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading