From 3ac86a40e3ab6488d543e66abdcbfe2845077522 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 26 Dec 2025 12:26:39 +0100 Subject: [PATCH] refact base and input-target condition --- pina/condition/condition_base.py | 210 +++++++++++++++++++++++ pina/condition/condition_interface.py | 94 +--------- pina/condition/input_target_condition.py | 186 +++++++++++++++++--- 3 files changed, 378 insertions(+), 112 deletions(-) create mode 100644 pina/condition/condition_base.py diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py new file mode 100644 index 000000000..361352d0f --- /dev/null +++ b/pina/condition/condition_base.py @@ -0,0 +1,210 @@ +import torch +from copy import deepcopy +from .condition_interface import ConditionInterface +from ..graph import Graph, LabelBatch +from ..label_tensor import LabelTensor +from ..data.dummy_dataloader import DummyDataloader +from torch_geometric.data import Data, Batch +from torch.utils.data import DataLoader +from functools import partial + + +class ConditionBase(ConditionInterface): + collate_fn_dict = { + "tensor": torch.stack, + "label_tensor": LabelTensor.stack, + "graph": LabelBatch.from_data_list, + "data": Batch.from_data_list, + } + + def __init__(self, **kwargs): + super().__init__() + self.data = self._store_data(**kwargs) + + @property + def problem(self): + return self._problem + + @problem.setter + def problem(self, value): + self._problem = value + + @staticmethod + def _check_graph_list_consistency(data_list): + """ + Check the consistency of the list of Data | Graph objects. + The following checks are performed: + + - All elements in the list must be of the same type (either + :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`). + + - All elements in the list must have the same keys. + + - The data type of each tensor must be consistent across all elements. + + - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels + must also be consistent across all elements. + + :param data_list: The list of Data | Graph objects to check. + :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] + :raises ValueError: If the input types are invalid. + :raises ValueError: If all elements in the list do not have the same + keys. + :raises ValueError: If the type of each tensor is not consistent across + all elements in the list. + :raises ValueError: If the labels of the LabelTensors are not consistent + across all elements in the list. + """ + # If the data is a Graph or Data object, perform no checks + if isinstance(data_list, (Graph, Data)): + return + + # Check all elements in the list are of the same type + if not all(isinstance(i, (Graph, Data)) for i in data_list): + raise ValueError( + "Invalid input. Please, provide either Data or Graph objects." + ) + + # Store the keys, data types and labels of the first element + data = data_list[0] + keys = sorted(list(data.keys())) + data_types = {name: tensor.__class__ for name, tensor in data.items()} + labels = { + name: tensor.labels + for name, tensor in data.items() + if isinstance(tensor, LabelTensor) + } + + # Iterate over the list of Data | Graph objects + for data in data_list[1:]: + + # Check that all elements in the list have the same keys + if sorted(list(data.keys())) != keys: + raise ValueError( + "All elements in the list must have the same keys." + ) + + # Iterate over the tensors in the current element + for name, tensor in data.items(): + # Check that the type of each tensor is consistent + if tensor.__class__ is not data_types[name]: + raise ValueError( + f"Data {name} must be a {data_types[name]}, got " + f"{tensor.__class__}" + ) + + # Check that the labels of each LabelTensor are consistent + if isinstance(tensor, LabelTensor): + if tensor.labels != labels[name]: + raise ValueError( + "LabelTensor must have the same labels" + ) + + def _store_tensor_data(self, **kwargs): + """ + Store data for standard tensor condition + + :param kwargs: Keyword arguments representing the data to be stored. + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = {} + for key, value in kwargs.items(): + data[key] = value + return data + + def _store_graph_data(self, graphs, tensors=None, key=None): + """ + Store data for graph condition + + :param graphs: List of graphs to store data in. + :type graphs: list[Graph] | list[Data] + :param tensors: List of tensors to store in the graphs. + :type tensors: list[torch.Tensor] | list[LabelTensor] + :param key: Key under which to store the tensors in the graphs. + :type key: str + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = [] + for i, graph in enumerate(graphs): + new_graph = deepcopy(graph) + tensor = tensors[i] + setattr(new_graph, key, tensor) + data.append(new_graph) + return {"data": data} + + def _store_data(self, **kwargs): + return self._store_tensor_data(**kwargs) + + def __len__(self): + return len(next(iter(self.data.values()))) + + def __getitem__(self, idx): + return {key: self.data[key][idx] for key in self.data} + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + + to_return = {key: [] for key in batch[0].keys()} + for item in batch: + for key, value in item.items(): + to_return[key].append(value) + for key, values in to_return.items(): + collate_function = cls.collate_fn_dict.get( + "label_tensor" + if isinstance(values[0], LabelTensor) + else ( + "label_tensor" + if isinstance(values[0], torch.Tensor) + else "graph" if isinstance(values[0], Graph) else "data" + ) + ) + to_return[key] = collate_function(values) + return to_return + + @staticmethod + def collate_fn(batch, condition): + """ + Collate function for automatic batching to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: list + """ + data = condition[batch] + return data + + def create_dataloader( + self, dataset, batch_size, shuffle, automatic_batching + ): + """ + Create a DataLoader for the condition. + + :param int batch_size: The batch size for the DataLoader. + :param bool shuffle: Whether to shuffle the data. Default is ``False``. + :return: The DataLoader for the condition. + :rtype: torch.utils.data.DataLoader + """ + if batch_size == len(dataset): + return DummyDataloader(dataset) + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=( + partial(self.collate_fn, condition=self) + if not automatic_batching + else self.automatic_batching_collate_fn + ), + # collate_fn = self.automatic_batching_collate_fn + ) diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index b0264517c..427b85502 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -1,6 +1,6 @@ """Module for the Condition interface.""" -from abc import ABCMeta +from abc import ABCMeta, abstractmethod from torch_geometric.data import Data from ..label_tensor import LabelTensor from ..graph import Graph @@ -15,13 +15,14 @@ class ConditionInterface(metaclass=ABCMeta): description of all available conditions and how to instantiate them. """ - def __init__(self): + @abstractmethod + def __init__(self, **kwargs): """ Initialization of the :class:`ConditionInterface` class. """ - self._problem = None @property + @abstractmethod def problem(self): """ Return the problem associated with this condition. @@ -29,9 +30,9 @@ def problem(self): :return: Problem associated with this condition. :rtype: ~pina.problem.abstract_problem.AbstractProblem """ - return self._problem @problem.setter + @abstractmethod def problem(self, value): """ Set the problem associated with this condition. @@ -39,88 +40,3 @@ def problem(self, value): :param pina.problem.abstract_problem.AbstractProblem value: The problem to associate with this condition """ - self._problem = value - - @staticmethod - def _check_graph_list_consistency(data_list): - """ - Check the consistency of the list of Data | Graph objects. - The following checks are performed: - - - All elements in the list must be of the same type (either - :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`). - - - All elements in the list must have the same keys. - - - The data type of each tensor must be consistent across all elements. - - - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels - must also be consistent across all elements. - - :param data_list: The list of Data | Graph objects to check. - :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] - :raises ValueError: If the input types are invalid. - :raises ValueError: If all elements in the list do not have the same - keys. - :raises ValueError: If the type of each tensor is not consistent across - all elements in the list. - :raises ValueError: If the labels of the LabelTensors are not consistent - across all elements in the list. - """ - # If the data is a Graph or Data object, perform no checks - if isinstance(data_list, (Graph, Data)): - return - - # Check all elements in the list are of the same type - if not all(isinstance(i, (Graph, Data)) for i in data_list): - raise ValueError( - "Invalid input. Please, provide either Data or Graph objects." - ) - - # Store the keys, data types and labels of the first element - data = data_list[0] - keys = sorted(list(data.keys())) - data_types = {name: tensor.__class__ for name, tensor in data.items()} - labels = { - name: tensor.labels - for name, tensor in data.items() - if isinstance(tensor, LabelTensor) - } - - # Iterate over the list of Data | Graph objects - for data in data_list[1:]: - - # Check that all elements in the list have the same keys - if sorted(list(data.keys())) != keys: - raise ValueError( - "All elements in the list must have the same keys." - ) - - # Iterate over the tensors in the current element - for name, tensor in data.items(): - # Check that the type of each tensor is consistent - if tensor.__class__ is not data_types[name]: - raise ValueError( - f"Data {name} must be a {data_types[name]}, got " - f"{tensor.__class__}" - ) - - # Check that the labels of each LabelTensor are consistent - if isinstance(tensor, LabelTensor): - if tensor.labels != labels[name]: - raise ValueError( - "LabelTensor must have the same labels" - ) - - def __getattribute__(self, name): - """ - Get an attribute from the object. - - :param str name: The name of the attribute to get. - :return: The requested attribute. - :rtype: Any - """ - to_return = super().__getattribute__(name) - if isinstance(to_return, (Graph, Data)): - to_return = [to_return] - return to_return diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 07b07bb7b..965eeecfc 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -3,13 +3,15 @@ """ import torch +from copy import deepcopy from torch_geometric.data import Data from ..label_tensor import LabelTensor -from ..graph import Graph -from .condition_interface import ConditionInterface +from ..graph import Graph, LabelBatch +from .condition_base import ConditionBase +from torch_geometric.data import Batch -class InputTargetCondition(ConditionInterface): +class InputTargetCondition(ConditionBase): """ The :class:`InputTargetCondition` class represents a supervised condition defined by both ``input`` and ``target`` data. The model is trained to @@ -55,7 +57,7 @@ class InputTargetCondition(ConditionInterface): """ # Available input and target data types - __slots__ = ["input", "target"] + __fields__ = ["input", "target"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) @@ -109,16 +111,6 @@ def __new__(cls, input, target): subclass = GraphInputTensorTargetCondition return subclass.__new__(subclass, input, target) - # Graph - Graph - if isinstance(input, (Graph, Data, list, tuple)) and isinstance( - target, (Graph, Data, list, tuple) - ): - cls._check_graph_list_consistency(input) - cls._check_graph_list_consistency(target) - subclass = GraphInputGraphTargetCondition - return subclass.__new__(subclass, input, target) - - # If the input and/or target are not of the correct type raise an error raise ValueError( "Invalid input | target types." "Please provide either torch_geometric.data.Data, Graph, " @@ -143,10 +135,8 @@ def __init__(self, input, target): objects, all elements in the list must share the same structure, with matching keys and consistent data types. """ - super().__init__() self._check_input_target_len(input, target) - self.input = input - self.target = target + super().__init__(input=input, target=target) @staticmethod def _check_input_target_len(input, target): @@ -181,6 +171,26 @@ class TensorInputTensorTargetCondition(InputTargetCondition): :class:`~pina.label_tensor.LabelTensor` objects. """ + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: torch.Tensor | LabelTensor + """ + return self.data["input"] + + @property + def target(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: torch.Tensor | LabelTensor + """ + return self.data["target"] + class TensorInputGraphTargetCondition(InputTargetCondition): """ @@ -190,6 +200,65 @@ class TensorInputGraphTargetCondition(InputTargetCondition): :class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object. """ + def _store_data(self, **kwargs): + return self._store_graph_data( + kwargs["target"], kwargs["input"], key="x" + ) + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: list[torch.Tensor] | list[LabelTensor] + """ + targets = [] + is_lt = isinstance(self.data["data"][0].x, LabelTensor) + for graph in self.data["data"]: + targets.append(graph.x) + return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) + + @property + def target(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: list[Graph] | list[Data] + """ + return self.data["data"] + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.get_multiple_data(idx) + return {"data": self.data["data"][idx]} + + def get_multiple_data(self, indices): + data = self.batch_fn([self.data["data"][i] for i in indices]) + x = data.x + del data.x # Avoid duplication of y on GPU memory + return { + "input": x, + "target": data, + } + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + collated_graphs = super().automatic_batching_collate_fn(batch) + x = collated_graphs["data"].x + del collated_graphs["data"].x # Avoid duplication of y on GPU memory + to_return = {"input": x, "input": collated_graphs["data"]} + return to_return + class GraphInputTensorTargetCondition(InputTargetCondition): """ @@ -199,10 +268,81 @@ class GraphInputTensorTargetCondition(InputTargetCondition): :class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object. """ + def __init__(self, input, target): + """ + Initialization of the :class:`GraphInputTensorTargetCondition` class. -class GraphInputGraphTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - both ``input`` and ``target`` are either :class:`~pina.graph.Graph` or - :class:`torch_geometric.data.Data` objects. - """ + :param input: The input data for the condition. + :type input: Graph | Data | list[Graph] | list[Data] | + tuple[Graph] | tuple[Data] + :param target: The target data for the condition. + :type target: torch.Tensor | LabelTensor + """ + super().__init__(input=input, target=target) + self.batch_fn = ( + LabelBatch.from_data_list + if isinstance(input[0], Graph) + else Batch.from_data_list + ) + + def _store_data(self, **kwargs): + return self._store_graph_data( + kwargs["input"], kwargs["target"], key="y" + ) + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: list[Graph] | list[Data] + """ + return self.data["data"] + + @property + def target(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: list[torch.Tensor] | list[LabelTensor] + """ + targets = [] + is_lt = isinstance(self.data["data"][0].y, LabelTensor) + for graph in self.data["data"]: + targets.append(graph.y) + + return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.get_multiple_data(idx) + return {"data": self.data["data"][idx]} + + def get_multiple_data(self, indices): + data = self.batch_fn([self.data["data"][i] for i in indices]) + y = data.y + del data.y # Avoid duplication of y on GPU memory + return { + "input": data, + "target": y, + } + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + collated_graphs = super().automatic_batching_collate_fn(batch) + y = collated_graphs["data"].y + del collated_graphs["data"].y # Avoid duplication of y on GPU memory + print("y shape:", y.shape) + print(y.labels) + to_return = {"target": y, "input": collated_graphs["data"]} + return to_return