opendataval.model.ModelFactory#

opendataval.model.ModelFactory(model_name: str, fetcher: DataFetcher | None = None, device: device = device(type='cpu'), *args, **kwargs) Model#

Factory to create prediction models from specified presets

Model Factory that creates a specified mode, based on the input parameters, it is recommended to import the specific model and specify additional arguments instead of relying on the factory.

Parameters#

model_namestr

Name of prediction model

covar_dimtuple[int, …]

Dimensions of the covariates, typically the shape besides first dimension

label_dimtuple[int, …]

Dimensions of the labels, typically the shape besides first dimension

devicetorch.device, optional

Tensor device for acceleration, some models do not use this argument, by default torch.device(“cpu”)

argstuple[Any]

Additional positional arguments passed to the Model constructor

kwargstuple[Any]

Additional key word arguments passed to the Model constructor

Returns#

Model

Preset model with the specified dimensions on the specified tensor device

Raises#

ValueError

Raises exception when model name is not matched