{-|
Module: MachineLearning.NeuralNetwork
Description: Neural Network
Copyright: (c) Alexander Ignatyev, 2016
License: BSD-3
Stability: experimental
Portability: POSIX

Simple Neural Networks.
-}

module MachineLearning.NeuralNetwork
(
    Model(..)
  , NeuralNetworkModel(..)
  , MLC.calcAccuracy
  , T.Topology
  , T.initializeTheta
  , T.initializeThetaIO
  , T.initializeThetaM
  , Regularization(..)
)

where

import qualified Numeric.LinearAlgebra as LA
import MachineLearning.Types (R, Vector, Matrix)
import MachineLearning.Utils (reduceByRowsV)
import qualified MachineLearning.Classification.Internal as MLC
import MachineLearning.Model (Model(..))
import qualified MachineLearning.NeuralNetwork.Topology as T
import MachineLearning.Regularization (Regularization(..))


-- | Neural Network Model.
-- Takes neural network topology as a constructor argument.
newtype NeuralNetworkModel = NeuralNetwork T.Topology


instance Model NeuralNetworkModel where
  hypothesis (NeuralNetwork topology) x theta = predictions
    where thetaList = T.unflatten topology theta
          scores = calcScores topology x thetaList
          predictions = reduceByRowsV (fromIntegral . LA.maxIndex) scores

  cost (NeuralNetwork topology) lambda x y theta =
    let (ys, thetaList) = processParams topology y theta
        scores = calcScores topology x thetaList
    in T.loss topology lambda scores thetaList ys

  gradient (NeuralNetwork topology) lambda x y theta =
    let (ys, thetaList) = processParams topology y theta
        (scores, cacheList) = T.propagateForward topology x thetaList
        grad = T.flatten $ T.propagateBackward topology lambda scores cacheList ys
    in grad


-- | Score function. Takes a topology, X and theta list.
calcScores :: T.Topology -> Matrix -> [(Matrix, Matrix)] -> Matrix
calcScores topology x thetaList = fst $ T.propagateForward topology x thetaList


processParams :: T.Topology -> Vector -> Vector -> (Matrix, [(Matrix, Matrix)])
processParams topology y theta =
  let nOutputs = T.numberOutputs topology
      ys = LA.fromColumns $ MLC.processOutputOneVsAll nOutputs y
      thetaList = T.unflatten topology theta
  in (ys, thetaList)