Msc Thesis Pt2: Predicting NHCI Dropouts

Working document

Author
Affiliation

Brian ODonnell

Charité - Universitätsmedizin Berlin

Part one imported NHCI visit level data from the DHIS2 export, and described patient dropout patterns over time, location, and patient characteristics. Visit data were deduplicated, and quality of care was described.

Daily LGA-level weather data and weekly state-level market prices, each calculated as trailing 10-day averages. were also introduced. These data were merged into the visit-level data by matching on the “target end date” for the patient’s next appointment, defined as 37 days after the current visit. Thus, they act as instrumental variables for climate and economic conditions during the expected visit window.

In this section we use the above dataset to create a prediction model.

Implementer decisions should be made on how to define the outcome and put these predictions into patient followup practice:

  1. Is it better to predict return within a “target” 37 day window, or return within a “timely” 90 day window of latest visit?

  2. If a patient returns to care >90 days after their latest visit, are they still considered a dropout? Note this is about 17% of all followup visits.

  3. If a patient returns to care 1-21 days after latest visit, is this considered a successful followup thats within the target window? Note this is about 5% of all followup visits.

Goals Summary

The model described in manuscript is MODEL 2

Code
 #| label: load-packages
 #| include: false
 #| echo: false

library(tidyverse)
library(tidymodels)
library(table1)
library(slider)
library(ggcorrplot)
library(vip)
library(summarytools)
library(glmnet)
library(readxl)
library(kableExtra)
library(effectsize)
library(splines2)
library(probably)
library(RColorBrewer)
library(skimr)
library(DT)
library(themis)
library(ROSE)
library(patchwork)

palette <- brewer.pal(n = 8, name = "Dark2")
palette_reversed <- rev(palette)


mytheme<-
  theme_minimal()+
  theme(
  panel.background = element_rect(fill = "white", color = NA),
  plot.background = element_rect(fill = "white", color = NA),
  legend.background = element_rect(fill = "white", color = NA),
  text = element_text(family="sans"),
  strip.background = element_rect(
    color = "grey70", 
    linewidth = 1))

options(ggplot2.discrete.colour = palette, ggplot2.discrete.fill = palette)

theme_set(mytheme)

my_ggsave<-function(x, ...){
  
  ggsave(paste0(x,".png"), plot=last_plot(),
            path=here::here("figs"), 
            dpi=200,
            width=ifelse(is.na(...),7, ...))
  
    ggsave(paste0(x,".tiff"), plot=last_plot(), 
            path=here::here("figs"), 
            dpi=200,
            width=ifelse(is.na(...),7, ...))
}

PREDICTION

Import, Cleaning, and Skim

Importing full dataset, we remove the censored observations. These include the last visit before a patient is marked Transferred or Died, or those visits after April 1, 2024 (ie no opportunity to finish).

Code
tots_timely_qual_cli_mkt<-readr::read_rds("tots_timely_qual_cli_mkt.rds")

pop_params_top<-readr::read_lines("cli_mkt_params_for_model.txt")


dta_prep_all_models<-tots_timely_qual_cli_mkt %>% 
 # remove censored events  
    filter(censor==0) %>% 
    filter(registration_date>=ymd("2021-01-01")) %>% 
    mutate(event_year_month=make_date(year=year(event_date), 
                                      month=month(event_date), day="01")) %>% 
    mutate(event_year=as.factor(year(event_date))) %>% 
    mutate(LGA=str_remove_all(LGA, " Local Government Area")) %>% 
    add_count(LGA, event_year_month, name="LGA_monthly_visitLoad") %>%
    add_count(HC, event_year_month, name="HC_monthly_visitLoad") %>% 
  # here we remove all visits after a dropout
    mutate(visits_since_dropout=replace_na(visits_since_dropout,0)) %>% 
    filter(visits_since_dropout<2) %>% 
      mutate(bp_both_na=as.factor(
                      if_else(is.na(bp_diastole_mmhg) & is.na(bp_systole_mmhg), 
                             1, 0))) %>% 
      mutate(age_group = as.factor(cut(age,
                         breaks = c(0, 30, 45, 60, 75, 110, 100000),
                         include.lowest = T,
                         labels = c("Other", "30-44","45-59", 
                                    "60-74", "75-110", "Other")))) %>%
      mutate(med_any=as.factor(if_else(!is.na(med_amlodipine) | 
                                        !is.na(med_telmisartan) | 
                                        !is.na(med_losartan) | 
                                        !is.na(med_hydrochlorothiazide) |
                                        !is.na(other_hypertension_medication),
                                        1, 0))) %>%
      # if there is any medication, assume it is 30 days
      mutate(days_of_medication=as.integer(
                                str_remove(days_of_medication, "days"))) %>% 
      mutate(days_of_medication=if_else(med_any==1 & is.na(days_of_medication), 
                                        30, days_of_medication)) %>% 
      mutate(days_of_medication=replace_na(days_of_medication, 0)) %>% 
      mutate(meds_other=if_else(is.na(other_hypertension_medication),
                                      "none","recorded")) %>% 
      select(!c(starts_with("other_"))) %>% 
   # convert to factors as appropriate
      mutate(across(starts_with("has_"), as.factor)) %>%
      mutate(across(starts_with("p_"), as.factor)) %>%
      mutate(across(starts_with("consent"), as.character)) %>% 
      mutate(across(where(is.character), as.factor)) %>% 
      arrange(reg_id, appt_tally) %>% 
      group_by(reg_id) %>% # days since visit as rolling mean
      mutate(days_since_last_mean = map_dbl(seq_along(appt_tally), 
                                  ~ mean(days_since_last[1:.x], 
                                     na.rm = TRUE))) %>% 
      mutate(days_since_last_sd = map_dbl(seq_along(appt_tally), 
                                  ~ sd(days_since_last[1:.x], 
                                     na.rm = TRUE))) %>% 
      ungroup() %>% 
      mutate(across(starts_with("days"), ~ ifelse(is.nan(.), NA, .)))


skimr::skim(dta_prep_all_models)
Data summary
Name dta_prep_all_models
Number of rows 99740
Number of columns 115
_______________________
Column type frequency:
Date 11
factor 45
numeric 59
________________________
Group variables None

Variable type: Date

skim_variable n_missing complete_rate min max median n_unique
registration_date 0 1.00 2021-01-01 2024-04-02 2022-05-20 1162
event_date 0 1.00 2021-01-01 2024-04-02 2022-10-17 1188
prev_event_date 28184 0.72 2021-01-01 2024-03-21 2022-10-18 1169
lead_event_date 13015 0.87 2021-01-04 2024-07-01 2022-12-19 1258
nextdate_37 0 1.00 2021-02-07 2024-05-09 2022-11-23 1188
nextdate_90 0 1.00 2021-04-01 2024-07-01 2023-01-15 1188
last_updated_on 0 1.00 2021-01-02 2024-07-16 2023-06-05 1242
scheduled_date 0 1.00 2012-12-13 2201-08-29 2022-12-30 1242
incident_date 0 1.00 2021-01-01 2024-07-14 2022-06-22 1153
exp_year_month 0 1.00 2021-04-01 2024-07-01 2023-01-01 40
event_year_month 0 1.00 2021-01-01 2024-04-01 2022-10-01 40

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
reg_visit_id 0 1.00 FALSE 99740 0_1: 1, 1_1: 1, 10_: 1, 10_: 1
sex 37 1.00 FALSE 2 Fem: 75696, Mal: 24007
State 0 1.00 FALSE 2 Kan: 63363, Ogu: 36377
ncd_patient_status 0 1.00 FALSE 5 Act: 98792, Tra: 341, Die: 287, Ina: 264
stored_by 21670 0.78 FALSE 107 Hot: 6196, Gwa: 4919, Jae: 4896, Daw: 3816
created_by 0 1.00 FALSE 106 Aku: 14321, DAK: 7209, Rec: 4726, Rec: 4475
last_updated_by 0 1.00 FALSE 107 Rec: 7400, DAK: 7209, Rec: 5818, Rec: 5611
organisation_unit_name 0 1.00 FALSE 104 kn : 7981, kn : 7209, kn : 6029, kn : 5802
Nat 0 1.00 FALSE 1 fg : 99740
LGA 0 1.00 FALSE 42 Nas: 21269, Gwa: 9706, Dal: 5996, Abe: 5276
Ward 0 1.00 FALSE 100 kn: 7981, kn: 7209, kn: 6029, kn: 5802
HC 0 1.00 FALSE 105 kn: 7981, kn: 7209, kn: 6029, kn: 5802
FACILITY 93686 0.06 FALSE 10 og: 1193, og: 1135, og: 778, og: 683
program_status 0 1.00 FALSE 3 ACT: 99334, CAN: 362, COM: 44
event_status 0 1.00 FALSE 2 COM: 96653, ACT: 3087
organisation_unit 0 1.00 FALSE 105 Dxf: 7981, jpP: 7209, ehJ: 6029, xUO: 5802
med_losartan 61630 0.38 FALSE 2 Los: 32105, Los: 6005
hypertension_treated_in_the_past 601 0.99 FALSE 2 Fir: 83075, Was: 16064
med_hydrochlorothiazide 98817 0.01 FALSE 2 Hyd: 886, Hyd: 37
patient_followup_status 7902 0.92 FALSE 3 Ove: 73474, Upc: 17844, Cal: 520
consent_to_send_sms_reminders 83169 0.17 FALSE 1 1: 16571
does_patient_have_diabetes 6864 0.93 FALSE 2 No: 90967, Yes: 1909
relationship_to_patient 21169 0.79 FALSE 9 Son: 28699, Hus: 17704, Dau: 13760, Oth: 6426
newage 0 1.00 FALSE 96 50: 8330, 60: 7690, 70: 5594, 40: 4724
med_amlodipine 1991 0.98 FALSE 3 Aml: 89511, Aml: 8229, Oth: 9
med_telmisartan 99452 0.00 FALSE 2 Tel: 218, Tel: 70
does_patient_have_hypertension 0 1.00 FALSE 1 Yes: 99740
consent_to_record_data 0 1.00 FALSE 1 1: 99740
has_supporter 0 1.00 FALSE 2 1: 79370, 0: 20370
has_phone 0 1.00 FALSE 2 1: 81111, 0: 18629
p_step1_need 0 1.00 FALSE 2 0: 87868, 1: 11872
p_step2_need 0 1.00 FALSE 2 0: 78454, 1: 21286
p_step3_need 0 1.00 FALSE 2 0: 89438, 1: 10302
p_step4_need 0 1.00 FALSE 2 0: 96040, 1: 3700
p_step1_met 0 1.00 FALSE 2 0: 88426, 1: 11314
p_step2_met 0 1.00 FALSE 2 0: 86110, 1: 13630
p_step3_met 0 1.00 FALSE 2 0: 97972, 1: 1768
p_step4_met 0 1.00 FALSE 2 0: 99732, 1: 8
lga_37date 0 1.00 FALSE 23130 Nas: 107, Nas: 84, Nas: 78, Nas: 72
state_37date 0 1.00 FALSE 2352 Kan: 207, Kan: 196, Kan: 189, Kan: 171
event_year 0 1.00 FALSE 4 202: 41930, 202: 33676, 202: 17550, 202: 6584
bp_both_na 0 1.00 FALSE 2 0: 99490, 1: 250
age_group 0 1.00 FALSE 5 45-: 41423, 30-: 23564, 60-: 23143, 75-: 6049
med_any 0 1.00 FALSE 2 1: 98043, 0: 1697
meds_other 0 1.00 FALSE 2 non: 98575, rec: 1165

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
reg_id 0 1.00 16084.09 9338.81 0.00 7910.00 16024.50 24157.00 32332.00 ▇▇▇▇▇
appt_tally 0 1.00 5.05 5.13 1.00 1.00 3.00 7.00 43.00 ▇▁▁▁▁
total_patient_visits 0 1.00 11.18 8.24 1.00 4.00 10.00 16.00 45.00 ▇▅▂▁▁
nextdate_37_kept 0 1.00 0.52 0.50 0.00 0.00 1.00 1.00 1.00 ▇▁▁▁▇
nextdate_90_kept 0 1.00 0.74 0.44 0.00 0.00 1.00 1.00 1.00 ▃▁▁▁▇
height_cm 94800 0.05 82.76 80.31 0.45 1.60 140.00 160.00 758.00 ▇▅▁▁▁
bp_diastole_mmhg 647 0.99 84.92 13.28 1.00 77.00 83.00 92.00 200.00 ▁▆▇▁▁
weight_kg 94421 0.05 72.91 15.63 28.00 61.00 70.00 82.00 210.00 ▃▇▁▁▁
bmi 94797 0.05 152304.17 322447.99 0.00 27.00 41.90 269098.80 4296296.30 ▇▁▁▁▁
bmi_measurement 94079 0.06 1.00 0.00 1.00 1.00 1.00 1.00 1.00 ▁▁▇▁▁
days_of_medication 0 1.00 29.61 4.30 0.00 30.00 30.00 30.00 90.00 ▁▇▁▁▁
blood_sugar_measurement 99478 0.00 1.00 0.00 1.00 1.00 1.00 1.00 1.00 ▁▁▇▁▁
bp_systole_mmhg 290 1.00 143.16 20.05 2.00 130.00 140.00 155.00 300.00 ▁▁▇▁▁
age 0 1.00 53.97 14.20 18.00 45.00 53.00 63.00 124.00 ▂▇▃▁▁
days_since_entry 0 1.00 140.18 171.33 0.00 0.00 80.00 210.00 1170.00 ▇▂▁▁▁
exp_year 0 1.00 2022.53 0.86 2021.00 2022.00 2023.00 2023.00 2024.00 ▂▇▁▇▂
exp_month 0 1.00 6.51 3.41 1.00 4.00 6.00 9.00 12.00 ▇▆▆▅▇
days_since_last 28184 0.72 37.14 17.53 7.00 28.00 30.00 42.00 772.00 ▇▁▁▁▁
nextdate_90_dropout 0 1.00 0.26 0.44 0.00 0.00 0.00 1.00 1.00 ▇▁▁▁▃
visits_since_dropout 0 1.00 0.26 0.44 0.00 0.00 0.00 1.00 1.00 ▇▁▁▁▃
censor 0 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 ▁▁▇▁▁
sbp_baseline 299 1.00 157.92 17.17 14.00 146.00 158.00 169.00 300.00 ▁▁▇▁▁
dbp_baseline 443 1.00 92.48 13.30 1.00 85.00 92.00 100.00 200.00 ▁▂▇▁▁
sbp_change_from_baseline 480 1.00 -14.74 19.60 -180.00 -28.00 -10.00 0.00 178.00 ▁▂▇▁▁
dbp_change_from_baseline 913 0.99 -7.55 13.22 -143.00 -15.00 -3.00 0.00 130.00 ▁▁▇▁▁
cli_max_temp_c 0 1.00 30.24 3.39 18.50 27.70 29.40 32.40 41.10 ▁▃▇▃▁
cli_max_heat_index_c 0 1.00 33.43 3.26 19.00 31.10 33.50 35.60 43.60 ▁▁▇▇▁
cli_rain_mm 0 1.00 4.02 7.39 0.00 0.00 0.60 5.20 154.80 ▇▁▁▁▁
cli_rain_hours 0 1.00 4.87 6.00 0.00 0.00 2.00 8.00 24.00 ▇▂▂▁▁
cli_max_temp_c_pre10day_mean 0 1.00 30.26 3.21 20.99 27.89 29.39 32.39 39.70 ▁▆▇▃▂
cli_max_heat_index_c_pre10day_mean 0 1.00 33.44 2.89 22.46 31.43 33.49 35.27 42.35 ▁▃▇▅▁
cli_rain_mm_pre10day_mean 0 1.00 4.02 4.39 0.00 0.11 2.87 6.35 42.86 ▇▁▁▁▁
cli_rain_hours_pre10day_mean 0 1.00 4.77 4.64 0.00 0.40 3.70 7.60 20.60 ▇▃▂▁▁
mkt_diesel 0 1.00 692.09 283.68 188.00 634.00 744.00 822.60 1470.00 ▃▃▇▂▁
mkt_gasoline 0 1.00 313.80 189.21 161.00 179.00 199.00 488.00 1070.00 ▇▁▂▁▁
mkt_maize_grain_white 0 1.00 307.89 136.43 151.54 212.00 267.03 322.41 814.00 ▇▃▂▁▁
mkt_maize_grain_yellow 0 1.00 332.92 140.36 159.23 236.14 300.55 369.14 851.40 ▇▅▂▁▁
mkt_millet_pearl 0 1.00 334.85 144.73 160.33 244.05 300.56 351.12 847.86 ▇▆▁▁▁
mkt_rice_5_percent_broken 0 1.00 814.02 327.72 475.79 603.45 708.93 957.51 2160.50 ▇▂▁▁▁
mkt_rice_milled 0 1.00 643.92 262.61 278.00 464.29 571.43 775.96 1641.30 ▇▅▂▁▁
mkt_sorghum_brown 0 1.00 341.15 172.27 155.90 231.08 297.52 370.52 991.50 ▇▃▁▁▁
mkt_sorghum_white 0 1.00 332.27 157.89 154.68 224.00 308.00 357.60 852.30 ▇▆▁▁▁
mkt_yams 0 1.00 294.54 146.76 136.92 186.34 249.42 345.97 836.70 ▇▃▁▁▁
mkt_usd_exchange_rate 0 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 ▁▃▁▂▇
mkt_diesel_pre10day_mean 0 1.00 688.30 283.55 188.00 634.00 744.00 819.20 1458.00 ▃▃▇▂▁
mkt_gasoline_pre10day_mean 0 1.00 310.70 185.39 160.76 176.60 199.00 488.04 1070.00 ▇▁▂▁▁
mkt_maize_grain_white_pre10day_mean 0 1.00 306.30 133.74 153.85 218.24 265.02 322.41 808.69 ▇▃▁▁▁
mkt_maize_grain_yellow_pre10day_mean 0 1.00 331.23 138.20 162.11 234.46 303.58 366.94 851.40 ▇▆▂▁▁
mkt_millet_pearl_pre10day_mean 0 1.00 333.40 143.20 161.07 241.27 294.28 347.75 846.81 ▇▆▁▁▁
mkt_rice_5_percent_broken_pre10day_mean 0 1.00 810.16 323.26 478.97 598.28 708.93 957.50 2157.05 ▇▂▁▁▁
mkt_rice_milled_pre10day_mean 0 1.00 640.55 258.59 281.60 460.71 565.03 773.77 1636.41 ▇▅▂▁▁
mkt_sorghum_brown_pre10day_mean 0 1.00 339.25 170.23 159.15 228.69 298.90 360.60 988.98 ▇▃▁▁▁
mkt_sorghum_white_pre10day_mean 0 1.00 330.54 156.19 154.68 223.92 296.00 354.24 852.30 ▇▆▁▁▁
mkt_yams_pre10day_mean 0 1.00 292.48 144.11 139.12 184.84 251.56 344.47 763.26 ▇▃▁▁▁
mkt_usd_exchange_rate_pre10day_mean 0 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 ▁▂▁▂▇
LGA_monthly_visitLoad 0 1.00 279.64 263.03 1.00 85.00 164.00 406.00 1108.00 ▇▂▁▁▁
HC_monthly_visitLoad 0 1.00 130.25 125.51 1.00 37.00 75.00 185.00 532.00 ▇▂▁▁▁
days_since_last_mean 28184 0.72 36.27 12.93 7.00 28.45 32.75 40.33 772.00 ▇▁▁▁▁
days_since_last_sd 43611 0.56 10.85 10.05 0.00 3.39 9.70 15.89 522.55 ▇▁▁▁▁

Model 1: Failure to Return after Baseline

First prepare the dataset.

Code
dta_model1 <- dta_prep_all_models %>% 
  select(reg_id, appt_tally, nextdate_90_kept,
         event_year_month, total_patient_visits,
         sex, age, event_year, days_since_last,
         #KEEP FROM HERE
         contains("med"), contains("bp_"), ends_with("baseline"),
         contains("p_step"), relationship_to_patient,
         has_phone, does_patient_have_diabetes,
         State, 
         LGA,
         age_group,
         HC_monthly_visitLoad,
         consent_to_send_sms_reminders,
         any_of(pop_params_top) # weather and market params
         ) %>% 
  # define the target population (baseline)
  filter(appt_tally==1) %>% 
  # define the outcome
  mutate(patient_dropout=as.factor(if_else(total_patient_visits < 2, 
                                          1, 0))) %>% 
  select(!any_of(c("visits_since_dropout","nextdate_90_kept",
              "total_patient_visits","registration_date")))

M1.1. Split into training and validation and test sets

Code
# library(tidymodels)
set.seed(124)

# Extract unique groups
unique_groups <- unique(dta_model1$reg_id)

# Shuffle groups to ensure random distribution
shuffled_groups <- sample(unique_groups)

# Define the split ratio
split_ratio <- 0.8

# Determine split point
split_point <- floor(length(shuffled_groups) * split_ratio)

# Assign groups to training and validation sets
train_groups <- shuffled_groups[1:split_point]
test_groups <- shuffled_groups[(split_point + 1):length(shuffled_groups)]

# Create training and validation datasets based on group membership
training_data <- dta_model1 %>% filter(reg_id %in% train_groups)
testing_data <- dta_model1 %>% filter(reg_id %in% test_groups)


# Set up 5-fold cross-validation, grouped by the patient ID variable
cv_folds <- vfold_cv(training_data, v = 5) 

M1.2. Build the model

Code
# with lasso regularization

lr_mod <- logistic_reg(penalty = tune(), mixture=1) %>% 
  set_engine("glmnet")

M1.3. Build the “recipe” to preprocess the data

Code
lr_recipe <- recipe(patient_dropout ~ ., data = training_data) %>% 
  update_role(reg_id, new_role = "ID") %>% 
  # step_date(event_date, features = c("dow", "month")) %>% 
  step_date(event_year_month, features = c("year","month")) %>%
  # step_rm(event_date) %>% 
  step_rm(event_year_month) %>% 
  step_log(HC_monthly_visitLoad) %>% 
  # step_rm(exp_year_month) %>%
  # Step to impute missing values in factor_var (if NA is not a valid level)
  step_novel(all_factor_predictors(), new_level = "new") %>%
  step_unknown(all_factor_predictors(), new_level = "missing") %>%
  # Step to impute missing values in numeric variable x with mean
  step_impute_median(all_numeric_predictors()) %>% 
  # step_corr(all_numeric_predictors(), threshold = .5) %>%
  step_dummy(all_nominal_predictors()) %>% 
  step_ns(age) %>% # 
  # step_ns(sbp_change_from_baseline) %>%
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

M1.4. Add a workflow, combining the model and recipe

Code
lr_workflow<-workflow() %>% 
  add_model(lr_mod) %>% 
  add_recipe(lr_recipe)

M1.5 Create penalty values for tuning

Code
lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 20))

M1.6 Apply penalty values to the workflow

Code
set.seed(3303)
# library(glmnet)
# install.packages("glmnet")

lr_res <- lr_workflow %>% 
  tune_grid(resamples = cv_folds,
            grid = lr_reg_grid,
            control = control_grid(save_pred = TRUE),
            metrics = metric_set(roc_auc))

lr_plot <- 
  lr_res %>% 
  collect_metrics() %>% 
  ggplot(aes(x = penalty, y = mean)) + 
  geom_point() + 
  geom_line() + 
  ylab("Area under the ROC Curve") +
  scale_x_log10(labels = scales::label_number())

lr_plot

M1.7 Select best model

Lets try to find the best model…

Code
lr_res %>% 
  collect_metrics() %>% 
  arrange(penalty) %>% 
  rowid_to_column()
# A tibble: 20 × 8
   rowid  penalty .metric .estimator  mean     n std_err .config              
   <int>    <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
 1     1 0.0001   roc_auc binary     0.710     5 0.00278 Preprocessor1_Model01
 2     2 0.000144 roc_auc binary     0.710     5 0.00279 Preprocessor1_Model02
 3     3 0.000207 roc_auc binary     0.710     5 0.00281 Preprocessor1_Model03
 4     4 0.000298 roc_auc binary     0.710     5 0.00282 Preprocessor1_Model04
 5     5 0.000428 roc_auc binary     0.710     5 0.00285 Preprocessor1_Model05
 6     6 0.000616 roc_auc binary     0.710     5 0.00287 Preprocessor1_Model06
 7     7 0.000886 roc_auc binary     0.710     5 0.00290 Preprocessor1_Model07
 8     8 0.00127  roc_auc binary     0.710     5 0.00300 Preprocessor1_Model08
 9     9 0.00183  roc_auc binary     0.710     5 0.00316 Preprocessor1_Model09
10    10 0.00264  roc_auc binary     0.709     5 0.00333 Preprocessor1_Model10
11    11 0.00379  roc_auc binary     0.707     5 0.00352 Preprocessor1_Model11
12    12 0.00546  roc_auc binary     0.704     5 0.00388 Preprocessor1_Model12
13    13 0.00785  roc_auc binary     0.699     5 0.00397 Preprocessor1_Model13
14    14 0.0113   roc_auc binary     0.691     5 0.00377 Preprocessor1_Model14
15    15 0.0162   roc_auc binary     0.684     5 0.00344 Preprocessor1_Model15
16    16 0.0234   roc_auc binary     0.663     5 0.00424 Preprocessor1_Model16
17    17 0.0336   roc_auc binary     0.644     5 0.00264 Preprocessor1_Model17
18    18 0.0483   roc_auc binary     0.638     5 0.00260 Preprocessor1_Model18
19    19 0.0695   roc_auc binary     0.642     5 0.00253 Preprocessor1_Model19
20    20 0.1      roc_auc binary     0.5       5 0       Preprocessor1_Model20
Code
# select optimal penalty
lr_best<-  lr_res %>% 
  select_best(metric="roc_auc")
  # collect_metrics() %>% 
  # arrange(penalty) %>% 
  # slice(12)
# 
# lr_res %>% 
#   collect_metrics()

M1.8 ROC curve for best fit

OK, now lets see how this ROC curve looks for the best fit log model….

Code
lr_preds1 <- lr_res %>% 
  collect_predictions(parameters = lr_best) %>% 
  mutate(model_name="Model 1: First Visit Return Failure")
  
lr_auc<- lr_preds1 %>% 
  roc_curve(patient_dropout, .pred_1, event_level = "second") %>% 
  mutate(model = "Logistic Regression") 

autoplot(lr_auc)

Code
# Extract the glmnet model from the workflow and training data
# final_fit <- fit(lr_workflow, data = appts_other)
# 
# final_coefficients <- tidy(final_fit, penalty = lr_best$penalty[1], mixture=1)

its… ok!

But we can see how it would fit onto the test data now.

Model 1: Final Fit

Code
last_log_mod <-  logistic_reg(penalty = lr_best$penalty[1], mixture=1) %>% 
  set_engine("glmnet", importance="BIC")

# the last workflow
last_log_workflow <- lr_workflow %>% 
                    update_model(last_log_mod)

# last fit
last_lr_fit<- last_log_workflow %>% 
              fit(training_data)

# show metrics
lr_last_preds_m1 <- predict(last_lr_fit, new_data = testing_data, type="prob") %>% 
                bind_cols(testing_data) %>% 
                mutate(patient_dropout=as.factor(patient_dropout)) %>% 
                mutate(predicted_class_60 = as.factor(if_else(.pred_1 > 0.6, "1", "0"))) %>%
                mutate(predicted_class_50 = as.factor(if_else(.pred_1 > 0.5, "1", "0"))) %>%
                mutate(predicted_class_38 = as.factor(if_else(.pred_1 > 0.38, "1", "0"))) %>%
                mutate(predicted_class_20 = as.factor(if_else(.pred_1 > 0.2, "1", "0"))) %>%
                mutate(predicted_prob_yes = .pred_1, predicted_prob_no = .pred_0) %>% 
                mutate("model"="Model 1: Return Failure At Baseline")

About same ROC as with validation set…

M1 Top features

Code
vi_fit<- last_lr_fit %>% 
  extract_fit_parsnip() 

# vi_test<-vi_fit %>% pluck("fit")
# class(vi_test)

vi_fit %>% 
  vip(num_features = 25, mapping=aes(fill=Sign)) +
  scale_y_continuous(expand = c(0,0)) +
  theme_light() +
  labs(title="Top predictors for Model1",
  subtitle="Logistic Reg for Dropout Event after first visit")

Code
# pal <- palette.colors(2, palette = "Okabe-Ito")  # colorblind friendly palette
# vip(vi_fit, num_features = 20,  # Figure 3
#     geom = "point", horizontal = FALSE, mapping = aes(color = Sign)) +
#   scale_color_manual(values = unname(pal)) +
#   theme_light() +
#   theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
#   labs(title="Variable Importance in Logistic Regression Model")

LGAs seem to have most impact.

Get the coefficients from this final fit as per docs

Code
vi_fit %>% tidy() %>% DT::datatable()
Code
calibration_data <- lr_last_preds_m1 %>%
  mutate(bin = cut(.pred_1, breaks = seq(0, 1, by = 0.05), 
                   include.lowest = TRUE)) %>%
  group_by(bin) %>%
  summarize(
    mean_pred = mean(.pred_1),
    observed = mean(as.numeric(patient_dropout) - 1)  # Convert factor to 0/1
  )

ggplot(calibration_data, aes(x = mean_pred, y = observed)) +
  geom_point() +
  geom_line() +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "red") +
  labs(
    x = "Predicted Probability",
    y = "Observed Proportion",
    title = "Calibration Plot"
  ) +
  theme_minimal()

Model 2: Dropouts Per Opportunity, Full Specs

First prepare the dataset.

Code
dta_model2 <- dta_prep_all_models %>% 
  select(reg_id, appt_tally, nextdate_90_kept,
         sex, age, event_year_month, starts_with("days_since"),
         #KEEP FROM HERE
         contains("med"), contains("bp_"), ends_with("baseline"),
         contains("p_step"), relationship_to_patient,
         has_phone, does_patient_have_diabetes,
         "htn_past"=hypertension_treated_in_the_past,
         LGA, 
         State, age_group,
         HC_monthly_visitLoad,
         consent_to_send_sms_reminders,
         any_of(pop_params_top) # weather and market params
         ) %>% 
  # define the target population (full)
  # filter(appt_tally==1) %>% 
  # define the outcome
  mutate(patient_dropout=as.factor(if_else(nextdate_90_kept==1, 
                                          0, 1))) %>% 
  group_by(reg_id) %>% 
  mutate(patient_dropout_ever=if_else(sum(as.numeric(patient_dropout))>0, 
                                          1, 0)) %>% 
  ungroup() %>% 
  select(!any_of(c("visits_since_dropout","nextdate_90_kept",
              "total_patient_visits","registration_date")))

M2.1. Split into training and validation and test sets

Code
library(tidymodels)
set.seed(124)

# Extract unique groups
unique_groups <- unique(dta_model2$reg_id)

# Shuffle groups to ensure random distribution
shuffled_groups <- sample(unique_groups)

# Define the split ratio
split_ratio <- 0.8

# Determine split point
split_point <- floor(length(shuffled_groups) * split_ratio)

# Assign groups to training and validation sets
train_groups <- shuffled_groups[1:split_point]
test_groups <- shuffled_groups[(split_point + 1):length(shuffled_groups)]

# Create training and validation datasets based on group membership
training_data <- dta_model2 %>% filter(reg_id %in% train_groups)
testing_data <- dta_model2 %>% filter(reg_id %in% test_groups)




# Set up 5-fold cross-validation, grouped by the patient ID variable
cv_folds <- group_vfold_cv(training_data, v = 5, 
                           group=reg_id, #all visits for one patient go into same fold once the visit is sampled
                           strata=patient_dropout_ever) #same number of patients who eventually dropout go into each fold

M2.2. Build the model

Code
# with lasso regularization

lr_mod <- logistic_reg(penalty = tune(), mixture=1) %>% 
  set_engine("glmnet")

M2.3. Build the “recipe” to preprocess the data

Code
lr_recipe <- recipe(patient_dropout ~ ., data = training_data) %>% 
  update_role(reg_id, new_role = "ID") %>% 
  # step_date(event_date, features = c("dow", "month")) %>%
  step_date(event_year_month, features = c("month")) %>%
  # step_date(exp_year_month, features = c("year", "month")) %>%
  # step_rm(event_date) %>% 
  step_rm(patient_dropout_ever) %>%
  step_rm(event_year_month) %>% 
  step_log(appt_tally) %>% 
  step_log(HC_monthly_visitLoad) %>% 
  # step_rm(exp_year_month) %>%
  # Step to impute missing values in factor_var (if NA is not a valid level)
  step_novel(all_factor_predictors(), new_level = "new") %>%
  step_unknown(all_factor_predictors(), new_level = "missing") %>%
  # Step to impute missing values in numeric variable x with mean
  step_impute_median(all_numeric_predictors()) %>% 
  # step_rose(patient_dropout) %>% # this didnt help...
  # step_corr(all_numeric_predictors(), threshold = .5) %>%
  step_dummy(all_nominal_predictors()) %>% 
  step_ns(age) %>% # 
  # step_ns(sbp_change_from_baseline) %>%
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

M2.4 Combining the model and recipe into workflow

Code
lr_workflow<-workflow() %>% 
  add_model(lr_mod) %>% 
  add_recipe(lr_recipe)

M2.5. Penalty values for tuning

Code
lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 20))

M2.6. Apply penalty values to the workflow

Code
set.seed(3303)
# library(glmnet)
# install.packages("glmnet")

lr_res <- lr_workflow %>% 
  tune_grid(resamples = cv_folds,
            grid = lr_reg_grid,
            control = control_grid(save_pred = TRUE),
            metrics = metric_set(roc_auc))

lr_plot <- 
  lr_res %>% 
  collect_metrics() %>% 
  ggplot(aes(x = penalty, y = mean)) + 
  geom_point() + 
  geom_line() + 
  ylab("Area under the ROC Curve") +
  scale_x_log10(labels = scales::label_number()) +
  geom_vline(xintercept=lr_best$penalty, 
             linetype="dashed", color="blue") +
  annotate(geom="text", 
           x=lr_best$penalty+0.0008, y=0.721, color="blue", size=4,
           label=paste0("Optimum = ",round(lr_best$penalty, 6))) +
  labs(title="Tune grid of penalty values")

lr_plot

M2.7. Select best model

Code
lr_res %>% 
  collect_metrics() %>% 
  arrange(penalty) %>% 
  rowid_to_column()
# A tibble: 20 × 8
   rowid  penalty .metric .estimator  mean     n std_err .config              
   <int>    <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
 1     1 0.0001   roc_auc binary     0.736     5 0.00167 Preprocessor1_Model01
 2     2 0.000144 roc_auc binary     0.736     5 0.00167 Preprocessor1_Model02
 3     3 0.000207 roc_auc binary     0.736     5 0.00165 Preprocessor1_Model03
 4     4 0.000298 roc_auc binary     0.736     5 0.00164 Preprocessor1_Model04
 5     5 0.000428 roc_auc binary     0.736     5 0.00162 Preprocessor1_Model05
 6     6 0.000616 roc_auc binary     0.736     5 0.00161 Preprocessor1_Model06
 7     7 0.000886 roc_auc binary     0.735     5 0.00159 Preprocessor1_Model07
 8     8 0.00127  roc_auc binary     0.735     5 0.00157 Preprocessor1_Model08
 9     9 0.00183  roc_auc binary     0.734     5 0.00155 Preprocessor1_Model09
10    10 0.00264  roc_auc binary     0.733     5 0.00154 Preprocessor1_Model10
11    11 0.00379  roc_auc binary     0.731     5 0.00147 Preprocessor1_Model11
12    12 0.00546  roc_auc binary     0.729     5 0.00136 Preprocessor1_Model12
13    13 0.00785  roc_auc binary     0.726     5 0.00131 Preprocessor1_Model13
14    14 0.0113   roc_auc binary     0.721     5 0.00131 Preprocessor1_Model14
15    15 0.0162   roc_auc binary     0.715     5 0.00142 Preprocessor1_Model15
16    16 0.0234   roc_auc binary     0.710     5 0.00134 Preprocessor1_Model16
17    17 0.0336   roc_auc binary     0.702     5 0.00185 Preprocessor1_Model17
18    18 0.0483   roc_auc binary     0.695     5 0.00177 Preprocessor1_Model18
19    19 0.0695   roc_auc binary     0.682     5 0.00194 Preprocessor1_Model19
20    20 0.1      roc_auc binary     0.682     5 0.00194 Preprocessor1_Model20
Code
# select optimal penalty
lr_best<-  lr_res %>% 
  select_best(metric="roc_auc")
  # collect_metrics() %>% 
  # arrange(penalty) %>% 
  # slice(12)

lr_best
# A tibble: 1 × 2
   penalty .config              
     <dbl> <chr>                
1 0.000428 Preprocessor1_Model05

M2.8. ROC curve for best fit

Code
lr_preds2 <- lr_res %>% 
  collect_predictions(parameters = lr_best) %>% 
  mutate(model_name="Model 2: Dropout Per Opp (Full Spec)")


lr_auc<- lr_preds2 %>% 
  roc_curve(patient_dropout, .pred_1, event_level = "second") %>% 
  mutate(model = "Logistic Regression") 

roc_plot<-lr_auc

autoplot(lr_auc)

Code
library(patchwork)
roc_plot<-autoplot(lr_auc) +
  labs(title="ROC Curve") 

lr_plot + 
  roc_plot +
   plot_annotation(tag_levels = 'A') +
   theme_minimal()

Code
my_ggsave("fig6_roc_penalty", 14)

its… ok!

But we can see how it would fit onto the test data now.

Model 2: Final Fit

Code
last_log_mod <-  logistic_reg(penalty = lr_best$penalty[1], mixture=1) %>% 
  set_engine("glmnet", importance="BIC")

# the last workflow
last_log_workflow <- lr_workflow %>% 
                    update_model(last_log_mod)


# last fit
last_lr_fit<- last_log_workflow %>% 
              fit(training_data) 

lr_last_preds_m2 <- predict(last_lr_fit, new_data = testing_data, type="prob") %>% 
                bind_cols(testing_data) %>% 
                mutate(patient_dropout=as.factor(patient_dropout)) %>% 
                  mutate(predicted_class_60 = as.factor(if_else(.pred_1 > 0.6, "1", "0"))) %>%
                mutate(predicted_class_50 = as.factor(if_else(.pred_1 > 0.5, "1", "0"))) %>%
                mutate(predicted_class_38 = as.factor(if_else(.pred_1 > 0.38, "1", "0"))) %>%
                mutate(predicted_class_20 = as.factor(if_else(.pred_1 > 0.2, "1", "0"))) %>%
                mutate(predicted_prob_yes = .pred_1, predicted_prob_no = .pred_0) %>% 
                mutate("model"="Model 2: Full Spec")

Model 2: Performance Metrics

Code
# Define classification metrics
classification_metrics <- metric_set(accuracy, precision, recall, 
                                     f_meas, roc_auc, j_index)


# Calculate performance metrics
performance_50 <- classification_metrics(lr_last_preds_m2, truth = patient_dropout, 
                                      predicted_prob_yes,
                                      estimate=predicted_class_50, 
                                      event_level="second") %>% mutate(cutoff=0.5)


# Calculate performance metrics
performance_38 <- classification_metrics(lr_last_preds_m2, truth = patient_dropout, 
                                      predicted_prob_yes,
                                      estimate=predicted_class_38, 
                                      event_level="second") %>% mutate(cutoff=0.38)


performance<-bind_rows(performance_38, performance_50) %>% 
             pivot_wider(id_cols = cutoff, 
                        names_from=.metric, values_from=.estimate) %>% 
             rename("sensitivity"=recall) %>% 
             mutate(across(everything()), round(.,3))


performance
# A tibble: 2 × 7
  cutoff accuracy precision sensitivity f_meas j_index roc_auc
   <dbl>    <dbl>     <dbl>       <dbl>  <dbl>   <dbl>   <dbl>
1   0.38    0.741     0.504       0.477  0.491   0.311   0.742
2   0.5     0.751     0.564       0.204  0.3     0.148   0.742
Code
write_csv(performance, here::here("data","model_performance_latest.csv"))

M2 Subgroup analysis

Code
# Calculate performance metrics, by subgroups
metrics_by_subgroup <- function(data, sex_c, state_c,pred_class){
  
  data %>% 
    filter(
      case_when(
        sex_c=="Male" ~ sex==sex_c,
        sex_c=="Female" ~ sex==sex_c,
        TRUE ~ TRUE)) %>% 
      filter(
        case_when(
          state_c=="Kano" ~ State==state_c,
          state_c=="Ogun" ~ State==state_c,
          TRUE ~ TRUE)) %>% 
   classification_metrics(., truth = patient_dropout,
                                      predicted_prob_yes,
                                      estimate=pred_class,
                                      event_level="second") %>%
                                      mutate(cutoff=pred_class,
                                             sex=sex_c,
                                             state=state_c) %>%
    mutate(sex=if_else(sex=="Male"|sex=="Female",sex,"both")) %>%
    mutate(state=if_else(state_c=="Kano"|state_c=="Ogun",state_c,"both"))
  }

# metrics_by_subgroup(lr_last_preds_m2,"Male","Kano","predicted_class_50")
# metrics_by_subgroup(lr_last_preds_m2,statgrid[1,])
# metrics_by_subgroup(lr_last_preds_m2, "predicted_class_50","Male","Kano")

statgrid<-tidyr::expand_grid(sex=c("Male","Female","both"),
                   state=c("Kano","Ogun","both"),
                   pred=c("predicted_class_50","predicted_class_38"))



stats_all_groups<-map(1:nrow(statgrid), function(x)
                    metrics_by_subgroup(lr_last_preds_m2,
                        statgrid$sex[x],statgrid$state[x],statgrid$pred[x])) %>% 
                        bind_rows()  %>% 
                    select(!.estimator) %>%
                    mutate(.metric=if_else(.metric=="recall","sensitivity",.metric)) %>% 
                    mutate(across(is.numeric, ~round(.,3))) %>% 
                    mutate(threshold=str_replace_all(cutoff,"predicted_class_","0."))



stats_all_groups_wide <-stats_all_groups %>% 
              pivot_wider(names_from=.metric, values_from=.estimate) %>% 
              select(state,sex,threshold,everything()) %>% 
              arrange(state,sex,threshold)

write_csv(stats_all_groups_wide, 
          here::here("data","model_subgroupstats_latest.csv"))

stats_all_groups_wide %>% 
  DT::datatable()

Subgroup analysis plot

Code
# stats_all_groups %>% 
#         select(state,sex,cutoff,everything()) %>% 
#         # filter(cutoff=="0.38") %>% 
#         unite(col="group",1:2,sep="_",remove=F) %>% 
#         pivot_wider(names_from=cutoff, values_from=.estimate, names_prefix = "cutoff_") %>% 
#         ggplot(aes(x=group)) +
#         geom_point(aes(y=cutoff_0.50, fill="0.50"), alpha=0.5, size=2) +
#         geom_point(aes(y=cutoff_0.38, fill="0.38"), alpha=0.5, size=2) +
#         scale_fill_manual(values = c("red", "green"), name = "Cutoff", 
#                   guide = guide_legend(reverse = TRUE)) +       
#         geom_segment(aes(x=group, xend=group,
#                          y=cutoff_0.38, 
#                          yend=cutoff_0.50), alpha=0.5) +
#         coord_flip() +
#         facet_wrap(~.metric,scales = "free_x",nrow=2) +
#         labs(title="Performance Metrics by State and Age Subgroups",
#              subtitle="By Cutoff Point", x="", y="")



stats_all_groups %>% 
        select(state,sex,threshold,everything()) %>% 
        # filter(cutoff=="0.38") %>% 
        unite(col="group",1:2,sep="_",remove=F) %>% 
        ggplot(aes(x=group,y=.estimate)) +
        geom_point(aes(shape=threshold, color=threshold), alpha=0.7, size=3) +
        coord_flip() +
        facet_wrap(~.metric,scales = "free_x",nrow=2) +
        labs(title="Performance Metrics for State and Age Subgroups",
             subtitle="By Probability Threshold", x="", y="")

Code
my_ggsave("fig7_subgroup_analyses",10)
Code
# # Extract the number of observations in each fold
# fold_sizes <- cv_folds %>%
#   mutate(
#     training_size = map_int(splits, ~ nrow(analysis(.))),
#     validation_size = map_int(splits, ~ nrow(assessment(.)))
#   )
# 
# # Print the results
# fold_sizes %>%
#   select(id, training_size, validation_size) %>% 
#   mutate(sum=training_size+validation_size)

# dta_model2 %>% 
#   janitor::tabyl(patient_dropout)

M2 Net Benefit Analysis

Code
lr_last_preds_m2 %>%
      reframe(tp=if_else(predicted_class_50==1 & patient_dropout==1, 1, 0),
         fp=if_else(predicted_class_50==1 & patient_dropout==0, 1, 0)) %>%
      summarise(tp=sum(tp),
                fp=sum(fp),
                n=nrow(.)) %>% t()
    [,1]
tp  1063
fp   822
n  19948
Code
# (4195/19948)-(6896/19948) * (0.2/(1-0.2))
# 
# (2489/19948)-(2446/19948) * (0.38/(1-0.38))

(1063/19948)-(822/19948) * (0.5/(1-0.5))
[1] 0.01208141

About same ROC as with validation set…

M2 Top features

Code
vi_fit<- last_lr_fit

vi_fit %>%
  vip(num_features = 30, mapping=aes(fill=Sign)) +
  scale_y_continuous(expand = c(0,0)) +
  theme_light() +
  labs(title="Top predictors for Full Model (90 day dropout risk per visit)",
  subtitle="Coefficients for Dropout event after each visit, using LASSO regularization") +
  theme(plot.title.position = "plot") +
  scale_fill_manual(values = c(palette[2],palette[1]))

Code
my_ggsave("fig5_top_predictors_latest", 10)

Appointment Count and Days Since Entry (Registration) seemed to have greatest impact.

Get the coefficients from this final fit as per docs

Code
last_lr_fit %>% broom::tidy() %>% 
 DT::datatable()

M2 Calibration plot

Code
lr_last_preds_m2 %>% 
  mutate(bin = cut(predicted_prob_yes, breaks = seq(0, 1, by = 0.05), 
                   include.lowest = TRUE)) %>%
  group_by(bin) %>%
  reframe(model=model,
    mean_pred = mean(predicted_prob_yes),
    observed = mean(as.numeric(patient_dropout) - 1)  # Convert factor to 0/1
  ) %>% 
ggplot(aes(x = mean_pred, y = observed)) +
  geom_point() +
  geom_line() +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "red") +
  labs(
    x = "Predicted Probability",
    y = "Observed Proportion",
    title = paste0("Calibration Plot Full Model (90 day dropout risk per visit)"))

Code
my_ggsave("fig4_model_calplot_latest",10)

Whats the distribution of estimated probabilities?

Code
lr_last_preds_m2 %>% 
  ggplot() +
  geom_histogram(aes(x=predicted_prob_yes))

Only 9% are greated than 50% likelihood. But we know theres an attrition rate twice as high.

Code
lr_last_preds_m2 %>% 
  janitor::tabyl(predicted_class_50)
 predicted_class_50     n    percent
                  0 18063 0.90550431
                  1  1885 0.09449569
Code
lr_last_preds_m2 %>% 
  mutate(over70=if_else(predicted_prob_yes>=0.7, 1, 0)) %>% 
    janitor::tabyl(over70)
 over70     n     percent
      0 19833 0.994235011
      1   115 0.005764989

Model 3: Dropouts Per Opportunity, Parsimonious

First prepare the dataset.

Code
dta_model3 <- dta_prep_all_models %>% 
  select(reg_id, appt_tally, nextdate_90_kept, event_date,
         event_year_month, days_since_entry,
         sex, age, event_year, days_since_last,
         #KEEP FROM HERE
         contains("med"), contains("bp_"), ends_with("baseline"),
         contains("p_step"), relationship_to_patient,
         has_phone, does_patient_have_diabetes,
         State,age_group, 
         HC_monthly_visitLoad,
         consent_to_send_sms_reminders,
         any_of(pop_params_top) # weather and market params
         ) %>% 
  # define the target population (baseline)
  # filter(appt_tally==1) %>% 
  # define the outcome
  mutate(patient_dropout=if_else(nextdate_90_kept==1, 
                                          0, 1)) %>% 
  group_by(reg_id) %>% 
  mutate(patient_dropout_ever=if_else(sum(patient_dropout)>0, 
                                          1, 0)) %>% 
  ungroup() %>% 
  mutate(patient_dropout=as.factor(patient_dropout)) %>% 
# PARSIMONIOUS VARIABLE SET
  select(reg_id, patient_dropout, patient_dropout_ever,
          starts_with("days_since"),
          appt_tally,event_year_month)

M3.1. Split into training and validation and test sets

Code
library(tidymodels)
set.seed(124)

# Extract unique groups
unique_groups <- unique(dta_model3$reg_id)

# Shuffle groups to ensure random distribution
shuffled_groups <- sample(unique_groups)

# Define the split ratio
split_ratio <- 0.8

# Determine split point
split_point <- floor(length(shuffled_groups) * split_ratio)

# Assign groups to training and validation sets
train_groups <- shuffled_groups[1:split_point]
test_groups <- shuffled_groups[(split_point + 1):length(shuffled_groups)]

# Create training and validation datasets based on group membership
training_data <- dta_model3 %>% filter(reg_id %in% train_groups)
testing_data <- dta_model3 %>% filter(reg_id %in% test_groups)




# Set up 5-fold cross-validation, grouped by the patient ID variable
cv_folds <- group_vfold_cv(training_data, v = 5, 
                           group=reg_id, #all visits for one patient go into same fold once the visit is sampled
                           strata=patient_dropout_ever) #same number of patients who eventually dropout go into each fold

M3.2. Build the model

Code
# with lasso regularization

lr_mod <- logistic_reg(penalty = tune(), mixture=1) %>% 
  set_engine("glmnet")

M3.3. Build the “recipe” to preprocess the data

Code
lr_recipe <- recipe(patient_dropout ~ ., data = dta_model3) %>% 
  update_role(reg_id, new_role = "ID") %>% 
  # step_date(event_date, features = c("dow", "month")) %>% 
  step_date(event_year_month, features = c("month")) %>%
  step_rm(patient_dropout_ever) %>%
  step_rm(event_year_month) %>%
# these two have an exponential descent, so log them
  step_log(appt_tally) %>% 
  # step_rm(exp_year_month) %>%
  # Step to impute missing values in factor_var (if NA is not a valid level)
  step_novel(all_factor_predictors(), new_level = "new") %>%
  step_unknown(all_factor_predictors(), new_level = "missing") %>%
  # Step to impute missing values in numeric variable x with mean
  step_impute_median(all_numeric_predictors()) %>% 
  # step_corr(all_numeric_predictors(), threshold = .5) %>%
  step_dummy(all_nominal_predictors()) %>% 
  # step_ns(sbp_change_from_baseline) %>%
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

M3.4. Combine model and recipe into a workflow

Code
lr_workflow<-workflow() %>% 
  add_model(lr_mod) %>% 
  add_recipe(lr_recipe)

M3.5. Create penalty values for tuning

Code
lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 20))

M3.6. apply penalty values to the workflow

Code
set.seed(3303)
library(glmnet)
# install.packages("glmnet")

lr_res <- lr_workflow %>% 
  tune_grid(resamples = cv_folds,
            grid = lr_reg_grid,
            control = control_grid(save_pred = TRUE),
            metrics = metric_set(roc_auc))

lr_plot <- 
  lr_res %>% 
  collect_metrics() %>% 
  ggplot(aes(x = penalty, y = mean)) + 
  geom_point() + 
  geom_line() + 
  ylab("Area under the ROC Curve") +
  scale_x_log10(labels = scales::label_number())

lr_plot

M3.7. Select best model penalty

Code
lr_res %>% 
  collect_metrics() %>% 
  arrange(penalty) %>% 
  rowid_to_column()
# A tibble: 20 × 8
   rowid  penalty .metric .estimator  mean     n std_err .config              
   <int>    <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
 1     1 0.0001   roc_auc binary     0.701     5 0.00384 Preprocessor1_Model01
 2     2 0.000144 roc_auc binary     0.701     5 0.00384 Preprocessor1_Model02
 3     3 0.000207 roc_auc binary     0.701     5 0.00384 Preprocessor1_Model03
 4     4 0.000298 roc_auc binary     0.701     5 0.00384 Preprocessor1_Model04
 5     5 0.000428 roc_auc binary     0.701     5 0.00384 Preprocessor1_Model05
 6     6 0.000616 roc_auc binary     0.701     5 0.00387 Preprocessor1_Model06
 7     7 0.000886 roc_auc binary     0.701     5 0.00389 Preprocessor1_Model07
 8     8 0.00127  roc_auc binary     0.701     5 0.00394 Preprocessor1_Model08
 9     9 0.00183  roc_auc binary     0.700     5 0.00395 Preprocessor1_Model09
10    10 0.00264  roc_auc binary     0.700     5 0.00408 Preprocessor1_Model10
11    11 0.00379  roc_auc binary     0.699     5 0.00411 Preprocessor1_Model11
12    12 0.00546  roc_auc binary     0.697     5 0.00412 Preprocessor1_Model12
13    13 0.00785  roc_auc binary     0.696     5 0.00405 Preprocessor1_Model13
14    14 0.0113   roc_auc binary     0.695     5 0.00403 Preprocessor1_Model14
15    15 0.0162   roc_auc binary     0.693     5 0.00400 Preprocessor1_Model15
16    16 0.0234   roc_auc binary     0.691     5 0.00408 Preprocessor1_Model16
17    17 0.0336   roc_auc binary     0.687     5 0.00400 Preprocessor1_Model17
18    18 0.0483   roc_auc binary     0.681     5 0.00396 Preprocessor1_Model18
19    19 0.0695   roc_auc binary     0.681     5 0.00396 Preprocessor1_Model19
20    20 0.1      roc_auc binary     0.681     5 0.00396 Preprocessor1_Model20
Code
# select optimal penalty
lr_best<-  lr_res %>% 
  select_best(metric="roc_auc")
  # collect_metrics() %>% 
  # arrange(penalty) %>% 
  # slice(12)

M3.8. ROC Curve for best fit

Code
lr_preds3 <- lr_res %>% 
  collect_predictions(parameters = lr_best) %>% 
  mutate(model_name="Model 3: Dropout Per Opp (Minimal Spec)")
  
lr_auc<- lr_preds3 %>% 
  roc_curve(patient_dropout, .pred_1, event_level = "second") %>% 
  mutate(model = "Logistic Regression") 

its… ok!

Model 3: Final Fit

Code
last_log_mod <-  logistic_reg(penalty = lr_best$penalty[1], mixture=1) %>% 
  set_engine("glmnet", importance="BIC")

# the last workflow
last_log_workflow <- lr_workflow %>% 
                    update_model(last_log_mod)


# last fit
last_lr_fit<- last_log_workflow %>% 
              fit(training_data) 

lr_last_preds_m3 <- predict(last_lr_fit, new_data = testing_data, type="prob") %>% 
                bind_cols(testing_data) %>% 
                mutate(patient_dropout=as.factor(patient_dropout)) %>% 
                 mutate(predicted_class_60 = as.factor(if_else(.pred_1 > 0.6, "1", "0"))) %>%
                mutate(predicted_class_50 = as.factor(if_else(.pred_1 > 0.5, "1", "0"))) %>%
                mutate(predicted_class_38 = as.factor(if_else(.pred_1 > 0.38, "1", "0"))) %>%
                mutate(predicted_class_20 = as.factor(if_else(.pred_1 > 0.2, "1", "0"))) %>%
                mutate(predicted_prob_yes = .pred_1, predicted_prob_no = .pred_0) %>% 
                mutate("model"="Model 3: Minimal Spec")

About same ROC as with validation set…

M3 Top features

Fitting the model to the full dataset…

Code
library(vip)
vi_fit<- last_lr_fit


vi_fit %>% 
  vip(num_features = 25, mapping=aes(fill=Sign)) +
  scale_y_continuous(expand = c(0,0)) +
  theme_light() +
  labs(title="Top predictors for Model3",
  subtitle="Logistic Reg for Dropout Event after each visit")

Appointment Count and Days Since Entry (Registration) seem STILL to have greatest impact.

Get the coefficients from this final fit as per docs

Code
vi_fit %>% tidy() %>% DT::datatable()

Model Calibration Comparison

Code
theme_set(theme_minimal())

make_cal_plot <-function(predictions, basedata){
  
calibration_data <- predictions %>%
    mutate(appt_tally_level =
      case_when(
        appt_tally==1 ~ "Visit = 1",
        appt_tally>1 & appt_tally < 6 ~ "Visit = 2 to 5",
        TRUE ~ "Visit = 6+"
      )
    ) %>% 
  mutate(bin = cut(predicted_prob_yes, breaks = seq(0, 1, by = 0.02), 
                   include.lowest = TRUE)) %>%
  group_by(appt_tally_level, bin) %>%
  reframe(model=model,
    mean_pred = mean(.pred_1),
    observed = mean(as.numeric(patient_dropout) - 1)  # Convert factor to 0/1
  )
ggplot(calibration_data, aes(x = mean_pred, y = observed)) +
  geom_point() +
  geom_line() +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "red") +
  labs(
    x = "Predicted Probability",
    y = "Observed Proportion",
    title = paste0("Calibration Plot ", unique(calibration_data$model))) +
  facet_wrap(~appt_tally_level, ncol=1)
}



make_cal_plot(lr_last_preds_m1, dta_model1)

Code
make_cal_plot(lr_last_preds_m2, dta_model2)

Code
make_cal_plot(lr_last_preds_m3, dta_model3)

Model Performance Comparison

Code
bind_rows(
  lr_last_preds_m1,
  lr_last_preds_m2,
  lr_last_preds_m3
    ) %>% 
    mutate(model =
      case_when(
        model=="Model 1" ~ str_wrap("Model 1: Return Failure Post Baseline",15),
        model=="Model 2" ~ str_wrap("Model 2: Dropout Per Opp (Full Spec)",15),
        model=="Model 3" ~ str_wrap("Model 3: Dropout Per Opp (Minimal Spec)",15),
      )
    ) %>% 
  group_by(model) %>% 
  roc_curve(patient_dropout, .pred_1, event_level = "second") %>% 
  autoplot() +
  labs(title="ROC Curve: sensitivity comparison of discrimination") +
  theme_set(mytheme) +
  theme(legend.key.size = unit(2, "cm"))

Code
my_ggsave("roc_curve_compare",9)

Show metrics for last fits

Code
my_metrics <- function(predictions){
# Calculate performance metrics
classification_metrics(predictions, truth = patient_dropout, 
                                      predicted_prob_yes,
                                      estimate=predicted_class_38, 
                                      event_level="second") %>% 
                                      mutate(model=unique(predictions$model))
  }


lr_last_all<-list(lr_last_preds_m1,
                   lr_last_preds_m2,
                   lr_last_preds_m3)


r3<-function(x){round(x,3)}


metrics_comparison<-map(lr_last_all,my_metrics) %>% 
                    bind_rows() %>% 
                  select(model,.metric, .estimate) %>% 
                  pivot_wider(names_from=.metric, values_from=.estimate) %>% 
                  rename("sensitivity"=recall) %>% 
                  mutate(across(where(is.double), r3))

write_csv(metrics_comparison, 
          here::here("data","model_comparison_latest.csv"))

metrics_comparison %>% 
  kable()
model accuracy precision sensitivity f_meas j_index roc_auc
Model 1: Return Failure At Baseline 0.756 0.494 0.313 0.383 0.211 0.735
Model 2: Full Spec 0.741 0.504 0.477 0.491 0.311 0.742
Model 3: Minimal Spec 0.703 0.438 0.487 0.461 0.266 0.704

Session info below.

Code
sessionInfo()
R version 4.3.3 (2024-02-29 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 11 x64 (build 22631)

Matrix products: default


locale:
[1] LC_COLLATE=en_US.UTF-8  LC_CTYPE=en_US.UTF-8    LC_MONETARY=en_US.UTF-8
[4] LC_NUMERIC=C            LC_TIME=en_US.UTF-8    

time zone: Europe/Berlin
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] patchwork_1.2.0    ROSE_0.0-4         themis_1.0.2       DT_0.33           
 [5] skimr_2.1.5        RColorBrewer_1.1-3 probably_1.0.3     splines2_0.5.3    
 [9] effectsize_0.8.7   kableExtra_1.4.0   readxl_1.4.3       glmnet_4.1-8      
[13] Matrix_1.6-5       summarytools_1.0.1 vip_0.4.1          ggcorrplot_0.1.4.1
[17] slider_0.3.1       table1_1.4.3       yardstick_1.3.1    workflowsets_1.1.0
[21] workflows_1.1.4    tune_1.2.1         rsample_1.2.1      recipes_1.0.10    
[25] parsnip_1.2.1      modeldata_1.3.0    infer_1.0.7        dials_1.2.1       
[29] scales_1.3.0       broom_1.0.5        tidymodels_1.2.0   lubridate_1.9.3   
[33] forcats_1.0.0      stringr_1.5.1      dplyr_1.1.4        purrr_1.0.2       
[37] readr_2.1.5        tidyr_1.3.1        tibble_3.2.1       ggplot2_3.5.1     
[41] tidyverse_2.0.0   

loaded via a namespace (and not attached):
  [1] rstudioapi_0.16.0   jsonlite_1.8.8      shape_1.4.6.1      
  [4] datawizard_0.10.0   magrittr_2.0.3      TH.data_1.1-2      
  [7] estimability_1.5    magick_2.8.3        farver_2.1.2       
 [10] rmarkdown_2.26      ragg_1.3.0          vctrs_0.6.5        
 [13] base64enc_0.1-3     janitor_2.2.0       htmltools_0.5.8.1  
 [16] cellranger_1.1.0    Formula_1.2-5       sass_0.4.9         
 [19] parallelly_1.37.1   bslib_0.7.0         htmlwidgets_1.6.4  
 [22] plyr_1.8.9          sandwich_3.1-0      cachem_1.1.0       
 [25] emmeans_1.10.1      zoo_1.8-12          lifecycle_1.0.4    
 [28] iterators_1.0.14    pkgconfig_2.0.3     R6_2.5.1           
 [31] fastmap_1.2.0       snakecase_0.11.1    future_1.33.2      
 [34] digest_0.6.35       colorspace_2.1-0    furrr_0.3.1        
 [37] rprojroot_2.0.4     warp_0.2.1          textshaping_0.3.7  
 [40] crosstalk_1.2.1     labeling_0.4.3      fansi_1.0.6        
 [43] timechange_0.3.0    compiler_4.3.3      here_1.0.1         
 [46] bit64_4.0.5         withr_3.0.0         pander_0.6.5       
 [49] clock_0.7.0         backports_1.4.1     MASS_7.3-60.0.1    
 [52] lava_1.8.0          tools_4.3.3         future.apply_1.11.2
 [55] nnet_7.3-19         glue_1.7.0          grid_4.3.3         
 [58] checkmate_2.3.1     reshape2_1.4.4      generics_0.1.3     
 [61] gtable_0.3.5        tzdb_0.4.0          class_7.3-22       
 [64] data.table_1.15.4   hms_1.1.3           xml2_1.3.6         
 [67] utf8_1.2.4          foreach_1.5.2       pillar_1.9.0       
 [70] vroom_1.6.5         splines_4.3.3       lhs_1.1.6          
 [73] pryr_0.1.6          lattice_0.22-6      bit_4.0.5          
 [76] survival_3.5-8      tidyselect_1.2.1    knitr_1.46         
 [79] svglite_2.1.3       xfun_0.43           hardhat_1.3.1      
 [82] rapportools_1.1     timeDate_4032.109   matrixStats_1.2.0  
 [85] stringi_1.8.4       DiceDesign_1.10     yaml_2.3.8         
 [88] evaluate_0.23       codetools_0.2-19    tcltk_4.3.3        
 [91] cli_3.6.2           rpart_4.1.23        xtable_1.8-4       
 [94] parameters_0.21.6   systemfonts_1.1.0   jquerylib_0.1.4    
 [97] repr_1.1.7          munsell_0.5.1       Rcpp_1.0.12        
[100] globals_0.16.3      coda_0.19-4.1       parallel_4.3.3     
[103] ellipsis_0.3.2      gower_1.0.1         bayestestR_0.13.2  
[106] GPfit_1.0-8         listenv_0.9.1       viridisLite_0.4.2  
[109] mvtnorm_1.2-4       ipred_0.9-14        prodlim_2023.08.28 
[112] crayon_1.5.2        insight_0.19.10     rlang_1.1.4        
[115] multcomp_1.4-25