opendataval.experiment.ExperimentMediator#
- class opendataval.experiment.ExperimentMediator(fetcher: DataFetcher, pred_model: Model, train_kwargs: dict[str, Any] | None = None, metric_name: str | Metrics | Callable | None = None, output_dir: str | Path | None = None, raises_error: bool = False)#
Set up an experiment to compare a group of DataEvaluators.
Attributes#
timings : dict[str, timedelta]
Parameters#
- fetcherDataFetcher
DataFetcher for the data set used for the experiment. All exper_func take a DataFetcher as an argument to have access to all data points and noisy indices.
- pred_modelModel
Prediction model for the DataEvaluators
- train_kwargsdict[str, Any], optional
Training key word arguments for the prediction model, by default None
- metric_namestr | Metric | Callable[[Tensor, Tensor], float], optional
Name of the performance metric used to evaluate the performance of the prediction model, by default accuracy
- output_dir: Union[str, pathlib.Path], optional
Output directory of experiments
- raises_error: bool, optional
Raises exception if one of the data evaluators fail, otherwise warns the user but continues computation. By default, False
- __init__(fetcher: DataFetcher, pred_model: Model, train_kwargs: dict[str, Any] | None = None, metric_name: str | Metrics | Callable | None = None, output_dir: str | Path | None = None, raises_error: bool = False)#
Methods
__init__
(fetcher, pred_model[, ...])compute_data_values
(data_evaluators, *args, ...)Computes the data values for the input data evaluators.
evaluate
(exper_func[, save_output])Evaluate exper_func on each DataEvaluator.
model_factory_setup
(dataset_name[, ...])Set up ExperimentMediator from ModelFactory using an input string.
plot
(exper_func[, figure, row, col, save_output])Evaluate exper_func on each DataEvaluator and plots result in fig.
save_output
(file_name, df)Saves the output of the DataFrame to f"{self.output_directory}/{file_name}".
set_output_directory
(output_directory)Set directory to save output of experiment.
setup
(dataset_name[, cache_dir, ...])Create a DataFetcher from args and passes it into the init.