simpletensor.mnist_train#
- class simpletensor.mnist_train.Model(**kwargs)[source]#
Bases:
object
Creates CNN model
- fit(train_data, test_data, epochs, batch_size)[source]#
Fits model to training data
- Parameters:
- train_data(ndarray, ndarray)
Training images and labels
- test_data(ndarray, ndarray)
Testing images and labels
- epochsint
Number of epochs to train
- batch_sizeint
Minibatch size
- Returns:
- dict[str, list]
History of training and testing losses and accuracies over the epochs
- class simpletensor.mnist_train.TqdmProgBar(*_, **__)[source]#
Bases:
tqdm
Progress bar to visualize download progress of mnist.npz
- simpletensor.mnist_train.download_mnist(output_path)[source]#
Downloads MNIST dataset (mnist.npz) and stores it to file
- Parameters:
- output_pathfilepath
File path for mnist.npz to be downloaded to