ml.mlp

Functions for the multi-layer perceptron classifier.

Functions

build_dataloaders(data_split)

A function to build the DataLoaders from the data split.

create_layers(input_size, n_neurons)

evaluate_model(feature, mlp_cls, ...)

A function to evaluate the model on test data.

format_plots()

General plotting parameters for the Kulik Lab.

gradient_step(model, dataloader, optimizer, ...)

A function to train on the entire dataset for one epoch.

load_data(mimos, include_esp, data_loc)

Load data from CSV files for each mimo in the given list.

optuna_mlp(data_split_type, include_esp, ...)

plot_confusion_matrices(cms, mimos)

Plot confusion matrices for distance and charge features.

plot_data(df_charge, df_dist, mimos)

Plot the average charge and distance data for the given MIMO types.

plot_roc_curve(y_true, y_pred_proba, mimos, ...)

Plot the ROC curve for the test data of the charge and distance features.

plot_train_val_losses(train_loss_per_epoch, ...)

Plot the train and validation losses as a function of epoch number.

preprocess_data(df_charge, df_dist, mimos, ...)

Split train and test based on the given test and validation fractions.

run_mlp(data_split_type, include_esp, ...)

shap_analysis(mlp_cls, train_loader, ...)

Plot SHAP dot plots for each mimichrome to identify importance

train(feature, layers, lr, n_epochs, l2, ...)

A function to train and validate the model over all epochs.

train_with_hyperparameters(trial, feature, ...)

validate(model, dataloader, device)

A function to validate on the validation dataset for one epoch.

Classes

MDDataset(X, y)

MimoMLP(layers)