torch_frame.nn.models.Trompt

class Trompt(channels: int, out_channels: int, num_prompts: int, num_layers: int, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dicts: list[dict[torch_frame.stype, StypeEncoder]] | None = None)[source]

Bases: Module

The Trompt model introduced in the “Trompt: Towards a Better Deep Neural Network for Tabular Data” paper.

Note

For an example of using Trompt, see examples/trompt.py.

Parameters:
  • channels (int) – Hidden channel dimensionality

  • out_channels (int) – Output channels dimensionality

  • num_prompts (int) – Number of prompt columns.

  • num_layers (int, optional) – Number of TromptConv layers. (default: 6)

  • col_stats (Dict[str,Dict[torch_frame.data.stats.StatType,Any]]) – A dictionary that maps column name into stats. Available as dataset.col_stats.

  • col_names_dict (Dict[torch_frame.stype, List[str]]) – A dictionary that maps stype to a list of column names. The column names are sorted based on the ordering that appear in tensor_frame.feat_dict. Available as tensor_frame.col_names_dict.

  • stype_encoder_dicts – (list[dict[torch_frame.stype, torch_frame.nn.encoder.StypeEncoder]], optional): A list of num_layers dictionaries that each dictionary maps stypes into their stype encoders. (default: None, will call EmbeddingEncoder() for categorical feature and LinearEncoder() for numerical feature)

forward(tf: TensorFrame) Tensor[source]

Transforming TensorFrame object into a series of output predictions at each layer. Used during training to compute layer-wise loss.

Parameters:

tf (torch_frame.TensorFrame) – Input TensorFrame object.

Returns:

Output predictions stacked across layers. The

shape is [batch_size, num_layers, out_channels].

Return type:

torch.Tensor