Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 50 additions & 59 deletions src/nn/nn_train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
"""

"""


import argparse
import os
import random
Expand Down Expand Up @@ -41,6 +36,7 @@ class HydrophoneDataset(Dataset):
def __init__(
self,
csv_path: str,
features_list: list[str],
dtype: torch.dtype = torch.float32,
):
"""
Expand Down Expand Up @@ -69,25 +65,21 @@ def __init__(
if not rows:
raise ValueError("CSV is empty.")
header = [h.strip() for h in rows[0]]

try:
idx1 = header.index("Envelope H1")
idx2 = header.index("Envelope H2")
idx3 = header.index("Envelope H3")
idxy = header.index("Truth")
except ValueError as e:
raise ValueError(
"CSV must include header columns: Envelope H1, Envelope H2, Envelope H3, Truth"
) from e
col_idx = {name: i for i, name in enumerate(header)}
missing = [name for name in features_list if name not in col_idx]

if missing:
raise ValueError(f"CSV missing requested feature columns: {missing}")

idxy = col_idx["Truth"]

feats, labels = [], []
feature_indices = [col_idx[name] for name in features_list] # e.g. [idx1, idx3]

for r in rows[1:]:
if not r or all(c.strip() == "" for c in r):
continue

feats.append([float(r[idx2])])
# feats.append([float(r[idx2]), float(r[idx3])])
# feats.append([float(r[idx1]), float(r[idx2]), float(r[idx3])])
feats.append([float(r[i]) for i in feature_indices])
labels.append(int(float(r[idxy])))

X = torch.tensor(feats, dtype=dtype)
Expand Down Expand Up @@ -163,8 +155,7 @@ class using a Softmax activation at the output layer.
model : nn.Sequential
A sequential container implementing the layer stack.
"""
# def __init__(self, in_dim=1, num_classes=4, p_drop=0.2):
# def __init__(self, in_dim=2, num_classes=4, p_drop=0.2):

def __init__(self, in_dim=3, num_classes=4, p_drop=0.2):

super().__init__()
Expand Down Expand Up @@ -233,14 +224,19 @@ def main():
parser.add_argument("--batch", type=int, default=32)
parser.add_argument("--dropout", type=float, default=0.2)
parser.add_argument("--save_dir", type=str, default="artifacts")
parser.add_argument("--feature_cols", type=str, default="Envelope H1,Envelope H2,Envelope H3")
parser.add_argument("--conf_thresh", type=str, default="0.5,0.8")
args = parser.parse_args()

set_seed(13)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
feature_cols = [c.strip() for c in args.feature_cols.split(",")]
thresholds = [float(t.strip()) for t in args.conf_thresh.split(",") if t.strip()]


# Load dataset
ds = HydrophoneDataset(args.csv)
ds = HydrophoneDataset(csv_path=args.csv, features_list=feature_cols)
n = len(ds)
n_train = int(0.8 * n)
n_val = n - n_train
Expand All @@ -250,10 +246,7 @@ def main():
val_loader = DataLoader(val_ds, batch_size=args.batch)

# Build model
model = MLPProb(in_dim=1, num_classes=4, p_drop=args.dropout).to(device)
# model = MLPProb(in_dim=2, num_classes=4, p_drop=args.dropout).to(device)
# model = MLPProb(in_dim=3, num_classes=4, p_drop=args.dropout).to(device)

model = MLPProb(in_dim=len(feature_cols),num_classes=4, p_drop=args.dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
loss_fn = nn.NLLLoss() # using log(probs)

Expand Down Expand Up @@ -285,8 +278,9 @@ def main():
# Validation
model.eval()
val_loss, val_acc = 0, 0
val_conf_acc_50, val_coverage_50 = 0, 0
val_conf_acc_80, val_coverage_80 = 0, 0
val_conf_acc = {t: 0.0 for t in thresholds}
val_coverage = {t: 0.0 for t in thresholds}


with torch.no_grad():
for X, y in val_loader:
Expand All @@ -298,30 +292,29 @@ def main():
val_acc += accuracy(probs, y) * X.size(0)

# Confidence-aware metrics
acc_50, cov_50 = accuracy_with_pass(probs, y, confidence_threshold=0.5)
acc_80, cov_80 = accuracy_with_pass(probs, y, confidence_threshold=0.8)
val_conf_acc_50 += acc_50 * X.size(0)
val_coverage_50 += cov_50 * X.size(0)
val_conf_acc_80 += acc_80 * X.size(0)
val_coverage_80 += cov_80 * X.size(0)
for t in thresholds:
acc_t, cov_t = accuracy_with_pass(probs, y, confidence_threshold=t)
val_conf_acc[t] += acc_t * X.size(0)
val_coverage[t] += cov_t * X.size(0)

val_loss /= len(val_loader.dataset)
val_acc /= len(val_loader.dataset)
val_conf_acc_50 /= len(val_loader.dataset)
val_coverage_50 /= len(val_loader.dataset)
val_conf_acc_80 /= len(val_loader.dataset)
val_coverage_80 /= len(val_loader.dataset)
for t in thresholds:
val_conf_acc[t] /= len(val_loader.dataset)
val_coverage[t] /= len(val_loader.dataset)

if val_loss < best_val_loss:
best_val_loss = val_loss
best_acc = val_acc
best_state = {k: v.cpu() for k, v in model.state_dict().items()}

conf_parts = " | ".join(
[f"{int(t*100)}%: acc={val_conf_acc[t]:.3f} cov={val_coverage[t]:.3f}" for t in thresholds]
)
print(f"Epoch {epoch+1:03d}: "
f"train_loss={train_loss:.4f} acc={train_acc:.3f} | "
f"val_loss={val_loss:.4f} acc={val_acc:.3f} | "
f"50%: acc={val_conf_acc_50:.3f} cov={val_coverage_50:.3f} | "
f"80%: acc={val_conf_acc_80:.3f} cov={val_coverage_80:.3f}")
f"train_loss={train_loss:.4f} acc={train_acc:.3f} | "
f"val_loss={val_loss:.4f} acc={val_acc:.3f} | "
f"{conf_parts}")

# Save model + normalization stats
os.makedirs(args.save_dir, exist_ok=True)
Expand All @@ -330,28 +323,26 @@ def main():
# Calculate final confidence metrics for best model
model.load_state_dict(best_state)
model.eval()
final_conf_acc_50, final_coverage_50 = 0, 0
final_conf_acc_80, final_coverage_80 = 0, 0
final_conf_acc = {t: 0.0 for t in thresholds}
final_coverage = {t: 0.0 for t in thresholds}

with torch.no_grad():
for X, y in val_loader:
X, y = X.to(device), y.to(device)
probs = model(X)
acc_50, cov_50 = accuracy_with_pass(probs, y, confidence_threshold=0.5)
acc_80, cov_80 = accuracy_with_pass(probs, y, confidence_threshold=0.8)
final_conf_acc_50 += acc_50 * X.size(0)
final_coverage_50 += cov_50 * X.size(0)
final_conf_acc_80 += acc_80 * X.size(0)
final_coverage_80 += cov_80 * X.size(0)

final_conf_acc_50 /= len(val_loader.dataset)
final_coverage_50 /= len(val_loader.dataset)
final_conf_acc_80 /= len(val_loader.dataset)
final_coverage_80 /= len(val_loader.dataset)

print(f"BEST: val_loss={best_val_loss:.4f} acc={best_acc:.3f} | "
f"50%: acc={final_conf_acc_50:.3f} cov={final_coverage_50:.3f} | "
f"80%: acc={final_conf_acc_80:.3f} cov={final_coverage_80:.3f}")
for t in thresholds:
acc_t, cov_t = accuracy_with_pass(probs, y, confidence_threshold=t)
final_conf_acc[t] += acc_t * X.size(0)
final_coverage[t] += cov_t * X.size(0)

for t in thresholds:
final_conf_acc[t] /= len(val_loader.dataset)
final_coverage[t] /= len(val_loader.dataset)

final_parts = " | ".join(
[f"{int(t*100)}%: acc={final_conf_acc[t]:.3f} cov={final_coverage[t]:.3f}" for t in thresholds]
)
print(f"BEST: val_loss={best_val_loss:.4f} acc={best_acc:.3f} | {final_parts}")

torch.save({
"model_state_dict": best_state,
Expand Down