simpletensor.mnist_train#

class simpletensor.mnist_train.Model(**kwargs)[source]#

Bases: object

Creates CNN model

eval()[source]#
fit(train_data, test_data, epochs, batch_size)[source]#

Fits model to training data

Parameters:
train_data(ndarray, ndarray)

Training images and labels

test_data(ndarray, ndarray)

Testing images and labels

epochsint

Number of epochs to train

batch_sizeint

Minibatch size

Returns:
dict[str, list]

History of training and testing losses and accuracies over the epochs

train()[source]#
class simpletensor.mnist_train.TqdmProgBar(*_, **__)[source]#

Bases: tqdm

Progress bar to visualize download progress of mnist.npz

update_to(b=1, bsize=1, tsize=None)[source]#
simpletensor.mnist_train.download_mnist(output_path)[source]#

Downloads MNIST dataset (mnist.npz) and stores it to file

Parameters:
output_pathfilepath

File path for mnist.npz to be downloaded to

simpletensor.mnist_train.load_data(location)[source]#

Loads mnist.npz

Parameters:
locationfilepath

Location of mnist.npz

Returns:
((x_train, y_train), (x_test, y_test))

Training and testing data in the form of packed tuples

simpletensor.mnist_train.main(*args)[source]#
simpletensor.mnist_train.parse_args()[source]#

Parse path and hyperparameter arguments.

Returns:
Tuple

Tuple of options