EXAMPLES OF NEURAL NETWORK SURVIVAL MODELS The neural network software can also be used to model survival data. In this application, the network models the "hazard function", which determines the instantaneous risk of "death". The hazard may be depend on the values of inputs (covariates) for a particular case. It may be constant through time (for each case), or it may change over time. The data may be censored - ie, for some cases, it may be known only that the individual survived for at least a certain time. The details of these models are described in model-spec.doc. Command files for the examples below are found in the ex-surv directory. This sub-directory "pbc" contains files for the examples used in my talk on "Survival analysis using a Bayesian neural network" at the 2001 Joint Statistical Meetings in Atlanta. Survival models with constant hazard. I generated 700 cases of synthetic data in which the hazard function is constant through time, but depends on one input, which was randomly generated from a standard normal distribution. When the input was x, the hazard was h(x) = 1+(x-1)^2, which makes the distribution of the time of death be exponential with mean 1/h(x). For the first 200 cases, the time of death was censored at 0.5; if the actual time of death was later than this, the time recorded was minus the censoring time (ie, -0.5). The first 500 of these cases were used for training, and the remaining 200 for testing. The network used to model this data had one input (x), and one output, which was interpreted as log h(x). One layer of eight hidden units was used. The following commands specify the network, model, and data source: > net-spec clog.net 1 8 1 / ih=0.05:1 bh=0.05:1 ho=x0.05:1 bo=100 > model-spec clog.net survival const-hazard > data-spec clog.net 1 1 / cdata@1:500 . cdata@-1:500 . We can now train the network using the same general methods as for other models. Here is one set of commands that work reasonably well: > net-gen clog.net fix 0.5 > mc-spec clog.net repeat 10 heatbath hybrid 100:10 0.2 > net-mc clog.net 1 > mc-spec clog.net repeat 4 sample-sigmas heatbath hybrid 500:10 0.4 > net-mc clog.net 100 This takes 7.0 minutes on the system used (see Ex-system.doc). Predictions for test cases can now be made using net-pred. For example, the following command produces the 10% and 90% quantiles of the predictive distribution for time of death, based on iterations from 25 onward. These predictions can be compared to the actual times of death (the "targets"): > net-pred itq clog.net 25: Number of iterations used: 76 Case Inputs Targets 10% Qnt 90% Qnt 1 -0.01 0.29 0.05 1.05 2 -1.43 0.04 0.02 0.37 3 -0.68 0.03 0.02 0.52 4 0.47 0.11 0.09 1.67 5 0.63 0.97 0.07 1.74 ( middle lines omitted ) 196 1.39 0.21 0.08 1.99 197 0.04 0.07 0.06 1.11 198 1.66 0.56 0.09 1.61 199 1.33 1.48 0.07 1.76 200 -1.41 0.08 0.01 0.39 Since there is only one input for this model, we can easily look at the posterior distribution of the log of the hazard function with the following command: > net-eval clog.net 25:%5 / -3 3 100 | plot This plots sixteen functions that are drawn from the posterior distribution of the log of the hazard function. These functions mostly match the actual function, which is log(1+(x-1)^2), for values of x up to about one. There is considerable uncertainty for values of x above one, however. Some of the functions from the posterior are a good match for the true function in this region, but most are below the true function. Note that the data here are fairly sparse. Survival models with piecewise-constant hazard. One can also define models in which the hazard function depends on time, as well, perhaps, on various inputs. To demonstrate this, I generated synthetic data for which the hazard function was given by h(x,t) = (1+(x-1)^2) * t, which makes the cumulative distribution function for the time of death of an individual for whom the covariate is x be 1-exp(-(1+(x-1)^2)*t^2/2). I generated 1000 cases (none of which were censored), and used the first 700 for training and the remaining 300 for testing. The network for this problem has two input units. The first input is set to the time, t, the second to the value of the covariate, x. There is a single output unit, whose value will be interpreted as the log of the hazard function. If we decide to use one layer of eight hidden units, we might specify the network as follows: > net-spec vlog.net 2 8 1 / ih=0.05:1:1 bh=0.05:1 ho=x0.05:1 bo=100 The prior for the input-hidden weights uses separate hyperparameters for the two input units (representing time, t, and the covariate, x), allowing the smoothness of the hazard function to be different for the two. In particular, if the hazard is actually constant, the time input should end up being almost ignored. (This can be demonstrated by applying this model to the data used in the previous example.) The hazard function determines the likelihood by means of its integral from time zero up to the observed time of death, or the censoring time. Computing this exactly would be infeasible if the hazard function was defined to be the network output when the first input is set to a given time. To allow these integrals to be done in a reasonable amount of compute time, the hazard function is instead piecewise constant with respect to time (but continuous with respect to the covariates). The time points where the hazard function changes are given in the model-spec command. For example, we can use the following specification: > model-spec vlog.net survival pw-const-hazard 0.05 0.1 0.2 0.35 \ 0.5 0.7 1.0 1.5 The eight numbers above are the times at which the hazard function changes. The log of the value of the hazard before the first time (0.05) is found from the output of the network when the first input is set to this time (and other inputs are set to the covariates for the case); similarly for the log of the value of the hazard function after the last time point (1.5). The log of the hazard for other pieces is found by setting the first input of the network to the average of the times at the ends of the piece. For instance, the log of the value of the hazard function in the time range 0.7 to 1.0 is found by setting the first input to (0.7+1.0)/2 = 0.85. Alternatively, if the "log" option is used, everything is done in terms of the log of the time (see model-spec.doc for details). Note that the number of time points specified is limited only by the greater amount of computation required when there are many of them. The hyperparameters controlling the weight priors will be able to control "overfitting" even if the hazard function has many pieces. The data specification for this problem says there is just one input (the covariate) and one target (the time of death, or censoring time): > data-spec vlog.net 1 1 / vdata@1:700 . vdata@-1:700 . Note that data-spec should be told there is only one input even though net-spec is told there are two inputs - the extra input for the net is the time, not an input read from the data file. The following commands sample from the posterior distribution: > net-gen vlog.net fix 0.5 > mc-spec vlog.net repeat 10 heatbath hybrid 100:10 0.2 > net-mc vlog.net 1 > mc-spec vlog.net repeat 4 sample-sigmas heatbath hybrid 500:10 0.4 > net-mc vlog.net 100 This takes 28 minutes on the system used (see Ex-system.doc). The net-pred command can be used to make predictions for test cases. Here is a command to display the median of the predictive distribution for time of death, based on iterations from 25 on, along with the covariates (inputs) and the actual times of death (targets) for the test cases: > net-pred itd vlog.net 25: Number of iterations used: 76 Case Inputs Targets Medians |Error| 1 0.00 0.07 0.84 0.7711 2 -0.22 0.42 0.74 0.3155 3 -2.24 0.04 0.39 0.3480 4 -0.81 0.26 0.58 0.3172 5 1.02 3.12 1.04 2.0784 ( middle lines omitted ) 296 -0.81 0.28 0.58 0.2966 297 1.19 1.60 1.05 0.5503 298 -1.53 0.57 0.48 0.0940 299 0.99 2.67 1.09 1.5773 300 0.28 1.13 0.92 0.2030 Average abs. error guessing median: 0.32997+-0.01644 The predictive medians can be plotted against the inputs as follows: > net-pred idb vlog.net 25: | plot-points The fairly small amount of scatter seen in this plot is a result of computing the medians by Monte Carlo, which produces some random error. The exact median time of death as predicted by the model is a smooth function of the input. These predictions can be compared to the true median time of death for individuals with the given covariates, as determined by the way the data was generated, which are in the file "vmedians". The log of the true hazard function for this problem is the sum of a function of x only and a function of t only - what is called a "proportional hazards" model. A network model specified as follows can discover this: > net-spec v2log.net 2 8 8 1 / ih0=0.05:1:1 bh0=0.05:1 \ ih1=0.05:1:1 bh1=0.05:1 \ ho1=x0.05:1 ho0=x0.05:1 bo=100 The input-hidden weights for the two hidden layers have hierarchical priors that allow one of them to look at only the time input and the other to look at only the covariate. There are no connections between the hidden layers. Final output can therefore be an additive function of time and the covariate, if the hyperparameters are set in this way. The remaining commands can be the same as above, except that the stepsize needs to be reduced to get a reasonable acceptance rate (the number of leapfrog steps is increased to compensate): > model-spec v2log.net survival pw-const-hazard 0.05 0.1 0.2 0.35 \ 0.5 0.7 1.0 1.5 > data-spec v2log.net 1 1 / vdata@1:700 . vdata@-1:700 . > net-gen v2log.net fix 0.5 > mc-spec v2log.net repeat 10 heatbath hybrid 100:10 0.1 > net-mc v2log.net 1 > mc-spec v2log.net repeat 4 sample-sigmas heatbath hybrid 1000:10 0.25 > net-mc v2log.net 400 This takes 15 hours to run on the system used (see Ex-system.doc). The model does in fact discover that an additive form for the log hazard is appropriate, as can be seen by examining the hyperparameters with a command such as the following: > net-plt t h1@h3@ v2log.net | plot One can see from this that the Markov chain spends most of its time in regions of the hyperparameter space in which one hidden layer looks almost only at the time input, and the other looks almost only at the covariate.