Hidden dimensions for the RL Multilayer Perceptron Value Estimator (VE)
(details in DataValueEstimatorRL class), by default 100
layer_numberint, optional
Number of hidden layers for the Value Estimator (VE), by default 5
comb_dimint, optional
After concat inputs how many layers, much less than hidden_dim, by default 10
rl_epochsint, optional
Number of training epochs for the VE, by default 1000
rl_batch_sizeint, optional
Batch size for training the VE, by default 32
lrfloat, optional
Learning rate for the VE, by default 0.01
thresholdfloat, optional
Search rate threshold, the VE may get stuck in certain bounds close to
\([0, 1]\), thus outside of \([1-threshold, threshold]\) we encourage
searching, by default 0.9
devicetorch.device, optional
Tensor device for acceleration, by default torch.device(“cpu”)
Here, we assume a simple multi-layer perceptron architecture for the data
value evaluator model. For data types like tabular, multi-layer perceptron
is already efficient at extracting the relevant information.
For high-dimensional data types like images or text,
it is important to introduce inductive biases to the architecture to
extract information efficiently. In such cases, there are two options:
(i) Input the encoded representations (e.g. the last layer activations of
ResNet for images, or the last layer activations of BERT for text) and use
the multi-layer perceptron on top of it. The encoded representations can
simply come from a pre-trained predictor model using the entire dataset.
(ii) Modify the data value evaluator model definition below to have the
appropriate inductive bias (e.g. using convolutions layers for images,
or attention layers text).
Forward pass of inputs through value estimator for data values of input.
Forward pass through Value Estimator. Returns selection probabilities.
Concats the difference between labels and predicted labels to compute
selection probabilities.
Search rate threshold, the VE may get stuck in certain bounds close to
\([0, 1]\), thus outside of \([1-threshold, threshold]\) we encourage
searching, by default 0.9
exploration_weightfloat, optional
Large constant to encourage exploration in the Value Estimator, by default 1e3
Uses REINFORCE Algorithm to compute a loss for the Value Estimator.
pred_dataval is the data values. selector_input is a bernoulli random
variable with p=pred_dataval. Computes a BCE between pred_dataval and
selector_input and multiplies by the reward signal. Adds an additional loss
if the Value Estimator is getting stuck outside the threshold.