Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,762 changes: 918 additions & 844 deletions models/analyze_model.ipynb

Large diffs are not rendered by default.

217 changes: 175 additions & 42 deletions models/generate_patient_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,144 @@
from google.oauth2 import service_account
import os
import json
import random

def load_existing_records(filename="records.json"):
"""Load existing records to analyze distributions."""
try:
if os.path.exists(filename) and os.path.getsize(filename) > 0:
with open(filename, "r") as json_file:
return json.load(json_file)
except json.JSONDecodeError:
print(f"Warning: {filename} is not properly formatted. Starting fresh.")
except Exception as e:
print(f"Warning: Error reading {filename}: {str(e)}. Starting fresh.")
return []

def save_record_to_json(record, filename="records.json"):
"""Save or append the patient record to a JSON file."""
if os.path.exists(filename):
with open(filename, "r") as json_file:
records = json.load(json_file)
else:
try:
if os.path.exists(filename) and os.path.getsize(filename) > 0:
with open(filename, "r") as json_file:
records = json.load(json_file)
else:
records = []
except (json.JSONDecodeError, Exception) as e:
print(f"Warning: Error reading existing records: {str(e)}. Starting fresh.")
records = []

records.append(record)

with open(filename, "w") as json_file:
json.dump(records, json_file, indent=4)
# Ensure the directory exists
os.makedirs(os.path.dirname(filename) if os.path.dirname(filename) else '.', exist_ok=True)

try:
with open(filename, "w") as json_file:
json.dump(records, json_file, indent=4)
except Exception as e:
print(f"Error saving record: {str(e)}")
# Optionally, create a backup file with just this record
backup_filename = f"record_{record['patientId']}.json"
with open(backup_filename, "w") as backup_file:
json.dump([record], backup_file, indent=4)
print(f"Saved record to backup file: {backup_filename}")



def analyze_distributions(records):
"""Analyze current distributions of key attributes."""
if not records: # Handle empty records case
return {
'age_groups': {'0-18': 0, '19-40': 0, '41-65': 0, '65+': 0},
'conditions_count': {},
'gender_count': {'Male': 0, 'Female': 0}
}

distributions = {
'age_groups': {'0-18': 0, '19-40': 0, '41-65': 0, '65+': 0},
'conditions_count': {},
'gender_count': {'Male': 0, 'Female': 0},
}

for record in records:
try:
# Analyze age distribution
age = int(record['patientRecord']['Age'])
if age <= 18:
distributions['age_groups']['0-18'] += 1
elif age <= 40:
distributions['age_groups']['19-40'] += 1
elif age <= 65:
distributions['age_groups']['41-65'] += 1
else:
distributions['age_groups']['65+'] += 1

# Analyze conditions
conditions = record['patientRecord'].get('Existing medical conditions', '').split(',')
for condition in conditions:
condition = condition.strip()
if condition:
distributions['conditions_count'][condition] = distributions['conditions_count'].get(condition, 0) + 1

# Analyze gender
gender = record['patientRecord']['Gender']
distributions['gender_count'][gender] = distributions['gender_count'].get(gender, 0) + 1
except (KeyError, ValueError) as e:
print(f"Warning: Error processing record: {str(e)}. Skipping.")
continue

return distributions

def generate_prompt_constraints(distributions, patient_id):
"""Generate specific constraints based on current distributions."""
age_groups = distributions['age_groups']
total_records = sum(age_groups.values())

# For the first record or when no records exist, use default distributions
if total_records == 0:
target_age_group = "between 19 and 40" # Start with middle age range
gender_preference = "any gender"
common_conditions = []
else:
# Determine which age group is underrepresented
percentages = {k: v/total_records for k, v in age_groups.items()}
if percentages.get('0-18', 0) < 0.15:
target_age_group = "between 5 and 18"
elif percentages.get('19-40', 0) < 0.35:
target_age_group = "between 19 and 40"
elif percentages.get('41-65', 0) < 0.35:
target_age_group = "between 41 and 65"
elif percentages.get('65+', 0) < 0.15:
target_age_group = "above 65"
else:
target_age_group = "varied"

# Adjust gender balance
gender_preference = None
if distributions['gender_count'].get('Male', 0) > distributions['gender_count'].get('Female', 0):
gender_preference = "Female"
elif distributions['gender_count'].get('Female', 0) > distributions['gender_count'].get('Male', 0):
gender_preference = "Male"
else:
gender_preference = "any gender"

# Avoid common conditions
common_conditions = [cond for cond, count in distributions['conditions_count'].items()
if count > total_records * 0.1]

constraints = f"""Please generate a medical record with these specific constraints:
1. Patient age should be {target_age_group}
2. Gender should be {gender_preference}
3. Please avoid these common conditions that are overrepresented: {', '.join(common_conditions) if common_conditions else 'None yet'}
4. For record #{patient_id}, ensure high uniqueness in:
- Height (vary between 150-190 cm for adults, adjust appropriately for children)
- Weight (vary appropriately based on height and age)
- Blood type (maintain realistic distribution)
- Symptoms (avoid repetition from common patterns)

{text1}
"""
return constraints

def parse_patient_record(text):
"""Parse the Patient Record section into the desired format."""
Expand Down Expand Up @@ -108,44 +233,52 @@ def generate():
model = GenerativeModel("gemini-1.5-flash-002")

total_records = 1000
for patient_id in range(552, total_records + 1):
for patient_id in range(1, total_records + 1):
print(f"Generating record {patient_id}/{total_records}")

response = model.generate_content(
[text1],
generation_config=generation_config,
safety_settings=safety_settings,
stream=False
)
# Load and analyze existing records
existing_records = load_existing_records()
distributions = analyze_distributions(existing_records)

# Split into Patient Record and Diagnosis Report sections
full_text = response.text.replace('\n\n', '\n').strip()
# Generate constraints based on current distributions
constrained_prompt = generate_prompt_constraints(distributions, patient_id)

# Extract Patient Record section
if "#Patient Record" in full_text:
parts = full_text.split("#Patient Record")
if len(parts) > 1:
record_part = parts[1].split("#Diagnosis Report")[0].strip()
diagnosis_part = parts[1].split("#Diagnosis Report")[1].strip()
else:
# Alternative format
parts = full_text.split("# Patient Record")
if len(parts) > 1:
record_part = parts[1].split("# Diagnosis Report")[0].strip()
diagnosis_part = parts[1].split("# Diagnosis Report")[1].strip()

# Parse sections into desired format
patient_record = parse_patient_record(record_part)
diagnosis_report = parse_diagnosis_report(diagnosis_part)

# Create final record
record = {
"patientId": str(patient_id),
"patientRecord": patient_record,
"diagnosisReport": diagnosis_report
}

save_record_to_json(record)
try:
response = model.generate_content(
[constrained_prompt],
generation_config=generation_config,
safety_settings=safety_settings,
stream=False
)

# Rest of your parsing logic remains the same
full_text = response.text.replace('\n\n', '\n').strip()

if "#Patient Record" in full_text:
parts = full_text.split("#Patient Record")
if len(parts) > 1:
record_part = parts[1].split("#Diagnosis Report")[0].strip()
diagnosis_part = parts[1].split("#Diagnosis Report")[1].strip()
else:
parts = full_text.split("# Patient Record")
if len(parts) > 1:
record_part = parts[1].split("# Diagnosis Report")[0].strip()
diagnosis_part = parts[1].split("# Diagnosis Report")[1].strip()

patient_record = parse_patient_record(record_part)
diagnosis_report = parse_diagnosis_report(diagnosis_part)

record = {
"patientId": str(patient_id),
"patientRecord": patient_record,
"diagnosisReport": diagnosis_report
}

save_record_to_json(record)

except Exception as e:
print(f"Error generating record {patient_id}: {str(e)}")
continue

print("Record generation completed.")

Expand All @@ -172,7 +305,7 @@ def generate():
- Various specialties involved
- Different levels of diagnostic certainty

Generate a single detailed male patient record and a diagnosis report with the following structure:
Generate a single detailed patient record and a diagnosis report with the following structure:

#Patient Record
- Gender
Expand Down Expand Up @@ -212,7 +345,7 @@ def generate():

generation_config = {
"max_output_tokens": 8192,
"temperature": 2,
"temperature": 1.3,
"top_p": 0.95,
}

Expand Down
Loading