Skip to main content

Interpretable generalized neural additive models for mortality prediction of COVID-19 hospitalized patients in Hamadan, Iran

Abstract

Background

The high number of COVID-19 deaths is a serious threat to the world. Demographic and clinical biomarkers are significantly associated with the mortality risk of this disease. This study aimed to implement Generalized Neural Additive Model (GNAM) as an interpretable machine learning method to predict the COVID-19 mortality of patients.

Methods

This cohort study included 2181 COVID-19 patients admitted from February 2020 to July 2021 in Sina and Besat hospitals in Hamadan, west of Iran. A total of 22 baseline features including patients' demographic information and clinical biomarkers were collected. Four strategies including removing missing values, mean, K-Nearest Neighbor (KNN), and Multivariate Imputation by Chained Equations (MICE) imputation methods were used to deal with missing data. Firstly, the important features for predicting binary outcome (1: death, 0: recovery) were selected using the Random Forest (RF) method. Also, synthetic minority over-sampling technique (SMOTE) method was used for handling imbalanced data. Next, considering the selected features, the predictive performance of GNAM for predicting mortality outcome was compared with logistic regression, RF, generalized additive model (GAMs), gradient boosting decision tree (GBDT), and deep neural networks (DNNs) classification models. Each model trained on fifty different subsets of a train-test dataset to ensure a model performance. The average accuracy, F1-score and area under the curve (AUC) evaluation indices were used for comparison of the predictive performance of the models.

Results

Out of the 2181 COVID-19 patients, 624 died during hospitalization and 1557 recovered. The missing rate was 3 percent for each patient. The mean age of dead patients (71.17 ± 14.44 years) was statistically significant higher than recovered patients (58.25 ± 16.52 years). Based on RF, 10 features with the highest relative importance were selected as the best influential features; including blood urea nitrogen (BUN), lymphocytes (Lym), age, blood sugar (BS), serum glutamic-oxaloacetic transaminase (SGOT), monocytes (Mono), blood creatinine (CR), neutrophils (NUT), alkaline phosphatase (ALP) and hematocrit (HCT). The results of predictive performance comparisons showed GNAM with the mean accuracy, F1-score, and mean AUC in the test dataset of 0.847, 0.691, and 0.774, respectively, had the best performance. The smooth function graphs learned from the GNAM were descending for the Lym and ascending for the other important features.

Conclusions

Interpretable GNAM can perform well in predicting the mortality of COVID-19 patients. Therefore, the use of such a reliable model can help physicians to prioritize some important demographic and clinical biomarkers by identifying the effective features and the type of predictive trend in disease progression.

Peer Review reports

Background

Since late 2019, the spread of SARS-CoV-2 pneumonia, known as COVID-19, began in Wuhan, China, and has become a worldwide pandemic disease [1]. In the treatment of COVID-19 patients, the assessment of demographic, laboratory biomarkers, and clinical risk factors as well as the identification of death predictors in these patients have always been considered as one of the challenges facing researchers [2]. Several studies have been performed to evaluate changes in levels and relationships between laboratory biomarkers such as aspartate aminotransferase (AST), alanine aminotransferase (ALT), lymphocytes (LYM), neutrophil (NEU), and lactate dehydrogenase (LDH) in patients with COVID-19 [3, 4].

Recently, advanced models of medical information analysis have been extended that can help interpret complex biological relationships between clinical measurements and patient outcomes. Machine learning is a very powerful tool for identifying patterns, classifying clinical decision making, also identifying features of medical data that are relevant to clinical outcomes [5].

Prediction and classification models are designed to help healthcare professionals with some decisions such as using different diagnostic, starting or stopping treatments, using the available resources in a good way, and also can avoid some common biases in clinical decision making [6]. To estimate the probability that a specific outcome i.e. death will occur, risk prediction models are employed. These models used patient characteristics and the accuracy of the prediction depends on the ability of the model in discovering the complex relationship between patient characteristics and outcome [7, 8]. Recently, various machine learning methods have been used to predict and classify caused by COVID-19 mortality, including the use of logistic regression, RF, and GBDT [9,10,11].

To improve the performance of any analysis such as identifying the most important features and classification analysis many preprocessing techniques can be applied. One of the most important stages of preprocessing is dealing with missing values in features. Some methods require complete data without missing values. Conducting the analysis without considering missing values will bias the results and make some analyzes impossible [12]. There are different strategies in dealing with missing values such as deleting missing cases, but may leads to bias, imputing missing values using statistical imputation methods i.e., univariate methods; the mode, mean, or zero, but the results are not optimal, imputing missing values with KNN, and MICE imputation methods [13,14,15,16].

In most cases, the relationship between features and the clinical outcomes is non-linear. For this situation, the classic models e.g. the linear regression models are inappropriate. There are methods such as GAMs that enable to capture of non-linear patterns [17]. GAMs are superior in several respects, and the purpose of using these models is to maximize the accuracy of response prediction and to discover the nonlinear relationships of predictor features while maintaining explain ability [18].

Machine learning algorithms seem to be suitable as a nonlinear method for data modeling as well as for predicting and classifying responses, because of automatic discovering the relationships between the data and being able to generate a suitable output with minimum error [19]. Among the machine learning methods that have been used for prediction, we can mention neural network-based methods such as DNNs, and GNAMs [20, 21].

Despite the remarkable results of DNNs in predicting the effects of clinical biomarkers on virus infections [22] and biomedical studies [23], since these models are considered as black-box models, it is inexplicit to understand how they perform their predictions and how can be interpreted. Therefore lack of interpretability is an inevitable problem in applying these methods in fields such as healthcare. Interpretive machine learning method inspired by generalized additive models is an emerging research topic that seeks to solve this problem.

One of the suitable methods for debugging neural network predictions is the use of GNAMs which are inherently interpretable. Advantages of GNAMs include showing a larger class of classic GAMs, interpretable of a neural network model, and showing learned diagrams as accurate descriptions of predictions [20].

This study aimed to predict COVID-19 patient outcomes (dead/recovered) admitted to hospitals in Hamadan, Iran, GAMs will be used as a classical method and DNNs and GNAMs will be used as machine learning methods. The fitted models will be compared with Accuracy and AUC classification indices. The behavior of each feature in mortality risk will be visualized and interpreted based on GNAMs.

Methods

The method section is assigned into several subsections. First, data details used in this study was introduced in the subsection COVID-19 dataset. Then, in the subsection Data Imputation, imputation algorithms were introduced. In the subsection Data Description, the method of reporting the results is presented. In the subsection Feature Selection using random forest, the method of feature selection using the random forest algorithm was described. Finally, the classification models used in this study were fully explained in the Logistic regression model, Gradient Boosting Decision Tree (GBDT), Generalized Additive Models, Deep Neural Networks (DNN), and Generalized Neural Additive Models sub-sections. At last, in the Evaluation metrics and Class imbalanced issue subsection performance of the models was described.

COVID-19 dataset

In this cohort study, the dataset of 2181 Covid-19 patients who were admitted to Sina (COVID-19 treatment center) and Besat hospitals affiliated to Hamadan University of Medical Sciences, Iran were used. In this study, patients with positive real time reverse transcriptase polymerase chain reaction (RT-PCR) on samples from upper respiratory nasopharyngeal swabs were enrolled to the study. The study was approved by the Ethical Committee of the Hamadan University of Medical Science with the approved ethical code: IR.UMSHA.REC.1400.366. The dataset was collected from patient information from February 2020 to July 2021, which includes baseline demographic and clinical biomarkers. Demographic characteristics i.e. age, sex, smoking, compromised immune system (Com.immune.sys), renal insufficiency, diabetes, and hypertension as well as clinical biomarkers i.e. erythrocyte sedimentation rate (ESR), blood urea nitrogen (BUN), blood sugar (BS), blood creatinine (CR), prothrombin (PT), serum glutamic-pyruvic transaminase (SGPT), serum glutamic-oxaloacetic transaminase (SGOT), alkaline phosphatase (Alp), thromboplastin or partial thromboplastin time (PTT), platelets (Plat), hematocrit (HCT), hemoglobin (Hb), lymphocytes (Lym), monocytes (Mono), and neutrophils (NUT) were collected from patient information. A total of 22 features (or input features) were retrieved, consisting of 15 clinical biomarkers and 7 demographic characteristics of patients. For all classification models the patient's recovery status considered as a binary outcome (death = 1 and recovery = 0).

Data Imputation

The missing rate in this study was 3 percent based on all features for each patient. A detailed description of the missing rate for each feature was reported in Table 1. In this study, we followed four different strategies to deal with missing data. First, discarding entire rows (cases) containing missing values and subsequent analysis was done. In this strategy, the information of 2117 patients was analyzed (Complete case dataset). In the three other strategies missing values were imputed by the mean, the KNN imputation, the MICE method, and subsequent analysis was done (Imputed dataset). In the mean imputation method, the missing values in features are imputed by the average of all the observations in that feature that are not missing. In the KNN imputation method, the missing values are imputed by an average of the corresponding values of the k nearest features which are computed by similarity measures such as Euclidean distance. In this study, k varies from 5 to 100 and the best k value for imputation was 10. The MICE imputation method based on fully conditional specification, first, calculates the mean of each feature that has a missing value and uses the mean as replacement values. Then (linear/logistic) regression models with chain equations are fitted using features with missing values and target feature. Finally, the missing values are predicted and updated with 100 iterations.

Table 1 Descriptive statistic of demographic characteristics and laboratory biomarker of COVID-19 patients based on the complete dataset

After applying different methods of dealing with missing data; the complete case dataset (for the first strategy) and imputed datasets separately were prepared for use in the following steps (Fig. 1).

Fig. 1
figure 1

The steps of study design for COVID-19 mortality prediction

Data Description

The quantitative features described as mean ± standard deviation and qualitative features as frequency and percentage. Two independent sample t-test were used to compare the mean of the quantitative features between two groups. To investigate the relationship between the qualitative features in pairs, the Pearson Chi-square test was employed. The significance level was set at 0.05 in all analyses.

Feature selection using random forest

Due to the large number of input features, the RF algorithm was used as one of the most common approaches to identify important features that had acceptable results. RF is one of the supervised learning algorithms for classification and regression. The RF is an ensemble of several decision trees that grow using recursive partitioning of bootstrap samples. RF uses several indices to calculate the importance of features in predicting outcome, and one of them is the Gini index and is the value between zero to one [24]. In this study, the RF algorithm with 600 decision trees and the Gini index employed to calculate the importance of each feature, and the features with a relative importance value higher than 4% chose for further analysis. Also, the RF was considered as a classification model for comparing with the other models.

Logistic regression model

Logistic regression is a traditional statistical model used in the classification task. For a binary outcome, the logistic regression model is shown below:

$$\mathrm{log}\left(\frac{p}{1-p}\right)={\beta }_{0}+\sum_{j=1}^{q}{\beta }_{j}{x}_{ij}+{\varepsilon }_{ij},$$
(1)

where for a given input feature \({x}_{i}=\left({x}_{1i}, \dots , {x}_{iq}\right), i=1,\dots ,n;\) n is the number of training samples, q is the number of the input features, and p is the probability of belonging to class 1, the logarithm of the odds of this class (\(\mathrm{log}\left(\frac{p}{1-p}\right)\)) is called the logit function which is a linear function of the input features. Also, \({\beta }_{\mathrm{j}}\) are regression coefficients that are estimated based on the dataset from the maximum likelihood method [25].

Gradient Boosting Decision Tree (GBDT)

GBDT is a reinforcement algorithm in machine learning where several weak classifiers (individual decision trees) are constructed to form a strong classifier. By combining the results of each weak classifier, the end prediction results are obtained. For a binary outcome, GBDT used the decision tree as the weak classifier and makes global convergence of the algorithm by following the negative gradient [26]. Let \({x}_{i}=\left({x}_{1i}, \dots , {x}_{ip}\right), i=1,\dots ,n\), n and p are the number of training samples and the number of the input features, respectively, and \({y}_{i}\in {\left\{\mathrm{0,1}\right\}}_{i=1}^{n}\) denoted input feature or binary target. The steps of GBDT are as follows:

Step I: the model β is the initial constant value, for the regression model (\({\varvec{Y}}=\beta {\varvec{X}}\)):

$$F\left(x\right)=\mathrm{arg}\underset{\beta }{\mathrm{min}}\sum_{i=1}^{n}L\left({y}_{i},\beta \right),$$
(2)

Step II: calculate the residuals; let \({F}_{m}:{\mathbf{R}}^{p}\to \mathbf{R}\) be a predictive model at iteration m,\(m=1,\dots ,M\), and let \(L\left({y}_{i},{F}_{m}({x}_{i})\right)\) be a differentiable loss function. According to the least square approach, the parameter \({e}_{m}\) of the model is obtained and the model \(h({x}_{i};{e}_{m})\) is fitted.

$${e}_{m}=\mathrm{arg}\underset{e,\beta }{\mathrm{min}}\sum_{i=1}^{n}{\left(-\frac{\partial L\left({y}_{i},{F}_{m-1}({x}_{i})\right)}{\partial {F}_{m-1}({x}_{i})}-\beta h\left({x}_{i};e\right)\right)}^{2},$$
(3)

Step III: minimization of loss function; where \({\beta }_{m}\) is obtained by fitting a regression tree to the gradients of each sample concerning the current estimator at stage m;

$${\beta }_{m}=\mathrm{arg}\underset{e,\beta }{\mathrm{min}}\sum_{i=1}^{n}L\left({y}_{i},{F}_{m-1}\left({x}_{i}\right)+\beta h\left({x}_{i};e\right)\right),$$
(4)

Step IV: update of the model and reduce overfitting (\(h\left({x}_{i};e\right)\) so called learning rate);

$${F}_{m}\left({x}_{i}\right)={F}_{m-1}\left({x}_{i}\right)+{\beta }_{m} h\left({x}_{i};e\right),$$
(5)

These steps are repeated until m trees are grown [27].

Be noted that, if the target feature was binary, the logistic regression was selected for the growing tree. Therefore, the model β is the initial constant value, for the logistic regression model;

$$F\left(x\right)=\mathrm{arg}\underset{\beta }{\mathrm{min}}\sum_{i=1}^{n}-{y}_{i}\mathrm{log}\left({\widehat{p}}_{i}\right)-\left(1-{y}_{i}\right)\mathrm{log}\left(1-{\widehat{p}}_{i}\right),$$
(6)

Generalized additive models

GAMs are a semi-parametric extension of the generalized linear models, used for the case when there is no a priori reason for choosing a particular response function (such as linear, quadratic, etc.). GAMs are interpretable by design because of their functional forms. Given an input \({x}_{j}\in {\mathbb{R}}^{D} j=1,\dots ,p\), where p is the number of features a binary response \({y}_{i} i=1,\dots ,n\), where n is sample belongs to an exponential or non-exponential family distribution, a link function \(g\) (e.g. \(g\) is \(\mathrm{log}\frac{\pi }{1-\pi }\) in binary classification, \(\pi\) is the probability of death), main effects \({f}_{j}\) for the jth feature called smooth functions with \(\mathrm{E}\left({f}_{j}\right)=0\). For a univariate response variable of multiple features, GAMs is expressed as follows [18]:

$$g\left(\mu \right)={\beta }_{0}+\sum_{j=1}^{p}{f}_{j}\left({x}_{ji}\right)+\varepsilon ,$$
(7)

where \(\mu =E\left({y}_{i}\right)\), \({\beta }_{0}\) is intercept parameter, and \(\varepsilon \sim N\left(0,{\sigma }^{2}\right)\) is random variables. In GAMs the linear or nonlinear relationships between response and features follow smooth patterns and are explained by unspecified smooth functions known as splines or basis function. The smoothness of each function determined by the smoothing regularization parameter known \(\lambda\). In this study, the GAM was fitted by 30 to 300 basis functions, and the \(\lambda\) varies from 0.1 to 0.9.

Deep Neural Networks (DNN)

The structure of a shallow neural network consists of three layers, i.e., the input, the hidden, and the output was considered. Each layer has a weight that indicates the effect of the features on each other. The goal of neural network is to reduce the error or cost function in classification or regression and bring the network closer to the desired result. To achieve this goal the connection weights is update during training by various algorithms i.e. backpropagation, and amount of changes in weights control by a hyperparameter called learning rate. Another parameter known as batch size should be obtained which is the number of training examples in one forward or backward pass. In DNNs, the hidden layers are more than one, which may increase the classification and prediction accuracy of the network [28]. Schematic representation of the DNN architecture used in the current study with 3 hidden layers containing 100 nodes in each layer is given in Fig. 2. It should be noted that rectified linear unit (ReLU) activation function used for all hidden layers.

Fig. 2
figure 2

The structure of the DNN

The mathematical form of this structure is given following;

$$\mathrm{Y}=\mathrm{sigmoid}\left\{{{\varvec{W}}}^{\left(4\right)}\mathrm{ReLU}\left({{\varvec{W}}}^{\left(3\right)}\mathrm{ReLU}\left({{\varvec{W}}}^{\left(2\right)}\mathrm{ReLU}\left({{{\varvec{W}}}^{\left(1\right)}{{\varvec{X}}}_{{\varvec{i}}}+{\varvec{b}}}^{\left(1\right)}\right)+{{\varvec{b}}}^{\left(2\right)}\right)+{{\varvec{b}}}^{\left(3\right)}\right)+{{\varvec{b}}}^{<span class='reftype'><span class='reftype'>(4)</span></span>}\right\},$$
(8)

where W matrix and b vectors are weight and bias in each layer.

Initial values of W matrix was chosen as random values from a normal distribution with mean and variance equal to zero and 0.2, respectively, and the initial values of b vector was chosen as 1.

Generalized Neural Additive Models

GNAMs belong to the GAMs family and learn a linear combination of multi-layer perceptron (MLP) with an input, an output, and several hidden layers. The output of each MLP is

$${f}_{j}\left({x}_{j}\right)={\omega }_{1j}\mathrm{ExU}\left({x}_{j,1}\right)+\dots +{\omega }_{\tau j}\mathrm{ExU}\left({x}_{j,\tau }\right),$$
(9)

where \({\varvec{\omega}}\) is the network parameter, \(j=1,\dots ,p\), where p is the number of features, and initial values of these parameters could be chosen as random values from a normal distribution with mean and variance equal to zero and 0.2, respectively. It should be noted that the weight distribution was selected based on the pre-trained source network proposed by Agarwal et al., [20] which was designed for a similar binary classification task.

The Exp-Centered (ExU) activation function in the hidden layer for each neuron computes given by

$$\mathrm{ExU}\left({x}_{j,\tau }\right)=\left({x}_{j}-{\omega }_{0\tau j}\right)\mathrm{exp}\left({\omega }_{1\tau j}\right),$$
(10)

where \(\tau =1,\dots ,k\), k is the number of neurons in the hidden layer of the jth MLP. The sum of all MLP outputs, in addition to the intercept of the model is \(g\left(\mu \right)\) which is shown earlier in formula (1). In the last step to classify binary outcome, sigmoid activation function, h(.), was employed [20]. The architecture of the GNAMs with a hidden layer is presented in Fig. 3.

Fig. 3
figure 3

Example of a GNAM architecture

One of the characteristics of the GNAMs is interpreting by visualizing its corresponding smooth function from, \({f}_{j}({x}_{j})\), versus \({x}_{j}\). We take this advantage of the GNAMs and plot each \({f}_{j}({x}_{j})\), versus \({x}_{j}\) for features extracted in feature selection step.

Based on the loss function below, the training of GNAMs is done,

$$L\left(\theta \right)={E}_{x,y D}\left(l\left(x,y;\theta \right)+{\lambda }_{1}\eta \left(x;\theta \right)\right)+{\lambda }_{2}\gamma \left(\theta \right),$$
(11)

where \(D={\{{x}_{(i)},{y}_{(i)}\}}_{i=1}^{n}\) is the training set of size n for input (x) and target (y) features, \(l(x, y;\theta )\) is the loss function, \(\eta \left(x;\theta \right)=\frac{1}{K}{\sum }_{x}{\sum }_{k}{\left({f}_{k}^{\theta }\left({x}_{k}\right)\right)}^{2}\) is the output penalty for K features (L2 norm), \(\gamma \left(\theta \right)\) is the weight decay for K features (L2 norm), \({f}_{k}^{\theta }\) is the feature network for the kt feature. The development of the network is also regularized based on feature dropout (drop collinearity features during training) and dropout related to smoothness (regularization of ExUs in each features until the smooth functions are learned while being able to represent jumps) with coefficients λ3 and λ4 respectively.

Then, the cross-entropy loss function for binary target given as,

$$l\left(x,y;\theta \right)=-ylog\left({p}_{\theta }\left(x\right)\right)-\left(1-y\right)\mathrm{log}\left(1-{p}_{\theta }\left(x\right)\right),$$
(12)

Where \({p}_{\theta }\left(x\right)=sigmoid\left({\beta }_{0}^{\theta }+\sum_{k=1}^{K}{f}_{k}^{\theta }\left({x}_{k}\right)\right)\) is predicted probability from output GNAMs [20].

Tuning of hyperparameters of Adam optimizer for training networks such as learning rate and batch size were tested based on the values between 0.001 to 0.2 and 100 to 500, respectively. With the aim of avoiding overfitting, the value of regularization parameters (λ) for DNN and GNAM was tuned and the optimized values were set as follows: the λ1 (output penalty coefficient) in the discrete set {0.001, 0.01, 0.1}, λ2 (weight decay coefficient) in the discrete set {0, 0.00001, 0.0001, 0.001, 0.01}, λ3 (dropout coefficient) in the discrete set {0, 0.1, 0.3, 0.5, 0.7, 0.9} and λ4 (feature dropout coefficient) in the discrete set {0, 0.05, 0.1}. The tuning of the λ2 and λ3 parameters for DNNs was the same as the tuning of the GNAMs model and the number of epochs for both models was 200. The desired range for mentioned parameters such as the number of hidden layer neurons, learning rate, number of the epoch, etc., was also valued based on the range proposed in Agarwal's study [20]. However, the specific optimal value of each hyperparameter was determined according to the data of the present study by the method of cross-validation.

Evaluation metrics and Class imbalanced issue

Firstly, regarding the imbalance ratio of 4:10 (minority/majority) for the imbalanced binary classification problem, data balancing was done using the synthetic minority over-sampling technique (SMOTE) method. For subsequent analysis, the dataset is randomly divided into the two subsets of train and test with a ratio of 7:3, respectively. The process of splitting the dataset into the train and test sets was repeated fifty times. The desired prediction models were fitted based on each data set and the evaluation indices for the respective train and test sets were calculated separately. The final performance of the models is calculated as the average of these iterations.

Then the logistic regression, RF, GBDT, GAM, DNN, and GNAM are trained based on the selected features. To compare the predictive performance of models the accuracy, F1-score and the area under the receiver operating characteristics curve (AUC) indices were employed. The steps are shown graphically in Fig. 1. For all evaluation metrics the closer the value to one showed the higher the diagnostic power of the test or the predictive accuracy of the model. The analysis of these methods was done by python 3.8 with sklearn and torch modules using Xeon® 4210 Core i32 CPU with 128 GB ram memory and the source code for NAM models available at https://neural-additive-models.github.io. Although, depending on the conditions and available data, some changes have been made in the original codes.

Results

In this study, out of 2181 (53.6% male) COVID-19 patients, 1557 were recovered and 624 were dead. The mean age of recovered patients (58.25 ± 16.52 years) was significantly lower than dead patients (71.17 ± 14.44 years) (p < 0.001). Thoroughly, the frequency, percentage, mean, and standard deviation of the mentioned clinical biomarkers with their statistical significance between the two groups of dead and recovered patients are shown in Table 1.

After the implementation of four imputation strategies, the relative importance score of each feature was calculated by Gini index in RF classifier and the results were reported in Fig. 4. These figures confirmed that these different strategies have led to the selection of the same important features; BUN and Com.immune.sys selected as the most and the least important features, respectively. Ten first important features with the relative importance more than 4% were used as the final input features to fit the logistic, GBDT, RF, GAM, DNN, and GNAM models.

Fig. 4
figure 4

The relative importance of all features selected by random forest classifier based on four imputation strategies (complete case dataset and imputed datasets)

The optimal parameters values of different models were adjusted according to cross-validation in different imputation scenarios as follows:

Based on the first strategy of handling missing value (remove cases with missing values); The best results for GAM were obtained by 250 basis functions and \(\lambda =0.8\). For the GBDT, the number of trees and learning rate regularization parameters were set to 200 and 0.05, respectively. For the RF, the number of trees and the maximum depth of the tree were set to 300 and 10, respectively.

After various checks with different values of the regularization parameters, the DNN optimized by the λ2 and λ3 regularization parameters of 0.001 and 0.3, respectively, and the batch size 150, learning rate were set to 0.005, the number of hidden layers in this model containing 3 layers with 86, 64, and 16 neurons in each hidden layer, respectively.

The GNAM was optimized by the batch size of 150, the learning rate were set to 0.005, and the number of hidden layers in this model containing 3 layers with 86, 64, and 16 neurons, respectively. Also, the GNAM optimized by the λ1, λ2, λ3, and λ4 regularization parameters were set to 0.01, 0.001, 0.5, and 0.05, respectively.

Based on three imputation strategies similar results were obtained; The best results based on GAM were obtained by 230 basis functions and \(\lambda =0.75\). For the GBDT, the number of trees and learning rate regularization parameters were set to 200 and 0.04, respectively. For the RF, the number of trees and the maximum depth of the tree were set to 300 and 10, respectively.

After various checks with different values of the regularization parameters, the DNN optimized by the λ2 and λ3 regularization parameters were set to 0.001 and 0.25, respectively, and the batch size 300, learning rate were set to 0.002, the number of hidden layers in this model containing 3 layers with 86, 64, and 16 neurons in each hidden layer, respectively.

The GNAM was optimized by the batch size 300, the learning rate were set to 0.002, and the number of hidden layers in this model containing 3 layers with 86, 64, and 16 neurons, respectively. Also, the GNAM optimized by the λ1, λ2, λ3, and λ4 regularization parameters were set to 0.01, 0.001, 0.4, and 0.03, respectively.

The accuracy, F1-score and AUC of the trained models based on the four imputation strategies are reported in Table 2. Results showed that, the GNAM model had the best performance with the test accuracy, F1-score and AUC value 0.847, 0.691, and 0.774 for the removed missing values strategy (a), and AUC value of 0.855, 0.704, and 0.782 for KNN imputation dataset, respectively. The GAM model in all datasets had the worst performance with the test accuracy, F1-score and AUC value.

Table 2 Accuracy, F1-score and AUC of the different classification models for different imputation strategeis, based on the train and test datasets

Feature smooth functions learned by ensemble of fifty GNAMs along with the density of the COVID-19 dataset be shown in Fig. 5. The deepteal color indicates the data density for each feature. The darker the bar the more data is present in that area and the trend of the learned feature smooth functions is shown by the red lines. For example, the smooth feature function of age showed that the risk of death in COVID-19 patients increased as age increased and for ages, almost over 60 years, the risk of death slope has the most rapid rise.

Fig. 5
figure 5

Feature smooth functions learned by ensemble of fifty GNAM along density of COVID-19 dataset (the complete case dataset was used)

The smooth feature function of CR showed three different behavior. The risk of death in COVID-19 patients is constant from the CR values zero to 30, it increased from the values 30 to 80, slowly, and for CR almost over 80, the risk of death slope has the most rapid rise.

The smooth feature function of Lym showed, the risk of death in COVID-19 patients is higher in lower values of Lym counts. So by increasing the Lym counts, the probability of death is decreased. It is also showed, the risk of death is high and constant from the values zero to 50, and the risk of death slope has the most rapid fall.

Discussion

In this study, some effective laboratory findings and demographic features were evaluated to predict the probability of death due to COVID-19 disease. Since the pattern of some features affecting COVID-19 is generally nonlinear, the aim is to determine an appropriate model with the highest prediction accuracy.

The results showed that the risk of death from COVID-19 increases with age. This finding was also confirmed in studies performed on SARS [29], Middle East respiratory syndrome (MERS) [30], and COVID-19 [31]. Because various medical conditions such as hypertension, surgery, hyperlipidemia, and hyperglycemia occur due to old age, this group of patients is susceptible to COVID-19 [32].

Further findings indicated that sex, smoking, compromised immune system, renal insufficiency, PTT, and Plat had no significant effect on death. In Abohamr et al., study, sex, smoking, renal insufficiency, and Plat had a significant effect on death. In the study of Liu et al., smoking, Plat, and PTT factors had no significant effect on death [33]. In the present study, from overall patients, only five patients with a compromised immune system were reported and only one patient experienced death. However, in the study of Kostoff et al., and Yazdanpanah et al., they examined the importance of the immune system feature, and the results of their studies showed that increasing biomarkers of the immune system leads to inflammation and ultimately more damage to other organs and even death [34, 35].

Also, the results of the evaluated factors such as diabetes, hypertension, age, ESR, BUN, BS, CR, PT, SGPT, SGOT, ALP, HCT, Hb, LYM, Mono, and NUT indicate the significant effects on the COVID-19 death that the same finding was reported in the studies of Bertimas et al., Liu et al., Bahl et al., Guan et al., Cao et al., and Chen et al. So that the probability of death increased impressively when increasing their level of the normal range [36,37,38]. As vital components of the immune system, Neutrophils and lymphocytes play a considerable role in host defense and clearance of infections. In the blood, fewer lymphocyte counts may be an important factor in disease severity and mortality in COVID-19 [39]. Ruan et al. and Chen et al. showed that if factors such as the number of lymphocytes, neutrophils, monocytes, and platelets were out of the normal range, it would indicate virus replication and inflammation in COVID-19 patients [40, 41]. According to studies on the factors affecting the severity of COVID-19 disease, the number of neutrophils in patients with high disease severity was higher than patients with moderate disease severity [42]. In this study, the mortality risk of COVID-19 patients increased with increasing neutrophil count and with decreasing lymphocyte count.

There are several methods to select important features. Based on this method, ten features BUN, Lym, age, BS, SGOT, Mono, CR, NUT, Alp, HCT, respectively, were the most important effective features in the probability of COVID-19 death were selected whose their relative importance was more than 4%. In the study of Ma et al., Lym, age, NUT, Mono [43], in Aljame et al., Lym, NUT, age [44], and in Subudhi et al., Lym, NUT, BS, and age [45] due to the most important risk factors in COVID-19 death was introduced. Also, the results showed that the relative importance of the selected features by removing missing values or imputation did not change much in terms of the order of relative importance.

Many studies, such as the present study, have predicted mortality from COVID-19 using supervised machine learning techniques. The results of predictive performance comparisons showed overlay GNAM had the best performance with the test accuracy with mean accuracy, F1-score and mean AUC in the test dataset of 0.847, 0.691, and 0.774 respectively.

Li et al. considered six important biomarkers (D-dimer, blood oxygen, Lym to NUT ratio, C-reactive protein (CRP), and lactate dehydrogenase) using the DNN model to predict the mortality of COVID-19 patients. Their model AUC was 0.95 [46]. In a study by Lin et al. of 30 demographic and laboratory biomarkers using machine learning methods including neural network to predict mortality of COVID-19 patients, the accuracy and AUC of this model were 0.91 and 0.88, respectively [47]. In the study of Morales et al., ten important demographic and laboratory biomarkers were used, including age, blood pressure, liver, and kidney failure. The accuracy of predicting the death of COVID-19 patients using the neural network model was 0.88 [48].

An important advantage of the GNAM over other neural networks, including DNN (black box), is its interpretability, which is based on the smooth function graphs of each feature learned from the GNAM. In this regard, instead of inflexible parametric assumptions, the relationship between the output and the input feature is expressed by a smoothing function that can be applied to virtually any form of data.

Considering the 3% of missing values, similar results were obtained in different scenarios of handling missing data. The GNAM model compared to other models had a higher predictive performance. Therefore, the interpretable smooth function graphs related to the GNAM model were drawn and reported for the non-missing data. The smooth functions learned from fifty fitted GNAMs are visualized in Fig. 5 to interpret the behavior of each feature. With increasing the value of the features age, BUN, BS, CR, SGOT, ALP, HCT, Mono, and NUT, the smooth function of the logarithm odds of COVID-19 mortality in these features increased. However, with increasing the value of Lym, the smooth function of the logarithm odds decreases.

Since the critical time for disease progression has been reported 10 to 14 days from the onset of clinical symptoms, identifying the factors affecting patient mortality from hospitalization makes it possible for decision-making and predictive power for physicians [41]. Also, the use of interpretable machine learning models such as GNAM, compared to common black-box neural networks, can provide physicians with a broad view of changes resulting from effective variables that reduce or increase the risk of patient mortality. As a result, they increase the accuracy of predicting patient mortality.

Limitations

The samples used in this study are limited to the central city of Hamadan, Hamadan province, Iran, so the number of generalizable samples was less. To identify the features affecting patient mortality, some patient information such as BMI, predisposing factors, or underlying comorbidities due to high missing did not use. These features may also play an efficient role in predicting patient mortality.

Conclusion

When the relationships between predictor features and response are nonlinear, machine learning models such as GNAM can perform well in predicting the mortality of COVID-19 patients. Therefore, the use of interpretable machine learning models such as GNAM helps physicians to prioritize some important demographic factors and laboratory findings by identifying the effective features and the type of predictive trend in disease progression.

Availability of data and materials

The dataset used for analysis during the current study are not publicly available due to restrictions related to our internal review board policy. However, the dataset is available from the corresponding author (Maryam Farhadian) on reasonable request.

Abbreviations

GNAM:

Generalized Neural Additive Model

NAM:

Neural Additive Model

RF:

Random Forest

MICE:

Multivariate Imputation by Chained Equations

GAMs:

Generalized Additive Models

SMOT:

Synthetic Minority Oversampling Technique

DNNs:

Deep Neural Networks

AUC:

Area under the Curve

AST:

Aspartate Aminotransferase

ALT:

Alanine Aminotransferase

LYM:

Lymphocytes

NEU:

Neutrophil

LDH:

Lactate Dehydrogenase

Comp.immune.sys:

Compromised immune system

ESR:

Erythrocyte Sedimentation Rate

BUN:

Blood Urea Nitrogen

BS:

Blood Sugar

CR:

Blood Creatinine

PT:

Prothrombin Time

SGPT:

Serum Glutamic-Pyruvic Transaminase

SGOT:

Serum Glutamic-Oxaloacetic Transaminase

Alp:

Alkaline Phosphatase

PTT:

Partial Thromboplastin Time

Plat:

Platelets

HCT:

Hematocrit

Hb:

Hemoglobin

Lym:

Lymphocytes

Mono:

Monocytes

NUT:

Neutrophils

MLP:

Multi-Layer Perceptron

ExU:

Exp-Centered

ROC:

Receiver Operating Characteristics

SE:

Standard Error

MERS:

Middle East Respiratory Syndrome

CRP:

C-Reactive Protein

References

  1. Dhama K, Khan S, Tiwari R, Sircar S, Bhat S, Malik YS, et al. Coronavirus disease 2019–COVID-19. Clin Microbiol Rev. 2020;33(4):e00028-e120.

    Article  CAS  Google Scholar 

  2. Henry BM, De Oliveira MHS, Benoit S, Plebani M, Lippi G. Hematologic, biochemical and immune biomarker abnormalities associated with severe illness and mortality in coronavirus disease 2019 (COVID-19): a meta-analysis. Clin Chem Lab Med. 2020;58(7):1021–8.

    Article  CAS  Google Scholar 

  3. Chen L, Lin Z, Chen J, Liu S, Shi T, Xin Y. Can elevated concentrations of ALT and AST predict the risk of ‘recurrence’of COVID-19? Epidemiol Infect. 2020;148:e218.

    Article  CAS  Google Scholar 

  4. Liu Y, Yang Y, Zhang C, Huang F, Wang F, Yuan J, et al. Clinical and biochemical indexes from 2019-nCoV infected patients linked to viral loads and lung injury. Science China Life Sciences. 2020;63(3):364–74.

    Article  CAS  Google Scholar 

  5. Deo RC. Machine learning in medicine. Circulation. 2015;132(20):1920–30.

    Article  Google Scholar 

  6. Harrell FE. Regression modeling strategies. Bios. 2018;2017(330):14.

    Google Scholar 

  7. Hao B, Sotudian S. Early prediction of level-of-care requirements in patients with COVID-19. Elife. 2020;9:e60519.

    Article  CAS  Google Scholar 

  8. Moons KG, Kengne AP, Woodward M, Royston P, Vergouwe Y, Altman DG, et al. Risk prediction models: I. Development, internal validation, and assessing the incremental value of a new (bio) marker. Heart. 2012;98(9):683–90.

    Article  Google Scholar 

  9. Wollenstein-Betech S, Silva AA, Fleck JL, Cassandras CG, Paschalidis IC. Physiological and socioeconomic characteristics predict COVID-19 mortality and resource utilization in Brazil. PLoS ONE. 2020;15(10):e0240346.

    Article  CAS  Google Scholar 

  10. Gutierrez JM, Volkovs M, Poutanen T, Watson T, Rosella LC. Risk stratification for COVID-19 hospitalization: a multivariable model based on gradient-boosting decision trees. CMAJ Open. 2021;9(4):E1223–31.

    Article  Google Scholar 

  11. Wang T, Paschalidis A, Liu Q, Liu Y, Yuan Y, Paschalidis IC. Predictive models of mortality for hospitalized patients with COVID-19: retrospective cohort study. JMIR Med Inform. 2020;8(10):e21788.

    Article  Google Scholar 

  12. Beretta L, Santaniello A. Nearest neighbor imputation algorithms: a critical evaluation. BMC Med Inform Decis Mak. 2016;16(3):197–208.

    Google Scholar 

  13. Lall R, Robinson T. The MIDAS touch: accurate and scalable missing-data imputation with deep learning. Political Analysis. 2022;30(2):179–96.

    Article  Google Scholar 

  14. Troyanskaya O, Cantor M, Sherlock G, Brown P, Hastie T, Tibshirani R, et al. Missing value estimation methods for DNA microarrays. Bioinformatics. 2001;17(6):520–5.

    Article  CAS  Google Scholar 

  15. Stekhoven DJ. missForest: Nonparametric missing value imputation using random forest. Bioinformatics. 2012;28(1):112–8.

    Article  CAS  Google Scholar 

  16. Little RJ, Rubin DB. Statistical analysis with missing data. 3rd ed. Wiley; 2019.

  17. Vaid A, Somani S, Russak AJ, De Freitas JK, Chaudhry FF, Paranjpe I, et al. Machine learning to predict mortality and critical events in a cohort of patients with COVID-19 in New York City: model development and validation. J Med Internet Res. 2020;22(11):e24018.‏

  18. Hastie TJ, Tibshirani RJ. Generalized additive models. 1rd ed. Routledge; 2017.

  19. Chen Y, Ouyang L, Bao FS, Li Q, Han L, Zhang H, et al. A multimodality machine learning approach to differentiate severe and nonsevere COVID-19: model development and validation. J Med Internet Res. 2021;23(4):e23948.

    Article  Google Scholar 

  20. Agarwal R, Frosst N, Zhang X, Caruana R, Hinton GE. Neural additive models: Interpretable machine learning with neural nets. 2020. arXiv preprint arXiv:200413912.

    Google Scholar 

  21. Bianchini M, Scarselli F. On the complexity of neural network classifiers: A comparison between shallow and deep architectures. IEEE transactions on neural networks and learning systems. 2014;25(8):1553–65.

    Article  Google Scholar 

  22. Deng L, Zhao J, Zhang J, editors. Predict the protein-protein interaction between virus and host through hybrid deep neural network. International Conference on Bioinformatics and Biomedicine (BIBM). 2020:11–6.

  23. Yoon W, So CH, Lee J, Kang J. Collabonet: collaboration of deep neural networks for biomedical named entity recognition. BMC Bioinformatics. 2019;20(10):55–65.

    Google Scholar 

  24. Ishwaran H, Lu M. Standard errors and confidence intervals for variable importance in random forest regression, classification, and survival. Stat Med. 2019;38(4):558–82.

    Article  Google Scholar 

  25. Nusinovici S, Tham YC, Yan MYC, Ting DSW, Li J, Sabanayagam C, et al. Logistic regression was as good as machine learning for predicting major chronic diseases. J Clin Epidemiol. 2020;122:56–69.

    Article  Google Scholar 

  26. Rao H, Shi X, Rodrigue AK, Feng J, Xia Y, Elhoseny M, et al. Feature selection based on artificial bee colony and gradient boosting decision tree. Appl Soft Comput. 2019;74:634–42.

    Article  Google Scholar 

  27. Adler AI, Painsky A. Feature Importance in Gradient Boosting Trees with Cross-Validation Feature Selection. Entropy. 2022;24(5):687.

    Article  Google Scholar 

  28. Awad M, Khanna R. Efficient learning machines: theories, concepts, and applications for engineers and system designers. 1rd ed. Apress Berkeley; 2015.

  29. Chan JC, Tsui EL, Wong VC, Group HASC. Prognostication in severe acute respiratory syndrome: a retrospective time-course analysis of 1312 laboratory-confirmed patients in Hong Kong. Respirology. 2007;12(4):531–42.

    Article  Google Scholar 

  30. Assiri A, Al-Tawfiq JA, Al-Rabeeah AA, Al-Rabiah FA, Al-Hajjar S, Al-Barrak A, et al. Epidemiological, demographic, and clinical characteristics of 47 cases of Middle East respiratory syndrome coronavirus disease from Saudi Arabia: a descriptive study. Lancet Infect Dis. 2013;13(9):752–61.

    Article  Google Scholar 

  31. Gong J, Ou J, Qiu X, Jie Y, Chen Y, Yuan L, et al. A tool for early prediction of severe coronavirus disease 2019 (COVID-19): a multicenter study using the risk nomogram in Wuhan and Guangdong. China Clin Infect Dis. 2020;71(15):833–40.

    Article  CAS  Google Scholar 

  32. Weng Z, Chen Q, Li S, Li H, Zhang Q, Lu S, et al. ANDC: an early warning score to predict mortality risk for patients with coronavirus disease 2019. J Transl Med. 2020;18(1):1–10.

    Article  Google Scholar 

  33. Liu Y, Du X, Chen J, Jin Y, Peng L, Wang HH, et al. Neutrophil-to-lymphocyte ratio as an independent risk factor for mortality in hospitalized patients with COVID-19. J Infect. 2020;81(1):e6–12.

    Article  Google Scholar 

  34. Kostoff RN, Briggs MB, Porter AL. COVID-19: Preventing Future Pandemics. Georgia Institute of Technology; 2020.

  35. Yazdanpanah F, Hamblin MR, Rezaei N. The immune system and COVID-19: Friend or foe? Life Sci. 2020;256:117900.

    Article  CAS  Google Scholar 

  36. Cao M, Zhang D, Wang Y, Lu Y, Zhu X, Li Y, et al. Clinical features of patients infected with the 2019 novel coronavirus (COVID-19) in Shanghai, China. MedRxiv. 2020. https://doi.org/10.1101/2020.03.04.20030395.

  37. Chen T, Wu D, Chen H, Yan W, Yang D, Chen G, et al. Clinical characteristics of 113 deceased patients with coronavirus disease 2019: retrospective study. BMJ. 2020;368:m1091.

    Article  Google Scholar 

  38. Guan W-J, Ni Z-Y, Hu Y, Liang W-H, Ou C-Q, He J-X, et al. Clinical characteristics of coronavirus disease 2019 in China. N Engl J Med. 2020;382(18):1708–20.

    Article  CAS  Google Scholar 

  39. Huang I, Pranata R. Lymphopenia in severe coronavirus disease-2019 (COVID-19): systematic review and meta-analysis. J Intensive Care. 2020;8:1–10.

    Article  Google Scholar 

  40. Chen N, Zhou M, Dong X, Qu J, Gong F, Han Y, et al. Epidemiological and clinical characteristics of 99 cases of 2019 novel coronavirus pneumonia in Wuhan, China: a descriptive study. Lancet. 2020;395(10223):507–13.

    Article  CAS  Google Scholar 

  41. Ruan Q, Yang K, Wang W, Jiang L, Song J. Clinical predictors of mortality due to COVID-19 based on an analysis of data of 150 patients from Wuhan. China Intensive Care Med. 2020;46(5):846–8.

    Article  CAS  Google Scholar 

  42. Kong M, Zhang H, Cao X, Mao X, Lu Z. Higher level of neutrophil-to-lymphocyte is associated with severe COVID-19. Epidemiol Infect. 2020;148:e139.

    Article  Google Scholar 

  43. Ma X, Ng M, Xu S, Xu Z, Qiu H, Liu Y, et al. Development and validation of prognosis model of mortality risk in patients with COVID-19. Epidemiol Infect. 2020;148:e168.

    Article  CAS  Google Scholar 

  44. AlJame M, Imtiaz A, Ahmad I, Mohammed A. Deep forest model for diagnosing COVID-19 from routine blood tests. Sci Rep. 2021;11(1):16682.

    Article  CAS  Google Scholar 

  45. Subudhi S, Verma A, Patel AB, Hardin CC, Khandekar MJ, Lee H, et al. Comparing machine learning algorithms for predicting ICU admission and mortality in COVID-19. NPJ Digital Med. 2021;4(1):1–7.

    Article  Google Scholar 

  46. Li X, Ge P, Zhu J, Li H, Graham J, Singer A, et al. Deep learning prediction of likelihood of ICU admission and mortality in COVID-19 patients using clinical variables. PeerJ. 2020;8:e10337.

    Article  Google Scholar 

  47. Lin J-K, Chien T-W, Wang L-Y, Chou W. An artificial neural network model to predict the mortality of COVID-19 patients using routine blood samples at the time of hospital admission: Development and validation study. Medicine. 2021;100(28):e26532.

    Article  CAS  Google Scholar 

  48. Morales GRV, Monterrubio SMM, García JAR, Ger PM. Explainable Machine Learning Prediction for Mortality of COVID-19 in the Colombian Population. 2021.

    Google Scholar 

Download references

Acknowledgements

This work is a part of a PhD thesis of Samad Moslehi in Biostatistics at the Hamadan University of Medical Sciences in Iran with the approved ethical code: IR.UMSHA.REC.1400.366. Lastly, we would like to thank the vice chancellor for research and technology of Hamadan University of Medical Sciences in Iran. This work is supported by the Vice-Chancellor for Research and Technology of Hamadan University of Medical Sciences, Iran (No 140008186884).

Funding

Not applicable.

Author information

Authors and Affiliations

Authors

Contributions

All authors have contributed to the preparation and expanding the manuscript. Each author has read and approved the manuscript. MF, HM, and SM contributed to the literature review and analysis of the study and drafting the manuscript. MM contributed to clinical aspects of the manuscript. MF, SM, HM, AR S, and MM reviewed the manuscript. The author(s) read and approved the final manuscript.

Corresponding author

Correspondence to Maryam Farhadian.

Ethics declarations

Ethics approval and consent to participate

The study was approved by the Ethical Committee of Hamadan University of Medical Science with the approved ethical code: IR.UMSHA.REC.1400.366. Informed consent was obtained from all patients or their legal guardian and patients under age of 18 from parents/legally authorized representatives. It was given by the first relative family of patients with severe conditions. The study adhered to relevant guidelines and regulations.

Consent for publication

Not applicable.

Competing interests

The authors declare that they have no competing and conflict of interests.

Additional information

Publisher’s Note

Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Supplementary Information

Additional file 1.

Source code & dataset.

Rights and permissions

Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article's Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article's Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/. The Creative Commons Public Domain Dedication waiver (http://creativecommons.org/publicdomain/zero/1.0/) applies to the data made available in this article, unless otherwise stated in a credit line to the data.

Reprints and permissions

About this article

Check for updates. Verify currency and authenticity via CrossMark

Cite this article

Moslehi, S., Mahjub, H., Farhadian, M. et al. Interpretable generalized neural additive models for mortality prediction of COVID-19 hospitalized patients in Hamadan, Iran. BMC Med Res Methodol 22, 339 (2022). https://doi.org/10.1186/s12874-022-01827-y

Download citation

  • Received:

  • Accepted:

  • Published:

  • DOI: https://doi.org/10.1186/s12874-022-01827-y

Keywords