diff --git a/configs/arrhythmia-class-2/demo.json b/configs/arrhythmia-class-2/demo.json new file mode 100644 index 00000000..297b12fe --- /dev/null +++ b/configs/arrhythmia-class-2/demo.json @@ -0,0 +1,29 @@ +{ + "job_dir": "./results/arrhythmia-class-2", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 800, + "num_classes": 2, + "model_file": "./results/arrhythmia-class-2/model.tflite", + "threshold": 0.75, + "backend": "pc", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/arrhythmia-class-2/evaluate.json b/configs/arrhythmia-class-2/evaluate.json new file mode 100644 index 00000000..a9b462a7 --- /dev/null +++ b/configs/arrhythmia-class-2/evaluate.json @@ -0,0 +1,31 @@ +{ + "job_dir": "./results/arrhythmia-class-2", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 800, + "num_classes": 2, + "samples_per_patient": [25, 200], + "test_patients": 1000, + "test_size": 100000, + "model_file": "./results/arrhythmia-class-2/model.tf", + "threshold": 0.75, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/arrhythmia-class-2/export.json b/configs/arrhythmia-class-2/export.json new file mode 100644 index 00000000..25ace2a4 --- /dev/null +++ b/configs/arrhythmia-class-2/export.json @@ -0,0 +1,36 @@ +{ + "job_dir": "./results/arrhythmia-class-2", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 800, + "num_classes": 2, + "samples_per_patient": [5, 40], + "test_patients": 1000, + "test_size": 10000, + "model_file": "./results/arrhythmia-class-2/model.tf", + "quantization": true, + "use_logits": false, + "threshold": 0.75, + "val_acc_threshold": 0.98, + "tflm_var_name": "g_arrhythmia_model", + "tflm_file": "./results/arrhythmia-class-2/arrhythmia_model_buffer.h", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/arrhythmia-class-2/train.json b/configs/arrhythmia-class-2/train.json new file mode 100644 index 00000000..c69463c7 --- /dev/null +++ b/configs/arrhythmia-class-2/train.json @@ -0,0 +1,78 @@ +{ + "job_dir": "./results/arrhythmia-class-2", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 800, + "num_classes": 2, + "samples_per_patient": [25, 200], + "val_samples_per_patient": [25, 200], + "train_patients": 10000, + "val_file": "./results/arrhythmia-class-2-10000pt-200fs-4s.pkl", + "val_patients": 0.20, + "val_size": 100000, + "batch_size": 256, + "buffer_size": 100000, + "epochs": 200, + "steps_per_epoch": 20, + "val_metric": "loss", + "datasets": ["icentia11k"], + "lr_rate": 5e-3, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ], + "augmentations": [ + { + "name": "baseline_wander", + "args": { + "amplitude": [0.5, 1.0], + "frequency": [0.4, 0.5] + } + }, + { + "name": "motion_noise", + "args": { + "amplitude": [0.2, 0.4], + "frequency": [0.4, 0.6] + } + }, + { + "name": "burst_noise", + "args": { + "burst_number": [0, 4], + "amplitude": [0.05, 0.5], + "frequency": [80, 100] + } + }, + { + "name": "powerline_noise", + "args": { + "amplitude": [0.005, 0.01], + "frequency": [50, 60] + } + }, + { + "name": "noise_sources", + "args": { + "num_sources": [1, 2], + "amplitude": [0.04, 0.1], + "frequency": [10, 40] + } + } + ] +} diff --git a/configs/arrhythmia-class-3/demo.json b/configs/arrhythmia-class-3/demo.json new file mode 100644 index 00000000..9eb52c18 --- /dev/null +++ b/configs/arrhythmia-class-3/demo.json @@ -0,0 +1,29 @@ +{ + "job_dir": "./results/arrhythmia-class-3", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 800, + "num_classes": 3, + "model_file": "./results/arrhythmia-class-3/model.tflite", + "threshold": 0.75, + "backend": "pc", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/arrhythmia-class-3/evaluate.json b/configs/arrhythmia-class-3/evaluate.json new file mode 100644 index 00000000..6b61ebe6 --- /dev/null +++ b/configs/arrhythmia-class-3/evaluate.json @@ -0,0 +1,31 @@ +{ + "job_dir": "./results/arrhythmia-class-3", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 800, + "num_classes": 3, + "samples_per_patient": [25, 100, 100], + "test_patients": 1000, + "test_size": 50000, + "model_file": "./results/arrhythmia-class-3/model.tf", + "threshold": 0.5, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/arrhythmia-class-3/export.json b/configs/arrhythmia-class-3/export.json new file mode 100644 index 00000000..be1ba614 --- /dev/null +++ b/configs/arrhythmia-class-3/export.json @@ -0,0 +1,36 @@ +{ + "job_dir": "./results/arrhythmia-class-3", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 800, + "num_classes": 3, + "samples_per_patient": [5, 20, 20], + "test_patients": 1000, + "test_size": 10000, + "model_file": "./results/arrhythmia-class-3/model.tf", + "quantization": true, + "use_logits": false, + "threshold": 0.5, + "val_acc_threshold": 0.98, + "tflm_var_name": "g_arrhythmia_model", + "tflm_file": "./results/arrhythmia-class-3/arrhythmia_model_buffer.h", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/arrhythmia-class-3/train.json b/configs/arrhythmia-class-3/train.json new file mode 100644 index 00000000..13fda675 --- /dev/null +++ b/configs/arrhythmia-class-3/train.json @@ -0,0 +1,79 @@ +{ + "job_dir": "./results/arrhythmia-class-3", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 800, + "num_classes": 3, + "samples_per_patient": [25, 100, 100], + "val_samples_per_patient": [25, 100, 100], + "train_patients": 10000, + "val_file": "./results/arrhythmia-class-3-10000pt-200fs-4s.pkl", + "val_patients": 0.20, + "val_size": 100000, + "batch_size": 256, + "buffer_size": 100000, + "epochs": 150, + "steps_per_epoch": 20, + "val_metric": "loss", + "datasets": ["icentia11k"], + "lr_rate": 5e-3, + "lr_cycles": 1, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ], + "augmentations": [ + { + "name": "baseline_wander", + "args": { + "amplitude": [0.5, 1.0], + "frequency": [0.4, 0.5] + } + }, + { + "name": "motion_noise", + "args": { + "amplitude": [0.2, 0.4], + "frequency": [0.4, 0.6] + } + }, + { + "name": "burst_noise", + "args": { + "burst_number": [0, 4], + "amplitude": [0.05, 0.5], + "frequency": [80, 100] + } + }, + { + "name": "powerline_noise", + "args": { + "amplitude": [0.005, 0.01], + "frequency": [50, 60] + } + }, + { + "name": "noise_sources", + "args": { + "num_sources": [1, 2], + "amplitude": [0.04, 0.1], + "frequency": [10, 40] + } + } + ] +} diff --git a/configs/beat-class-2/demo.json b/configs/beat-class-2/demo.json new file mode 100644 index 00000000..aa4b4e15 --- /dev/null +++ b/configs/beat-class-2/demo.json @@ -0,0 +1,28 @@ +{ + "job_dir": "./results/beat-class-2", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 160, + "num_classes": 2, + "model_file": "./results/beat-class-2/model.tflite", + "backend": "pc", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/beat-class-2/evaluate.json b/configs/beat-class-2/evaluate.json new file mode 100644 index 00000000..9e9fdd88 --- /dev/null +++ b/configs/beat-class-2/evaluate.json @@ -0,0 +1,31 @@ +{ + "job_dir": "./results/beat-class-2", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 160, + "num_classes": 2, + "samples_per_patient": [25, 200], + "test_patients": 1000, + "test_size": 50000, + "model_file": "./results/beat-class-3/model.tf", + "threshold": 0.50, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/beat-class-2/export.json b/configs/beat-class-2/export.json new file mode 100644 index 00000000..1972bc9e --- /dev/null +++ b/configs/beat-class-2/export.json @@ -0,0 +1,36 @@ +{ + "job_dir": "./results/beat-class-2", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 160, + "num_classes": 2, + "samples_per_patient": [5, 40], + "test_patients": 1000, + "test_size": 10000, + "model_file": "./results/beat-class-2/model.tf", + "quantization": true, + "use_logits": false, + "threshold": 0.50, + "val_acc_threshold": 0.98, + "tflm_var_name": "g_beat_model", + "tflm_file": "./results/beat-class-2/beat_model_buffer.h", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/train-arrhythmia-model.json b/configs/beat-class-2/train.json similarity index 66% rename from configs/train-arrhythmia-model.json rename to configs/beat-class-2/train.json index b3ade9db..a9216d81 100644 --- a/configs/train-arrhythmia-model.json +++ b/configs/beat-class-2/train.json @@ -1,22 +1,40 @@ { - "job_dir": "./results/arrhythmia", + "job_dir": "./results/beat-class-2", "ds_path": "./datasets", "sampling_rate": 200, - "frame_size": 800, + "frame_size": 160, "num_classes": 2, - "samples_per_patient": [100, 800], - "val_samples_per_patient": [100, 800], + "samples_per_patient": [25, 200], + "val_samples_per_patient": [25, 200], "train_patients": 10000, - "val_file": "./results/arrhythmia-10000pt-200fs-4s.pkl", - "val_patients": 0.10, - "val_size": 200000, + "val_patients": 0.20, + "val_size": 100000, "batch_size": 256, "buffer_size": 100000, - "epochs": 200, + "epochs": 150, "steps_per_epoch": 20, "val_metric": "loss", "datasets": ["icentia11k"], "lr_rate": 5e-3, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ], "augmentations": [ { "name": "baseline_wander", diff --git a/configs/beat-class-3/demo.json b/configs/beat-class-3/demo.json new file mode 100644 index 00000000..309fb370 --- /dev/null +++ b/configs/beat-class-3/demo.json @@ -0,0 +1,28 @@ +{ + "job_dir": "./results/beat-class-3", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 160, + "num_classes": 3, + "model_file": "./results/beat-class-3/model.tflite", + "backend": "pc", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/beat-class-3/evaluate.json b/configs/beat-class-3/evaluate.json new file mode 100644 index 00000000..36d88b30 --- /dev/null +++ b/configs/beat-class-3/evaluate.json @@ -0,0 +1,31 @@ +{ + "job_dir": "./results/beat-class-3", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 160, + "num_classes": 3, + "samples_per_patient": [25, 100, 100], + "test_patients": 1000, + "test_size": 50000, + "model_file": "./results/beat-class-3/model.tf", + "threshold": 0.50, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/beat-class-3/export.json b/configs/beat-class-3/export.json new file mode 100644 index 00000000..6a91783a --- /dev/null +++ b/configs/beat-class-3/export.json @@ -0,0 +1,36 @@ +{ + "job_dir": "./results/beat-class-3", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 160, + "num_classes": 3, + "samples_per_patient": [5, 20, 20], + "test_patients": 1000, + "test_size": 10000, + "model_file": "./results/beat-class-3/model.tf", + "quantization": true, + "use_logits": false, + "threshold": 0.50, + "val_acc_threshold": 0.98, + "tflm_var_name": "g_beat_model", + "tflm_file": "./results/beat-class-3/beat_model_buffer.h", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/train-beat-model.json b/configs/beat-class-3/train.json similarity index 65% rename from configs/train-beat-model.json rename to configs/beat-class-3/train.json index 13546e30..5b00b178 100644 --- a/configs/train-beat-model.json +++ b/configs/beat-class-3/train.json @@ -1,22 +1,40 @@ { - "job_dir": "./results/beat", + "job_dir": "./results/beat-class-3", "ds_path": "./datasets", "sampling_rate": 200, "frame_size": 160, "num_classes": 3, - "samples_per_patient": [50, 400, 400], - "val_samples_per_patient": [50, 100, 100], + "samples_per_patient": [25, 100, 100], + "val_samples_per_patient": [25, 100, 100], "train_patients": 10000, - "val_file-dis": "./results/beat-1000pt-800ms-200fs-4n.pkl", - "val_patients": 0.10, - "val_size": 20000, - "batch_size": 512, - "buffer_size": 10000, + "val_patients": 0.20, + "val_size": 100000, + "batch_size": 256, + "buffer_size": 100000, "epochs": 150, "steps_per_epoch": 20, "val_metric": "loss", "datasets": ["icentia11k"], "lr_rate": 5e-3, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ], "augmentations": [ { "name": "baseline_wander", diff --git a/configs/demo-arrhythmia.json b/configs/demo-arrhythmia.json deleted file mode 100644 index 5e7a60ee..00000000 --- a/configs/demo-arrhythmia.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "job_dir": "./results/arrhythmia", - "ds_path": "./datasets", - "sampling_rate": 200, - "frame_size": 800, - "num_classes": 2, - "model_file": "./results/arrhythmia/model.tflite", - "backend": "pc" -} diff --git a/configs/demo-beat.json b/configs/demo-beat.json deleted file mode 100644 index dc95f0ea..00000000 --- a/configs/demo-beat.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "job_dir": "./results/beat", - "ds_path": "./datasets", - "sampling_rate": 200, - "frame_size": 160, - "num_classes": 3, - "model_file": "./results/beat/model.tflite", - "backend": "pc" -} diff --git a/configs/demo-segmentation.json b/configs/demo-segmentation.json deleted file mode 100644 index b0d194bf..00000000 --- a/configs/demo-segmentation.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "job_dir": "./results/segmentation", - "ds_path": "./datasets", - "num_classes": 4, - "sampling_rate": 200, - "frame_size": 512, - "model_file": "./results/segmentation/model.tflite", - "backend": "pc" -} diff --git a/configs/evaluate-arrhythmia-model.json b/configs/evaluate-arrhythmia-model.json deleted file mode 100644 index 6ee89f66..00000000 --- a/configs/evaluate-arrhythmia-model.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "job_dir": "./results/arrhythmia", - "ds_path": "./datasets", - "sampling_rate": 200, - "frame_size": 800, - "num_classes": 2, - "samples_per_patient": [50, 400, 200], - "test_patients": 1000, - "test_size": 100000, - "model_file": "./results/arrhythmia/model.tf", - "model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v0", - "threshold": 0.75 -} diff --git a/configs/evaluate-beat-model.json b/configs/evaluate-beat-model.json deleted file mode 100644 index e6aa48a2..00000000 --- a/configs/evaluate-beat-model.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "job_dir": "./results/beat", - "ds_path": "./datasets", - "sampling_rate": 200, - "frame_size": 160, - "num_classes": 3, - "samples_per_patient": [50, 100, 100], - "test_patients": 1000, - "test_size": 50000, - "model_file": "./results/beat/model.tf", - "model_file_rmt": "wandb://ambiq/model-registry/heartkit-beat:v1", - "threshold": 0.50 -} diff --git a/configs/evaluate-segmentation-model.json b/configs/evaluate-segmentation-model.json deleted file mode 100644 index ffc60b14..00000000 --- a/configs/evaluate-segmentation-model.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "job_dir": "./results/segmentation-pre-tcn-14", - "ds_path": "./datasets", - "sampling_rate": 200, - "frame_size": 512, - "num_classes": 4, - "samples_per_patient": 100, - "test_size": 10000, - "num_pts": 600, - "use_logits": false, - "datasets": ["synthetic", "ludb"], - "model_file": "./results/segmentation-pre-tcn-14/model.tf", - "model_file_rmt": "wandb://ambiq/model-registry/heartkit-segmentation:v1", - "threshold": 0.50 -} diff --git a/configs/export-arrhythmia-model.json b/configs/export-arrhythmia-model.json deleted file mode 100644 index 01c84bf0..00000000 --- a/configs/export-arrhythmia-model.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "job_dir": "./results/arrhythmia", - "ds_path": "./datasets", - "sampling_rate": 200, - "frame_size": 800, - "num_classes": 2, - "samples_per_patient": [5, 40], - "test_patients": 1000, - "test_size": 10000, - "model_file": "./results/arrhythmia/model.tf", - "model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v0", - "quantization": true, - "use_logits": false, - "threshold": 0.80, - "val_acc_threshold": 0.98, - "tflm_var_name": "g_arrhythmia_model", - "tflm_file": "./evb/src/arrhythmia_model_buffer.h" -} diff --git a/configs/export-beat-model.json b/configs/export-beat-model.json deleted file mode 100644 index a68f3f85..00000000 --- a/configs/export-beat-model.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "job_dir": "./results/beat", - "ds_path": "./datasets", - "sampling_rate": 200, - "frame_size": 160, - "num_classes": 3, - "samples_per_patient": [10, 20, 20], - "test_patients": 1000, - "test_size": 10000, - "model_file": "./results/beat/model.tf", - "model_file_rmt": "wandb://ambiq/model-registry/heartkit-beat:v1", - "quantization": true, - "use_logits": false, - "threshold": 0.50, - "val_acc_threshold": 0.98, - "tflm_var_name": "g_beat_model", - "tflm_file": "./evb/src/beat_model_buffer.h" -} diff --git a/configs/export-segmentation-model.json b/configs/export-segmentation-model.json deleted file mode 100644 index 87a812a1..00000000 --- a/configs/export-segmentation-model.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "job_dir": "./results/segmentation-pre-tcn-14", - "ds_path": "./datasets", - "sampling_rate": 200, - "frame_size": 512, - "samples_per_patient": 100, - "num_pts": 400, - "test_size": 1000, - "datasets": ["synthetic", "ludb"], - "model_file": "./results/segmentation-pre-tcn-14/model.tf", - "model_file_rmt": "wandb://ambiq/model-registry/heartkit-segmentation:v1", - "quantization": true, - "use_logits": false, - "threshold": 0.50, - "val_acc_threshold": 0.98, - "tflm_var_name": "g_segmentation_model", - "tflm_file_dis": "./evb/src/segmentation_model_buffer.h" - -} diff --git a/configs/pretrain-segmentation-model.json b/configs/pretrain-segmentation-model.json deleted file mode 100644 index 7252f7a9..00000000 --- a/configs/pretrain-segmentation-model.json +++ /dev/null @@ -1,58 +0,0 @@ -{ - "job_dir": "./results/segmentation-pre", - "ds_path": "./datasets", - "sampling_rate": 200, - "frame_size": 512, - "num_classes": 4, - "samples_per_patient": 100, - "val_samples_per_patient": 100, - "val_patients": 0.10, - "batch_size": 512, - "buffer_size": 25000, - "epochs": 80, - "steps_per_epoch": 100, - "val_metric": "loss", - "datasets": ["ludb", "synthetic"], - "lr_rate": 5e-3, - "num_pts": 400, - "quantization": false, - "augmentations": [ - { - "name": "baseline_wander", - "args": { - "amplitude": [0.5, 1.5], - "frequency": [0.4, 0.5] - } - }, - { - "name": "motion_noise", - "args": { - "amplitude": [0.5, 1.5], - "frequency": [0.4, 0.7] - } - }, - { - "name": "burst_noise", - "args": { - "burst_number": [2, 10], - "amplitude": [0.5, 1.5], - "frequency": [40, 100] - } - }, - { - "name": "powerline_noise", - "args": { - "amplitude": [0.005, 0.01], - "frequency": [50, 60] - } - }, - { - "name": "noise_sources", - "args": { - "num_sources": [1, 8], - "amplitude": [0.05, 0.25], - "frequency": [10, 40] - } - } - ] -} diff --git a/configs/segmentation-class-2/demo.json b/configs/segmentation-class-2/demo.json new file mode 100644 index 00000000..38044f1c --- /dev/null +++ b/configs/segmentation-class-2/demo.json @@ -0,0 +1,28 @@ +{ + "job_dir": "./results/segmentation-class-2", + "ds_path": "./datasets", + "num_classes": 2, + "sampling_rate": 200, + "frame_size": 512, + "model_file": "./results/segmentation-class-2/model.tflite", + "backend": "pc", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/segmentation-class-2/evaluate.json b/configs/segmentation-class-2/evaluate.json new file mode 100644 index 00000000..bc93dd11 --- /dev/null +++ b/configs/segmentation-class-2/evaluate.json @@ -0,0 +1,33 @@ +{ + "job_dir": "./results/segmentation-class-2", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 512, + "num_classes": 2, + "samples_per_patient": 100, + "test_size": 10000, + "num_pts": 500, + "use_logits": false, + "datasets": ["synthetic", "ludb"], + "model_file": "./results/segmentation-class-2/model.tf", + "threshold": 0.50, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/segmentation-class-2/export.json b/configs/segmentation-class-2/export.json new file mode 100644 index 00000000..393ca9b4 --- /dev/null +++ b/configs/segmentation-class-2/export.json @@ -0,0 +1,38 @@ +{ + "job_dir": "./results/segmentation-class-2", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 512, + "num_classes": 2, + "samples_per_patient": 100, + "num_pts": 500, + "test_size": 1000, + "datasets": ["synthetic", "ludb"], + "model_file": "./results/segmentation-class-2/model.tf", + "quantization": true, + "use_logits": false, + "threshold": 0.50, + "val_acc_threshold": 0.98, + "tflm_var_name": "segmentation_flatbuffer", + "tflm_file": "./results/segmentation-class-2/segmentation_flatbuffer.h", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + + ] +} diff --git a/configs/pretrain-segmentation-model-ln.json b/configs/segmentation-class-2/finetune.json similarity index 65% rename from configs/pretrain-segmentation-model-ln.json rename to configs/segmentation-class-2/finetune.json index fa88c075..9f627a9a 100644 --- a/configs/pretrain-segmentation-model-ln.json +++ b/configs/segmentation-class-2/finetune.json @@ -1,21 +1,42 @@ { - "job_dir": "./results/segmentation-pre-ln", + "job_dir": "./results/segmentation-class-2", + "model_file": "./results/segmentation-class-2/model.tf", "ds_path": "./datasets", "sampling_rate": 200, "frame_size": 512, - "num_classes": 4, + "num_classes": 2, "samples_per_patient": 100, "val_samples_per_patient": 100, "val_patients": 0.10, - "batch_size": 512, + "batch_size": 128, "buffer_size": 25000, - "epochs": 80, + "epochs": 40, "steps_per_epoch": 100, "val_metric": "loss", "datasets": ["ludb", "synthetic"], - "lr_rate": 5e-3, - "num_pts": 400, - "quantization": false, + "lr_rate": 1e-4, + "lr_cycles": 1, + "num_pts": 500, + "quantization": true, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ], "augmentations": [ { "name": "baseline_wander", diff --git a/configs/segmentation-class-2/train.json b/configs/segmentation-class-2/train.json new file mode 100644 index 00000000..4d86eb0b --- /dev/null +++ b/configs/segmentation-class-2/train.json @@ -0,0 +1,90 @@ +{ + "job_dir": "./results/segmentation-class-2", + "ds_path": "./datasets", + "sampling_rate": 100, + "frame_size": 256, + "num_classes": 2, + "samples_per_patient": 100, + "val_samples_per_patient": 100, + "val_patients": 0.10, + "batch_size": 128, + "buffer_size": 10000, + "epochs": 100, + "steps_per_epoch": 100, + "val_metric": "loss", + "datasets": ["ludb", "synthetic"], + "lr_rate": 5e-3, + "num_pts": 400, + "quantization": false, + "model": "unet", + "model_params": { + "blocks": [ + {"filters": 8, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 3], "strides": [1, 2], "skip": true}, + {"filters": 12, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 3], "strides": [1, 2], "skip": true}, + {"filters": 16, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 3], "strides": [1, 2], "skip": true}, + {"filters": 24, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 3], "strides": [1, 2], "skip": true}, + {"filters": 32, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 3], "strides": [1, 2], "skip": true} + ], + "output_kernel_size": [1, 5], + "include_top": true, + "use_logits": true + }, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": false, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ], + "augmentations": [ + { + "name": "baseline_wander", + "args": { + "amplitude": [0.5, 1.5], + "frequency": [0.4, 0.5] + } + }, + { + "name": "motion_noise", + "args": { + "amplitude": [0.5, 1.5], + "frequency": [0.4, 0.7] + } + }, + { + "name": "burst_noise", + "args": { + "burst_number": [2, 10], + "amplitude": [0.5, 1.5], + "frequency": [40, 100] + } + }, + { + "name": "powerline_noise", + "args": { + "amplitude": [0.005, 0.01], + "frequency": [50, 60] + } + }, + { + "name": "noise_sources", + "args": { + "num_sources": [1, 8], + "amplitude": [0.05, 0.25], + "frequency": [10, 40] + } + } + ] +} diff --git a/configs/segmentation-class-4/demo.json b/configs/segmentation-class-4/demo.json new file mode 100644 index 00000000..b509d293 --- /dev/null +++ b/configs/segmentation-class-4/demo.json @@ -0,0 +1,28 @@ +{ + "job_dir": "./results/segmentation-class-4", + "ds_path": "./datasets", + "num_classes": 4, + "sampling_rate": 200, + "frame_size": 512, + "model_file": "./results/segmentation-class-4/model.tflite", + "backend": "pc", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/segmentation-class-4/evaluate.json b/configs/segmentation-class-4/evaluate.json new file mode 100644 index 00000000..a894ffe7 --- /dev/null +++ b/configs/segmentation-class-4/evaluate.json @@ -0,0 +1,33 @@ +{ + "job_dir": "./results/segmentation-class-4", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 512, + "num_classes": 4, + "samples_per_patient": 100, + "test_size": 10000, + "num_pts": 500, + "use_logits": false, + "datasets": ["synthetic", "ludb"], + "model_file": "./results/segmentation-class-4/model.tf", + "threshold": 0.50, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/segmentation-class-4/export.json b/configs/segmentation-class-4/export.json new file mode 100644 index 00000000..795c30b4 --- /dev/null +++ b/configs/segmentation-class-4/export.json @@ -0,0 +1,38 @@ +{ + "job_dir": "./results/segmentation-class-4", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 512, + "num_classes": 4, + "samples_per_patient": 100, + "num_pts": 500, + "test_size": 1000, + "datasets": ["synthetic", "ludb"], + "model_file": "./results/segmentation-class-4/model.tf", + "quantization": true, + "use_logits": false, + "threshold": 0.50, + "val_acc_threshold": 0.98, + "tflm_var_name": "segmentation_flatbuffer", + "tflm_file": "./results/segmentation-class-4/segmentation_flatbuffer.h", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + + ] +} diff --git a/configs/train-segmentation-model.json b/configs/segmentation-class-4/finetune.json similarity index 72% rename from configs/train-segmentation-model.json rename to configs/segmentation-class-4/finetune.json index a63c15a9..cc0c2af3 100644 --- a/configs/train-segmentation-model.json +++ b/configs/segmentation-class-4/finetune.json @@ -1,10 +1,10 @@ { - "job_dir": "./results/segmentation", - "model_file": "./results/segmentation-pre/model.tf", + "job_dir": "./results/segmentation-class-4", + "model_file": "./results/segmentation-class-4/model.tf", "ds_path": "./datasets", "sampling_rate": 200, - "num_classes": 4, "frame_size": 512, + "num_classes": 4, "samples_per_patient": 100, "val_samples_per_patient": 100, "val_patients": 0.10, @@ -16,8 +16,27 @@ "datasets": ["ludb", "synthetic"], "lr_rate": 1e-4, "lr_cycles": 1, - "num_pts": 400, + "num_pts": 500, "quantization": true, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ], "augmentations": [ { "name": "baseline_wander", diff --git a/configs/segmentation-class-4/train.json b/configs/segmentation-class-4/train.json new file mode 100644 index 00000000..822779b8 --- /dev/null +++ b/configs/segmentation-class-4/train.json @@ -0,0 +1,90 @@ +{ + "job_dir": "./results/segmentation-class-4", + "ds_path": "./datasets", + "sampling_rate": 200, + "frame_size": 512, + "num_classes": 4, + "samples_per_patient": 100, + "val_samples_per_patient": 100, + "val_patients": 0.10, + "batch_size": 128, + "buffer_size": 10000, + "epochs": 100, + "steps_per_epoch": 100, + "val_metric": "loss", + "datasets": ["ludb", "synthetic"], + "lr_rate": 5e-3, + "num_pts": 500, + "quantization": false, + "model": "unet", + "model_params": { + "blocks": [ + {"filters": 8, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 3], "strides": [1, 2], "skip": true}, + {"filters": 16, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 3], "strides": [1, 2], "skip": true}, + {"filters": 24, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 3], "strides": [1, 2], "skip": true}, + {"filters": 32, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 3], "strides": [1, 2], "skip": true}, + {"filters": 40, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 3], "strides": [1, 2], "skip": true} + ], + "output_kernel_size": [1, 5], + "include_top": true, + "use_logits": true + }, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ], + "augmentations": [ + { + "name": "baseline_wander", + "args": { + "amplitude": [0.5, 1.5], + "frequency": [0.4, 0.5] + } + }, + { + "name": "motion_noise", + "args": { + "amplitude": [0.5, 1.5], + "frequency": [0.4, 0.7] + } + }, + { + "name": "burst_noise", + "args": { + "burst_number": [2, 10], + "amplitude": [0.5, 1.5], + "frequency": [40, 100] + } + }, + { + "name": "powerline_noise", + "args": { + "amplitude": [0.005, 0.01], + "frequency": [50, 60] + } + }, + { + "name": "noise_sources", + "args": { + "num_sources": [1, 8], + "amplitude": [0.05, 0.25], + "frequency": [10, 40] + } + } + ] +} diff --git a/configs/segmentation-tcn-class-2/demo.json b/configs/segmentation-tcn-class-2/demo.json new file mode 100644 index 00000000..7042c1f0 --- /dev/null +++ b/configs/segmentation-tcn-class-2/demo.json @@ -0,0 +1,29 @@ +{ + "job_dir": "./results/segmentation-tcn-class-2", + "ds_path": "./datasets", + "num_classes": 2, + "sampling_rate": 100, + "frame_size": 250, + "model_file": "./results/segmentation-tcn-class-2/model.tflite", + "backend": "pc", + "datasets": ["icentia11k"], + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + ] +} diff --git a/configs/segmentation-tcn-class-2/evaluate.json b/configs/segmentation-tcn-class-2/evaluate.json new file mode 100644 index 00000000..76ba8620 --- /dev/null +++ b/configs/segmentation-tcn-class-2/evaluate.json @@ -0,0 +1,34 @@ +{ + "job_dir": "./results/segmentation-tcn-class-2", + "ds_path": "./datasets", + "sampling_rate": 100, + "frame_size": 250, + "num_classes": 2, + "samples_per_patient": 100, + "test_size": 10000, + "num_pts": 500, + "use_logits": false, + "datasets": ["synthetic", "ludb"], + "model_file": "./results/segmentation-tcn-class-2/model.tf", + "threshold": 0.50, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + + ] +} diff --git a/configs/segmentation-tcn-class-2/export.json b/configs/segmentation-tcn-class-2/export.json new file mode 100644 index 00000000..8bf32946 --- /dev/null +++ b/configs/segmentation-tcn-class-2/export.json @@ -0,0 +1,38 @@ +{ + "job_dir": "./results/segmentation-tcn-class-2", + "ds_path": "./datasets", + "sampling_rate": 100, + "frame_size": 250, + "num_classes": 2, + "samples_per_patient": 100, + "num_pts": 500, + "test_size": 1000, + "datasets": ["icentia11k"], + "model_file": "./results/segmentation-tcn-class-2/model.tf", + "quantization": true, + "use_logits": false, + "threshold": 0.50, + "val_acc_threshold": 0.98, + "tflm_var_name": "segmentation_flatbuffer", + "tflm_file": "./results/segmentation-tcn-class-2/segmentation_flatbuffer.h", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + + ] +} diff --git a/configs/segmentation-tcn-class-2/train.json b/configs/segmentation-tcn-class-2/train.json new file mode 100644 index 00000000..e1d4c304 --- /dev/null +++ b/configs/segmentation-tcn-class-2/train.json @@ -0,0 +1,95 @@ +{ + "job_dir": "./results/segmentation-tcn-class-2", + "ds_path": "./datasets", + "sampling_rate": 100, + "frame_size": 250, + "num_classes": 2, + "samples_per_patient": 10, + "val_samples_per_patient": 10, + "val_patients": 0.10, + "batch_size": 128, + "buffer_size": 40000, + "epochs": 100, + "steps_per_epoch": 50, + "val_metric": "loss", + "datasets": ["icentia11k", "synthetic", "ludb"], + "lr_rate": 1e-3, + "lr_cycles": 1, + "num_pts": 500, + "quantization": false, + "model": "tcn", + "model_params": { + "input_kernel": [1, 3], + "input_norm": "batch", + "blocks": [ + {"depth": 1, "branch": 1, "filters": 12, "kernel": [1, 3], "dilation": [1, 1], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 0, "norm": "batch"}, + {"depth": 1, "branch": 1, "filters": 20, "kernel": [1, 3], "dilation": [1, 1], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 2, "norm": "batch"}, + {"depth": 1, "branch": 1, "filters": 28, "kernel": [1, 3], "dilation": [1, 2], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 2, "norm": "batch"}, + {"depth": 1, "branch": 1, "filters": 36, "kernel": [1, 3], "dilation": [1, 4], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 2, "norm": "batch"}, + {"depth": 1, "branch": 1, "filters": 40, "kernel": [1, 3], "dilation": [1, 8], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 2, "norm": "batch"} + ], + "output_kernel": [1, 3], + "include_top": true, + "use_logits": true, + "model_name": "tcn" + }, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 3, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + + ], + "augmentations": [ + { + "name": "baseline_wander", + "args": { + "amplitude": [0.5, 1.5], + "frequency": [0.4, 0.5] + } + }, + { + "name": "motion_noise", + "args": { + "amplitude": [0.5, 1.5], + "frequency": [0.4, 0.7] + } + }, + { + "name": "burst_noise", + "args": { + "burst_number": [2, 10], + "amplitude": [0.5, 1.5], + "frequency": [40, 100] + } + }, + { + "name": "powerline_noise", + "args": { + "amplitude": [0.005, 0.01], + "frequency": [50, 60] + } + }, + { + "name": "noise_sources", + "args": { + "num_sources": [1, 8], + "amplitude": [0.05, 0.25], + "frequency": [10, 40] + } + } + ] +} diff --git a/configs/segmentation-tcn-class-4/evaluate.json b/configs/segmentation-tcn-class-4/evaluate.json new file mode 100644 index 00000000..5e10ed3e --- /dev/null +++ b/configs/segmentation-tcn-class-4/evaluate.json @@ -0,0 +1,34 @@ +{ + "job_dir": "./results/segmentation-tcn-class-4", + "ds_path": "./datasets", + "sampling_rate": 100, + "frame_size": 250, + "num_classes": 4, + "samples_per_patient": 100, + "test_size": 10000, + "num_pts": 500, + "use_logits": false, + "datasets": ["synthetic", "ludb"], + "model_file": "./results/segmentation-tcn-class-4/model.tf", + "threshold": 0.50, + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 5, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + + ] +} diff --git a/configs/segmentation-tcn-class-4/export.json b/configs/segmentation-tcn-class-4/export.json new file mode 100644 index 00000000..502d1f51 --- /dev/null +++ b/configs/segmentation-tcn-class-4/export.json @@ -0,0 +1,38 @@ +{ + "job_dir": "./results/segmentation-tcn-class-4", + "ds_path": "./datasets", + "sampling_rate": 100, + "frame_size": 250, + "num_classes": 4, + "samples_per_patient": 100, + "num_pts": 500, + "test_size": 1000, + "datasets": ["synthetic", "ludb"], + "model_file": "./results/segmentation-tcn-class-4/model.tf", + "quantization": true, + "use_logits": false, + "threshold": 0.50, + "val_acc_threshold": 0.98, + "tflm_var_name": "segmentation_flatbuffer", + "tflm_file": "./results/segmentation-tcn-class-4/segmentation_flatbuffer.h", + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 5, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + + ] +} diff --git a/configs/pretrain-segmentation-model-tcn.json b/configs/segmentation-tcn-class-4/train.json similarity index 64% rename from configs/pretrain-segmentation-model-tcn.json rename to configs/segmentation-tcn-class-4/train.json index 23930b5b..11ffd79e 100644 --- a/configs/pretrain-segmentation-model-tcn.json +++ b/configs/segmentation-tcn-class-4/train.json @@ -1,13 +1,13 @@ { - "job_dir": "./results/segmentation-pre-tcn-15", + "job_dir": "./results/segmentation-tcn-class-4", "ds_path": "./datasets", - "sampling_rate": 200, - "frame_size": 512, + "sampling_rate": 100, + "frame_size": 250, "num_classes": 4, "samples_per_patient": 50, "val_samples_per_patient": 50, "val_patients": 0.10, - "batch_size": 256, + "batch_size": 128, "buffer_size": 10000, "epochs": 100, "steps_per_epoch": 50, @@ -15,25 +15,45 @@ "datasets": ["ludb", "synthetic"], "lr_rate": 1e-3, "lr_cycles": 1, - "num_pts": 400, + "num_pts": 500, "quantization": false, "model": "tcn", "model_params": { "input_kernel": [1, 3], "input_norm": "batch", "blocks": [ - {"depth": 1, "branch": 1, "filters": 8, "kernel": [1, 3], "dilation": [1, 1], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 0, "norm": "batch"}, - {"depth": 2, "branch": 1, "filters": 16, "kernel": [1, 3], "dilation": [1, 1], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 4, "norm": "batch"}, - {"depth": 2, "branch": 1, "filters": 24, "kernel": [1, 3], "dilation": [1, 2], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 4, "norm": "batch"}, - {"depth": 2, "branch": 1, "filters": 32, "kernel": [1, 3], "dilation": [1, 4], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 4, "norm": "batch"}, - {"depth": 1, "branch": 1, "filters": 40, "kernel": [1, 3], "dilation": [1, 8], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 4, "norm": "batch"} + {"depth": 1, "branch": 1, "filters": 12, "kernel": [1, 3], "dilation": [1, 1], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 0, "norm": "batch"}, + {"depth": 1, "branch": 1, "filters": 20, "kernel": [1, 3], "dilation": [1, 1], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 2, "norm": "batch"}, + {"depth": 1, "branch": 1, "filters": 28, "kernel": [1, 3], "dilation": [1, 2], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 2, "norm": "batch"}, + {"depth": 1, "branch": 1, "filters": 36, "kernel": [1, 3], "dilation": [1, 4], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 2, "norm": "batch"}, + {"depth": 1, "branch": 1, "filters": 40, "kernel": [1, 3], "dilation": [1, 8], "dropout": 0.10, "ex_ratio": 1, "se_ratio": 2, "norm": "batch"} ], "output_kernel": [1, 3], "include_top": true, "use_logits": true, "model_name": "tcn" }, - "augmentations-dis": [ + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 5, + "forward_backward": true, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + + ], + "augmentations": [ { "name": "baseline_wander", "args": { diff --git a/configs/pretrain-segmentation-model-unext.json b/configs/segmentation-unext-class-4/train.json similarity index 58% rename from configs/pretrain-segmentation-model-unext.json rename to configs/segmentation-unext-class-4/train.json index da4f04cc..ed631581 100644 --- a/configs/pretrain-segmentation-model-unext.json +++ b/configs/segmentation-unext-class-4/train.json @@ -1,36 +1,56 @@ { - "job_dir": "./results/segmentation-pre-unext-1", + "job_dir": "./results/segmentation-unext-class-4", "ds_path": "./datasets", - "sampling_rate": 200, - "frame_size": 512, + "sampling_rate": 100, + "frame_size": 256, "num_classes": 4, - "samples_per_patient": 50, - "val_samples_per_patient": 50, + "samples_per_patient": 100, + "val_samples_per_patient": 100, "val_patients": 0.10, - "batch_size": 256, + "batch_size": 128, "buffer_size": 10000, "epochs": 100, - "steps_per_epoch": 50, + "steps_per_epoch": 100, "val_metric": "loss", "datasets": ["ludb"], - "lr_rate": 1e-3, + "lr_rate": 5e-3, "lr_cycles": 1, "num_pts": 400, "quantization": false, "model": "unext", "model_params": { "blocks": [ - {"filters": 8, "depth": 2, "ddepth": 1, "kernel": [1, 3], "pool": [1, 2], "strides": [1, 2], "skip": true, "expand_ratio": 1, "se_ratio": 1, "dropout": 0}, - {"filters": 16, "depth": 2, "ddepth": 1, "kernel": [1, 3], "pool": [1, 2], "strides": [1, 2], "skip": true, "expand_ratio": 1, "se_ratio": 4, "dropout": 0}, - {"filters": 24, "depth": 2, "ddepth": 1, "kernel": [1, 3], "pool": [1, 2], "strides": [1, 2], "skip": true, "expand_ratio": 1, "se_ratio": 4, "dropout": 0}, - {"filters": 32, "depth": 2, "ddepth": 1, "kernel": [1, 3], "pool": [1, 2], "strides": [1, 2], "skip": true, "expand_ratio": 1, "se_ratio": 4, "dropout": 0}, - {"filters": 40, "depth": 2, "ddepth": 1, "kernel": [1, 3], "pool": [1, 2], "strides": [1, 2], "skip": true, "expand_ratio": 1, "se_ratio": 4, "dropout": 0} + {"filters": 8, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 2], "strides": [1, 2], "skip": true, "expand_ratio": 1, "se_ratio": 1, "dropout": 0}, + {"filters": 16, "depth": 2, "ddepth": 1, "kernel": [1, 5], "pool": [1, 2], "strides": [1, 2], "skip": true, "expand_ratio": 1, "se_ratio": 4, "dropout": 0}, + {"filters": 24, "depth": 2, "ddepth": 1, "kernel": [1, 5], "pool": [1, 2], "strides": [1, 2], "skip": true, "expand_ratio": 1, "se_ratio": 4, "dropout": 0}, + {"filters": 32, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 2], "strides": [1, 2], "skip": true, "expand_ratio": 1, "se_ratio": 4, "dropout": 0}, + {"filters": 40, "depth": 1, "ddepth": 1, "kernel": [1, 5], "pool": [1, 2], "strides": [1, 2], "skip": true, "expand_ratio": 1, "se_ratio": 4, "dropout": 0} ], "output_kernel_size": [1, 5], "include_top": true, "use_logits": true }, - "augmentations-dis": [ + "preprocesses": [ + { + "name": "filter", + "args": { + "lowcut": 0.5, + "highcut": 30, + "order": 5, + "forward_backward": false, + "axis": 0 + } + }, + { + "name": "znorm", + "args": { + "eps": 0.1, + "axis": null + } + } + + ], + "augmentations": [ { "name": "baseline_wander", "args": { diff --git a/docs/arrhythmia/methods.md b/docs/arrhythmia/methods.md index baf8e960..dd3dee43 100644 --- a/docs/arrhythmia/methods.md +++ b/docs/arrhythmia/methods.md @@ -2,7 +2,7 @@ ## Datasets -For training arrhythmia classification models, we use the [Icentia11k dataset](https://physionet.org/content/icentia11k-continuous-ecg/1.0.0/). This dataset consists of single lead ECG recordings from 11,000 patients and 2 billion labelled beats. +For training arrhythmia classification models, we currently use the [Icentia11k dataset](https://physionet.org/content/icentia11k-continuous-ecg/1.0.0/). This dataset consists of single lead ECG recordings from 11,000 patients and 2 billion labelled beats. --- @@ -12,15 +12,16 @@ The arrhythmia model utilizes a 1-D CNN built using MBConv style blocks that inc --- -## Feature Extraction +## Feature Sets -The arrhythmia classification models is trained directly on single channel ECG data. No feature extraction is performed. However, we do preprocess the data by applying a bandpass filter to remove noise followed by downsampling. +### ECG Signal + +The model is trained directly on single channel ECG data. No feature extraction is performed other than applying a bandpass filter to remove noise followed by downsampling. The signal is then normalized by subtracting the mean and dividing by the standard deviation. We also add a small epsilon value to the standard deviation to avoid division by zero. ---- -## Feature Normalization +### HR/HRV Metrics -The filtered ECG signals are normalized by subtracting the mean and dividing by the standard deviation. We also add a small epsilon value to the standard deviation to avoid division by zero. +From either ECG or PPG signals, we identify the R peaks (or systolic peaks) and compute a variety of heart rate (HR) and heart rate variability (HRV) metrics from the inter-beat intervals (IBI). --- diff --git a/docs/arrhythmia/overview.md b/docs/arrhythmia/overview.md index c5e7eceb..639e3f6c 100644 --- a/docs/arrhythmia/overview.md +++ b/docs/arrhythmia/overview.md @@ -16,17 +16,19 @@ There are a variety of heart arrhythmias that can be detected using ECG signals. Atrial fibrillation (AFIB) is a type of arrhythmia where the atria (upper chambers of the heart) beat irregularly and out of sync with the ventricles (lower chambers of the heart). AFIB is the most common type of arrhythmia and can lead to serious complications such as stroke and heart failure. AFIB is typically characterized by the following: - * Re-entrant circuit is present - * Atrium depolarization is 250-350 bpm - * Characteristic sawtooth pattern in P waves (f waves) - * QRS: Narrow (normal) - * AV conduction ratio (2:1, 3:1) + * Irregularly irregular rhythm + * No P waves + * Variable ventricular rate + * QRS complexes usually < 120ms + * Fibrillatory waves may be present === "AFL" Atrial flutter (AFL) is a type of arrhythmia where the atria (upper chambers of the heart) beat regularly but faster than normal. AFL is less common than AFIB and can lead to serious complications such as stroke and heart failure. AFL is typically characterized by the following: - * AP fire chaotically within pulmonary veins / atrium (quiver) - * Atrium depolarization is 400-600 bpm (ventricular 100-200 bpm) - * AV node is intermittently refractory - * Varying RR intervals + * Narrow complex tachycardia + * Regular atrial activity at ~300 bpm + * Loss of the isoelectric baseline + * “Saw-tooth” pattern of inverted flutter waves in leads II, III, aVF + * Upright flutter waves in V1 that may resemble P waves + * Ventricular rate depends on AV conduction ratio diff --git a/docs/arrhythmia/results.md b/docs/arrhythmia/results.md index a65a6e91..db9eba53 100644 --- a/docs/arrhythmia/results.md +++ b/docs/arrhythmia/results.md @@ -2,7 +2,7 @@ ## Overview -The results of the arrhythmia models when testing on 1,000 patients (not used during training) is summarized below. The baseline model is simply selecting the argmax of model outputs (`normal`, `AFIB/AFL`). The 75% confidence version adds inconclusive label that is assigned when softmax output is less than 75% for any model output. +The results of the pretrained arrhythmia models when testing on 1,000 patients (not used during training) is summarized below. The baseline model is simply selecting the argmax of model outputs (e.g. `AFIB/AFL`). The 75% confidence version adds inconclusive label that is assigned when softmax output is less than 75% for any model output. | Task | Params | FLOPS | Metric | Cycles/Inf | Time/Inf | | -------------- | -------- | ------- | ---------- | ---------- | ---------- | diff --git a/docs/assets/favicon.png b/docs/assets/favicon.png index 8b061b71..70383b99 100644 Binary files a/docs/assets/favicon.png and b/docs/assets/favicon.png differ diff --git a/docs/assets/logo.png b/docs/assets/logo.png index 8b061b71..70383b99 100644 Binary files a/docs/assets/logo.png and b/docs/assets/logo.png differ diff --git a/docs/overview.md b/docs/overview.md index 7f838f44..6b7c65ca 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -1,4 +1,6 @@ -# Overview +# :octicons-heart-fill-24:{ .heart } Overview + + __HeartKit__ can be used as either a CLI-based app or as a python package to perform advanced experimentation. In both forms, HeartKit exposes a number of modes and tasks discussed below: @@ -102,7 +104,7 @@ The `train` command is used to train a HeartKit model for the specified `task` a === "CLI" ```bash - heartkit --task arrhythmia --mode train --config ./configs/train-arrhythmia-model.json + heartkit --task arrhythmia --mode train --config ./configs/arrhythmia-class-2/train.json ``` === "Python" @@ -111,7 +113,7 @@ The `train` command is used to train a HeartKit model for the specified `task` a import heartkit as hk hk.arrhythmia.train(hk.defines.HeartTrainParams( - job_dir="./results/arrhythmia", + job_dir="./results/arrhythmia-class-2", ds_path="./datasets", sampling_rate=200, frame_size=800, @@ -141,7 +143,7 @@ The `evaluate` command will evaluate the performance of the model on the reserve === "CLI" ```bash - heartkit --task arrhythmia --mode evaluate --config ./configs/evaluate-arrhythmia-model.json + heartkit --task arrhythmia --mode evaluate --config ./configs/arrhythmia-class-2/evaluate.json ``` === "Python" @@ -150,7 +152,7 @@ The `evaluate` command will evaluate the performance of the model on the reserve import heartkit as hk hk.arrhythmia.evaluate(hk.defines.HeartTestParams( - job_dir="./results/arrhythmia", + job_dir="./results/arrhythmia-class-2", ds_path="./datasets", sampling_rate=200, frame_size=800, @@ -158,7 +160,7 @@ The `evaluate` command will evaluate the performance of the model on the reserve samples_per_patient=[100, 800], test_patients=1000, test_size=100000, - model_file="./results/arrhythmia/model.tf", + model_file="./results/arrhythmia-class-2/model.tf", threshold=0.75 )) ``` @@ -174,7 +176,7 @@ The `export` command will convert the trained TensorFlow model into both TensorF === "CLI" ```bash - heartkit --task arrhythmia --mode export --config ./configs/export-arrhythmia-model.json + heartkit --task arrhythmia --mode export --config ./configs/arrhythmia-class-2/export.json ``` === "Python" @@ -183,17 +185,17 @@ The `export` command will convert the trained TensorFlow model into both TensorF import heartkit as hk hk.arrhythmia.export(hk.defines.HeartExportParams( - job_dir="./results/arrhythmia", + job_dir="./results/arrhythmia-class-2", ds_path="./datasets", sampling_rate=200, frame_size=800, num_classes=2, samples_per_patient=[100, 500, 100], - model_file="./results/arrhythmia/model.tf", + model_file="./results/arrhythmia-class-2/model.tf", quantization=true, threshold=0.95, tflm_var_name="g_arrhythmia_model", - tflm_file="./evb/src/arrhythmia_model_buffer.h" + tflm_file="./results/arrhythmia-class-2/arrhythmia_model_buffer.h" )) ``` @@ -208,7 +210,7 @@ The `demo` command is used to run a task-level demonstration using either the PC === "CLI" ```bash - heartkit --task arrhythmia --mode demo --config ./configs/demo-arrhythmia.json + heartkit --task arrhythmia --mode demo --config ./configs/arrhythmia-class-2/demo.json ``` === "Python" @@ -217,12 +219,12 @@ The `demo` command is used to run a task-level demonstration using either the PC import heartkit as hk hk.arrhythmia.demo(hk.defines.HKDemoParams( - job_dir="./results/arrhythmia", + job_dir="./results/arrhythmia-class-2", ds_path="./datasets", sampling_rate=200, frame_size=800, num_classes=2, - model_file="./results/arrhythmia/model.tflite", + model_file="./results/arrhythmia-class-2/model.tflite", backend="pc" )) ``` diff --git a/docs/tutorials/heartkit-demo.md b/docs/tutorials/heartkit-demo.md index 0bbfe6ad..9d3b51dc 100644 --- a/docs/tutorials/heartkit-demo.md +++ b/docs/tutorials/heartkit-demo.md @@ -1,7 +1,10 @@ # :octicons-heart-fill-24:{ .heart } HeartKit Tutorial +## Overview -![HeartKit Architecture](../assets/heartkit-architecture.svg) +HeartKit demo highlights a number of key features of the HeartKit library including. By leveraging a modern multi-head network architecture coupled with Ambiq's ultra low-power SoC, the demo is designed to be **efficient**, **explainable**, and **extensible**. + +The architecture consists of an **ECG segmentation** model followed by three upstream heads: **HRV head**, **arrhythmia head**, and **beat head**. The ECG segmentation model serves as the backbone and is used to annotate every sample as either P-wave, QRS, T-wave, or none. The arrhythmia head is used to detect the presence of Atrial Fibrillation (AFIB) or Atrial Flutter (AFL). The HRV head is used to calculate heart rate, rhythm (e.g., bradycardia), and heart rate variability from the R peaks. Lastly, the beat head is used to identify individual irregular beats (PAC, PVC). This tutorial shows running the full HeartKit demonstrator on the Apollo 4 EVB. The basic flow chart is depicted below. @@ -21,6 +24,37 @@ flowchart LR In the first stage, 10 seconds of sensor data is collected- either directly from the MAX86150 sensor or test data from the PC. In stage 2, the data is preprocessed by bandpass filtering and standardizing. The data is then fed into the HeartKit models to perform inference. Finally, in stage 4, the ECG data and classification results will be displayed in the front-end UI. +--- + +## Architecture + +HeartKit demo leverages a multi-head network- a backbone segmentation model followed by 3 upstream heads: + +* __Segmentation backbone__ utilizes a custom 1-D UNET architecture to perform ECG segmentation. +* __HRV head__ utilizes segmentation results to derive a number of useful metrics including heart rate, rhythm and RR interval. +* __Arrhythmia head__ utilizes a 1-D MBConv CNN to detect arrhythmias include AFIB and AFL. +* __Beat-level head__ utilizes a 1-D MBConv CNN to detect irregular individual beats (PAC, PVC). + +![HeartKit Architecture](../assets/heartkit-architecture.svg) + +### ECG Segmentation + +The ECG segmentation model serves as the backbone and is used to annotate every sample as either P-wave, QRS, T-wave, or none. The resulting ECG data and segmentation mask is then fed into upstream “heads”. This model utilizes a custom 1-D UNET architecture w/ additional skip connections between encoder and decoder blocks. The encoder blocks are convolutional based and include both expansion and inverted residuals layers. The only preprocessing performed is band-pass filtering and standardization on the window of ECG data. + +### HRV Head + +The HRV head uses only DSP and statistics (i.e. no neural network is used). Using a combination of segmentation results and QRS filter, the HRV head detects R peak candidates. RR intervals are extracted and filtered, and then used to derive a variety of HRV metrics including heart rate, rhythm, SDNN, SDRR, SDANN, etc. All of the identified R peaks are further fed to the beat classifier head. Note that if segmentation model is not enabled, HRV head falls back to identifying R peaks purely on gradient of QRS signal. + +### Arrhythmia Head + +The arrhythmia head is used to detect the presence of Atrial Fibrillation (AFIB) or Atrial Flutter (AFL). Note that if heart arrhythmia is detected, the remaining heads are skipped. The arrhythmia model utilizes a 1-D CNN built using MBConv style blocks that incorporate expansion, inverted residuals, and squeeze and excitation layers. Furthermore, longer filter and stide lengths are utilized in the initial layers to capture more temporal dependencies. + +### Beat Head + +The beat head is used to extract individual beats and classify them as either normal, premature/ectopic atrial contraction (PAC), premature/ectopic ventricular contraction (PVC), or noise. In addition to the target beat, the surrounding beats are also fed into the network as context. The “neighboring” beats are determined based on the average RR interval and not the actual R peak. The beat head also utilizes a 1-D CNN built using MBConv style blocks. + +--- + ## Demo Setup Please follow [EVB Setup Guide](./evb-setup.md) to prepare EVB and connect to PC. To use the pre-trained models, please skip to [Run Demo Section](#run-demo). @@ -64,6 +98,8 @@ heartkit \ --config ./configs/train-beat-model.json ``` +--- + ### 2. Evaluate all the models 2.1 Evaluate the segmentation model performance: @@ -125,6 +161,8 @@ heartkit \ !!! note Review `./evb/src/constants.h` and ensure settings match configuration file. +--- + ## Run Demo Please open three terminals to ease running the demo. We shall refer to these as __EVB Terminal__, __REST Terminal__ and __PC Terminal__. @@ -167,3 +205,5 @@ Now that the EVB client, PC client, and PC REST server are running, press either To shutdown the PC client, a keyboard interrupt can be used (e.g `[CTRL]+C`) in __PC Terminal__. Likewise, a keyboard interrupt can be used (e.g `[CTRL]+C`) to stop the PC REST server in __REST Terminal__. + +--- diff --git a/evb/src/main.cc b/evb/src/main.cc index 9eba46ae..fa68edd5 100644 --- a/evb/src/main.cc +++ b/evb/src/main.cc @@ -29,10 +29,6 @@ #include "main.h" #include "model.h" -// RPC -static uint8_t rpcRxBuffer[USB_RX_BUFSIZE]; -static uint8_t rpcTxBuffer[USB_TX_BUFSIZE]; - // TFLM alignas(16) unsigned char modelBuffer[1024 * MAX_MODEL_SIZE]; static TfLiteTensor *inputs; @@ -46,6 +42,9 @@ static uint32_t outputIdx = 0; static AppState state = IDLE_STATE; static uint32_t app_err = 0; +// RPC +static uint8_t rpcRxBuffer[USB_RX_BUFSIZE]; +static uint8_t rpcTxBuffer[USB_TX_BUFSIZE]; ns_rpc_config_t rpcConfig = {.api = &ns_rpc_gdo_V1_0_0, .mode = NS_RPC_GENERICDATA_SERVER, .rx_buf = rpcRxBuffer, @@ -103,7 +102,7 @@ ns_rpc_data_to_evb_cb(const dataBlock *block) { if (modelIdx >= block->length) { ns_printf("Received model (%d)\n", block->length); reset_state(); - setup_model(modelBuffer, inputs, outputs); + model_setup(modelBuffer, inputs, outputs); modelInitialized = true; } } @@ -203,7 +202,7 @@ setup() { NS_TRY(ns_rpc_genericDataOperations_init(&rpcConfig), "RPC Init Failed\n"); // Initialize model - NS_TRY(init_model(), "Model init failed\n"); + NS_TRY(model_init(), "Model init failed\n"); ns_delay_us(5000); ns_lp_printf("Inference engine running...\n"); @@ -223,7 +222,7 @@ loop() { case INFERENCE_STATE: ns_printf("INFERENCE_STATE\n"); gpio_write(GPIO_TRIGGER, 1); - app_err = run_model(); + app_err = model_run(); gpio_write(GPIO_TRIGGER, 0); state = IDLE_STATE; break; diff --git a/evb/src/main.h b/evb/src/main.h index 8c675e6b..3b59a3d2 100644 --- a/evb/src/main.h +++ b/evb/src/main.h @@ -14,8 +14,17 @@ enum AppState { IDLE_STATE, INFERENCE_STATE, FAIL_STATE }; typedef enum AppState AppState; +/** + * @brief Application setup + * + */ void setup(void); + +/** + * @brief Application loop + * + */ void loop(void); diff --git a/evb/src/model.cc b/evb/src/model.cc index 42fe8a42..a9eb978b 100644 --- a/evb/src/model.cc +++ b/evb/src/model.cc @@ -24,7 +24,7 @@ static tflite::MicroInterpreter *interpreter = nullptr; static tflite::MicroProfiler *profiler = nullptr; uint32_t -init_model() { +model_init() { tflite::MicroErrorReporter micro_error_reporter; errorReporter = µ_error_reporter; @@ -151,7 +151,7 @@ init_model() { } uint32_t -setup_model(const void *modelBuffer, TfLiteTensor *inputs, TfLiteTensor *outputs) { +model_setup(const void *modelBuffer, TfLiteTensor *inputs, TfLiteTensor *outputs) { size_t bytesUsed; TfLiteStatus allocateStatus; model = tflite::GetModel(modelBuffer); @@ -183,7 +183,7 @@ setup_model(const void *modelBuffer, TfLiteTensor *inputs, TfLiteTensor *outputs } uint32_t -run_model() { +model_run() { TfLiteStatus invokeStatus; invokeStatus = interpreter->Invoke(); if (invokeStatus != kTfLiteOk) { diff --git a/evb/src/model.h b/evb/src/model.h index 62deeb60..bcacea6e 100644 --- a/evb/src/model.h +++ b/evb/src/model.h @@ -4,11 +4,31 @@ #include "tensorflow/lite/micro/micro_common.h" #include +/** + * @brief Initialize the model + * + * @return uint32_t + */ uint32_t -init_model(); +model_init(); + +/** + * @brief Setup the model + * + * @param modelBuffer Pointer to the model buffer + * @param inputs Pointer to the input tensor + * @param outputs Pointer to the output tensor + * @return uint32_t + */ uint32_t -setup_model(const void *modelBuffer, TfLiteTensor *inputs, TfLiteTensor *outputs); +model_setup(const void *modelBuffer, TfLiteTensor *inputs, TfLiteTensor *outputs); + +/** + * @brief Run the model + * + * @return uint32_t + */ uint32_t -run_model(); +model_run(); #endif // __HK_MODEL_H diff --git a/heartkit/__init__.py b/heartkit/__init__.py index 63b1e48a..54336208 100644 --- a/heartkit/__init__.py +++ b/heartkit/__init__.py @@ -1,7 +1,17 @@ import os from importlib.metadata import version -from . import arrhythmia, beat, cli, datasets, metrics, models, segmentation, tflite +from . import ( + arrhythmia, + beat, + cli, + datasets, + defines, + metrics, + models, + segmentation, + tflite, +) from .utils import setup_logger __version__ = version(__name__) diff --git a/heartkit/arrhythmia/demo.py b/heartkit/arrhythmia/demo.py index 8fc07c3e..c406a93a 100644 --- a/heartkit/arrhythmia/demo.py +++ b/heartkit/arrhythmia/demo.py @@ -52,7 +52,7 @@ def demo(params: HeartDemoParams): start, stop = x.shape[0] - params.frame_size, x.shape[0] else: start, stop = i, i + params.frame_size - xx = prepare(x[start:stop, :], sample_rate=params.sampling_rate) + xx = prepare(x[start:stop, :], sample_rate=params.sampling_rate, preprocesses=params.preprocesses) runner.set_inputs(xx) runner.perform_inference() yy = runner.get_outputs() diff --git a/heartkit/arrhythmia/utils.py b/heartkit/arrhythmia/utils.py index b84280b5..55b9fcae 100644 --- a/heartkit/arrhythmia/utils.py +++ b/heartkit/arrhythmia/utils.py @@ -2,30 +2,44 @@ from typing import Any, cast import numpy.typing as npt -import physiokit as pk import tensorflow as tf from rich.console import Console -from ..datasets import HeartKitDataset, IcentiaDataset, augment_pipeline -from ..defines import HeartExportParams, HeartTask, HeartTestParams, HeartTrainParams +from ..datasets import ( + HeartKitDataset, + IcentiaDataset, + augment_pipeline, + preprocess_pipeline, +) +from ..defines import ( + HeartExportParams, + HeartTask, + HeartTestParams, + HeartTrainParams, + PreprocessParams, +) from ..models import EfficientNetParams, EfficientNetV2, MBConvParams, generate_model console = Console() -def prepare(x: npt.NDArray, sample_rate: float) -> npt.NDArray: - """Prepare dataset.""" - x = pk.signal.filter_signal( - x, - lowcut=0.5, - highcut=30, - order=3, - sample_rate=sample_rate, - axis=0, - forward_backward=True, - ) - x = pk.signal.normalize_signal(x, eps=0.1, axis=None) - return x +def prepare(x: npt.NDArray, sample_rate: float, preprocesses: list[PreprocessParams]) -> npt.NDArray: + """Prepare dataset. + + Args: + x (npt.NDArray): Input signal + sample_rate (float): Sampling rate + preprocesses (list[PreprocessParams]): Preprocessing pipeline + + Returns: + npt.NDArray: Prepared signal + """ + if not preprocesses: + preprocesses = [ + dict(name="filter", args=dict(axis=0, lowcut=0.5, highcut=30, order=3, sample_rate=sample_rate)), + dict(name="znorm", args=dict(axis=None, eps=0.1)), + ] + return preprocess_pipeline(x, preprocesses=preprocesses, sample_rate=sample_rate) def load_dataset(ds_path: Path, frame_size: int, sampling_rate: int, class_map: dict[int, int]) -> HeartKitDataset: @@ -68,7 +82,7 @@ def preprocess(x: npt.NDArray) -> npt.NDArray: xx = x.copy().squeeze() if params.augmentations: xx = augment_pipeline(xx, augmentations=params.augmentations, sample_rate=params.sampling_rate) - xx = prepare(xx, sample_rate=params.sampling_rate) + xx = prepare(xx, sample_rate=params.sampling_rate, preprocesses=params.preprocesses) return xx # Create TF datasets @@ -124,7 +138,7 @@ def load_test_dataset( def preprocess(x: npt.NDArray) -> npt.NDArray: xx = x.copy().squeeze() - xx = prepare(xx, sample_rate=params.sampling_rate) + xx = prepare(xx, sample_rate=params.sampling_rate, preprocesses=params.preprocesses) return xx with console.status("[bold green] Loading test dataset..."): diff --git a/heartkit/beat/demo.py b/heartkit/beat/demo.py index fd1e6c4a..ccf48d14 100644 --- a/heartkit/beat/demo.py +++ b/heartkit/beat/demo.py @@ -70,7 +70,7 @@ def demo(params: HeartDemoParams): x[start + avg_rr : stop + avg_rr], ) ).T - xx = prepare(xx, sample_rate=params.sampling_rate) + xx = prepare(xx, sample_rate=params.sampling_rate, preprocesses=params.preprocesses) runner.set_inputs(xx) runner.perform_inference() yy = runner.get_outputs() diff --git a/heartkit/beat/utils.py b/heartkit/beat/utils.py index f1ce1aae..de526f90 100644 --- a/heartkit/beat/utils.py +++ b/heartkit/beat/utils.py @@ -2,30 +2,44 @@ from typing import Any import numpy.typing as npt -import physiokit as pk import tensorflow as tf from rich.console import Console -from ..datasets import HeartKitDataset, IcentiaDataset, augment_pipeline -from ..defines import HeartExportParams, HeartTask, HeartTestParams, HeartTrainParams +from ..datasets import ( + HeartKitDataset, + IcentiaDataset, + augment_pipeline, + preprocess_pipeline, +) +from ..defines import ( + HeartExportParams, + HeartTask, + HeartTestParams, + HeartTrainParams, + PreprocessParams, +) from ..models import EfficientNetParams, EfficientNetV2, MBConvParams, generate_model console = Console() -def prepare(x: npt.NDArray, sample_rate: float) -> npt.NDArray: - """Prepare dataset.""" - x = pk.signal.filter_signal( - x, - lowcut=0.5, - highcut=30, - order=3, - sample_rate=sample_rate, - axis=0, - forward_backward=True, - ) - x = pk.signal.normalize_signal(x, eps=0.1, axis=None) - return x +def prepare(x: npt.NDArray, sample_rate: float, preprocesses: list[PreprocessParams]) -> npt.NDArray: + """Prepare dataset. + + Args: + x (npt.NDArray): Input signal + sample_rate (float): Sampling rate + preprocesses (list[PreprocessParams]): Preprocessing pipeline + + Returns: + npt.NDArray: Prepared signal + """ + if not preprocesses: + preprocesses = [ + dict(name="filter", args=dict(axis=0, lowcut=0.5, highcut=30, order=3, sample_rate=sample_rate)), + dict(name="znorm", args=dict(axis=None, eps=0.1)), + ] + return preprocess_pipeline(x, preprocesses=preprocesses, sample_rate=sample_rate) def load_dataset(ds_path: Path, frame_size: int, sampling_rate: int, class_map: dict[int, int]) -> HeartKitDataset: @@ -75,7 +89,7 @@ def preprocess(x: npt.NDArray) -> npt.NDArray: sample_rate=params.sampling_rate, ) # END FOR - xx = prepare(xx, sample_rate=params.sampling_rate) + xx = prepare(xx, sample_rate=params.sampling_rate, preprocesses=params.preprocesses) return xx # Create TF datasets @@ -127,7 +141,7 @@ def load_test_dataset( def preprocess(x: npt.NDArray) -> npt.NDArray: xx = x.copy() - xx = prepare(xx, sample_rate=params.sampling_rate) + xx = prepare(xx, sample_rate=params.sampling_rate, preprocesses=params.preprocesses) return xx with console.status("[bold green] Loading test dataset..."): diff --git a/heartkit/datasets/__init__.py b/heartkit/datasets/__init__.py index fec5dbfb..3026361d 100644 --- a/heartkit/datasets/__init__.py +++ b/heartkit/datasets/__init__.py @@ -1,4 +1,4 @@ -from .augmentation import AugmentationParams, augment_pipeline +from .augmentation import augment_pipeline, preprocess_pipeline from .dataset import HeartKitDataset from .download import download_datasets from .icentia11k import IcentiaDataset diff --git a/heartkit/datasets/augmentation.py b/heartkit/datasets/augmentation.py index 4c3e28d5..41ead7b2 100644 --- a/heartkit/datasets/augmentation.py +++ b/heartkit/datasets/augmentation.py @@ -2,7 +2,31 @@ import numpy.typing as npt import physiokit as pk -from ..defines import AugmentationParams +from ..defines import AugmentationParams, PreprocessParams + + +def preprocess_pipeline(x: npt.NDArray, preprocesses: list[PreprocessParams], sample_rate: float) -> npt.NDArray: + """Apply preprocessing pipeline + + Args: + x (npt.NDArray): Signal + preprocesses (list[PreprocessParams]): Preprocessing pipeline + sample_rate (float): Sampling rate in Hz. + + Returns: + npt.NDArray: Preprocessed signal + """ + for preprocess in preprocesses: + match preprocess.name: + case "filter": + x = pk.signal.filter_signal(x, sample_rate=sample_rate, **preprocess.args) + case "znorm": + x = pk.signal.normalize_signal(x, **preprocess.args) + case _: + raise ValueError(f"Unknown preprocess '{preprocess.name}'") + # END MATCH + # END FOR + return x def augment_pipeline(x: npt.NDArray, augmentations: list[AugmentationParams], sample_rate: float) -> npt.NDArray: @@ -18,55 +42,58 @@ def augment_pipeline(x: npt.NDArray, augmentations: list[AugmentationParams], sa """ for augmentation in augmentations: args = augmentation.args - if augmentation.name == "baseline_wander": - amplitude = args.get("amplitude", [0.05, 0.06]) - frequency = args.get("frequency", [0, 1]) - x = pk.signal.add_baseline_wander( - x, - amplitude=np.random.uniform(amplitude[0], amplitude[1]), - frequency=np.random.uniform(frequency[0], frequency[1]), - sample_rate=sample_rate, - ) - elif augmentation.name == "motion_noise": - amplitude = args.get("amplitude", [0.5, 1.0]) - frequency = args.get("frequency", [0.4, 0.6]) - x = pk.signal.add_motion_noise( - x, - amplitude=np.random.uniform(amplitude[0], amplitude[1]), - frequency=np.random.uniform(frequency[0], frequency[1]), - sample_rate=sample_rate, - ) - elif augmentation.name == "burst_noise": - amplitude = args.get("amplitude", [0.05, 0.5]) - frequency = args.get("frequency", [sample_rate / 4, sample_rate / 2]) - burst_number = args.get("burst_number", [0, 2]) - x = pk.signal.add_burst_noise( - x, - amplitude=np.random.uniform(amplitude[0], amplitude[1]), - frequency=np.random.uniform(frequency[0], frequency[1]), - burst_number=np.random.randint(burst_number[0], burst_number[1]), - sample_rate=sample_rate, - ) - elif augmentation.name == "powerline_noise": - amplitude = args.get("amplitude", [0.005, 0.01]) - frequency = args.get("frequency", [50, 60]) - x = pk.signal.add_powerline_noise( - x, - amplitude=np.random.uniform(amplitude[0], amplitude[1]), - frequency=np.random.uniform(frequency[0], frequency[1]), - sample_rate=sample_rate, - ) - elif augmentation.name == "noise_sources": - num_sources = args.get("num_sources", [1, 2]) - amplitude = args.get("amplitude", [0, 0.1]) - frequency = args.get("frequency", [0, sample_rate / 2]) - num_sources: int = np.random.randint(num_sources[0], num_sources[1]) - x = pk.signal.add_noise_sources( - x, - amplitudes=[np.random.uniform(amplitude[0], amplitude[1]) for _ in range(num_sources)], - frequencies=[np.random.uniform(frequency[0], frequency[1]) for _ in range(num_sources)], - sample_rate=sample_rate, - ) - # END IF + match augmentation.name: + case "baseline_wander": + amplitude = args.get("amplitude", [0.05, 0.06]) + frequency = args.get("frequency", [0, 1]) + x = pk.signal.add_baseline_wander( + x, + amplitude=np.random.uniform(amplitude[0], amplitude[1]), + frequency=np.random.uniform(frequency[0], frequency[1]), + sample_rate=sample_rate, + ) + case "motion_noise": + amplitude = args.get("amplitude", [0.5, 1.0]) + frequency = args.get("frequency", [0.4, 0.6]) + x = pk.signal.add_motion_noise( + x, + amplitude=np.random.uniform(amplitude[0], amplitude[1]), + frequency=np.random.uniform(frequency[0], frequency[1]), + sample_rate=sample_rate, + ) + case "burst_noise": + amplitude = args.get("amplitude", [0.05, 0.5]) + frequency = args.get("frequency", [sample_rate / 4, sample_rate / 2]) + burst_number = args.get("burst_number", [0, 2]) + x = pk.signal.add_burst_noise( + x, + amplitude=np.random.uniform(amplitude[0], amplitude[1]), + frequency=np.random.uniform(frequency[0], frequency[1]), + burst_number=np.random.randint(burst_number[0], burst_number[1]), + sample_rate=sample_rate, + ) + case "powerline_noise": + amplitude = args.get("amplitude", [0.005, 0.01]) + frequency = args.get("frequency", [50, 60]) + x = pk.signal.add_powerline_noise( + x, + amplitude=np.random.uniform(amplitude[0], amplitude[1]), + frequency=np.random.uniform(frequency[0], frequency[1]), + sample_rate=sample_rate, + ) + case "noise_sources": + num_sources = args.get("num_sources", [1, 2]) + amplitude = args.get("amplitude", [0, 0.1]) + frequency = args.get("frequency", [0, sample_rate / 2]) + num_sources: int = np.random.randint(num_sources[0], num_sources[1]) + x = pk.signal.add_noise_sources( + x, + amplitudes=[np.random.uniform(amplitude[0], amplitude[1]) for _ in range(num_sources)], + frequencies=[np.random.uniform(frequency[0], frequency[1]) for _ in range(num_sources)], + sample_rate=sample_rate, + ) + case _: + raise ValueError(f"Unknown augmentation '{augmentation.name}'") + # END MATCH # END FOR return x diff --git a/heartkit/datasets/icentia11k.py b/heartkit/datasets/icentia11k.py index 79b44e2d..650f4b15 100644 --- a/heartkit/datasets/icentia11k.py +++ b/heartkit/datasets/icentia11k.py @@ -15,13 +15,14 @@ import numpy.typing as npt import pandas as pd import physiokit as pk +import scipy.ndimage import sklearn.model_selection import sklearn.preprocessing from botocore import UNSIGNED from botocore.client import Config from tqdm import tqdm -from ..defines import HeartBeat, HeartRate, HeartRhythm, HeartTask +from ..defines import HeartBeat, HeartRate, HeartRhythm, HeartSegment, HeartTask from ..utils import download_file from .dataset import HeartKitDataset from .defines import PatientGenerator, SampleGenerator @@ -170,6 +171,11 @@ def task_data_generator( patient_generator=patient_generator, samples_per_patient=samples_per_patient, ) + if self.task == HeartTask.segmentation: + return self.segmentation_data_generator( + patient_generator=patient_generator, + samples_per_patient=samples_per_patient, + ) raise NotImplementedError() def _split_train_test_patients(self, patient_ids: npt.NDArray, test_size: float) -> list[list[int]]: @@ -301,6 +307,82 @@ def rhythm_data_generator( # END FOR # END FOR + def segmentation_data_generator( + self, + patient_generator: PatientGenerator, + samples_per_patient: int | list[int] = 1, + ) -> SampleGenerator: + """Gnerate frames with annotated segments. + + Args: + patient_generator (PatientGenerator): Patient generator + samples_per_patient (int | list[int], optional): + + Returns: + SampleGenerator: Sample generator + """ + assert not isinstance(samples_per_patient, Iterable) + input_size = int(np.round((self.sampling_rate / self.target_rate) * self.frame_size)) + + # For each patient + for _, segments in patient_generator: + for _ in range(samples_per_patient): + # Randomly pick a segment + seg_key = np.random.choice(list(segments.keys())) + # Randomly pick a frame + frame_start = np.random.randint(segments[seg_key]["data"].shape[0] - input_size) + frame_end = frame_start + input_size + # Get data and labels + data = segments[seg_key]["data"][frame_start:frame_end].squeeze() + + if self.sampling_rate != self.target_rate: + ds_ratio = self.target_rate / self.sampling_rate + data = pk.signal.resample_signal(data, self.sampling_rate, self.target_rate, axis=0) + else: + ds_ratio = 1 + + blabels = segments[seg_key]["blabels"] + blabels = blabels[(blabels[:, 0] >= frame_start) & (blabels[:, 0] < frame_end)] + # Create segment mask + mask = np.zeros_like(data, dtype=np.int32) + for i in range(blabels.shape[0]): + bidx = int((blabels[i, 0] - frame_start) * ds_ratio) + btype = blabels[i, 1] + if btype == IcentiaBeat.undefined: + continue + + # Extract QRS segment + qrs_window = 0.1 + avg_window = 1.0 + qrs_prom_weight = 1.5 + abs_grad = np.abs(np.gradient(data)) + qrs_kernel = int(np.rint(qrs_window * self.target_rate)) + avg_kernel = int(np.rint(avg_window * self.target_rate)) + # Smooth gradients + qrs_grad = scipy.ndimage.uniform_filter1d(abs_grad, qrs_kernel, mode="nearest") + avg_grad = scipy.ndimage.uniform_filter1d(qrs_grad, avg_kernel, mode="nearest") + min_qrs_height = qrs_prom_weight * avg_grad + qrs = qrs_grad - min_qrs_height + win_len = max(1, int(0.08 * self.target_rate)) # 80 ms + b_left = max(0, bidx - win_len) + b_right = min(data.shape[0], bidx + win_len) + onset = np.where(np.flip(qrs[b_left:bidx]) < 0)[0] + onset = onset[0] if onset.size else b_left + offset = np.where(qrs[bidx + 1 : b_right] < 0)[0] + offset = offset[0] if offset.size else b_right + mask[bidx - onset : bidx + offset] = self.class_map.get(HeartSegment.qrs.value, 0) + # Ignore P, T, and U waves for now + # END FOR + x = np.nan_to_num(data).astype(np.float32) + y = mask.astype(np.int32) + # if self.sampling_rate != self.target_rate: + # x = pk.signal.resample_signal(x, self.sampling_rate, self.target_rate, axis=0) + # y = pk.signal.filter.resample_signal(y, self.sampling_rate, self.target_rate, axis=0) + # # END IF + yield x, y + # END FOR + # END FOR + def beat_data_generator( self, patient_generator: PatientGenerator, @@ -500,7 +582,7 @@ def signal_generator(self, patient_generator: PatientGenerator, samples_per_pati segment_size = segment["data"].shape[0] frame_start = np.random.randint(segment_size - input_size) frame_end = frame_start + input_size - x = segment["data"][frame_start:frame_end] + x = segment["data"][frame_start:frame_end].squeeze() x = np.nan_to_num(x).astype(np.float32) if self.sampling_rate != self.target_rate: x = pk.signal.resample_signal(x, self.sampling_rate, self.target_rate, axis=0) diff --git a/heartkit/datasets/synthetic.py b/heartkit/datasets/synthetic.py index 427842b3..2b225f06 100644 --- a/heartkit/datasets/synthetic.py +++ b/heartkit/datasets/synthetic.py @@ -120,9 +120,15 @@ def signal_generator(self, patient_generator: PatientGenerator, samples_per_pati EcgPresets.high_take_off, ) preset_weights = (14, 1, 1, 1, 1, 1, 1) + generate_funcs = ( + pk.ecg.generate_nsr, + pk.ecg.generate_afib, + ) + generate_weights = (5, 1) for _ in patient_generator: - _, syn_ecg, _, _, _ = pk.ecg.generate_nsr( + generate_func = random.choices(generate_funcs, generate_weights, k=1)[0] + _, syn_ecg, _, _, _ = generate_func( leads=num_leads, signal_frequency=self.sampling_rate, rate=np.random.uniform(40, 120), diff --git a/heartkit/defines.py b/heartkit/defines.py index a95144ab..56a66742 100644 --- a/heartkit/defines.py +++ b/heartkit/defines.py @@ -7,6 +7,13 @@ from pydantic import BaseModel, Extra, Field +class PreprocessParams(BaseModel, extra=Extra.allow): + """Preprocessing parameters""" + + name: str + args: dict[str, Any] + + class AugmentationParams(BaseModel, extra=Extra.allow): """Augmentation parameters""" @@ -164,6 +171,8 @@ class HeartTrainParams(BaseModel, extra=Extra.allow): epochs: int = Field(50, description="Number of epochs") steps_per_epoch: int | None = Field(None, description="Number of steps per epoch") val_metric: Literal["loss", "acc", "f1"] = Field("loss", description="Performance metric") + # Preprocessing/Augmentation arguments + preprocesses: list[PreprocessParams] = Field(default_factory=list, description="Preprocesses") augmentations: list[AugmentationParams] = Field(default_factory=list, description="Augmentations") # Extra arguments seed: int | None = Field(None, description="Random state seed") @@ -185,6 +194,7 @@ class HeartTestParams(BaseModel, extra=Extra.allow): default_factory=lambda: os.cpu_count() or 1, description="# of data loaders running in parallel", ) + preprocesses: list[PreprocessParams] = Field(default_factory=list, description="Preprocesses") # Model arguments model_file: str | None = Field(None, description="Path to model file") threshold: float | None = Field(None, description="Model output threshold") @@ -201,6 +211,7 @@ class HeartExportParams(BaseModel, extra=Extra.allow): sampling_rate: int = Field(250, description="Target sampling rate (Hz)") frame_size: int = Field(1250, description="Frame size") num_classes: int = Field(3, description="# of classes") + preprocesses: list[PreprocessParams] = Field(default_factory=list, description="Preprocesses") samples_per_patient: int | list[int] = Field(100, description="# test samples per patient") test_patients: float | None = Field(None, description="# or proportion of patients for testing") test_size: int = Field(100_000, description="# samples for testing") @@ -226,6 +237,7 @@ class HeartDemoParams(BaseModel, extra=Extra.allow): sampling_rate: int = Field(250, description="Target sampling rate (Hz)") frame_size: int = Field(1250, description="Frame size") num_classes: int = Field(3, description="# of classes") + preprocesses: list[PreprocessParams] = Field(default_factory=list, description="Preprocesses") # Model arguments model_file: str | None = Field(None, description="Path to model file") backend: Literal["pc", "evb"] = Field("pc", description="Backend") diff --git a/heartkit/models/tcn.py b/heartkit/models/tcn.py index 071a0b72..2ae6cb23 100644 --- a/heartkit/models/tcn.py +++ b/heartkit/models/tcn.py @@ -12,7 +12,7 @@ class TcnBlockParams(BaseModel): """TCN block parameters""" depth: int = Field(default=1, description="Layer depth") - branch: int | None = Field(default=None, description="Number of branches") + branch: int = Field(default=1, description="Number of branches") filters: int = Field(..., description="# filters") kernel: int | tuple[int, int] = Field(default=3, description="Kernel size") dilation: int | tuple[int, int] = Field(default=1, description="Dilation rate") @@ -27,6 +27,7 @@ class TcnParams(BaseModel): input_kernel: int | tuple[int, int] | None = Field(default=None, description="Input kernel size") input_norm: Literal["batch", "layer"] | None = Field(default="layer", description="Input normalization type") + block_type: Literal["lg", "mb", "sm"] = Field(default="mb", description="Block type") blocks: list[TcnBlockParams] = Field(default_factory=list, description="UNext blocks") output_kernel: int | tuple[int, int] = Field(default=3, description="Output kernel size") include_top: bool = Field(default=True, description="Include top") @@ -35,9 +36,25 @@ class TcnParams(BaseModel): def norm_layer(norm: str, name: str) -> KerasLayer: - """Normalization layer""" + """Normalization layer + + Args: + norm (str): Normalization type + name (str): Name + + Returns: + KerasLayer: Layer + """ def layer(x: tf.Tensor) -> tf.Tensor: + """Functional normalization layer + + Args: + x (tf.Tensor): Input tensor + + Returns: + tf.Tensor: Output tensor + """ if norm == "batch": return tf.keras.layers.BatchNormalization(axis=-1, name=f"{name}.BN")(x) if norm == "layer": @@ -47,8 +64,8 @@ def layer(x: tf.Tensor) -> tf.Tensor: return layer -def tcn_block(params: TcnBlockParams, name: str) -> KerasLayer: - """TCN block +def tcn_block_lg(params: TcnBlockParams, name: str) -> KerasLayer: + """TCN large block Args: params (TcnBlockParams): Parameters @@ -112,8 +129,8 @@ def layer(x: tf.Tensor) -> tf.Tensor: return layer -def tcn_block_sm2(params: TcnBlockParams, name: str) -> KerasLayer: - """TCN block +def tcn_block_mb(params: TcnBlockParams, name: str) -> KerasLayer: + """TCN mbconv block Args: params (TcnBlockParams): Parameters @@ -171,7 +188,7 @@ def layer(x: tf.Tensor) -> tf.Tensor: y = tf.keras.layers.Activation("relu6", name=f"{lcl_name}.DW.RELU")(y) # Squeeze and excite - if params.se_ratio > 0: + if params.se_ratio and y.shape[-1] // params.se_ratio > 0: y = se_block(ratio=params.se_ratio, name=f"{lcl_name}.SE")(y) # END IF @@ -183,7 +200,6 @@ def layer(x: tf.Tensor) -> tf.Tensor: kernel_size=(1, 1), strides=(1, 1), padding="same", - # groups=int(params.se_ratio) if params.se_ratio > 0 else 1, use_bias=params.norm is None, kernel_initializer="he_normal", kernel_regularizer=tf.keras.regularizers.L2(1e-3), @@ -215,7 +231,7 @@ def layer(x: tf.Tensor) -> tf.Tensor: def tcn_block_sm(params: TcnBlockParams, name: str) -> KerasLayer: - """TCN block + """TCN small block Args: params (TcnBlockParams): Parameters @@ -284,7 +300,7 @@ def layer(x: tf.Tensor) -> tf.Tensor: # END FOR # Squeeze and excite - if params.se_ratio > 0: + if y.shape[-1] // params.se_ratio > 1: y = se_block(ratio=params.se_ratio, name=f"{name}.SE")(y) # END IF @@ -309,13 +325,20 @@ def tcn_core(params: TcnParams) -> KerasLayer: Returns: KerasLayer: Layer """ + if params.block_type == "lg": + tcn_block = tcn_block_lg + elif params.block_type == "mb": + tcn_block = tcn_block_mb + elif params.block_type == "sm": + tcn_block = tcn_block_sm + else: + raise ValueError(f"Invalid block type: {params.block_type}") def layer(x: tf.Tensor) -> tf.Tensor: y = x for i, block in enumerate(params.blocks): name = f"B{i+1}" - y = tcn_block_sm2(params=block, name=name)(y) - # y = tcn_block_sm(params=block, name=name)(y) + y = tcn_block(params=block, name=name)(y) # END IF return y diff --git a/heartkit/segmentation/defines.py b/heartkit/segmentation/defines.py index 17294b36..d3ba8f77 100644 --- a/heartkit/segmentation/defines.py +++ b/heartkit/segmentation/defines.py @@ -10,7 +10,7 @@ def get_classes(nclasses: int = 4) -> list[str]: Returns: list[str]: List of class names """ - if 2 <= nclasses <= 5: + if 2 <= nclasses <= 4: return list(range(nclasses)) raise ValueError(f"Invalid number of classes: {nclasses}") @@ -49,14 +49,6 @@ def get_class_mapping(nclasses: int = 4) -> dict[int, int]: HeartSegment.twave: 3, HeartSegment.uwave: 0, } - case 5: - return { - HeartSegment.normal: 0, - HeartSegment.pwave: 1, - HeartSegment.qrs: 2, - HeartSegment.twave: 3, - HeartSegment.uwave: 4, - } case _: raise ValueError(f"Invalid number of classes: {nclasses}") # END MATCH @@ -78,7 +70,5 @@ def get_class_names(nclasses: int = 4) -> list[str]: return ["NONE", "QRS", "P/T-WAVE"] case 4: return ["NONE", "P-WAVE", "QRS", "T-WAVE"] - case 5: - return ["NONE", "P-WAVE", "QRS", "T-WAVE", "U-WAVE"] case _: raise ValueError(f"Invalid number of classes: {nclasses}") diff --git a/heartkit/segmentation/demo.py b/heartkit/segmentation/demo.py index 7b0bc7aa..a9941a25 100644 --- a/heartkit/segmentation/demo.py +++ b/heartkit/segmentation/demo.py @@ -58,7 +58,7 @@ def demo(params: HeartDemoParams): start, stop = x.size - params.frame_size, x.size else: start, stop = i, i + params.frame_size - xx = prepare(x[start:stop], sample_rate=params.sampling_rate) + xx = prepare(x[start:stop], sample_rate=params.sampling_rate, preprocesses=params.preprocesses) runner.set_inputs(xx) runner.perform_inference() yy = runner.get_outputs() @@ -77,7 +77,7 @@ def demo(params: HeartDemoParams): [{"colspan": 3, "type": "xy", "secondary_y": True}, None, None], [{"type": "xy"}, {"type": "bar"}, {"type": "table"}], ], - subplot_titles=("ECG Plot", "IBI Poincaré Plot", "HRV Frequency Bands"), + subplot_titles=("ECG Plot", "IBI Poincare Plot", "HRV Frequency Bands"), horizontal_spacing=0.1, vertical_spacing=0.2, ) diff --git a/heartkit/segmentation/train.py b/heartkit/segmentation/train.py index 01b27d71..ada158e9 100644 --- a/heartkit/segmentation/train.py +++ b/heartkit/segmentation/train.py @@ -69,7 +69,7 @@ def train(params: HeartTrainParams): y_true = np.argmax(np.concatenate(test_labels).squeeze(), axis=-1).flatten() class_weights = sklearn.utils.compute_class_weight("balanced", classes=classes, y=y_true) - class_weights = 0.25 + # class_weights = 0.25 with tfa.get_strategy().scope(): logger.info("Building model") diff --git a/heartkit/segmentation/utils.py b/heartkit/segmentation/utils.py index 6b90aaef..23299208 100644 --- a/heartkit/segmentation/utils.py +++ b/heartkit/segmentation/utils.py @@ -3,44 +3,47 @@ import numpy as np import numpy.typing as npt -import physiokit as pk import tensorflow as tf from rich.console import Console from ..datasets import ( HeartKitDataset, + IcentiaDataset, LudbDataset, QtdbDataset, SyntheticDataset, augment_pipeline, + preprocess_pipeline, +) +from ..defines import ( + HeartExportParams, + HeartTask, + HeartTestParams, + HeartTrainParams, + PreprocessParams, ) -from ..defines import HeartExportParams, HeartTask, HeartTestParams, HeartTrainParams from ..models import UNet, UNetBlockParams, UNetParams, generate_model console = Console() -def prepare(x: npt.NDArray, sample_rate: float) -> npt.NDArray: +def prepare(x: npt.NDArray, sample_rate: float, preprocesses: list[PreprocessParams]) -> npt.NDArray: """Prepare dataset Args: x (npt.NDArray): Input signal sample_rate (float): Sampling rate + preprocesses (list[PreprocessParams]): Preprocessing pipeline Returns: npt.NDArray: Prepared signal """ - x = pk.signal.filter_signal( - x, - lowcut=0.5, - highcut=30, - order=3, - sample_rate=sample_rate, - axis=0, - forward_backward=True, - ) - x = pk.signal.normalize_signal(x, eps=0.1, axis=None) - return x + if not preprocesses: + preprocesses = [ + dict(name="filter", args=dict(axis=0, lowcut=0.5, highcut=30, order=3, sample_rate=sample_rate)), + dict(name="znorm", args=dict(axis=None, eps=0.1)), + ] + return preprocess_pipeline(x, preprocesses=preprocesses, sample_rate=sample_rate) def load_datasets( @@ -96,6 +99,16 @@ def load_datasets( class_map=class_map, ) ) + if "icentia11k" in dataset_names: + datasets.append( + IcentiaDataset( + ds_path, + task=HeartTask.segmentation, + frame_size=frame_size, + target_rate=sampling_rate, + class_map=class_map, + ) + ) return datasets @@ -104,8 +117,11 @@ def load_train_datasets( params: HeartTrainParams, ) -> tuple[tf.data.Dataset, tf.data.Dataset]: """Load segmentation train datasets. + Args: + datasets (list[HeartKitDataset]): Datasets params (HeartTrainParams): Train params + Returns: tuple[tf.data.Dataset, tf.data.Dataset]: ds, train and validation datasets """ @@ -114,7 +130,7 @@ def preprocess(x: npt.NDArray) -> npt.NDArray: xx = x.copy().squeeze() if params.augmentations: xx = augment_pipeline(xx, augmentations=params.augmentations, sample_rate=params.sampling_rate) - xx = prepare(xx, sample_rate=params.sampling_rate) + xx = prepare(xx, sample_rate=params.sampling_rate, preprocesses=params.preprocesses) return xx train_datasets = [] @@ -176,7 +192,7 @@ def load_test_datasets( def preprocess(x: npt.NDArray) -> npt.NDArray: xx = x.copy().squeeze() - xx = prepare(xx, sample_rate=params.sampling_rate) + xx = prepare(xx, sample_rate=params.sampling_rate, preprocesses=params.preprocesses) return xx with console.status("[bold green] Loading test dataset..."): diff --git a/mkdocs.yml b/mkdocs.yml index 40379cef..d50cc58f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -10,24 +10,24 @@ nav: - Overview: overview.md - Datasets: datasets.md - Results: results.md + - Tasks: + - Segmentation: + - Overview: segmentation/overview.md + - Methods: segmentation/methods.md + - Results: segmentation/results.md + - Demo: segmentation/demo.md - - Segmentation Task: - - Overview: segmentation/overview.md - - Methods: segmentation/methods.md - - Results: segmentation/results.md - - Demo: segmentation/demo.md + - Arrhythmia: + - Overview: arrhythmia/overview.md + - Methods: arrhythmia/methods.md + - Results: arrhythmia/results.md + - Demo: arrhythmia/demo.md - - Arrhythmia Task: - - Overview: arrhythmia/overview.md - - Methods: arrhythmia/methods.md - - Results: arrhythmia/results.md - - Demo: arrhythmia/demo.md - - - Beat Task: - - Overview: beat/overview.md - - Methods: beat/methods.md - - Results: beat/results.md - - Demo: beat/demo.md + - Beat: + - Overview: beat/overview.md + - Methods: beat/methods.md + - Results: beat/results.md + - Demo: beat/demo.md - Tutorials: - EVB Setup: ./tutorials/evb-setup.md