melisgl / mgl

Common Lisp machine learning library.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

1d convolution clump

jccerrillo opened this issue · comments

Here is an attempt to write a 1d convolution lump. It is based on ->V*M. How would I continue with backpropagation?

(defclass-now ->conv (lump)
  ((x :initarg :x :reader x)
   (weights
    :initform (->weight :dimensions '(1 1))    
    :type ->weight :initarg :weights :reader weights
    :documentation "A ->WEIGHT lump.")
   (transpose-weights-p
    :initform nil :initarg :transpose-weights-p
    :reader transpose-weights-p
    :documentation "Determines whether the input is multiplied by
    WEIGHTS or its transpose.")
   (kernel-size
    :initform nil
    :initarg :kernel-size
    :reader kernel-size)
   (num-filters
    :initform nil
    :initarg :num-filters
    :reader num-filters)
   (stride
    :initform nil
    :initarg :stride
    :reader stride)
   (out-size
    :initform 1
    :accessor out-size))
  (:documentation "Perform (CONV X WEIGHTS) where X (the input) is of
  size M and WEIGHTS is a ->WEIGHT whose single stripe is taken to
  be of dimensions NUM-FILTERS x KERNEL-SIZE"))

(defmaker (->conv :unkeyword-args (x weights)))

(defmethod initialize-instance :after ((lump ->conv) &key &allow-other-keys)
  (setf (slot-value (weights lump) 'dimensions) 
        (if (transpose-weights-p lump)
	    (list (kernel-size lump)(num-filters lump))
	    (list (num-filters lump) (kernel-size lump))))
  ;; force reshaping
  (setf (max-n-stripes (weights lump)) (max-n-stripes (weights lump)))
  (setf (slot-value lump 'out-size) (1+ (floor (- (size (x lump)) 1) (stride lump)))))

(defmethod default-size ((lump ->conv))
  ;;default size is equal to the output of conv operation
  (1+ (floor (- (size (x lump)) 1) (stride lump))))

(defmethod forward ((lump ->conv))
  (let* ((x1 (x lump))
	 (x (nodes (x lump)))
         (weights (nodes (weights lump)))
	 (y (nodes lump))
	 (stride (list 1 (stride lump)))
	 (start '(0 0))
	 (anchor '(0 0)))
    ;; FIXEXT:
    (assert (stripedp x1))
    (if (transpose-weights-p lump)
	;; a = x *conv w'
	(error "->conv not implemented for transposed weights")
	;; a = conv(x, w)
	(convolve! x weights y :start start :stride stride :anchor anchor))))

(defmethod set-input (instances (conv simple-conv))
  (let ((input-nodes (nodes (find-clump 'input conv))))
    (fill! 0 input-nodes)
    (loop for instance in instances
       for i upfrom 0
       do
	 (loop for j upfrom 0
	    for digit in instance
	    do (setf (mref input-nodes i j) digit)))))

(defclass simple-conv (fnn) ())

(defun make-conv-fnn (&key(input-size 5)(num-filters 1)(kernel-size 3)(stride 2))
  (build-fnn (:class 'simple-conv)
    (input (->input :size input-size :name 'input))
    (bias (->weight :size (1+ (floor (- input-size 1) stride)) :name 'bias))
    (w (->weight :name 'filter :size num-filters))
    (conv (->conv input w :kernel-size kernel-size :num-filters num-filters :stride stride))
    (conv-act (->+ (list conv bias) :name 'conv-act))
    (relu (->relu conv-act))))

One can then do the following to test:

(let* ((input-size 7)
       (num-filters 3)
       (tconv (make-conv-fnn :input-size input-size :num-filters num-filters :kernel-size 4 :stride 2)))
  (map-segments (lambda (weights)
		  (let* ((fan-in (mat-dimension (nodes weights) 0))
			 (limit (sqrt (/ 6 fan-in))))
		    (uniform-random! (nodes weights)
				     :limit (* 2 limit))
		    (.+! (- limit) (nodes weights))))
		tconv)
  (setf (max-n-stripes tconv) num-filters)
  (set-input (loop repeat num-filters 
		collect (alexandria:shuffle
			 (loop repeat input-size collect (random 10.0)))) 
	     tconv)
  (forward tconv)
   tconv)