How to run the demo
You can try collective learning for yourself using the simple demo in run_demo. This demo creates n learners for one of six learning tasks and co-ordinates the collective learning between them.
There are six potential models for the demo
- KERAS_MNIST is the Tensorflow implementation of a small model for the standard handwritten digits recognition dataset
- KERAS_MNIST_RESNET is the Tensorflow implementation of a Resnet model for the standard handwritten digits recognition dataset
- KERAS_CIFAR10 is the Tensorflow implementation of the classical image recognition dataset
- PYTORCH_XRAY is Pytorch implementation of a binary classification task that requires predicting pneumonia from images of chest X-rays. The data need to be downloaded from Kaggle
- PYTORCH_COVID_XRAY is Pytorch implementation of a 3 class classification task that requires predicting no finding, covid or pneumonia from images of chest X-rays. This dataset is not currently publicly available.
- FRAUD The fraud dataset consists of information about credit card transactions, and the task is to predict whether transactions are fraudulent or not. The data need to be downloaded from Kaggle
Use the -h flag to see the options:
Arguments to run the demo:
--data_dir: Directory containing training data, not required for MNIST and CIFAR10
--test_dir: Optional directory containing test data. A fraction of the training set will be used as a test set when not specified
--model: Model to train, options are KERAS_MNIST KERAS_MNIST_RESNET KERAS_CIFAR10 PYTORCH_XRAY PYTORCH_COVID_XRAY FRAUD
--n_learners: Number of individual learners
--n_rounds: Number of training rounds
--vote_threshold: Minimum fraction of positive votes to accept the new model
--train_ratio: Fraction of training dataset to be used as test-set when no test-set is specified
--seed: Seed for initialising model and shuffling datasets
--learning_rate: Learning rate for optimiser
--batch_size: Size of training batch
Running MNIST
The simplest task to run is MNIST because the data are downloaded automatically from tensorflow_datasets
.
The command below runs the MNIST task with five learners for 15 rounds.
As you can see, there are five learners, and initially they perform poorly. In round one, learner 0 is selected to propose a new set of weights.
Other datasets
To run the CIFAR10 dataset:
The Fraud and X-ray datasets need to be downloaded from kaggle (this requires a kaggle account). To run the fraud dataset: To run the X-ray dataset: