Datasets and Dataloaders#

from plinder.core.loader.dataset import PlinderDataset, get_torch_loader

PlinderDataset provides an interface to interact with PLINDER data as a dataset. It is a subclass of torch.utils.data.Dataset, as such subclassing it and extending should be familiar to most users. Flexibility and general applicability is our top concern when designing this interface and PlinderDataset allows users to not only define their own split but to also bring their own featurizer. It can be initialized with the following parameters

Parameters
    ----------
    df : pd.DataFrame | None
        the split to use
    split : str
        the split to sample from
    split_parquet_path : str | Path, default=None
        split parquet file
    input_structure_priority : str, default="apo"
        Which alternate structure to proritize
    featurizer: Callable[
            [Structure, int], dict[str, torch.Tensor]
    ] = structure_featurizer,
        Transformation to turn structure to input tensors
    padding_value : int
        Value for padding uneven array
    **kwargs : Any
        Any other keyword args

For an example of how to write your own featurizer see Featurizer Example. The signature is shown below:

def structure_featurizer(
    structure: Structure, pad_value: int = -100
    ) -> dict[str, Any]:

The input is a Structure object and it returns dictionary of padded tensor features.

Note

This is where you may want to load a train dataset, but for the purposes of demonstration - we will start with val due to smaller memory footprint, and load only a small subset of systems containing ATP as ligand. We also set use_alternate_structures=False to prevent downloading and loading alternate structures for the docs.

val_dataset = PlinderDataset(
    split="val",
    filters=[
        ("system_num_protein_chains", "==", 1),
        ("ligand_unique_ccd_code", "in", {"ATP"}),
    ],
    use_alternate_structures=False,
)
len(val_dataset)
9
val_data = val_dataset[1]
val_loader = get_torch_loader(val_dataset)
for data in val_loader:
    test_torch = data
    break
2024-10-23 14:41:05,078 | plinder.core.structure.atoms:214 | WARNING : get_template_to_mol_matches: could not match template fully - retry with unmatched bonds set as UNSPECIFIED
test_torch.keys()
dict_keys(['system_ids', 'holo_structures', 'alternate_structures', 'paths', 'features_and_coords'])
test_torch["system_ids"]
['4jxz__1__1.A__1.C', '4a8s__1__1.A__1.H_1.I']
for k, v in test_torch["features_and_coords"].items():
    print(k, v.shape)
sequence_atom_mask_feature torch.Size([2, 1, 4656])
input_sequence_residue_mask_feature torch.Size([2, 1, 665])
protein_coordinates torch.Size([2, 1, 5265, 3])
protein_calpha_coordinates torch.Size([2, 1, 664, 3])
input_sequence_full_atom_feature torch.Size([2, 1, 5272, 12])
protein_structure_residue_feature torch.Size([2, 1, 5265, 21])
input_conformer_ligand_feature torch.Size([2, 2, 31, 16])
input_conformer_ligand_coordinates torch.Size([2, 2, 31, 3])
resolved_ligand_mols_feature torch.Size([2, 2, 31, 3])