from typing import Optional
import warnings
import torch
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.validation import check_is_fitted
from .hazard_integrator import HazardIntegrator
[docs]
class HazardSurvivalModel(BaseEstimator, RegressorMixin):
"""Gradient Boosting Hazard Integration Survival Model.
This model combines gradient boosting with hazard integration for continuous
survival analysis. It uses either XGBoost or LightGBM as the underlying
boosting engine wrapped in a PyTorch module.
The model supports both static and time-varying datasets:
- Static datasets: Each unit has one observation with static features
- Time-varying datasets: Each unit has multiple observations over time
Parameters
----------
nrounds : int, optional
Number of boosting rounds. Defaults to 100.
params : dict, optional
Additional parameters passed to the gradient boosting model.
module_type : str, optional
Type of gradient boosting module to use, either "XGBModule" or "LGBModule".
Defaults to "XGBModule".
min_hess : float, optional
Minimum hessian value for numerical stability. Defaults to 0.0.
input_is_expanded : bool, optional
If True, trust X already contains the intended unit-time rows and skip
internal expansion. Defaults to False.
integration_method : str, optional
Method for integrating hazards and survival estimates. One of
"trapezoid", "stepwise_left", or "stepwise_right".
Attributes
----------
integrator_ : HazardIntegrator
Trained hazard integrator module. Set after fitting.
losses_ : list
List of loss values recorded at each training iteration.
data_format_ : str
Detected data format: 'static' or 'time_varying'.
Methods
-------
fit(X, y)
Trains the model using input features X and survival data y.
predict_survival(X, times)
Predicts survival probabilities for given times.
predict_hazard(X, times)
Predicts hazard values for given times.
predict(X)
Predicts the expected survival time.
score(X, y)
Returns the negative log likelihood score.
Notes
-----
The model uses hazard integration to model continuous survival times.
The gradient boosting model learns hazard rates for each time point,
which are then integrated to compute survival probabilities.
Supported data formats:
- Static: X is DataFrame with static features, y is DataFrame with 'time', 'event', 'unit_id'
- Time-varying: X is DataFrame with time-varying features, y has 'time', 'event', 'unit_id'
For survival data, y must be a DataFrame containing:
- 'time': observed time (continuous)
- 'event': event indicator (0=censored, 1=event)
- 'unit_id': unique identifier for each unit/subject
"""
def __init__(
self,
nrounds=None,
params=None,
module_type="LGBModule",
min_hess=0.0,
input_is_expanded=False,
integration_method="trapezoid",
):
assert integration_method in {"trapezoid", "stepwise_left", "stepwise_right"}
if params is None:
params = {"max_delta_step": 1 if module_type == "XGBModule" else 5}
[docs]
self.module_type = module_type
# Default boosting rounds depend on module type
if nrounds is None:
nrounds = 50 if module_type == "XGBModule" else 100
[docs]
self.min_hess = min_hess
[docs]
self.integration_method = integration_method
# self.integrator_ = None
[docs]
def _static_to_minimal_time_varying_dataset(self, X):
assert "time" in X and "unit_id" in X and X["unit_id"].is_unique
X0 = X.copy()
X0["time"] = 0
return (
pd.concat([X0, X], axis=0)
.sort_values(["unit_id", "time"])
.reset_index(drop=True)
.copy()
)
[docs]
def fit(self, X, y):
"""Fit the hazard integration survival model.
Parameters
----------
X : pd.DataFrame
Training features. Can be static or time-varying.
y : pd.DataFrame
Survival data with 'time', 'event', 'unit_id' columns.
Returns
-------
self : object
Returns self.
"""
# Ensure X and y are DataFrames
if not isinstance(X, pd.DataFrame):
raise ValueError("X must be a pandas DataFrame")
if not isinstance(y, pd.DataFrame):
raise ValueError("y must be a pandas DataFrame")
self.max_time = y["time"].max()
self.data_format_, self.exp_df, self.y = self._validate_and_convert_input_data(
X, y
)
# Pre-compute event indicators for efficiency
self.event_indicators_ = self.y.groupby("unit_id")["event"].last().values
self.n_samples_ = len(self.event_indicators_)
# Initialize hazard integrator with appropriate covariate columns
covariate_cols = [
col for col in X.columns if col not in ["unit_id", "time", "event"]
]
self.integrator_ = HazardIntegrator(
covariate_cols=covariate_cols,
params=self.params,
min_hess=self.min_hess,
module_type=self.module_type,
integration_method=self.integration_method,
)
# Training loop
self.losses_ = []
for i in range(self.nrounds):
self.integrator_.train()
self.integrator_.zero_grad()
out = self.integrator_(self.exp_df, return_survival_estimates=False)
# Negative log-likelihood loss using pre-computed event indicators
loss = (
out["unit_integrated_hazard"].sum()
- (
torch.log(out["unit_last_hazard"])
* torch.tensor(self.event_indicators_ == 1, dtype=torch.float32)
).sum()
) / self.n_samples_
loss.backward(create_graph=True)
self.losses_.append(loss.item())
self.integrator_.gb_step()
self.integrator_.eval()
return self
[docs]
def predict_base(self, X, y):
if not isinstance(X, pd.DataFrame):
raise ValueError("X must be a pandas DataFrame")
if not isinstance(y, pd.DataFrame):
raise ValueError("y must be a pandas DataFrame")
_, exp_df, y = self._validate_and_convert_input_data(X, y)
return self.integrator_(exp_df)
[docs]
def predict_times(self, X, times=None):
check_is_fitted(self, "integrator_")
if not isinstance(X, pd.DataFrame):
raise ValueError("X must be a pandas DataFrame")
if times is None:
times = np.linspace(0, self.max_time, 100)
X = X.copy()
_, exp_df, y = self._validate_and_convert_input_data(X, times)
exp_df = exp_df.reset_index(drop=True).copy()
output = self.integrator_(exp_df)
exp_df["hazard"] = output["hazard"]
exp_df["survival"] = output["survival"]
udf = exp_df[["unit_id"]].drop_duplicates().reset_index(drop=True).copy()
udf["last_hazard"] = output["unit_last_hazard"]
udf["integrated_hazard"] = output["unit_integrated_hazard"]
udf["expected_time"] = output["unit_expected_time"]
return exp_df, udf
[docs]
def predict_survival(self, X, times=None):
check_is_fitted(self, "integrator_")
exp_df, udf = self.predict_times(X, times)
return exp_df[["unit_id", "time", "survival", "hazard"]]
[docs]
def predict(self, X, times=None):
check_is_fitted(self, "integrator_")
exp_df, udf = self.predict_times(X, times)
median = (
exp_df[exp_df["survival"] > 0.5]
.groupby("unit_id")["time"]
.max()
.rename("predicted_median_time")
.reset_index()
)
output = udf[["unit_id", "expected_time"]].merge(
median, on="unit_id", how="left", validate="one_to_one"
)
output["predicted_median_time"] = output["predicted_median_time"].fillna(0)
return output
[docs]
def score(self, X, y):
"""Return the negative log likelihood score."""
check_is_fitted(self, "integrator_")
if not isinstance(X, pd.DataFrame):
raise ValueError("X must be a pandas DataFrame")
if not isinstance(y, pd.DataFrame):
raise ValueError("y must be a pandas DataFrame")
_, exp_df, y = self._validate_and_convert_input_data(X, y)
out = self.integrator_(exp_df, return_survival_estimates=False)
event_indicators = y.groupby("unit_id")["event"].last().values
loss = (
out["unit_integrated_hazard"].sum()
- (
torch.log(out["unit_last_hazard"])
* torch.tensor(event_indicators == 1, dtype=torch.float32)
).sum()
) / len(event_indicators)
return loss.detach().item()
[docs]
def expand_overlapping_units_locf(
df: pd.DataFrame,
y: Optional[np.ndarray] = None,
unit_col: str = "unit_id",
time_col: str = "time",
):
# Unique times observed anywhere in the data, sorted
if y is None:
all_times = np.sort(df[time_col].unique())
else:
all_times = np.sort(
np.unique(np.concatenate([df[time_col].values, np.asarray(y)]))
)
# Min & max time for each unit
t_min = df.groupby(unit_col)[time_col].min()
t_max = df.groupby(unit_col)[time_col].max()
# Skeleton of unit–time combinations
if y is None:
pieces = []
for unit in t_min.index:
mask = (all_times >= t_min[unit]) & (all_times <= t_max[unit])
pieces.append(pd.DataFrame({unit_col: unit, time_col: all_times[mask]}))
skeleton = pd.concat(pieces, ignore_index=True)
else:
skeleton = (
df[[unit_col]]
.drop_duplicates()
.merge(pd.DataFrame({"time": all_times}), how="cross")
)
# Merge and sort
out = (
skeleton.merge(df, on=[unit_col, time_col], how="left")
.sort_values([unit_col, time_col], kind="mergesort")
.reset_index(drop=True)
)
# Identify covariate columns (excluding unit and time)
covariate_cols = [col for col in df.columns if col not in {unit_col, time_col}]
# LOCF: forward fill per unit
out[covariate_cols] = out.groupby(unit_col)[covariate_cols].ffill()
# Optional: still fill any remaining NaNs (e.g., if a unit starts mid-way)
# out[covariate_cols] = out[covariate_cols].fillna(fill_value)
return out