RLearn_Model Class

This class provides the splitting test and train data for the SAR model, attacker simple observation, train the RLearn model and visualize the Q values.

Initialization

__init__(state_def, config, seed, num_process, input_path, output_path)

Initializes the RLearn model with the specified configuration.

Parameters:
  • state_def – (str) State definition to use for the model.

  • config – (str) Path to a JSON configuration file.

  • seed – (int, optional) Random seed for reproducibility. Defaults to 42.

  • num_process – (int, optional) Number of processes to use for parallelization. Defaults to 4.

  • input_path – (str, optional) Path to the input data. Defaults to None.

  • output_path – (str, optional) Path to save the output data. Defaults to None.

Methods

split_train_test()

Splits the input data into training and testing sets.

preprocess_observations(batch_size)

Preprocesses the input data to generate the observations for the attacker.

Parameters:

batch_size – (int) Batch size for processing the data.

train_model(exp_name, run_name, accelerator, devices, strategy)

Trains the RLearn model using the preprocessed data.

Parameters:
  • exp_name – (str) Name of the experiment. Defaults to ‘sarsa_attacker’.

  • run_name – (str) Name to save in tensorflow, mlflow, etc.

  • accelerator – (str) Accelerator to use for training. Defaults to ‘auto’.

  • devices – (int) Number of devices to use for training. Defaults to 1.

  • strategy – (str) Controls the model distribution across training, evaluation, and prediction to be used by the Trainer. Defaults to ‘auto’.

visualize_q_values(model_name, checkpoint_path, match_id, sequence_id)

Visualizes the Q-values for the trained RLearn model.

Parameters:
  • model_name – (str) Name of the model used for training.

  • checkpoint_path – (str) Path to the saved model checkpoint.

  • match_id – (str) Match ID you want to visualize.

  • sequence_id – (str) Sequence ID you want to visualize.

run_rlearn(run_split_train_test, run_preprocess_observation, run_train_and_test, run_visualize_data, batch_size, exp_name, run_name, accelerator, devices, strategy)

Runs the RLearn data splitting, observation processing, model training and evaluation pipeline.

Parameters:
  • run_split_train_test – (bool) Whether to run the train/validation/test split.

  • run_preprocess_observation – (bool) Whether to run the observation preprocessing.

  • run_train_and_test – (bool) Whether to run the model training and inference.

  • run_visualize_data – (bool) Whether to run the data visualization.

  • batch_size – (int) Batch size for processing the data.

  • exp_name – (str) Name of the experiment.

  • run_name – (str) Name of the run.

  • accelerator – (str) Accelerator to use for training.

  • devices – (int) Number of devices to use for training.

  • strategy – (str) Strategy to use for training.