{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Custom Split Example\n", "\n", "Contributed by [Franz Görlich](https://github.com/frgoe003).\n", "\n", "Import core libraries:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import plinder \n", "\n", "from plinder.data.splits import (\n", " split,\n", " get_default_config,\n", ")\n", "from plinder.core.scores import query_index\n", "\n", "plinder_cfg = plinder.core.get_config()\n", "plinder_local_storage = plinder_cfg.data.plinder_dir" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split Config" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's modify the split config. We will first get the default config using `get_default_config()` and then change some of the parameters." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Validation set size: 1000\n", "Test set size: 1000\n", "Minimum size of each cluster in the validation set: 30\n" ] } ], "source": [ "cfg = get_default_config() \n", "print(f'Validation set size: {cfg.split.num_val}')\n", "print(f'Test set size: {cfg.split.num_test}')\n", "print(f'Minimum size of each cluster in the validation set: {cfg.split.min_val_cluster_size}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we reduced the total number of samples, let's also reduce the minimum validation set cluster size, so we avoid removing to many systems.
All configs can be found [here](https://plinder-org.github.io/plinder/dataset.html#splits-splits)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "cfg.split.num_test = 500 # Reduce the max size of the test set\n", "cfg.split.num_val = 500 # Reduce the max size of the validation set\n", "cfg.split.min_val_cluster_size = 5 # Reduce the minimum required size of each cluster in the validation set" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's generate a custom dataset that we want to resplit. First, let's load the plindex and then create a custom dataset." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-10-25 11:14:50,015 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.53s\n", "2024-10-25 11:14:51,601 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.55s\n" ] } ], "source": [ "cols_of_interest = [\n", " \"system_id\",\n", " \"entry_pdb_id\",\n", " \"ligand_ccd_code\",\n", " \"ligand_binding_affinity\",\n", " \"ligand_is_proper\",\n", " \"ligand_molecular_weight\",\n", " \"system_has_binding_affinity\",\n", "]\n", "custom_df = query_index(\n", " columns=cols_of_interest, splits=[\"train\", \"val\", \"test\", \"removed\"] \n", ")" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(567394, 8)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "custom_df.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's filter every system out that doesn't have a binding affinity `system_has_binding_affinity`, ions and artifacts `ligand_is_proper` and ligands with a molecular weight of less than 400 g/mol `ligand_molecular_weight`." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
system_identry_pdb_idligand_ccd_codeligand_binding_affinityligand_is_properligand_molecular_weightsystem_has_binding_affinitysplit
202grt__1__1.A_2.A__1.C2grtGDS6.079633True612.151962Truetrain
222grt__1__1.A_2.A__2.C2grtGDS6.079633True612.151962Truetrain
748gr9__1__1.A_1.B__1.C_1.J8gr9COA5.465907True767.115209Trueremoved
851grn__1__1.A_1.B__1.C_1.D_1.E1grnGDP3.428291True443.024330Truetrain
\n", "
" ], "text/plain": [ " system_id entry_pdb_id ligand_ccd_code \\\n", "20 2grt__1__1.A_2.A__1.C 2grt GDS \n", "22 2grt__1__1.A_2.A__2.C 2grt GDS \n", "74 8gr9__1__1.A_1.B__1.C_1.J 8gr9 COA \n", "85 1grn__1__1.A_1.B__1.C_1.D_1.E 1grn GDP \n", "\n", " ligand_binding_affinity ligand_is_proper ligand_molecular_weight \\\n", "20 6.079633 True 612.151962 \n", "22 6.079633 True 612.151962 \n", "74 5.465907 True 767.115209 \n", "85 3.428291 True 443.024330 \n", "\n", " system_has_binding_affinity split \n", "20 True train \n", "22 True train \n", "74 True removed \n", "85 True train " ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "custom_df = custom_df[\n", " (custom_df[\"ligand_is_proper\"] == True) &\n", " (custom_df[\"system_has_binding_affinity\"] == True) &\n", " (custom_df[\"ligand_molecular_weight\"] > 400)\n", "]\n", "custom_df.head(4)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(36247, 8)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "custom_df.shape" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "split\n", "train 25610\n", "removed 10483\n", "val 92\n", "test 62\n", "Name: count, dtype: int64" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "custom_df['split'].value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Resplitting the Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that we end up with only 92 systems in our validation set and 62 in our test set. We also have over 10k removed systems. Let's resplit the dataset and see how the new split looks like." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__NOTE__: resplitting the dataset requires a lot of memory and might only be feasible on a HPC cluster." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(33047, 13)" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_dir = Path(plinder_local_storage)\n", "custom_systems = set(custom_df['system_id'].unique())\n", "split_name = 'custom_1'\n", "\n", "new_split_df = split(\n", " data_dir=data_dir,\n", " cfg=cfg, # here we use the modified config from earlier\n", " relpath=split_name,\n", " selected_systems=custom_systems \n", ")\n", "new_split_df.shape" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
system_iduniquenesssplitclustercluster_for_val_splitsystem_pass_validation_criteriasystem_pass_statistics_criteriasystem_proper_num_ligand_chainssystem_proper_num_pocket_residuessystem_proper_num_interactionssystem_proper_ligand_max_molecular_weightsystem_has_binding_affinitysystem_has_apo_or_pred
010gs__1__1.A_1.B__1.C10gs__A_B__C_c101993trainc62c0TrueTrue12415473.162057TrueFalse
110gs__1__1.A_1.B__1.E10gs__A_B__E_c101949trainc62c0TrueTrue12413473.162057TrueFalse
219gs__1__1.A_1.B__1.C_1.D19gs__A_B__C_D_c147080trainc62c0FalseTrue23012787.630334TrueFalse
319gs__1__1.A_1.B__1.F_1.G19gs__A_B__F_G_c101954trainc62c0FalseTrue23014787.630334TrueFalse
\n", "
" ], "text/plain": [ " system_id uniqueness split cluster \\\n", "0 10gs__1__1.A_1.B__1.C 10gs__A_B__C_c101993 train c62 \n", "1 10gs__1__1.A_1.B__1.E 10gs__A_B__E_c101949 train c62 \n", "2 19gs__1__1.A_1.B__1.C_1.D 19gs__A_B__C_D_c147080 train c62 \n", "3 19gs__1__1.A_1.B__1.F_1.G 19gs__A_B__F_G_c101954 train c62 \n", "\n", " cluster_for_val_split system_pass_validation_criteria \\\n", "0 c0 True \n", "1 c0 True \n", "2 c0 False \n", "3 c0 False \n", "\n", " system_pass_statistics_criteria system_proper_num_ligand_chains \\\n", "0 True 1 \n", "1 True 1 \n", "2 True 2 \n", "3 True 2 \n", "\n", " system_proper_num_pocket_residues system_proper_num_interactions \\\n", "0 24 15 \n", "1 24 13 \n", "2 30 12 \n", "3 30 14 \n", "\n", " system_proper_ligand_max_molecular_weight system_has_binding_affinity \\\n", "0 473.162057 True \n", "1 473.162057 True \n", "2 787.630334 True \n", "3 787.630334 True \n", "\n", " system_has_apo_or_pred \n", "0 False \n", "1 False \n", "2 False \n", "3 False " ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_split_df.head(4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualizing the new Split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have our first custom split, let's use the `SplitPropertiesPlotter` to visualize the new split." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from plinder.core.split.plot import SplitPropertiesPlotter" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plotter = SplitPropertiesPlotter.from_files(\n", " data_dir = Path(plinder_local_storage),\n", " split_file = Path(f'{plinder_local_storage}/splits/split_{split_name}.parquet'),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This will create a folder `split_plots` in the current working directory with the following plots:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`split_plots/split_proportions.png`\n", "![split_proportions.png](./split_plots/chain_composition.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`split_plots/chain_composition.png`\n", "![chain_composition.png](./split_plots/chain_composition.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`split_plots/domain_classifications.png`\n", "![domain_classifications.png](./split_plots/domain_classifications.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`split_plots/ligand_types.png`\n", "![ligand_types.png](./split_plots/ligand_types.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`split_plots/molecular_descriptors.png`\n", "![molecular_descriptors.png](./split_plots/molecular_descriptors.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`split_plots/priorities.png`\n", "![priorities.png](./split_plots/priorities.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`split_plots/plinder_clusters.png`\n", "![plinder_clusters.png](./split_plots/plinder_clusters.png)" ] } ], "metadata": { "kernelspec": { "display_name": "MPT_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" }, "mystnb": { "execution_mode": "off" } }, "nbformat": 4, "nbformat_minor": 2 }