Source code for gbnet.models.survival.hazard_integrator

from typing import List, Dict, Any, Optional

import pandas as pd
import torch


[docs] def loadModule(module): """Load the appropriate gradient boosting module.""" assert module in {"XGBModule", "LGBModule"} if module == "XGBModule": from gbnet import xgbmodule return xgbmodule.XGBModule if module == "LGBModule": from gbnet import lgbmodule return lgbmodule.LGBModule
[docs] class HazardIntegrator(torch.nn.Module): def __init__( self, covariate_cols: List[str] = [], params: Dict = {}, min_hess: float = 0.0, module_type: str = "XGBModule", integration_method: str = "trapezoid", ): """ Parameters ---------- covariate_cols Columns to feed into the model. "time" is always included. params Parameters for the gradient boosting model. min_hess Minimum hessian for the gradient boosting model. module_type Type of gradient boosting module to use, either "XGBModule" or "LGBModule". Defaults to "XGBModule". integration_method Method for integrating hazards and survival estimates. One of "trapezoid", "stepwise_left", or "stepwise_right". """ super().__init__() assert integration_method in {"trapezoid", "stepwise_left", "stepwise_right"}
[docs] self.params = params.copy()
[docs] self.min_hess = min_hess
[docs] self.module_type = module_type
[docs] self.integration_method = integration_method
[docs] self.covariate_cols = ["time"] + covariate_cols
[docs] self.gb_module: Optional[object] = None
[docs] self.Module = loadModule(module_type)
# Buffer to store pre-processed static data during training
[docs] self.static_data: Dict[str, torch.Tensor] = {}
[docs] def _integrate_slice(self, values, dt, same_unit): slice_values = torch.zeros_like(values) if self.integration_method == "trapezoid": slice_values[same_unit] = ( 0.5 * (values[same_unit] + values.roll(1, 0)[same_unit]) * dt[same_unit] ) elif self.integration_method == "stepwise_left": slice_values[same_unit] = values.roll(1, 0)[same_unit] * dt[same_unit] elif self.integration_method == "stepwise_right": slice_values[same_unit] = values[same_unit] * dt[same_unit] return slice_values
[docs] def _prepare_data(self, df: pd.DataFrame): """ Pre-processes and caches data that is static during training. This method performs sorting, tensor conversion, and computes time differences and group boundaries once. """ if {"unit_id", "time"} - set(df.columns): raise ValueError("DataFrame must contain 'unit_id' and 'time'.") # 1. Sort by (unit_id, time) to enable group-wise operations df_sorted = df.sort_values(["unit_id", "time"], kind="mergesort").reset_index() # 2. Create tensors for indices, unit IDs, and times # These are fundamental for all subsequent calculations. orig_idx = torch.as_tensor(df_sorted["index"].values, dtype=torch.long) unit_codes, _ = pd.factorize(df_sorted["unit_id"]) unit_ids = torch.as_tensor(unit_codes, dtype=torch.long) times = torch.as_tensor(df_sorted["time"].values, dtype=torch.float32) # 3. Pre-calculate time differences (dt) within each unit dt = torch.diff(times, prepend=times.new_zeros(1)) # `same_unit` mask ensures we only calculate dt for records from the same unit same_unit = unit_ids == torch.roll(unit_ids, 1, 0) same_unit[0] = False # The first element is never preceded by the same unit dt.mul_(same_unit) # In-place multiplication for efficiency # 4. Create the input matrix for the gradient boosting model # Using .to_numpy() is generally faster than .values X = df_sorted[self.covariate_cols] # Create appropriate data matrix based on module type if self.module_type == "XGBModule": import xgboost as xgb dmatrix = xgb.DMatrix(X, enable_categorical=True) elif self.module_type == "LGBModule": import lightgbm as lgb dmatrix = lgb.Dataset(X) else: raise ValueError(f"Unsupported module type: {self.module_type}") # 5. Store all static tensors in the cache self.static_data = { "dmatrix": dmatrix, "num_rows": X.shape[0], "num_cols": X.shape[1], "unit_ids": unit_ids, "dt": dt, "same_unit": same_unit, "unsort_idx": torch.argsort(orig_idx), "interleave_amts": torch.as_tensor( df_sorted.groupby("unit_id").size().to_numpy(), dtype=torch.int64 ), }
[docs] def forward( self, df: pd.DataFrame, return_survival_estimates: bool = True ) -> Dict[str, Any]: # During training, use cached data. In eval mode, re-process every time. if not self.static_data or not self.training: self._prepare_data(df) # Unpack cached tensors for cleaner access dmatrix = self.static_data["dmatrix"] num_rows = self.static_data["num_rows"] num_features = self.static_data["num_cols"] unit_ids = self.static_data["unit_ids"] dt = self.static_data["dt"] same_unit = self.static_data["same_unit"] unsort_idx = self.static_data["unsort_idx"] interleave_amts = self.static_data["interleave_amts"] # Lazily initialize the gradient boosting module on the first forward pass if self.gb_module is None: self.gb_module = self.Module( num_rows, num_features, 1, params=self.params, min_hess=self.min_hess ) # 1. Model inference: This is the dynamic part of the forward pass log_hazard = self.gb_module(dmatrix).flatten() # [N] hazard = torch.exp(log_hazard) # λ(t) # 2. Per-row slice for hazard integration: λ(t) * Δt hazard_slice = self._integrate_slice(hazard, dt, same_unit) # 3. Cumulative hazard per unit: Λ(T) = Σ λ(t)Δt # `scatter_reduce` is highly efficient for grouped sums. # We get both the total integrated hazard per unit... unit_Lambda = torch.zeros( interleave_amts.size(0), device=hazard.device ).scatter_reduce_(0, unit_ids, hazard_slice, reduce="sum", include_self=False) is_last_in_group = torch.cat( ( unit_ids[:-1] != unit_ids[1:], torch.tensor([True], device=unit_ids.device), ) ) last_hazard = hazard[is_last_in_group] if return_survival_estimates: # ...and the cumulative hazard Λ(t) for each time step. # Your original logic is solid; we just use the cached tensors. Lambda_global = torch.cumsum(hazard_slice, dim=0) # Create a tensor of cumulative sums at the end of each *previous* group cusum_prev_groups = torch.cat( [ torch.tensor([0.0], device=hazard.device), torch.cumsum(unit_Lambda, 0)[:-1], ] ) # Broadcast this value to subtract it from each member of the next group Lambda = Lambda_global - torch.repeat_interleave( cusum_prev_groups, interleave_amts ) # 4. Survival function S(t) = exp(‑Λ(t)) survival = torch.exp(-Lambda) # 5. Expected value 𝔼[T] ≈ Σ S(t)Δt surv_slice = self._integrate_slice(survival, dt, same_unit) unit_E = torch.zeros_like(unit_Lambda).scatter_reduce_( 0, unit_ids, surv_slice, reduce="sum", include_self=False ) # 7. Unsort results to match original DataFrame order return { "hazard": hazard[unsort_idx], "unit_last_hazard": last_hazard, "unit_integrated_hazard": unit_Lambda, "survival": None if not return_survival_estimates else survival[unsort_idx], "unit_expected_time": None if not return_survival_estimates else unit_E, }
[docs] def gb_step(self): """Triggers the gradient boosting model to take a step.""" if self.gb_module: self.gb_module.gb_step()