In this example, we are going to use the MNIST data set. If you have cloned the MXNet repo and cd contrib/clojure-package, we can run some helper scripts to download the data for us.
To construct a module, we need to have a symbol as input. This symbol takes input data in the first layer and then has subsequent layers of fully connected and relu activation layers, ending up in a softmax layer for output.
By default, context is the CPU. If you need data parallelization, you can specify a GPU context or an array of GPU contexts like this (m/module out {:contexts [(context/gpu)]})
Before you can compute with a module, you need to call bind to allocate the device memory and init-params or set-params to initialize the parameters. If you simply want to fit a module, you don’t need to call bind and init-params explicitly, because the fit function automatically calls them if they are needed.
You can pass in batch-end callbacks using batch-end-callback and epoch-end callbacks using epoch-end-callback in the fit-params. You can also set parameters using functions like in the fit-params like optimizer and eval-metric. To learn more about the fit-params, see the fit-param function options. To predict with a module, call predict with a DataIter:
The module collects and returns all of the prediction results. For more details about the format of the return values, see the documentation for the predict function.
When prediction results might be too large to fit in memory, use the predict-every-batch API.
1234567
(let [preds(m/predict-every-batchmod{:eval-datatest-data})](mx-io/reduce-batchestest-data(fn [ibatch](println (str "pred is "(first (get predsi))))(println (str "label is "(mx-io/batch-labelbatch)));;; do something(inc i))))
If you need to evaluate on a test set and don’t need the prediction output, call the score function with a data iterator and an eval metric:
This runs predictions on each batch in the provided data iterator and computes the evaluation score using the provided eval metric. The evaluation results are stored in metric so that you can query later.
Saving and Loading
To save the module parameters in each training epoch, use a checkpoint function:
12345678910111213
(let [save-prefix"my-model"](doseq [epoch-num(range 3)](mx-io/do-batchestrain-data(fn [batch;; do something]))(m/save-checkpointmod{:prefixsave-prefix:epochepoch-num:save-opt-statestrue})));; INFO org.apache.mxnet.module.Module: Saved checkpoint to my-model-0000.params;; INFO org.apache.mxnet.module.Module: Saved optimizer state to my-model-0000.states;; INFO org.apache.mxnet.module.Module: Saved checkpoint to my-model-0001.params;; INFO org.apache.mxnet.module.Module: Saved optimizer state to my-model-0001.states;; INFO org.apache.mxnet.module.Module: Saved checkpoint to my-model-0002.params;; INFO org.apache.mxnet.module.Module: Saved optimizer state to my-model-0002.states
To load the saved module parameters, call the load-checkpoint function:
To initialize parameters, Bind the symbols to construct executors first with bind function. Then, initialize the parameters and auxiliary states by calling init-params function.
To resume training from a saved checkpoint, instead of calling set-params, directly call fit, passing the loaded parameters, so that fit knows to start from those parameters instead of initializing randomly
Create fit-params, and then use it to set begin-epoch so that fit knows to resume from a saved epoch.
123456
;; reset the training data before calling fit or you will get an error(mx-io/resettrain-data)(mx-io/resettest-data)(m/fitnew-mod{:train-datatrain-data:eval-datatest-data:num-epoch2:fit-params(-> (m/fit-params{:begin-epoch1}))})
If you are interested in checking out MXNet and exploring on your own, check out the main page here with instructions on how to install and other information.