gbnet.models.survival.hazard_survival
Classes
Gradient Boosting Hazard Integration Survival Model. |
Functions
|
Module Contents
- class gbnet.models.survival.hazard_survival.HazardSurvivalModel(nrounds=None, params=None, module_type='LGBModule', min_hess=0.0, input_is_expanded=False, integration_method='trapezoid')[source]
Bases:
sklearn.base.BaseEstimator,sklearn.base.RegressorMixinGradient Boosting Hazard Integration Survival Model.
This model combines gradient boosting with hazard integration for continuous survival analysis. It uses either XGBoost or LightGBM as the underlying boosting engine wrapped in a PyTorch module.
The model supports both static and time-varying datasets: - Static datasets: Each unit has one observation with static features - Time-varying datasets: Each unit has multiple observations over time
- Parameters:
nrounds (int, optional) – Number of boosting rounds. Defaults to 100.
params (dict, optional) – Additional parameters passed to the gradient boosting model.
module_type (str, optional) – Type of gradient boosting module to use, either “XGBModule” or “LGBModule”. Defaults to “XGBModule”.
min_hess (float, optional) – Minimum hessian value for numerical stability. Defaults to 0.0.
input_is_expanded (bool, optional) – If True, trust X already contains the intended unit-time rows and skip internal expansion. Defaults to False.
integration_method (str, optional) – Method for integrating hazards and survival estimates. One of “trapezoid”, “stepwise_left”, or “stepwise_right”.
- Variables:
integrator (HazardIntegrator) – Trained hazard integrator module. Set after fitting.
losses (list) – List of loss values recorded at each training iteration.
data_format (str) – Detected data format: ‘static’ or ‘time_varying’.
- predict_hazard(X, times)
Predicts hazard values for given times.
Notes
The model uses hazard integration to model continuous survival times. The gradient boosting model learns hazard rates for each time point, which are then integrated to compute survival probabilities.
Supported data formats: - Static: X is DataFrame with static features, y is DataFrame with ‘time’, ‘event’, ‘unit_id’ - Time-varying: X is DataFrame with time-varying features, y has ‘time’, ‘event’, ‘unit_id’
For survival data, y must be a DataFrame containing: - ‘time’: observed time (continuous) - ‘event’: event indicator (0=censored, 1=event) - ‘unit_id’: unique identifier for each unit/subject
- _validate_and_convert_input_data(X, y)[source]
Validate input data according to the new requirements.
- Parameters:
X (pd.DataFrame) – Input features
y (pd.DataFrame) – Survival data with ‘time’, ‘event’, ‘unit_id’ columns
- Returns:
(data_format, modified_X) where data_format is ‘static’ or ‘time_varying’ and modified_X is the potentially modified X DataFrame
- Return type:
tuple