gbnet.models.survival.hazard_survival

Classes

HazardSurvivalModel

Gradient Boosting Hazard Integration Survival Model.

Functions

expand_overlapping_units_locf(df[, y, unit_col, time_col])

Module Contents

class gbnet.models.survival.hazard_survival.HazardSurvivalModel(nrounds=None, params=None, module_type='LGBModule', min_hess=0.0, input_is_expanded=False, integration_method='trapezoid')[source]

Bases: sklearn.base.BaseEstimator, sklearn.base.RegressorMixin

Gradient Boosting Hazard Integration Survival Model.

This model combines gradient boosting with hazard integration for continuous survival analysis. It uses either XGBoost or LightGBM as the underlying boosting engine wrapped in a PyTorch module.

The model supports both static and time-varying datasets: - Static datasets: Each unit has one observation with static features - Time-varying datasets: Each unit has multiple observations over time

Parameters:
  • nrounds (int, optional) – Number of boosting rounds. Defaults to 100.

  • params (dict, optional) – Additional parameters passed to the gradient boosting model.

  • module_type (str, optional) – Type of gradient boosting module to use, either “XGBModule” or “LGBModule”. Defaults to “XGBModule”.

  • min_hess (float, optional) – Minimum hessian value for numerical stability. Defaults to 0.0.

  • input_is_expanded (bool, optional) – If True, trust X already contains the intended unit-time rows and skip internal expansion. Defaults to False.

  • integration_method (str, optional) – Method for integrating hazards and survival estimates. One of “trapezoid”, “stepwise_left”, or “stepwise_right”.

Variables:
  • integrator (HazardIntegrator) – Trained hazard integrator module. Set after fitting.

  • losses (list) – List of loss values recorded at each training iteration.

  • data_format (str) – Detected data format: ‘static’ or ‘time_varying’.

fit(X, y)[source]

Trains the model using input features X and survival data y.

predict_survival(X, times)[source]

Predicts survival probabilities for given times.

predict_hazard(X, times)

Predicts hazard values for given times.

predict(X)[source]

Predicts the expected survival time.

score(X, y)[source]

Returns the negative log likelihood score.

Notes

The model uses hazard integration to model continuous survival times. The gradient boosting model learns hazard rates for each time point, which are then integrated to compute survival probabilities.

Supported data formats: - Static: X is DataFrame with static features, y is DataFrame with ‘time’, ‘event’, ‘unit_id’ - Time-varying: X is DataFrame with time-varying features, y has ‘time’, ‘event’, ‘unit_id’

For survival data, y must be a DataFrame containing: - ‘time’: observed time (continuous) - ‘event’: event indicator (0=censored, 1=event) - ‘unit_id’: unique identifier for each unit/subject

params = None[source]
module_type = 'LGBModule'[source]
nrounds = None[source]
min_hess = 0.0[source]
input_is_expanded = False[source]
integration_method = 'trapezoid'[source]
losses_ = [][source]
data_format_ = None[source]
_static_to_minimal_time_varying_dataset(X)[source]
_warn_if_expanded_input_missing_times(X, y)[source]
_validate_and_convert_input_data(X, y)[source]

Validate input data according to the new requirements.

Parameters:
  • X (pd.DataFrame) – Input features

  • y (pd.DataFrame) – Survival data with ‘time’, ‘event’, ‘unit_id’ columns

Returns:

(data_format, modified_X) where data_format is ‘static’ or ‘time_varying’ and modified_X is the potentially modified X DataFrame

Return type:

tuple

fit(X, y)[source]

Fit the hazard integration survival model.

Parameters:
  • X (pd.DataFrame) – Training features. Can be static or time-varying.

  • y (pd.DataFrame) – Survival data with ‘time’, ‘event’, ‘unit_id’ columns.

Returns:

self – Returns self.

Return type:

object

predict_base(X, y)[source]
predict_times(X, times=None)[source]
predict_survival(X, times=None)[source]
predict(X, times=None)[source]
score(X, y)[source]

Return the negative log likelihood score.

gbnet.models.survival.hazard_survival.expand_overlapping_units_locf(df, y=None, unit_col='unit_id', time_col='time')[source]
Parameters:
  • df (pandas.DataFrame)

  • y (Optional[numpy.ndarray])

  • unit_col (str)

  • time_col (str)