RLearn_Model Class
This class provides the splitting test and train data for the SAR model, attacker simple observation, train the RLearn model and visualize the Q values.
Initialization
- __init__(state_def, config, seed, num_process, input_path, output_path)
Initializes the RLearn model with the specified configuration.
- Parameters:
state_def – (str) State definition to use for the model.
config – (str) Path to a JSON configuration file.
seed – (int, optional) Random seed for reproducibility. Defaults to 42.
num_process – (int, optional) Number of processes to use for parallelization. Defaults to 4.
input_path – (str, optional) Path to the input data. Defaults to None.
output_path – (str, optional) Path to save the output data. Defaults to None.
Methods
- split_train_test()
Splits the input data into training and testing sets.
- preprocess_observations(batch_size)
Preprocesses the input data to generate the observations for the attacker.
- Parameters:
batch_size – (int) Batch size for processing the data.
- train_model(exp_name, run_name, accelerator, devices, strategy)
Trains the RLearn model using the preprocessed data.
- Parameters:
exp_name – (str) Name of the experiment. Defaults to ‘sarsa_attacker’.
run_name – (str) Name to save in tensorflow, mlflow, etc.
accelerator – (str) Accelerator to use for training. Defaults to ‘auto’.
devices – (int) Number of devices to use for training. Defaults to 1.
strategy – (str) Controls the model distribution across training, evaluation, and prediction to be used by the Trainer. Defaults to ‘auto’.
- visualize_q_values(model_name, checkpoint_path, match_id, sequence_id)
Visualizes the Q-values for the trained RLearn model.
- Parameters:
model_name – (str) Name of the model used for training.
checkpoint_path – (str) Path to the saved model checkpoint.
match_id – (str) Match ID you want to visualize.
sequence_id – (str) Sequence ID you want to visualize.
- run_rlearn(run_split_train_test, run_preprocess_observation, run_train_and_test, run_visualize_data, batch_size, exp_name, run_name, accelerator, devices, strategy)
Runs the RLearn data splitting, observation processing, model training and evaluation pipeline.
- Parameters:
run_split_train_test – (bool) Whether to run the train/validation/test split.
run_preprocess_observation – (bool) Whether to run the observation preprocessing.
run_train_and_test – (bool) Whether to run the model training and inference.
run_visualize_data – (bool) Whether to run the data visualization.
batch_size – (int) Batch size for processing the data.
exp_name – (str) Name of the experiment.
run_name – (str) Name of the run.
accelerator – (str) Accelerator to use for training.
devices – (int) Number of devices to use for training.
strategy – (str) Strategy to use for training.