opendataval.dataval.DVRL#
- class opendataval.dataval.DVRL(*args, **kwargs)#
Data valuation using reinforcement learning class, implemented with PyTorch.
References#
Parameters#
- hidden_dimint, optional
Hidden dimensions for the RL Multilayer Perceptron Value Estimator (VE) (details in
DataValueEstimatorRL
class), by default 100- layer_numberint, optional
Number of hidden layers for the Value Estimator (VE), by default 5
- comb_dimint, optional
After concat inputs how many layers, much less than hidden_dim, by default 10
- rl_epochsint, optional
Number of training epochs for the VE, by default 1000
- rl_batch_sizeint, optional
Batch size for training the VE, by default 32
- lrfloat, optional
Learning rate for the VE, by default 0.01
- thresholdfloat, optional
Search rate threshold, the VE may get stuck in certain bounds close to \([0, 1]\), thus outside of \([1-threshold, threshold]\) we encourage searching, by default 0.9
- devicetorch.device, optional
Tensor device for acceleration, by default torch.device(“cpu”)
- random_stateRandomState, optional
Random initial state, by default None
- __init__(hidden_dim: int = 100, layer_number: int = 5, comb_dim: int = 10, rl_epochs: int = 1000, rl_batch_size: int = 32, lr: float = 0.01, threshold: float = 0.9, device: device = device(type='cpu'), random_state: RandomState | None = None)#
Methods
__init__
([hidden_dim, layer_number, ...])evaluate
(y, y_hat)Evaluate performance of the specified metric between label and predictions.
evaluate_data_values
()Return data values for each training data point.
input_data
(x_train, y_train, x_valid, y_valid)Store and transform input data for DVRL.
input_fetcher
(fetcher)Input data from a DataFetcher object.
input_metric
(metric)Input the evaluation metric.
input_model
(pred_model)Input the prediction model.
input_model_metric
(pred_model, metric)Input the prediction model and the evaluation metric.
setup
(fetcher[, pred_model, metric])Inputs model, metric and data into Data Evaluator.
train
(fetcher[, pred_model, metric])Store and transform data, then train model to predict data values.
train_data_values
(*args[, num_workers])Trains model to predict data values.
Attributes
Evaluators
data_values
Cached data values.