Pytorch modelsΒΆ
- class giskard.models.pytorch.PyTorchModel(model, model_type: SupportedModelTypes | Literal['classification', 'regression', 'text_generation'], torch_dtype: Literal['float32', 'float', 'float64', 'double', 'complex64', 'cfloat', 'float16', 'half', 'bfloat16', 'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long', 'bool'] = 'float32', device='cpu', name: str | None = None, data_preprocessing_function=None, model_postprocessing_function=None, feature_names=None, classification_threshold=0.5, classification_labels=None, iterate_dataset: bool = True, id: str | None = None, batch_size: int | None = None, **kwargs)[source]ΒΆ
Automatically wraps a PyTorch model.
This class provides a default wrapper around the PyTorch library for usage with Giskard.
- Parameters:
model (Any) β The PyTorch model to wrap.
model_type (ModelType) β The type of the model, either
regression
orclassification
.torch_dtype (Optional[TorchDType]) β The data type to use for the input data. Default is βfloat32β.
device (Optional[str]) β The device to use for the model. We will ensure that the model is on this device before running the inference. Default is βcpuβ. Make sure that your
data_preprocessing_function
returns tensors on the same device.name (Optional[str]) β A name for the wrapper. Default is
None
.data_preprocessing_function (Optional[Callable[[pd.DataFrame], Any]]) β A function that will be applied to incoming data, before passing them to the model. You may want use this to convert the data to tensors. Default is
None
.model_postprocessing_function (Optional[Callable[[Any], Any]]) β A function that will be applied to the modelβs predictions. Default is
None
.feature_names (Optional[Iterable]) β A list of feature names. Default is
None
.classification_threshold (Optional[float]) β The probability threshold for classification. Default is 0.5.
classification_labels (Optional[Iterable]) β A list of classification labels. Default is
None
.iterate_dataset (Optional[bool]) β Whether to iterate over the dataset. Default is
True
.batch_size (Optional[int]) β The batch size to use for inference. Default is 1.
- classmethod load_model(local_dir, model_py_ver: Tuple[str, str, str] | None = None, *_args, **_kwargs)[source]ΒΆ
Loads the wrapped
model
object.- Parameters:
path (Union[str, Path]) β Path from which the model should be loaded.
model_py_ver (Optional[Tuple[str, str, str]]) β Python version used to save the model, to validate if model loading failed.
- model_predict(data)[source]ΒΆ
Performs the model inference/forward pass.
- Parameters:
data (Any) β The input data for making predictions. If you did not specify a data_preprocessing_function, this will be a
pd.DataFrame
, otherwise it will be whatever the data_preprocessing_function returns.- Returns:
If the model is
classification
, it should return an array of probabilities of shape(num_entries, num_classes)
. If the model isregression
ortext_generation
, it should return an array ofnum_entries
predictions.- Return type:
numpy.ndarray