Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,9 @@ leaderboard/credentials.json
leaderboard/rtd_token.txt

# locally pre-trained models
pyhealth/medcode/pretrained_embeddings/kg_emb/examples/pretrained_model
pyhealth/medcode/pretrained_embeddings/kg_emb/examples/pretrained_model

# local testing files
halo_testing/
halo_testing_script.py
test_halo_model.slurm
1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, *args, **kwargs):
from .eicu import eICUDataset
from .isruc import ISRUCDataset
from .medical_transcriptions import MedicalTranscriptionsDataset
from .halo_mimic3 import HALO_MIMIC3Dataset
from .mimic3 import MIMIC3Dataset
from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset
from .mimicextract import MIMICExtractDataset
Expand Down
149 changes: 149 additions & 0 deletions pyhealth/datasets/configs/hcup_ccs_2015_definitions_benchmark.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"Septicemia (except in labor)":
use_in_benchmark: True
type: "acute"
id: 2
codes: [ "0031", "0202", "0223", "0362", "0380", "0381", "03810", "03811", "03812", "03819", "0382", "0383", "03840", "03841", "03842", "03843", "03844", "03849", "0388", "0389", "0545", "449", "77181", "7907", "99591", "99592" ]

"Diabetes mellitus without complication":
use_in_benchmark: True
type: "chronic"
id: 49
codes: [ "24900", "25000", "25001", "7902", "79021", "79022", "79029", "7915", "7916", "V4585", "V5391", "V6546" ]

"Diabetes mellitus with complications":
use_in_benchmark: True
type: "chronic"
id: 50
codes: [ "24901", "24910", "24911", "24920", "24921", "24930", "24931", "24940", "24941", "24950", "24951", "24960", "24961", "24970", "24971", "24980", "24981", "24990", "24991", "25002", "25003", "25010", "25011", "25012", "25013", "25020", "25021", "25022", "25023", "25030", "25031", "25032", "25033", "25040", "25041", "25042", "25043", "25050", "25051", "25052", "25053", "25060", "25061", "25062", "25063", "25070", "25071", "25072", "25073", "25080", "25081", "25082", "25083", "25090", "25091", "25092", "25093" ]

"Disorders of lipid metabolism":
use_in_benchmark: True
type: "chronic"
id: 53
codes: [ "2720", "2721", "2722", "2723", "2724" ]

"Fluid and electrolyte disorders":
use_in_benchmark: True
type: "acute"
id: 55
codes: [ "2760", "2761", "2762", "2763", "2764", "2765", "27650", "27651", "27652", "2766", "27669", "2767", "2768", "2769", "9951" ]

"Essential hypertension":
use_in_benchmark: True
type: "chronic"
id: 98
codes: [ "4011", "4019" ]

"Hypertension with complications and secondary hypertension":
use_in_benchmark: True
type: "chronic"
id: 99
codes: [ "4010", "40200", "40201", "40210", "40211", "40290", "40291", "4030", "40300", "40301", "4031", "40310", "40311", "4039", "40390", "40391", "4040", "40400", "40401", "40402", "40403", "4041", "40410", "40411", "40412", "40413", "4049", "40490", "40491", "40492", "40493", "40501", "40509", "40511", "40519", "40591", "40599", "4372" ]

"Acute myocardial infarction":
use_in_benchmark: True
type: "acute"
id: 100
codes: [ "4100", "41000", "41001", "41002", "4101", "41010", "41011", "41012", "4102", "41020", "41021", "41022", "4103", "41030", "41031", "41032", "4104", "41040", "41041", "41042", "4105", "41050", "41051", "41052", "4106", "41060", "41061", "41062", "4107", "41070", "41071", "41072", "4108", "41080", "41081", "41082", "4109", "41090", "41091", "41092" ]

"Coronary atherosclerosis and other heart disease":
use_in_benchmark: True
type: "chronic"
id: 101
codes: [ "4110", "4111", "4118", "41181", "41189", "412", "4130", "4131", "4139", "4140", "41400", "41401", "41406", "4142", "4143", "4144", "4148", "4149", "V4581", "V4582" ]

"Conduction disorders":
use_in_benchmark: True
type: "chronic"
id: 105
codes: [ "4260", "42610", "42611", "42612", "42613", "4262", "4263", "4264", "42650", "42651", "42652", "42653", "42654", "4266", "4267", "42681", "42682", "42689", "4269", "V450", "V4500", "V4501", "V4502", "V4509", "V533", "V5331", "V5332", "V5339" ]

"Cardiac dysrhythmias":
use_in_benchmark: True
type: "chronic"
id: 106
codes: [ "4270", "4271", "4272", "42731", "42732", "42760", "42761", "42769", "42781", "42789", "4279", "7850", "7851" ]

"Congestive heart failure; nonhypertensive":
use_in_benchmark: True
type: "acute"
id: 108
codes: [ "39891", "4280", "4281", "42820", "42821", "42822", "42823", "42830", "42831", "42832", "42833", "42840", "42841", "42842", "42843", "4289" ]

"Acute cerebrovascular disease":
use_in_benchmark: True
type: "acute"
id: 109
codes: [ "34660", "34661", "34662", "34663", "430", "431", "4320", "4321", "4329", "43301", "43311", "43321", "43331", "43381", "43391", "4340", "43400", "43401", "4341", "43410", "43411", "4349", "43490", "43491", "436" ]

"Pneumonia (except that caused by tuberculosis or sexually transmitted disease)":
use_in_benchmark: True
type: "acute"
id: 122
codes: [ "00322", "0203", "0204", "0205", "0212", "0221", "0310", "0391", "0521", "0551", "0730", "0830", "1124", "1140", "1144", "1145", "11505", "11515", "11595", "1304", "1363", "4800", "4801", "4802", "4803", "4808", "4809", "481", "4820", "4821", "4822", "4823", "48230", "48231", "48232", "48239", "4824", "48240", "48241", "48242", "48249", "4828", "48281", "48282", "48283", "48284", "48289", "4829", "483", "4830", "4831", "4838", "4841", "4843", "4845", "4846", "4847", "4848", "485", "486", "5130", "5171" ]

"Chronic obstructive pulmonary disease and bronchiectasis":
use_in_benchmark: True
type: "chronic"
id: 127
codes: [ "490", "4910", "4911", "4912", "49120", "49121", "49122", "4918", "4919", "4920", "4928", "494", "4940", "4941", "496" ]

"Pleurisy; pneumothorax; pulmonary collapse":
use_in_benchmark: True
type: "acute"
id: 130
codes: [ "5100", "5109", "5110", "5111", "5118", "51189", "5119", "5120", "5128", "51281", "51282", "51283", "51284", "51289", "5180", "5181", "5182" ]

"Respiratory failure; insufficiency; arrest (adult)":
use_in_benchmark: True
type: "acute"
id: 131
codes: [ "5173", "5185", "51851", "51852", "51853", "51881", "51882", "51883", "51884", "7991", "V461", "V4611", "V4612", "V4613", "V4614", "V462" ]

"Other lower respiratory disease":
use_in_benchmark: True
type: "acute"
id: 133
codes: [ "5131", "514", "515", "5160", "5161", "5162", "5163", "51630", "51631", "51632", "51633", "51634", "51635", "51636", "51637", "5164", "5165", "51661", "51662", "51663", "51664", "51669", "5168", "5169", "5172", "5178", "5183", "5184", "51889", "5194", "5198", "5199", "7825", "78600", "78601", "78602", "78603", "78604", "78605", "78606", "78607", "78609", "7862", "7863", "78630", "78631", "78639", "7864", "78652", "7866", "7867", "7868", "7869", "7931", "79311", "79319", "7942", "V126", "V1260", "V1261", "V1269", "V426" ]

"Other upper respiratory disease":
use_in_benchmark: True
type: "acute"
id: 134
codes: [ "470", "4710", "4711", "4718", "4719", "4720", "4721", "4722", "4760", "4761", "4770", "4772", "4778", "4779", "4780", "4781", "47811", "47819", "47820", "47821", "47822", "47824", "47825", "47826", "47829", "47830", "47831", "47832", "47833", "47834", "4784", "4785", "4786", "47870", "47871", "47874", "47875", "47879", "4788", "4789", "5191", "51911", "51919", "5192", "5193", "7841", "78440", "78441", "78442", "78443", "78444", "78449", "7847", "7848", "7849", "78499", "7861", "V414", "V440", "V550" ]

"Other liver diseases":
use_in_benchmark: True
type: "acute"
id: 151
codes: [ "570", "5715", "5716", "5718", "5719", "5720", "5721", "5722", "5723", "5724", "5728", "5730", "5734", "5735", "5738", "5739", "7824", "7891", "7895", "78959", "7904", "7905", "7948", "V427" ]

"Gastrointestinal hemorrhage":
use_in_benchmark: True
type: "acute"
id: 153
codes: [ "4560", "45620", "5307", "53082", "53100", "53101", "53120", "53121", "53140", "53141", "53160", "53161", "53200", "53201", "53220", "53221", "53240", "53241", "53260", "53261", "53300", "53301", "53320", "53321", "53340", "53341", "53360", "53361", "53400", "53401", "53420", "53421", "53440", "53441", "53460", "53461", "5693", "5780", "5781", "5789" ]

"Acute and unspecified renal failure":
use_in_benchmark: True
type: "acute"
id: 157
codes: [ "5845", "5846", "5847", "5848", "5849", "586" ]

"Chronic kidney disease":
use_in_benchmark: True
type: "chronic"
id: 158
codes: [ "585", "5851", "5852", "5853", "5854", "5855", "5856", "5859", "7925", "V420", "V451", "V4511", "V4512", "V560", "V561", "V562", "V5631", "V5632", "V568" ]

"Complications of surgical procedures or medical care":
use_in_benchmark: True
type: "acute"
id: 238
codes: [ "27661", "27783", "27788", "2853", "28741", "3490", "3491", "34931", "41511", "4294", "4582", "45821", "45829", "5121", "5122", "5187", "5190", "51900", "51901", "51902", "51909", "53086", "53087", "53640", "53641", "53642", "53649", "53901", "53909", "53981", "53989", "5642", "5643", "5644", "5696", "56962", "56971", "56979", "5793", "59681", "78062", "78063", "78066", "9093", "99524", "9954", "99586", "9970", "99700", "99701", "99702", "99709", "9971", "9972", "9973", "99731", "99732", "99739", "9974", "99741", "99749", "9975", "99760", "99761", "99762", "99769", "99771", "99772", "99779", "9979", "99791", "99799", "9980", "99800", "99801", "99802", "99809", "9981", "99811", "99812", "99813", "9982", "9983", "99830", "99831", "99832", "99833", "9984", "9985", "99851", "99859", "9986", "9987", "9988", "99881", "99882", "99883", "99889", "9989", "9990", "9991", "9992", "9993", "99934", "99939", "9994", "99941", "99942", "99949", "9995", "99951", "99952", "99959", "9996", "99960", "99961", "99962", "99963", "99969", "9997", "99970", "99971", "99972", "99973", "99974", "99975", "99976", "99977", "99978", "99979", "9998", "99980", "99981", "99982", "99983", "99984", "99985", "99988", "99989", "9999", "V1553", "V1580", "V1583", "V9001", "V9009" ]

"Shock":
use_in_benchmark: True
type: "acute"
id: 249
codes: [ "78550", "78551", "78552", "78559" ]
134 changes: 134 additions & 0 deletions pyhealth/datasets/halo_mimic3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import logging
import yaml
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split

logger = logging.getLogger(__name__)


class HALO_MIMIC3Dataset:
"""
A dataset class for handling MIMIC-III data, specifically designed to be compatible with HALO.

This class is responsible for loading and managing the MIMIC-III dataset,
which includes tables such as patients, admissions, and icustays.

Attributes:
mimic3_dir (str): The root directory where the dataset is stored.
pkl_data_dir (str): The directory in which .pkl files related to the dataset object will be stored.
gzip (Optional[bool]): Determines whether the object will look for ".csv.gz" (True) or ".csv" (False) files in mimic3_dir.
"""

def __init__(
self,
mimic3_dir: str = "./",
pkl_data_dir: str = "./",
gzip: bool = False
) -> None:
self.gzip = gzip
self.mimic3_dir = mimic3_dir
self.pkl_data_dir = pkl_data_dir
self.build_dataset()

def build_dataset(self) -> None:
admissionFile = self.mimic3_dir + f"ADMISSIONS.csv{'.gz' if self.gzip else ''}"
diagnosisFile = self.mimic3_dir + f"DIAGNOSES_ICD.csv{'.gz' if self.gzip else ''}"

admissionDf = pd.read_csv(admissionFile, dtype=str)
admissionDf['ADMITTIME'] = pd.to_datetime(admissionDf['ADMITTIME'])
admissionDf = admissionDf.sort_values('ADMITTIME')
admissionDf = admissionDf.reset_index(drop=True)
diagnosisDf = pd.read_csv(diagnosisFile, dtype=str).set_index("HADM_ID")
diagnosisDf = diagnosisDf[diagnosisDf['ICD9_CODE'].notnull()]
diagnosisDf = diagnosisDf[['ICD9_CODE']]

data = {}
for row in tqdm(admissionDf.itertuples(), total=admissionDf.shape[0]):
#Extracting Admissions Table Info
hadm_id = row.HADM_ID
subject_id = row.SUBJECT_ID

# Extracting the Diagnoses
if hadm_id in diagnosisDf.index:
diagnoses = list(set(diagnosisDf.loc[[hadm_id]]["ICD9_CODE"]))
else:
diagnoses = []

# Building the hospital admission data point
if subject_id not in data:
data[subject_id] = {'visits': [diagnoses]}
else:
data[subject_id]['visits'].append(diagnoses)

code_to_index = {}
all_codes = list(set([c for p in data.values() for v in p['visits'] for c in v]))
np.random.shuffle(all_codes)
for c in all_codes:
code_to_index[c] = len(code_to_index)
print(f"VOCAB SIZE: {len(code_to_index)}")
index_to_code = {v: k for k, v in code_to_index.items()}

data = list(data.values())

with open("./configs/hcup_ccs_2015_definitions_benchmark.yaml") as definitions_file:
definitions = yaml.full_load(definitions_file)

code_to_group = {}
for group in definitions:
if definitions[group]['use_in_benchmark'] == False:
continue
codes = definitions[group]['codes']
for code in codes:
if code not in code_to_group:
code_to_group[code] = group
else:
assert code_to_group[code] == group

id_to_group = sorted([k for k in definitions.keys() if definitions[k]['use_in_benchmark'] == True])
group_to_id = dict((x, i) for (i, x) in enumerate(id_to_group))

# Add Labels
for p in data:
label = np.zeros(len(group_to_id))
for v in p['visits']:
for c in v:
if c in code_to_group:
label[group_to_id[code_to_group[c]]] = 1

p['labels'] = label

for p in data:
new_visits = []
for v in p['visits']:
new_visit = []
for c in v:
new_visit.append(code_to_index[c])

new_visits.append((list(set(new_visit))))

p['visits'] = new_visits

print(f"MAX LEN: {max([len(p['visits']) for p in data])}")
print(f"AVG LEN: {np.mean([len(p['visits']) for p in data])}")
print(f"MAX VISIT LEN: {max([len(v) for p in data for v in p['visits']])}")
print(f"AVG VISIT LEN: {np.mean([len(v) for p in data for v in p['visits']])}")
print(f"NUM RECORDS: {len(data)}")
print(f"NUM LONGITUDINAL RECORDS: {len([p for p in data if len(p['visits']) > 1])}")

# Train-Val-Test Split
train_dataset, test_dataset = train_test_split(data, test_size=0.2, random_state=4, shuffle=True)
train_dataset, val_dataset = train_test_split(train_dataset, test_size=0.1, random_state=4, shuffle=True)

# Save Everything
print("Saving Everything")
print(len(index_to_code))
pickle.dump(code_to_index, open(f"{self.pkl_data_dir}codeToIndex.pkl", "wb"))
pickle.dump(index_to_code, open(f"{self.pkl_data_dir}indexToCode.pkl", "wb"))
pickle.dump(id_to_group, open(f"{self.pkl_data_dir}idToLabel.pkl", "wb"))
pickle.dump(train_dataset, open(f"{self.pkl_data_dir}trainDataset.pkl", "wb"))
pickle.dump(val_dataset, open(f"{self.pkl_data_dir}valDataset.pkl", "wb"))
pickle.dump(test_dataset, open(f"{self.pkl_data_dir}testDataset.pkl", "wb"))

1 change: 1 addition & 0 deletions pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .transformer import Transformer, TransformerLayer
from .transformers_model import TransformersModel
from .vae import VAE
from .generators.halo import HALO
Empty file.
Loading