gbnet.models.survival.hazard_integrator

Classes

HazardIntegrator

Functions

loadModule(module)

Load the appropriate gradient boosting module.

Module Contents

gbnet.models.survival.hazard_integrator.loadModule(module)[source]

Load the appropriate gradient boosting module.

class gbnet.models.survival.hazard_integrator.HazardIntegrator(covariate_cols=[], params={}, min_hess=0.0, module_type='XGBModule', integration_method='trapezoid')[source]

Bases: torch.nn.Module

Parameters:
  • covariate_cols (List[str])

  • params (Dict)

  • min_hess (float)

  • module_type (str)

  • integration_method (str)

params[source]
min_hess = 0.0[source]
module_type = 'XGBModule'[source]
integration_method = 'trapezoid'[source]
covariate_cols = ['time'][source]
gb_module: object | None = None[source]
Module[source]
static_data: Dict[str, torch.Tensor][source]
_integrate_slice(values, dt, same_unit)[source]
_prepare_data(df)[source]

Pre-processes and caches data that is static during training. This method performs sorting, tensor conversion, and computes time differences and group boundaries once.

Parameters:

df (pandas.DataFrame)

forward(df, return_survival_estimates=True)[source]
Parameters:
  • df (pandas.DataFrame)

  • return_survival_estimates (bool)

Return type:

Dict[str, Any]

gb_step()[source]

Triggers the gradient boosting model to take a step.