opendataval.model.ModelFactory#
- opendataval.model.ModelFactory(model_name: str, fetcher: DataFetcher | None = None, device: device = device(type='cpu'), *args, **kwargs) Model #
Factory to create prediction models from specified presets
Model Factory that creates a specified mode, based on the input parameters, it is recommended to import the specific model and specify additional arguments instead of relying on the factory.
Parameters#
- model_namestr
Name of prediction model
- covar_dimtuple[int, …]
Dimensions of the covariates, typically the shape besides first dimension
- label_dimtuple[int, …]
Dimensions of the labels, typically the shape besides first dimension
- devicetorch.device, optional
Tensor device for acceleration, some models do not use this argument, by default torch.device(“cpu”)
- argstuple[Any]
Additional positional arguments passed to the Model constructor
- kwargstuple[Any]
Additional key word arguments passed to the Model constructor
Returns#
- Model
Preset model with the specified dimensions on the specified tensor device
Raises#
- ValueError
Raises exception when model name is not matched