Part one imported NHCI visit level data from the DHIS2 export, and described patient dropout patterns over time, location, and patient characteristics. Visit data were deduplicated, and quality of care was described.
Daily LGA-level weather data and weekly state-level market prices, each calculated as trailing 10-day averages. were also introduced. These data were merged into the visit-level data by matching on the “target end date” for the patient’s next appointment, defined as 37 days after the current visit. Thus, they act as instrumental variables for climate and economic conditions during the expected visit window.
In this section we use the above dataset to create a prediction model.
Implementer decisions should be made on how to define the outcome and put these predictions into patient followup practice:
Is it better to predict return within a “target” 37 day window, or return within a “timely” 90 day window of latest visit?
If a patient returns to care >90 days after their latest visit, are they still considered a dropout? Note this is about 17% of all followup visits.
If a patient returns to care 1-21 days after latest visit, is this considered a successful followup thats within the target window? Note this is about 5% of all followup visits.
Goals Summary
Use the tidymodels workflow for a series of three logistic regression models on patient dropout
Predict patient failure to return after baseline visit (zero visits) using all available variables
Predict patient dropout-per-opportunity within 90 days (inc. baseline visit and subsequent visits), using all available variables
Predict patient dropout-per-opportunity within 90 days (incl. baseline visit and subsequent visits), using a parsimonious selection of variables for maximum generalizability and transportability
Importing full dataset, we remove the censored observations. These include the last visit before a patient is marked Transferred or Died, or those visits after April 1, 2024 (ie no opportunity to finish).
Code
tots_timely_qual_cli_mkt<-readr::read_rds("tots_timely_qual_cli_mkt.rds")pop_params_top<-readr::read_lines("cli_mkt_params_for_model.txt")dta_prep_all_models<-tots_timely_qual_cli_mkt %>%# remove censored events filter(censor==0) %>%filter(registration_date>=ymd("2021-01-01")) %>%mutate(event_year_month=make_date(year=year(event_date), month=month(event_date), day="01")) %>%mutate(event_year=as.factor(year(event_date))) %>%mutate(LGA=str_remove_all(LGA, " Local Government Area")) %>%add_count(LGA, event_year_month, name="LGA_monthly_visitLoad") %>%add_count(HC, event_year_month, name="HC_monthly_visitLoad") %>%# here we remove all visits after a dropoutmutate(visits_since_dropout=replace_na(visits_since_dropout,0)) %>%filter(visits_since_dropout<2) %>%mutate(bp_both_na=as.factor(if_else(is.na(bp_diastole_mmhg) &is.na(bp_systole_mmhg), 1, 0))) %>%mutate(age_group =as.factor(cut(age,breaks =c(0, 30, 45, 60, 75, 110, 100000),include.lowest = T,labels =c("Other", "30-44","45-59", "60-74", "75-110", "Other")))) %>%mutate(med_any=as.factor(if_else(!is.na(med_amlodipine) |!is.na(med_telmisartan) |!is.na(med_losartan) |!is.na(med_hydrochlorothiazide) |!is.na(other_hypertension_medication),1, 0))) %>%# if there is any medication, assume it is 30 daysmutate(days_of_medication=as.integer(str_remove(days_of_medication, "days"))) %>%mutate(days_of_medication=if_else(med_any==1&is.na(days_of_medication), 30, days_of_medication)) %>%mutate(days_of_medication=replace_na(days_of_medication, 0)) %>%mutate(meds_other=if_else(is.na(other_hypertension_medication),"none","recorded")) %>%select(!c(starts_with("other_"))) %>%# convert to factors as appropriatemutate(across(starts_with("has_"), as.factor)) %>%mutate(across(starts_with("p_"), as.factor)) %>%mutate(across(starts_with("consent"), as.character)) %>%mutate(across(where(is.character), as.factor)) %>%arrange(reg_id, appt_tally) %>%group_by(reg_id) %>%# days since visit as rolling meanmutate(days_since_last_mean =map_dbl(seq_along(appt_tally), ~mean(days_since_last[1:.x], na.rm =TRUE))) %>%mutate(days_since_last_sd =map_dbl(seq_along(appt_tally), ~sd(days_since_last[1:.x], na.rm =TRUE))) %>%ungroup() %>%mutate(across(starts_with("days"), ~ifelse(is.nan(.), NA, .)))skimr::skim(dta_prep_all_models)
Data summary
Name
dta_prep_all_models
Number of rows
99740
Number of columns
115
_______________________
Column type frequency:
Date
11
factor
45
numeric
59
________________________
Group variables
None
Variable type: Date
skim_variable
n_missing
complete_rate
min
max
median
n_unique
registration_date
0
1.00
2021-01-01
2024-04-02
2022-05-20
1162
event_date
0
1.00
2021-01-01
2024-04-02
2022-10-17
1188
prev_event_date
28184
0.72
2021-01-01
2024-03-21
2022-10-18
1169
lead_event_date
13015
0.87
2021-01-04
2024-07-01
2022-12-19
1258
nextdate_37
0
1.00
2021-02-07
2024-05-09
2022-11-23
1188
nextdate_90
0
1.00
2021-04-01
2024-07-01
2023-01-15
1188
last_updated_on
0
1.00
2021-01-02
2024-07-16
2023-06-05
1242
scheduled_date
0
1.00
2012-12-13
2201-08-29
2022-12-30
1242
incident_date
0
1.00
2021-01-01
2024-07-14
2022-06-22
1153
exp_year_month
0
1.00
2021-04-01
2024-07-01
2023-01-01
40
event_year_month
0
1.00
2021-01-01
2024-04-01
2022-10-01
40
Variable type: factor
skim_variable
n_missing
complete_rate
ordered
n_unique
top_counts
reg_visit_id
0
1.00
FALSE
99740
0_1: 1, 1_1: 1, 10_: 1, 10_: 1
sex
37
1.00
FALSE
2
Fem: 75696, Mal: 24007
State
0
1.00
FALSE
2
Kan: 63363, Ogu: 36377
ncd_patient_status
0
1.00
FALSE
5
Act: 98792, Tra: 341, Die: 287, Ina: 264
stored_by
21670
0.78
FALSE
107
Hot: 6196, Gwa: 4919, Jae: 4896, Daw: 3816
created_by
0
1.00
FALSE
106
Aku: 14321, DAK: 7209, Rec: 4726, Rec: 4475
last_updated_by
0
1.00
FALSE
107
Rec: 7400, DAK: 7209, Rec: 5818, Rec: 5611
organisation_unit_name
0
1.00
FALSE
104
kn : 7981, kn : 7209, kn : 6029, kn : 5802
Nat
0
1.00
FALSE
1
fg : 99740
LGA
0
1.00
FALSE
42
Nas: 21269, Gwa: 9706, Dal: 5996, Abe: 5276
Ward
0
1.00
FALSE
100
kn: 7981, kn: 7209, kn: 6029, kn: 5802
HC
0
1.00
FALSE
105
kn: 7981, kn: 7209, kn: 6029, kn: 5802
FACILITY
93686
0.06
FALSE
10
og: 1193, og: 1135, og: 778, og: 683
program_status
0
1.00
FALSE
3
ACT: 99334, CAN: 362, COM: 44
event_status
0
1.00
FALSE
2
COM: 96653, ACT: 3087
organisation_unit
0
1.00
FALSE
105
Dxf: 7981, jpP: 7209, ehJ: 6029, xUO: 5802
med_losartan
61630
0.38
FALSE
2
Los: 32105, Los: 6005
hypertension_treated_in_the_past
601
0.99
FALSE
2
Fir: 83075, Was: 16064
med_hydrochlorothiazide
98817
0.01
FALSE
2
Hyd: 886, Hyd: 37
patient_followup_status
7902
0.92
FALSE
3
Ove: 73474, Upc: 17844, Cal: 520
consent_to_send_sms_reminders
83169
0.17
FALSE
1
1: 16571
does_patient_have_diabetes
6864
0.93
FALSE
2
No: 90967, Yes: 1909
relationship_to_patient
21169
0.79
FALSE
9
Son: 28699, Hus: 17704, Dau: 13760, Oth: 6426
newage
0
1.00
FALSE
96
50: 8330, 60: 7690, 70: 5594, 40: 4724
med_amlodipine
1991
0.98
FALSE
3
Aml: 89511, Aml: 8229, Oth: 9
med_telmisartan
99452
0.00
FALSE
2
Tel: 218, Tel: 70
does_patient_have_hypertension
0
1.00
FALSE
1
Yes: 99740
consent_to_record_data
0
1.00
FALSE
1
1: 99740
has_supporter
0
1.00
FALSE
2
1: 79370, 0: 20370
has_phone
0
1.00
FALSE
2
1: 81111, 0: 18629
p_step1_need
0
1.00
FALSE
2
0: 87868, 1: 11872
p_step2_need
0
1.00
FALSE
2
0: 78454, 1: 21286
p_step3_need
0
1.00
FALSE
2
0: 89438, 1: 10302
p_step4_need
0
1.00
FALSE
2
0: 96040, 1: 3700
p_step1_met
0
1.00
FALSE
2
0: 88426, 1: 11314
p_step2_met
0
1.00
FALSE
2
0: 86110, 1: 13630
p_step3_met
0
1.00
FALSE
2
0: 97972, 1: 1768
p_step4_met
0
1.00
FALSE
2
0: 99732, 1: 8
lga_37date
0
1.00
FALSE
23130
Nas: 107, Nas: 84, Nas: 78, Nas: 72
state_37date
0
1.00
FALSE
2352
Kan: 207, Kan: 196, Kan: 189, Kan: 171
event_year
0
1.00
FALSE
4
202: 41930, 202: 33676, 202: 17550, 202: 6584
bp_both_na
0
1.00
FALSE
2
0: 99490, 1: 250
age_group
0
1.00
FALSE
5
45-: 41423, 30-: 23564, 60-: 23143, 75-: 6049
med_any
0
1.00
FALSE
2
1: 98043, 0: 1697
meds_other
0
1.00
FALSE
2
non: 98575, rec: 1165
Variable type: numeric
skim_variable
n_missing
complete_rate
mean
sd
p0
p25
p50
p75
p100
hist
reg_id
0
1.00
16084.09
9338.81
0.00
7910.00
16024.50
24157.00
32332.00
▇▇▇▇▇
appt_tally
0
1.00
5.05
5.13
1.00
1.00
3.00
7.00
43.00
▇▁▁▁▁
total_patient_visits
0
1.00
11.18
8.24
1.00
4.00
10.00
16.00
45.00
▇▅▂▁▁
nextdate_37_kept
0
1.00
0.52
0.50
0.00
0.00
1.00
1.00
1.00
▇▁▁▁▇
nextdate_90_kept
0
1.00
0.74
0.44
0.00
0.00
1.00
1.00
1.00
▃▁▁▁▇
height_cm
94800
0.05
82.76
80.31
0.45
1.60
140.00
160.00
758.00
▇▅▁▁▁
bp_diastole_mmhg
647
0.99
84.92
13.28
1.00
77.00
83.00
92.00
200.00
▁▆▇▁▁
weight_kg
94421
0.05
72.91
15.63
28.00
61.00
70.00
82.00
210.00
▃▇▁▁▁
bmi
94797
0.05
152304.17
322447.99
0.00
27.00
41.90
269098.80
4296296.30
▇▁▁▁▁
bmi_measurement
94079
0.06
1.00
0.00
1.00
1.00
1.00
1.00
1.00
▁▁▇▁▁
days_of_medication
0
1.00
29.61
4.30
0.00
30.00
30.00
30.00
90.00
▁▇▁▁▁
blood_sugar_measurement
99478
0.00
1.00
0.00
1.00
1.00
1.00
1.00
1.00
▁▁▇▁▁
bp_systole_mmhg
290
1.00
143.16
20.05
2.00
130.00
140.00
155.00
300.00
▁▁▇▁▁
age
0
1.00
53.97
14.20
18.00
45.00
53.00
63.00
124.00
▂▇▃▁▁
days_since_entry
0
1.00
140.18
171.33
0.00
0.00
80.00
210.00
1170.00
▇▂▁▁▁
exp_year
0
1.00
2022.53
0.86
2021.00
2022.00
2023.00
2023.00
2024.00
▂▇▁▇▂
exp_month
0
1.00
6.51
3.41
1.00
4.00
6.00
9.00
12.00
▇▆▆▅▇
days_since_last
28184
0.72
37.14
17.53
7.00
28.00
30.00
42.00
772.00
▇▁▁▁▁
nextdate_90_dropout
0
1.00
0.26
0.44
0.00
0.00
0.00
1.00
1.00
▇▁▁▁▃
visits_since_dropout
0
1.00
0.26
0.44
0.00
0.00
0.00
1.00
1.00
▇▁▁▁▃
censor
0
1.00
0.00
0.00
0.00
0.00
0.00
0.00
0.00
▁▁▇▁▁
sbp_baseline
299
1.00
157.92
17.17
14.00
146.00
158.00
169.00
300.00
▁▁▇▁▁
dbp_baseline
443
1.00
92.48
13.30
1.00
85.00
92.00
100.00
200.00
▁▂▇▁▁
sbp_change_from_baseline
480
1.00
-14.74
19.60
-180.00
-28.00
-10.00
0.00
178.00
▁▂▇▁▁
dbp_change_from_baseline
913
0.99
-7.55
13.22
-143.00
-15.00
-3.00
0.00
130.00
▁▁▇▁▁
cli_max_temp_c
0
1.00
30.24
3.39
18.50
27.70
29.40
32.40
41.10
▁▃▇▃▁
cli_max_heat_index_c
0
1.00
33.43
3.26
19.00
31.10
33.50
35.60
43.60
▁▁▇▇▁
cli_rain_mm
0
1.00
4.02
7.39
0.00
0.00
0.60
5.20
154.80
▇▁▁▁▁
cli_rain_hours
0
1.00
4.87
6.00
0.00
0.00
2.00
8.00
24.00
▇▂▂▁▁
cli_max_temp_c_pre10day_mean
0
1.00
30.26
3.21
20.99
27.89
29.39
32.39
39.70
▁▆▇▃▂
cli_max_heat_index_c_pre10day_mean
0
1.00
33.44
2.89
22.46
31.43
33.49
35.27
42.35
▁▃▇▅▁
cli_rain_mm_pre10day_mean
0
1.00
4.02
4.39
0.00
0.11
2.87
6.35
42.86
▇▁▁▁▁
cli_rain_hours_pre10day_mean
0
1.00
4.77
4.64
0.00
0.40
3.70
7.60
20.60
▇▃▂▁▁
mkt_diesel
0
1.00
692.09
283.68
188.00
634.00
744.00
822.60
1470.00
▃▃▇▂▁
mkt_gasoline
0
1.00
313.80
189.21
161.00
179.00
199.00
488.00
1070.00
▇▁▂▁▁
mkt_maize_grain_white
0
1.00
307.89
136.43
151.54
212.00
267.03
322.41
814.00
▇▃▂▁▁
mkt_maize_grain_yellow
0
1.00
332.92
140.36
159.23
236.14
300.55
369.14
851.40
▇▅▂▁▁
mkt_millet_pearl
0
1.00
334.85
144.73
160.33
244.05
300.56
351.12
847.86
▇▆▁▁▁
mkt_rice_5_percent_broken
0
1.00
814.02
327.72
475.79
603.45
708.93
957.51
2160.50
▇▂▁▁▁
mkt_rice_milled
0
1.00
643.92
262.61
278.00
464.29
571.43
775.96
1641.30
▇▅▂▁▁
mkt_sorghum_brown
0
1.00
341.15
172.27
155.90
231.08
297.52
370.52
991.50
▇▃▁▁▁
mkt_sorghum_white
0
1.00
332.27
157.89
154.68
224.00
308.00
357.60
852.30
▇▆▁▁▁
mkt_yams
0
1.00
294.54
146.76
136.92
186.34
249.42
345.97
836.70
▇▃▁▁▁
mkt_usd_exchange_rate
0
1.00
0.00
0.00
0.00
0.00
0.00
0.00
0.00
▁▃▁▂▇
mkt_diesel_pre10day_mean
0
1.00
688.30
283.55
188.00
634.00
744.00
819.20
1458.00
▃▃▇▂▁
mkt_gasoline_pre10day_mean
0
1.00
310.70
185.39
160.76
176.60
199.00
488.04
1070.00
▇▁▂▁▁
mkt_maize_grain_white_pre10day_mean
0
1.00
306.30
133.74
153.85
218.24
265.02
322.41
808.69
▇▃▁▁▁
mkt_maize_grain_yellow_pre10day_mean
0
1.00
331.23
138.20
162.11
234.46
303.58
366.94
851.40
▇▆▂▁▁
mkt_millet_pearl_pre10day_mean
0
1.00
333.40
143.20
161.07
241.27
294.28
347.75
846.81
▇▆▁▁▁
mkt_rice_5_percent_broken_pre10day_mean
0
1.00
810.16
323.26
478.97
598.28
708.93
957.50
2157.05
▇▂▁▁▁
mkt_rice_milled_pre10day_mean
0
1.00
640.55
258.59
281.60
460.71
565.03
773.77
1636.41
▇▅▂▁▁
mkt_sorghum_brown_pre10day_mean
0
1.00
339.25
170.23
159.15
228.69
298.90
360.60
988.98
▇▃▁▁▁
mkt_sorghum_white_pre10day_mean
0
1.00
330.54
156.19
154.68
223.92
296.00
354.24
852.30
▇▆▁▁▁
mkt_yams_pre10day_mean
0
1.00
292.48
144.11
139.12
184.84
251.56
344.47
763.26
▇▃▁▁▁
mkt_usd_exchange_rate_pre10day_mean
0
1.00
0.00
0.00
0.00
0.00
0.00
0.00
0.00
▁▂▁▂▇
LGA_monthly_visitLoad
0
1.00
279.64
263.03
1.00
85.00
164.00
406.00
1108.00
▇▂▁▁▁
HC_monthly_visitLoad
0
1.00
130.25
125.51
1.00
37.00
75.00
185.00
532.00
▇▂▁▁▁
days_since_last_mean
28184
0.72
36.27
12.93
7.00
28.45
32.75
40.33
772.00
▇▁▁▁▁
days_since_last_sd
43611
0.56
10.85
10.05
0.00
3.39
9.70
15.89
522.55
▇▁▁▁▁
Model 1: Failure to Return after Baseline
First prepare the dataset.
Code
dta_model1 <- dta_prep_all_models %>%select(reg_id, appt_tally, nextdate_90_kept, event_year_month, total_patient_visits, sex, age, event_year, days_since_last,#KEEP FROM HEREcontains("med"), contains("bp_"), ends_with("baseline"),contains("p_step"), relationship_to_patient, has_phone, does_patient_have_diabetes, State, LGA, age_group, HC_monthly_visitLoad, consent_to_send_sms_reminders,any_of(pop_params_top) # weather and market params ) %>%# define the target population (baseline)filter(appt_tally==1) %>%# define the outcomemutate(patient_dropout=as.factor(if_else(total_patient_visits <2, 1, 0))) %>%select(!any_of(c("visits_since_dropout","nextdate_90_kept","total_patient_visits","registration_date")))
M1.1. Split into training and validation and test sets
Code
# library(tidymodels)set.seed(124)# Extract unique groupsunique_groups <-unique(dta_model1$reg_id)# Shuffle groups to ensure random distributionshuffled_groups <-sample(unique_groups)# Define the split ratiosplit_ratio <-0.8# Determine split pointsplit_point <-floor(length(shuffled_groups) * split_ratio)# Assign groups to training and validation setstrain_groups <- shuffled_groups[1:split_point]test_groups <- shuffled_groups[(split_point +1):length(shuffled_groups)]# Create training and validation datasets based on group membershiptraining_data <- dta_model1 %>%filter(reg_id %in% train_groups)testing_data <- dta_model1 %>%filter(reg_id %in% test_groups)# Set up 5-fold cross-validation, grouped by the patient ID variablecv_folds <-vfold_cv(training_data, v =5)
M1.2. Build the model
Code
# with lasso regularizationlr_mod <-logistic_reg(penalty =tune(), mixture=1) %>%set_engine("glmnet")
M1.3. Build the “recipe” to preprocess the data
Code
lr_recipe <-recipe(patient_dropout ~ ., data = training_data) %>%update_role(reg_id, new_role ="ID") %>%# step_date(event_date, features = c("dow", "month")) %>% step_date(event_year_month, features =c("year","month")) %>%# step_rm(event_date) %>% step_rm(event_year_month) %>%step_log(HC_monthly_visitLoad) %>%# step_rm(exp_year_month) %>%# Step to impute missing values in factor_var (if NA is not a valid level)step_novel(all_factor_predictors(), new_level ="new") %>%step_unknown(all_factor_predictors(), new_level ="missing") %>%# Step to impute missing values in numeric variable x with meanstep_impute_median(all_numeric_predictors()) %>%# step_corr(all_numeric_predictors(), threshold = .5) %>%step_dummy(all_nominal_predictors()) %>%step_ns(age) %>%# # step_ns(sbp_change_from_baseline) %>%step_zv(all_predictors()) %>%step_normalize(all_predictors())
M1.4. Add a workflow, combining the model and recipe
# Extract the glmnet model from the workflow and training data# final_fit <- fit(lr_workflow, data = appts_other)# # final_coefficients <- tidy(final_fit, penalty = lr_best$penalty[1], mixture=1)
its… ok!
But we can see how it would fit onto the test data now.
dta_model2 <- dta_prep_all_models %>%select(reg_id, appt_tally, nextdate_90_kept, sex, age, event_year_month, starts_with("days_since"),#KEEP FROM HEREcontains("med"), contains("bp_"), ends_with("baseline"),contains("p_step"), relationship_to_patient, has_phone, does_patient_have_diabetes,"htn_past"=hypertension_treated_in_the_past, LGA, State, age_group, HC_monthly_visitLoad, consent_to_send_sms_reminders,any_of(pop_params_top) # weather and market params ) %>%# define the target population (full)# filter(appt_tally==1) %>% # define the outcomemutate(patient_dropout=as.factor(if_else(nextdate_90_kept==1, 0, 1))) %>%group_by(reg_id) %>%mutate(patient_dropout_ever=if_else(sum(as.numeric(patient_dropout))>0, 1, 0)) %>%ungroup() %>%select(!any_of(c("visits_since_dropout","nextdate_90_kept","total_patient_visits","registration_date")))
M2.1. Split into training and validation and test sets
Code
library(tidymodels)set.seed(124)# Extract unique groupsunique_groups <-unique(dta_model2$reg_id)# Shuffle groups to ensure random distributionshuffled_groups <-sample(unique_groups)# Define the split ratiosplit_ratio <-0.8# Determine split pointsplit_point <-floor(length(shuffled_groups) * split_ratio)# Assign groups to training and validation setstrain_groups <- shuffled_groups[1:split_point]test_groups <- shuffled_groups[(split_point +1):length(shuffled_groups)]# Create training and validation datasets based on group membershiptraining_data <- dta_model2 %>%filter(reg_id %in% train_groups)testing_data <- dta_model2 %>%filter(reg_id %in% test_groups)# Set up 5-fold cross-validation, grouped by the patient ID variablecv_folds <-group_vfold_cv(training_data, v =5, group=reg_id, #all visits for one patient go into same fold once the visit is sampledstrata=patient_dropout_ever) #same number of patients who eventually dropout go into each fold
M2.2. Build the model
Code
# with lasso regularizationlr_mod <-logistic_reg(penalty =tune(), mixture=1) %>%set_engine("glmnet")
M2.3. Build the “recipe” to preprocess the data
Code
lr_recipe <-recipe(patient_dropout ~ ., data = training_data) %>%update_role(reg_id, new_role ="ID") %>%# step_date(event_date, features = c("dow", "month")) %>%step_date(event_year_month, features =c("month")) %>%# step_date(exp_year_month, features = c("year", "month")) %>%# step_rm(event_date) %>% step_rm(patient_dropout_ever) %>%step_rm(event_year_month) %>%step_log(appt_tally) %>%step_log(HC_monthly_visitLoad) %>%# step_rm(exp_year_month) %>%# Step to impute missing values in factor_var (if NA is not a valid level)step_novel(all_factor_predictors(), new_level ="new") %>%step_unknown(all_factor_predictors(), new_level ="missing") %>%# Step to impute missing values in numeric variable x with meanstep_impute_median(all_numeric_predictors()) %>%# step_rose(patient_dropout) %>% # this didnt help...# step_corr(all_numeric_predictors(), threshold = .5) %>%step_dummy(all_nominal_predictors()) %>%step_ns(age) %>%# # step_ns(sbp_change_from_baseline) %>%step_zv(all_predictors()) %>%step_normalize(all_predictors())
vi_fit<- last_lr_fitvi_fit %>%vip(num_features =30, mapping=aes(fill=Sign)) +scale_y_continuous(expand =c(0,0)) +theme_light() +labs(title="Top predictors for Full Model (90 day dropout risk per visit)",subtitle="Coefficients for Dropout event after each visit, using LASSO regularization") +theme(plot.title.position ="plot") +scale_fill_manual(values =c(palette[2],palette[1]))
Code
my_ggsave("fig5_top_predictors_latest", 10)
Appointment Count and Days Since Entry (Registration) seemed to have greatest impact.
Get the coefficients from this final fit as per docs
Code
last_lr_fit %>% broom::tidy() %>% DT::datatable()
M2 Calibration plot
Code
lr_last_preds_m2 %>%mutate(bin =cut(predicted_prob_yes, breaks =seq(0, 1, by =0.05), include.lowest =TRUE)) %>%group_by(bin) %>%reframe(model=model,mean_pred =mean(predicted_prob_yes),observed =mean(as.numeric(patient_dropout) -1) # Convert factor to 0/1 ) %>%ggplot(aes(x = mean_pred, y = observed)) +geom_point() +geom_line() +geom_abline(slope =1, intercept =0, linetype ="dashed", color ="red") +labs(x ="Predicted Probability",y ="Observed Proportion",title =paste0("Calibration Plot Full Model (90 day dropout risk per visit)"))
Code
my_ggsave("fig4_model_calplot_latest",10)
Whats the distribution of estimated probabilities?
M3.1. Split into training and validation and test sets
Code
library(tidymodels)set.seed(124)# Extract unique groupsunique_groups <-unique(dta_model3$reg_id)# Shuffle groups to ensure random distributionshuffled_groups <-sample(unique_groups)# Define the split ratiosplit_ratio <-0.8# Determine split pointsplit_point <-floor(length(shuffled_groups) * split_ratio)# Assign groups to training and validation setstrain_groups <- shuffled_groups[1:split_point]test_groups <- shuffled_groups[(split_point +1):length(shuffled_groups)]# Create training and validation datasets based on group membershiptraining_data <- dta_model3 %>%filter(reg_id %in% train_groups)testing_data <- dta_model3 %>%filter(reg_id %in% test_groups)# Set up 5-fold cross-validation, grouped by the patient ID variablecv_folds <-group_vfold_cv(training_data, v =5, group=reg_id, #all visits for one patient go into same fold once the visit is sampledstrata=patient_dropout_ever) #same number of patients who eventually dropout go into each fold
M3.2. Build the model
Code
# with lasso regularizationlr_mod <-logistic_reg(penalty =tune(), mixture=1) %>%set_engine("glmnet")
M3.3. Build the “recipe” to preprocess the data
Code
lr_recipe <-recipe(patient_dropout ~ ., data = dta_model3) %>%update_role(reg_id, new_role ="ID") %>%# step_date(event_date, features = c("dow", "month")) %>% step_date(event_year_month, features =c("month")) %>%step_rm(patient_dropout_ever) %>%step_rm(event_year_month) %>%# these two have an exponential descent, so log themstep_log(appt_tally) %>%# step_rm(exp_year_month) %>%# Step to impute missing values in factor_var (if NA is not a valid level)step_novel(all_factor_predictors(), new_level ="new") %>%step_unknown(all_factor_predictors(), new_level ="missing") %>%# Step to impute missing values in numeric variable x with meanstep_impute_median(all_numeric_predictors()) %>%# step_corr(all_numeric_predictors(), threshold = .5) %>%step_dummy(all_nominal_predictors()) %>%# step_ns(sbp_change_from_baseline) %>%step_zv(all_predictors()) %>%step_normalize(all_predictors())
library(vip)vi_fit<- last_lr_fitvi_fit %>%vip(num_features =25, mapping=aes(fill=Sign)) +scale_y_continuous(expand =c(0,0)) +theme_light() +labs(title="Top predictors for Model3",subtitle="Logistic Reg for Dropout Event after each visit")
Appointment Count and Days Since Entry (Registration) seem STILL to have greatest impact.
Get the coefficients from this final fit as per docs
---title: "Msc Thesis Pt2: Predicting NHCI Dropouts"subtitle: "Working document"format: html: code-fold: true code-tools: true embed-resources: true toc: true toc-location: left toc-title: Contents toc-depth: 3 toc-expand: 2editor: visualwarning: falseauthor: name: Brian ODonnell affiliation: Charité - Universitätsmedizin Berlin---[Part one](https://iambodo.github.io/projects/nhci_analyses/msc_thesis_pt1_desc.html) imported NHCI visit level data from the DHIS2 export, and described patient dropout patterns over time, location, and patient characteristics. Visit data were deduplicated, and quality of care was described.Daily LGA-level weather data and weekly state-level market prices, each calculated as trailing 10-day averages. were also introduced. These data were merged into the visit-level data by matching on the "target end date" for the patient's next appointment, defined as 37 days after the current visit. Thus, they act as instrumental variables for climate and economic conditions during the expected visit window.In this section we use the above dataset to create a prediction model.**Implementer decisions should be made on how to define the outcome and put these predictions into patient followup practice:**1. Is it better to predict return within a "target" 37 day window, or return within a "timely" 90 day window of latest visit?2. If a patient returns to care \>90 days after their latest visit, are they still considered a dropout? Note this is about 17% of all followup visits.3. If a patient returns to care 1-21 days after latest visit, is this considered a successful followup thats within the target window? Note this is about 5% of all followup visits.*Goals Summary*- Use the [tidymodels](https://www.tidymodels.org/start/case-study) workflow for a series of **three logistic regression models** on patient dropout- Predict patient failure to return after baseline visit (zero visits) using all available variables- Predict patient dropout-per-opportunity within 90 days (inc. baseline visit and subsequent visits), using all available variables- Predict patient dropout-per-opportunity within 90 days (incl. baseline visit and subsequent visits), using a parsimonious selection of variables for maximum generalizability and transportabilityThe model described in manuscript is ***MODEL 2***```{r}#| label: load-packages#| include: false#| echo: falselibrary(tidyverse)library(tidymodels)library(table1)library(slider)library(ggcorrplot)library(vip)library(summarytools)library(glmnet)library(readxl)library(kableExtra)library(effectsize)library(splines2)library(probably)library(RColorBrewer)library(skimr)library(DT)library(themis)library(ROSE)library(patchwork)palette <-brewer.pal(n =8, name ="Dark2")palette_reversed <-rev(palette)mytheme<-theme_minimal()+theme(panel.background =element_rect(fill ="white", color =NA),plot.background =element_rect(fill ="white", color =NA),legend.background =element_rect(fill ="white", color =NA),text =element_text(family="sans"),strip.background =element_rect(color ="grey70", linewidth =1))options(ggplot2.discrete.colour = palette, ggplot2.discrete.fill = palette)theme_set(mytheme)my_ggsave<-function(x, ...){ggsave(paste0(x,".png"), plot=last_plot(),path=here::here("figs"), dpi=200,width=ifelse(is.na(...),7, ...))ggsave(paste0(x,".tiff"), plot=last_plot(), path=here::here("figs"), dpi=200,width=ifelse(is.na(...),7, ...))}```# PREDICTION## Import, Cleaning, and SkimImporting full dataset, we remove the censored observations. These include the last visit before a patient is marked *Transferred* or *Died*, or those visits after April 1, 2024 (ie no opportunity to finish).```{r}tots_timely_qual_cli_mkt<-readr::read_rds("tots_timely_qual_cli_mkt.rds")pop_params_top<-readr::read_lines("cli_mkt_params_for_model.txt")dta_prep_all_models<-tots_timely_qual_cli_mkt %>%# remove censored events filter(censor==0) %>%filter(registration_date>=ymd("2021-01-01")) %>%mutate(event_year_month=make_date(year=year(event_date), month=month(event_date), day="01")) %>%mutate(event_year=as.factor(year(event_date))) %>%mutate(LGA=str_remove_all(LGA, " Local Government Area")) %>%add_count(LGA, event_year_month, name="LGA_monthly_visitLoad") %>%add_count(HC, event_year_month, name="HC_monthly_visitLoad") %>%# here we remove all visits after a dropoutmutate(visits_since_dropout=replace_na(visits_since_dropout,0)) %>%filter(visits_since_dropout<2) %>%mutate(bp_both_na=as.factor(if_else(is.na(bp_diastole_mmhg) &is.na(bp_systole_mmhg), 1, 0))) %>%mutate(age_group =as.factor(cut(age,breaks =c(0, 30, 45, 60, 75, 110, 100000),include.lowest = T,labels =c("Other", "30-44","45-59", "60-74", "75-110", "Other")))) %>%mutate(med_any=as.factor(if_else(!is.na(med_amlodipine) |!is.na(med_telmisartan) |!is.na(med_losartan) |!is.na(med_hydrochlorothiazide) |!is.na(other_hypertension_medication),1, 0))) %>%# if there is any medication, assume it is 30 daysmutate(days_of_medication=as.integer(str_remove(days_of_medication, "days"))) %>%mutate(days_of_medication=if_else(med_any==1&is.na(days_of_medication), 30, days_of_medication)) %>%mutate(days_of_medication=replace_na(days_of_medication, 0)) %>%mutate(meds_other=if_else(is.na(other_hypertension_medication),"none","recorded")) %>%select(!c(starts_with("other_"))) %>%# convert to factors as appropriatemutate(across(starts_with("has_"), as.factor)) %>%mutate(across(starts_with("p_"), as.factor)) %>%mutate(across(starts_with("consent"), as.character)) %>%mutate(across(where(is.character), as.factor)) %>%arrange(reg_id, appt_tally) %>%group_by(reg_id) %>%# days since visit as rolling meanmutate(days_since_last_mean =map_dbl(seq_along(appt_tally), ~mean(days_since_last[1:.x], na.rm =TRUE))) %>%mutate(days_since_last_sd =map_dbl(seq_along(appt_tally), ~sd(days_since_last[1:.x], na.rm =TRUE))) %>%ungroup() %>%mutate(across(starts_with("days"), ~ifelse(is.nan(.), NA, .)))skimr::skim(dta_prep_all_models)```## Model 1: Failure to Return after BaselineFirst prepare the dataset.```{r}dta_model1 <- dta_prep_all_models %>%select(reg_id, appt_tally, nextdate_90_kept, event_year_month, total_patient_visits, sex, age, event_year, days_since_last,#KEEP FROM HEREcontains("med"), contains("bp_"), ends_with("baseline"),contains("p_step"), relationship_to_patient, has_phone, does_patient_have_diabetes, State, LGA, age_group, HC_monthly_visitLoad, consent_to_send_sms_reminders,any_of(pop_params_top) # weather and market params ) %>%# define the target population (baseline)filter(appt_tally==1) %>%# define the outcomemutate(patient_dropout=as.factor(if_else(total_patient_visits <2, 1, 0))) %>%select(!any_of(c("visits_since_dropout","nextdate_90_kept","total_patient_visits","registration_date")))```### M1.1. Split into training and validation and test sets```{r}# library(tidymodels)set.seed(124)# Extract unique groupsunique_groups <-unique(dta_model1$reg_id)# Shuffle groups to ensure random distributionshuffled_groups <-sample(unique_groups)# Define the split ratiosplit_ratio <-0.8# Determine split pointsplit_point <-floor(length(shuffled_groups) * split_ratio)# Assign groups to training and validation setstrain_groups <- shuffled_groups[1:split_point]test_groups <- shuffled_groups[(split_point +1):length(shuffled_groups)]# Create training and validation datasets based on group membershiptraining_data <- dta_model1 %>%filter(reg_id %in% train_groups)testing_data <- dta_model1 %>%filter(reg_id %in% test_groups)# Set up 5-fold cross-validation, grouped by the patient ID variablecv_folds <-vfold_cv(training_data, v =5) ```### M1.2. Build the model```{r}# with lasso regularizationlr_mod <-logistic_reg(penalty =tune(), mixture=1) %>%set_engine("glmnet")```### M1.3. Build the "recipe" to preprocess the data```{r}lr_recipe <-recipe(patient_dropout ~ ., data = training_data) %>%update_role(reg_id, new_role ="ID") %>%# step_date(event_date, features = c("dow", "month")) %>% step_date(event_year_month, features =c("year","month")) %>%# step_rm(event_date) %>% step_rm(event_year_month) %>%step_log(HC_monthly_visitLoad) %>%# step_rm(exp_year_month) %>%# Step to impute missing values in factor_var (if NA is not a valid level)step_novel(all_factor_predictors(), new_level ="new") %>%step_unknown(all_factor_predictors(), new_level ="missing") %>%# Step to impute missing values in numeric variable x with meanstep_impute_median(all_numeric_predictors()) %>%# step_corr(all_numeric_predictors(), threshold = .5) %>%step_dummy(all_nominal_predictors()) %>%step_ns(age) %>%# # step_ns(sbp_change_from_baseline) %>%step_zv(all_predictors()) %>%step_normalize(all_predictors())```### M1.4. Add a workflow, combining the model and recipe```{r}lr_workflow<-workflow() %>%add_model(lr_mod) %>%add_recipe(lr_recipe)```### M1.5 Create penalty values for tuning```{r}lr_reg_grid <-tibble(penalty =10^seq(-4, -1, length.out =20))```### M1.6 Apply penalty values to the workflow```{r}set.seed(3303)# library(glmnet)# install.packages("glmnet")lr_res <- lr_workflow %>%tune_grid(resamples = cv_folds,grid = lr_reg_grid,control =control_grid(save_pred =TRUE),metrics =metric_set(roc_auc))lr_plot <- lr_res %>%collect_metrics() %>%ggplot(aes(x = penalty, y = mean)) +geom_point() +geom_line() +ylab("Area under the ROC Curve") +scale_x_log10(labels = scales::label_number())lr_plot```### M1.7 Select best modelLets try to find the best model...```{r}lr_res %>%collect_metrics() %>%arrange(penalty) %>%rowid_to_column()# select optimal penaltylr_best<- lr_res %>%select_best(metric="roc_auc")# collect_metrics() %>% # arrange(penalty) %>% # slice(12)# # lr_res %>% # collect_metrics()```### M1.8 ROC curve for best fitOK, now lets see how this ROC curve looks for the best fit log model....```{r}lr_preds1 <- lr_res %>%collect_predictions(parameters = lr_best) %>%mutate(model_name="Model 1: First Visit Return Failure")lr_auc<- lr_preds1 %>%roc_curve(patient_dropout, .pred_1, event_level ="second") %>%mutate(model ="Logistic Regression") autoplot(lr_auc)# Extract the glmnet model from the workflow and training data# final_fit <- fit(lr_workflow, data = appts_other)# # final_coefficients <- tidy(final_fit, penalty = lr_best$penalty[1], mixture=1)```its... ok!But we can see how it would fit onto the test data now.### Model 1: Final Fit```{r}last_log_mod <-logistic_reg(penalty = lr_best$penalty[1], mixture=1) %>%set_engine("glmnet", importance="BIC")# the last workflowlast_log_workflow <- lr_workflow %>%update_model(last_log_mod)# last fitlast_lr_fit<- last_log_workflow %>%fit(training_data)# show metricslr_last_preds_m1 <-predict(last_lr_fit, new_data = testing_data, type="prob") %>%bind_cols(testing_data) %>%mutate(patient_dropout=as.factor(patient_dropout)) %>%mutate(predicted_class_60 =as.factor(if_else(.pred_1 >0.6, "1", "0"))) %>%mutate(predicted_class_50 =as.factor(if_else(.pred_1 >0.5, "1", "0"))) %>%mutate(predicted_class_38 =as.factor(if_else(.pred_1 >0.38, "1", "0"))) %>%mutate(predicted_class_20 =as.factor(if_else(.pred_1 >0.2, "1", "0"))) %>%mutate(predicted_prob_yes = .pred_1, predicted_prob_no = .pred_0) %>%mutate("model"="Model 1: Return Failure At Baseline")```About same ROC as with validation set...### M1 Top features```{r}vi_fit<- last_lr_fit %>%extract_fit_parsnip() # vi_test<-vi_fit %>% pluck("fit")# class(vi_test)vi_fit %>%vip(num_features =25, mapping=aes(fill=Sign)) +scale_y_continuous(expand =c(0,0)) +theme_light() +labs(title="Top predictors for Model1",subtitle="Logistic Reg for Dropout Event after first visit")# pal <- palette.colors(2, palette = "Okabe-Ito") # colorblind friendly palette# vip(vi_fit, num_features = 20, # Figure 3# geom = "point", horizontal = FALSE, mapping = aes(color = Sign)) +# scale_color_manual(values = unname(pal)) +# theme_light() +# theme(axis.text.x = element_text(angle = 45, hjust = 1)) +# labs(title="Variable Importance in Logistic Regression Model")```LGAs seem to have most impact.Get the coefficients from this final fit [as per docs](https://www.tidymodels.org/learn/models/coefficients/#more-complex-a-glmnet-model)```{r}vi_fit %>%tidy() %>% DT::datatable()``````{r}calibration_data <- lr_last_preds_m1 %>%mutate(bin =cut(.pred_1, breaks =seq(0, 1, by =0.05), include.lowest =TRUE)) %>%group_by(bin) %>%summarize(mean_pred =mean(.pred_1),observed =mean(as.numeric(patient_dropout) -1) # Convert factor to 0/1 )ggplot(calibration_data, aes(x = mean_pred, y = observed)) +geom_point() +geom_line() +geom_abline(slope =1, intercept =0, linetype ="dashed", color ="red") +labs(x ="Predicted Probability",y ="Observed Proportion",title ="Calibration Plot" ) +theme_minimal()```## Model 2: Dropouts Per Opportunity, Full SpecsFirst prepare the dataset.```{r}dta_model2 <- dta_prep_all_models %>%select(reg_id, appt_tally, nextdate_90_kept, sex, age, event_year_month, starts_with("days_since"),#KEEP FROM HEREcontains("med"), contains("bp_"), ends_with("baseline"),contains("p_step"), relationship_to_patient, has_phone, does_patient_have_diabetes,"htn_past"=hypertension_treated_in_the_past, LGA, State, age_group, HC_monthly_visitLoad, consent_to_send_sms_reminders,any_of(pop_params_top) # weather and market params ) %>%# define the target population (full)# filter(appt_tally==1) %>% # define the outcomemutate(patient_dropout=as.factor(if_else(nextdate_90_kept==1, 0, 1))) %>%group_by(reg_id) %>%mutate(patient_dropout_ever=if_else(sum(as.numeric(patient_dropout))>0, 1, 0)) %>%ungroup() %>%select(!any_of(c("visits_since_dropout","nextdate_90_kept","total_patient_visits","registration_date")))```### M2.1. Split into training and validation and test sets```{r}library(tidymodels)set.seed(124)# Extract unique groupsunique_groups <-unique(dta_model2$reg_id)# Shuffle groups to ensure random distributionshuffled_groups <-sample(unique_groups)# Define the split ratiosplit_ratio <-0.8# Determine split pointsplit_point <-floor(length(shuffled_groups) * split_ratio)# Assign groups to training and validation setstrain_groups <- shuffled_groups[1:split_point]test_groups <- shuffled_groups[(split_point +1):length(shuffled_groups)]# Create training and validation datasets based on group membershiptraining_data <- dta_model2 %>%filter(reg_id %in% train_groups)testing_data <- dta_model2 %>%filter(reg_id %in% test_groups)# Set up 5-fold cross-validation, grouped by the patient ID variablecv_folds <-group_vfold_cv(training_data, v =5, group=reg_id, #all visits for one patient go into same fold once the visit is sampledstrata=patient_dropout_ever) #same number of patients who eventually dropout go into each fold```### M2.2. Build the model```{r}# with lasso regularizationlr_mod <-logistic_reg(penalty =tune(), mixture=1) %>%set_engine("glmnet")```### M2.3. Build the "recipe" to preprocess the data```{r}lr_recipe <-recipe(patient_dropout ~ ., data = training_data) %>%update_role(reg_id, new_role ="ID") %>%# step_date(event_date, features = c("dow", "month")) %>%step_date(event_year_month, features =c("month")) %>%# step_date(exp_year_month, features = c("year", "month")) %>%# step_rm(event_date) %>% step_rm(patient_dropout_ever) %>%step_rm(event_year_month) %>%step_log(appt_tally) %>%step_log(HC_monthly_visitLoad) %>%# step_rm(exp_year_month) %>%# Step to impute missing values in factor_var (if NA is not a valid level)step_novel(all_factor_predictors(), new_level ="new") %>%step_unknown(all_factor_predictors(), new_level ="missing") %>%# Step to impute missing values in numeric variable x with meanstep_impute_median(all_numeric_predictors()) %>%# step_rose(patient_dropout) %>% # this didnt help...# step_corr(all_numeric_predictors(), threshold = .5) %>%step_dummy(all_nominal_predictors()) %>%step_ns(age) %>%# # step_ns(sbp_change_from_baseline) %>%step_zv(all_predictors()) %>%step_normalize(all_predictors())```### M2.4 Combining the model and recipe into workflow```{r}lr_workflow<-workflow() %>%add_model(lr_mod) %>%add_recipe(lr_recipe)```### M2.5. Penalty values for tuning```{r}lr_reg_grid <-tibble(penalty =10^seq(-4, -1, length.out =20))```### M2.6. Apply penalty values to the workflow```{r}set.seed(3303)# library(glmnet)# install.packages("glmnet")lr_res <- lr_workflow %>%tune_grid(resamples = cv_folds,grid = lr_reg_grid,control =control_grid(save_pred =TRUE),metrics =metric_set(roc_auc))lr_plot <- lr_res %>%collect_metrics() %>%ggplot(aes(x = penalty, y = mean)) +geom_point() +geom_line() +ylab("Area under the ROC Curve") +scale_x_log10(labels = scales::label_number()) +geom_vline(xintercept=lr_best$penalty, linetype="dashed", color="blue") +annotate(geom="text", x=lr_best$penalty+0.0008, y=0.721, color="blue", size=4,label=paste0("Optimum = ",round(lr_best$penalty, 6))) +labs(title="Tune grid of penalty values")lr_plot```### M2.7. Select best model```{r}lr_res %>%collect_metrics() %>%arrange(penalty) %>%rowid_to_column()# select optimal penaltylr_best<- lr_res %>%select_best(metric="roc_auc")# collect_metrics() %>% # arrange(penalty) %>% # slice(12)lr_best```### M2.8. ROC curve for best fit```{r}lr_preds2 <- lr_res %>%collect_predictions(parameters = lr_best) %>%mutate(model_name="Model 2: Dropout Per Opp (Full Spec)")lr_auc<- lr_preds2 %>%roc_curve(patient_dropout, .pred_1, event_level ="second") %>%mutate(model ="Logistic Regression") roc_plot<-lr_aucautoplot(lr_auc)``````{r}library(patchwork)roc_plot<-autoplot(lr_auc) +labs(title="ROC Curve") lr_plot + roc_plot +plot_annotation(tag_levels ='A') +theme_minimal()my_ggsave("fig6_roc_penalty", 14)```its... ok!But we can see how it would fit onto the test data now.### Model 2: Final Fit```{r}last_log_mod <-logistic_reg(penalty = lr_best$penalty[1], mixture=1) %>%set_engine("glmnet", importance="BIC")# the last workflowlast_log_workflow <- lr_workflow %>%update_model(last_log_mod)# last fitlast_lr_fit<- last_log_workflow %>%fit(training_data) lr_last_preds_m2 <-predict(last_lr_fit, new_data = testing_data, type="prob") %>%bind_cols(testing_data) %>%mutate(patient_dropout=as.factor(patient_dropout)) %>%mutate(predicted_class_60 =as.factor(if_else(.pred_1 >0.6, "1", "0"))) %>%mutate(predicted_class_50 =as.factor(if_else(.pred_1 >0.5, "1", "0"))) %>%mutate(predicted_class_38 =as.factor(if_else(.pred_1 >0.38, "1", "0"))) %>%mutate(predicted_class_20 =as.factor(if_else(.pred_1 >0.2, "1", "0"))) %>%mutate(predicted_prob_yes = .pred_1, predicted_prob_no = .pred_0) %>%mutate("model"="Model 2: Full Spec")```### Model 2: Performance Metrics```{r}# Define classification metricsclassification_metrics <-metric_set(accuracy, precision, recall, f_meas, roc_auc, j_index)# Calculate performance metricsperformance_50 <-classification_metrics(lr_last_preds_m2, truth = patient_dropout, predicted_prob_yes,estimate=predicted_class_50, event_level="second") %>%mutate(cutoff=0.5)# Calculate performance metricsperformance_38 <-classification_metrics(lr_last_preds_m2, truth = patient_dropout, predicted_prob_yes,estimate=predicted_class_38, event_level="second") %>%mutate(cutoff=0.38)performance<-bind_rows(performance_38, performance_50) %>%pivot_wider(id_cols = cutoff, names_from=.metric, values_from=.estimate) %>%rename("sensitivity"=recall) %>%mutate(across(everything()), round(.,3))performancewrite_csv(performance, here::here("data","model_performance_latest.csv"))```### M2 Subgroup analysis```{r}# Calculate performance metrics, by subgroupsmetrics_by_subgroup <-function(data, sex_c, state_c,pred_class){ data %>%filter(case_when( sex_c=="Male"~ sex==sex_c, sex_c=="Female"~ sex==sex_c,TRUE~TRUE)) %>%filter(case_when( state_c=="Kano"~ State==state_c, state_c=="Ogun"~ State==state_c,TRUE~TRUE)) %>%classification_metrics(., truth = patient_dropout, predicted_prob_yes,estimate=pred_class,event_level="second") %>%mutate(cutoff=pred_class,sex=sex_c,state=state_c) %>%mutate(sex=if_else(sex=="Male"|sex=="Female",sex,"both")) %>%mutate(state=if_else(state_c=="Kano"|state_c=="Ogun",state_c,"both")) }# metrics_by_subgroup(lr_last_preds_m2,"Male","Kano","predicted_class_50")# metrics_by_subgroup(lr_last_preds_m2,statgrid[1,])# metrics_by_subgroup(lr_last_preds_m2, "predicted_class_50","Male","Kano")statgrid<-tidyr::expand_grid(sex=c("Male","Female","both"),state=c("Kano","Ogun","both"),pred=c("predicted_class_50","predicted_class_38"))stats_all_groups<-map(1:nrow(statgrid), function(x)metrics_by_subgroup(lr_last_preds_m2, statgrid$sex[x],statgrid$state[x],statgrid$pred[x])) %>%bind_rows() %>%select(!.estimator) %>%mutate(.metric=if_else(.metric=="recall","sensitivity",.metric)) %>%mutate(across(is.numeric, ~round(.,3))) %>%mutate(threshold=str_replace_all(cutoff,"predicted_class_","0."))stats_all_groups_wide <-stats_all_groups %>%pivot_wider(names_from=.metric, values_from=.estimate) %>%select(state,sex,threshold,everything()) %>%arrange(state,sex,threshold)write_csv(stats_all_groups_wide, here::here("data","model_subgroupstats_latest.csv"))stats_all_groups_wide %>% DT::datatable()```Subgroup analysis plot```{r}# stats_all_groups %>% # select(state,sex,cutoff,everything()) %>% # # filter(cutoff=="0.38") %>% # unite(col="group",1:2,sep="_",remove=F) %>% # pivot_wider(names_from=cutoff, values_from=.estimate, names_prefix = "cutoff_") %>% # ggplot(aes(x=group)) +# geom_point(aes(y=cutoff_0.50, fill="0.50"), alpha=0.5, size=2) +# geom_point(aes(y=cutoff_0.38, fill="0.38"), alpha=0.5, size=2) +# scale_fill_manual(values = c("red", "green"), name = "Cutoff", # guide = guide_legend(reverse = TRUE)) + # geom_segment(aes(x=group, xend=group,# y=cutoff_0.38, # yend=cutoff_0.50), alpha=0.5) +# coord_flip() +# facet_wrap(~.metric,scales = "free_x",nrow=2) +# labs(title="Performance Metrics by State and Age Subgroups",# subtitle="By Cutoff Point", x="", y="")stats_all_groups %>%select(state,sex,threshold,everything()) %>%# filter(cutoff=="0.38") %>% unite(col="group",1:2,sep="_",remove=F) %>%ggplot(aes(x=group,y=.estimate)) +geom_point(aes(shape=threshold, color=threshold), alpha=0.7, size=3) +coord_flip() +facet_wrap(~.metric,scales ="free_x",nrow=2) +labs(title="Performance Metrics for State and Age Subgroups",subtitle="By Probability Threshold", x="", y="")my_ggsave("fig7_subgroup_analyses",10)``````{r}# # Extract the number of observations in each fold# fold_sizes <- cv_folds %>%# mutate(# training_size = map_int(splits, ~ nrow(analysis(.))),# validation_size = map_int(splits, ~ nrow(assessment(.)))# )# # # Print the results# fold_sizes %>%# select(id, training_size, validation_size) %>% # mutate(sum=training_size+validation_size)# dta_model2 %>% # janitor::tabyl(patient_dropout)```### M2 Net Benefit Analysis```{r}lr_last_preds_m2 %>%reframe(tp=if_else(predicted_class_50==1& patient_dropout==1, 1, 0),fp=if_else(predicted_class_50==1& patient_dropout==0, 1, 0)) %>%summarise(tp=sum(tp),fp=sum(fp),n=nrow(.)) %>%t()# (4195/19948)-(6896/19948) * (0.2/(1-0.2))# # (2489/19948)-(2446/19948) * (0.38/(1-0.38))(1063/19948)-(822/19948) * (0.5/(1-0.5))```About same ROC as with validation set...### M2 Top features```{r}vi_fit<- last_lr_fitvi_fit %>%vip(num_features =30, mapping=aes(fill=Sign)) +scale_y_continuous(expand =c(0,0)) +theme_light() +labs(title="Top predictors for Full Model (90 day dropout risk per visit)",subtitle="Coefficients for Dropout event after each visit, using LASSO regularization") +theme(plot.title.position ="plot") +scale_fill_manual(values =c(palette[2],palette[1]))my_ggsave("fig5_top_predictors_latest", 10)```Appointment Count and Days Since Entry (Registration) seemed to have greatest impact.Get the coefficients from this final fit [as per docs](https://www.tidymodels.org/learn/models/coefficients/#more-complex-a-glmnet-model)```{r}last_lr_fit %>% broom::tidy() %>% DT::datatable()```### M2 Calibration plot```{r}lr_last_preds_m2 %>%mutate(bin =cut(predicted_prob_yes, breaks =seq(0, 1, by =0.05), include.lowest =TRUE)) %>%group_by(bin) %>%reframe(model=model,mean_pred =mean(predicted_prob_yes),observed =mean(as.numeric(patient_dropout) -1) # Convert factor to 0/1 ) %>%ggplot(aes(x = mean_pred, y = observed)) +geom_point() +geom_line() +geom_abline(slope =1, intercept =0, linetype ="dashed", color ="red") +labs(x ="Predicted Probability",y ="Observed Proportion",title =paste0("Calibration Plot Full Model (90 day dropout risk per visit)"))my_ggsave("fig4_model_calplot_latest",10)```Whats the distribution of estimated probabilities?```{r}lr_last_preds_m2 %>%ggplot() +geom_histogram(aes(x=predicted_prob_yes))```Only 9% are greated than 50% likelihood. But we know theres an attrition rate twice as high.```{r}lr_last_preds_m2 %>% janitor::tabyl(predicted_class_50)lr_last_preds_m2 %>%mutate(over70=if_else(predicted_prob_yes>=0.7, 1, 0)) %>% janitor::tabyl(over70)```## Model 3: Dropouts Per Opportunity, ParsimoniousFirst prepare the dataset.```{r}dta_model3 <- dta_prep_all_models %>%select(reg_id, appt_tally, nextdate_90_kept, event_date, event_year_month, days_since_entry, sex, age, event_year, days_since_last,#KEEP FROM HEREcontains("med"), contains("bp_"), ends_with("baseline"),contains("p_step"), relationship_to_patient, has_phone, does_patient_have_diabetes, State,age_group, HC_monthly_visitLoad, consent_to_send_sms_reminders,any_of(pop_params_top) # weather and market params ) %>%# define the target population (baseline)# filter(appt_tally==1) %>% # define the outcomemutate(patient_dropout=if_else(nextdate_90_kept==1, 0, 1)) %>%group_by(reg_id) %>%mutate(patient_dropout_ever=if_else(sum(patient_dropout)>0, 1, 0)) %>%ungroup() %>%mutate(patient_dropout=as.factor(patient_dropout)) %>%# PARSIMONIOUS VARIABLE SETselect(reg_id, patient_dropout, patient_dropout_ever,starts_with("days_since"), appt_tally,event_year_month)```### M3.1. Split into training and validation and test sets```{r}library(tidymodels)set.seed(124)# Extract unique groupsunique_groups <-unique(dta_model3$reg_id)# Shuffle groups to ensure random distributionshuffled_groups <-sample(unique_groups)# Define the split ratiosplit_ratio <-0.8# Determine split pointsplit_point <-floor(length(shuffled_groups) * split_ratio)# Assign groups to training and validation setstrain_groups <- shuffled_groups[1:split_point]test_groups <- shuffled_groups[(split_point +1):length(shuffled_groups)]# Create training and validation datasets based on group membershiptraining_data <- dta_model3 %>%filter(reg_id %in% train_groups)testing_data <- dta_model3 %>%filter(reg_id %in% test_groups)# Set up 5-fold cross-validation, grouped by the patient ID variablecv_folds <-group_vfold_cv(training_data, v =5, group=reg_id, #all visits for one patient go into same fold once the visit is sampledstrata=patient_dropout_ever) #same number of patients who eventually dropout go into each fold```### M3.2. Build the model```{r}# with lasso regularizationlr_mod <-logistic_reg(penalty =tune(), mixture=1) %>%set_engine("glmnet")```### M3.3. Build the "recipe" to preprocess the data```{r}lr_recipe <-recipe(patient_dropout ~ ., data = dta_model3) %>%update_role(reg_id, new_role ="ID") %>%# step_date(event_date, features = c("dow", "month")) %>% step_date(event_year_month, features =c("month")) %>%step_rm(patient_dropout_ever) %>%step_rm(event_year_month) %>%# these two have an exponential descent, so log themstep_log(appt_tally) %>%# step_rm(exp_year_month) %>%# Step to impute missing values in factor_var (if NA is not a valid level)step_novel(all_factor_predictors(), new_level ="new") %>%step_unknown(all_factor_predictors(), new_level ="missing") %>%# Step to impute missing values in numeric variable x with meanstep_impute_median(all_numeric_predictors()) %>%# step_corr(all_numeric_predictors(), threshold = .5) %>%step_dummy(all_nominal_predictors()) %>%# step_ns(sbp_change_from_baseline) %>%step_zv(all_predictors()) %>%step_normalize(all_predictors())```### M3.4. Combine model and recipe into a workflow```{r}lr_workflow<-workflow() %>%add_model(lr_mod) %>%add_recipe(lr_recipe)```### M3.5. Create penalty values for tuning```{r}lr_reg_grid <-tibble(penalty =10^seq(-4, -1, length.out =20))```### M3.6. apply penalty values to the workflow```{r}set.seed(3303)library(glmnet)# install.packages("glmnet")lr_res <- lr_workflow %>%tune_grid(resamples = cv_folds,grid = lr_reg_grid,control =control_grid(save_pred =TRUE),metrics =metric_set(roc_auc))lr_plot <- lr_res %>%collect_metrics() %>%ggplot(aes(x = penalty, y = mean)) +geom_point() +geom_line() +ylab("Area under the ROC Curve") +scale_x_log10(labels = scales::label_number())lr_plot```### M3.7. Select best model penalty```{r}lr_res %>%collect_metrics() %>%arrange(penalty) %>%rowid_to_column()# select optimal penaltylr_best<- lr_res %>%select_best(metric="roc_auc")# collect_metrics() %>% # arrange(penalty) %>% # slice(12)```### M3.8. ROC Curve for best fit```{r}lr_preds3 <- lr_res %>%collect_predictions(parameters = lr_best) %>%mutate(model_name="Model 3: Dropout Per Opp (Minimal Spec)")lr_auc<- lr_preds3 %>%roc_curve(patient_dropout, .pred_1, event_level ="second") %>%mutate(model ="Logistic Regression") ```its... ok!### Model 3: Final Fit```{r}last_log_mod <-logistic_reg(penalty = lr_best$penalty[1], mixture=1) %>%set_engine("glmnet", importance="BIC")# the last workflowlast_log_workflow <- lr_workflow %>%update_model(last_log_mod)# last fitlast_lr_fit<- last_log_workflow %>%fit(training_data) lr_last_preds_m3 <-predict(last_lr_fit, new_data = testing_data, type="prob") %>%bind_cols(testing_data) %>%mutate(patient_dropout=as.factor(patient_dropout)) %>%mutate(predicted_class_60 =as.factor(if_else(.pred_1 >0.6, "1", "0"))) %>%mutate(predicted_class_50 =as.factor(if_else(.pred_1 >0.5, "1", "0"))) %>%mutate(predicted_class_38 =as.factor(if_else(.pred_1 >0.38, "1", "0"))) %>%mutate(predicted_class_20 =as.factor(if_else(.pred_1 >0.2, "1", "0"))) %>%mutate(predicted_prob_yes = .pred_1, predicted_prob_no = .pred_0) %>%mutate("model"="Model 3: Minimal Spec")```About same ROC as with validation set...### M3 Top featuresFitting the model to the full dataset...```{r}library(vip)vi_fit<- last_lr_fitvi_fit %>%vip(num_features =25, mapping=aes(fill=Sign)) +scale_y_continuous(expand =c(0,0)) +theme_light() +labs(title="Top predictors for Model3",subtitle="Logistic Reg for Dropout Event after each visit")```Appointment Count and Days Since Entry (Registration) seem STILL to have greatest impact.Get the coefficients from this final fit [as per docs](https://www.tidymodels.org/learn/models/coefficients/#more-complex-a-glmnet-model)```{r}vi_fit %>%tidy() %>% DT::datatable()```# Model Calibration Comparison```{r}theme_set(theme_minimal())make_cal_plot <-function(predictions, basedata){calibration_data <- predictions %>%mutate(appt_tally_level =case_when( appt_tally==1~"Visit = 1", appt_tally>1& appt_tally <6~"Visit = 2 to 5",TRUE~"Visit = 6+" ) ) %>%mutate(bin =cut(predicted_prob_yes, breaks =seq(0, 1, by =0.02), include.lowest =TRUE)) %>%group_by(appt_tally_level, bin) %>%reframe(model=model,mean_pred =mean(.pred_1),observed =mean(as.numeric(patient_dropout) -1) # Convert factor to 0/1 )ggplot(calibration_data, aes(x = mean_pred, y = observed)) +geom_point() +geom_line() +geom_abline(slope =1, intercept =0, linetype ="dashed", color ="red") +labs(x ="Predicted Probability",y ="Observed Proportion",title =paste0("Calibration Plot ", unique(calibration_data$model))) +facet_wrap(~appt_tally_level, ncol=1)}make_cal_plot(lr_last_preds_m1, dta_model1)make_cal_plot(lr_last_preds_m2, dta_model2)make_cal_plot(lr_last_preds_m3, dta_model3)```# Model Performance Comparison```{r}bind_rows( lr_last_preds_m1, lr_last_preds_m2, lr_last_preds_m3 ) %>%mutate(model =case_when( model=="Model 1"~str_wrap("Model 1: Return Failure Post Baseline",15), model=="Model 2"~str_wrap("Model 2: Dropout Per Opp (Full Spec)",15), model=="Model 3"~str_wrap("Model 3: Dropout Per Opp (Minimal Spec)",15), ) ) %>%group_by(model) %>%roc_curve(patient_dropout, .pred_1, event_level ="second") %>%autoplot() +labs(title="ROC Curve: sensitivity comparison of discrimination") +theme_set(mytheme) +theme(legend.key.size =unit(2, "cm"))my_ggsave("roc_curve_compare",9)```Show metrics for last fits```{r}my_metrics <-function(predictions){# Calculate performance metricsclassification_metrics(predictions, truth = patient_dropout, predicted_prob_yes,estimate=predicted_class_38, event_level="second") %>%mutate(model=unique(predictions$model)) }lr_last_all<-list(lr_last_preds_m1, lr_last_preds_m2, lr_last_preds_m3)r3<-function(x){round(x,3)}metrics_comparison<-map(lr_last_all,my_metrics) %>%bind_rows() %>%select(model,.metric, .estimate) %>%pivot_wider(names_from=.metric, values_from=.estimate) %>%rename("sensitivity"=recall) %>%mutate(across(where(is.double), r3))write_csv(metrics_comparison, here::here("data","model_comparison_latest.csv"))metrics_comparison %>%kable()```Session info below.```{r}sessionInfo()```