class opendataval.dataval.DVRL(*args, **kwargs)#

Data valuation using reinforcement learning class, implemented with PyTorch.



hidden_dimint, optional

Hidden dimensions for the RL Multilayer Perceptron Value Estimator (VE) (details in DataValueEstimatorRL class), by default 100

layer_numberint, optional

Number of hidden layers for the Value Estimator (VE), by default 5

comb_dimint, optional

After concat inputs how many layers, much less than hidden_dim, by default 10

rl_epochsint, optional

Number of training epochs for the VE, by default 1000

rl_batch_sizeint, optional

Batch size for training the VE, by default 32

lrfloat, optional

Learning rate for the VE, by default 0.01

thresholdfloat, optional

Search rate threshold, the VE may get stuck in certain bounds close to \([0, 1]\), thus outside of \([1-threshold, threshold]\) we encourage searching, by default 0.9

devicetorch.device, optional

Tensor device for acceleration, by default torch.device(“cpu”)

random_stateRandomState, optional

Random initial state, by default None

__init__(hidden_dim: int = 100, layer_number: int = 5, comb_dim: int = 10, rl_epochs: int = 1000, rl_batch_size: int = 32, lr: float = 0.01, threshold: float = 0.9, device: device = device(type='cpu'), random_state: RandomState | None = None)#


__init__([hidden_dim, layer_number, ...])

evaluate(y, y_hat)

Evaluate performance of the specified metric between label and predictions.


Return data values for each training data point.

input_data(x_train, y_train, x_valid, y_valid)

Store and transform input data for DVRL.


Input data from a DataFetcher object.


Input the evaluation metric.


Input the prediction model.

input_model_metric(pred_model, metric)

Input the prediction model and the evaluation metric.

setup(fetcher[, pred_model, metric])

Inputs model, metric and data into Data Evaluator.

train(fetcher[, pred_model, metric])

Store and transform data, then train model to predict data values.

train_data_values(*args[, num_workers])

Trains model to predict data values.




Cached data values.