(module websockets
  (
   ; parameters
   ping-interval close-timeout
   connection-timeout accept-connection
   drop-incoming-pings propagate-common-errors
   max-frame-size max-message-size

   ; high level API
   with-websocket with-concurrent-websocket
   send-message receive-message

   ; low level API
   ;; send-frame read-frame read-frame-payload
   ;; receive-fragments valid-utf8?
   ;; control-frame? upgrade-to-websocket
   ;; current-websocket unmask close-websocket
   ;; process-fragments

   ;; ; fragment
   ;; make-fragment fragment? fragment-payload fragment-length
   ;; fragment-masked? fragment-masking-key fragment-last?
   ;; fragment-optype
   )

(import chicken scheme data-structures extras ports posix foreign
        srfi-13 srfi-14 srfi-18)
(use srfi-1 srfi-4 spiffy intarweb uri-common base64 simple-sha1
     mailbox comparse)

(define-inline (neq? obj1 obj2) (not (eq? obj1 obj2)))

(define current-websocket (make-parameter #f))
(define ping-interval (make-parameter 15))
(define close-timeout (make-parameter 5))
(define connection-timeout (make-parameter 58)) ; a little grace period from 60s
(define accept-connection (make-parameter (lambda (origin) #t)))
(define drop-incoming-pings (make-parameter #t))
(define propagate-common-errors (make-parameter #f))
(define access-denied ; TODO test
  (make-parameter (lambda () (send-status 'forbidden "<h1>Access denied</h1>"))))

(define max-frame-size (make-parameter 1048576)) ; 1MiB
(define max-message-size
  (make-parameter 1048576 ; 1MiB
                  (lambda (v)
                    (if (> v 1073741823) ; max int size for unmask/utf8 check
                        (signal (make-property-condition 'out-of-range))
                        v))))

(define (make-websocket-exception . conditions)
  (apply make-composite-condition (append `(,(make-property-condition 'websocket))
                                          conditions)))

(define (make-protocol-violation-exception msg)
  (make-composite-condition (make-property-condition 'websocket)
                            (make-property-condition 'protocol-error 'msg msg)))

(define (opcode->optype op)
  (case op
    ((0) 'continuation)
    ((1) 'text)
    ((2) 'binary)
    ((8) 'connection-close)
    ((9) 'ping)
    ((10) 'pong)
    (else (signal (make-protocol-violation-exception "bad opcode")))))

(define (optype->opcode t)
  (case t
    ('continuation 0)
    ('text 1)
    ('binary 2)
    ('connection-close 8)
    ('ping 9)
    ('pong 10)
    (else (signal (make-websocket-exception
                   (make-property-condition 'invalid-optype))))))

(define (control-frame? optype)
  (or (eq? optype 'ping) (eq? optype 'pong) (eq? optype 'connection-close)))

(define-record-type websocket
  (make-websocket inbound-port outbound-port user-thread
                  send-mutex read-mutex last-message-timestamp
                  state send-mailbox read-mailbox concurrent)
  websocket?
  (inbound-port websocket-inbound-port)
  (outbound-port websocket-outbound-port)
  (user-thread websocket-user-thread)
  (send-mutex websocket-send-mutex)
  (read-mutex websocket-read-mutex)
  (last-message-timestamp websocket-last-message-timestamp
                          set-websocket-last-message-timestamp!)
  (state websocket-state set-websocket-state!)
  (send-mailbox websocket-send-mailbox)
  (read-mailbox websocket-read-mailbox)
  (concurrent websocket-concurrent?))

(define-record-type websocket-fragment
  (make-fragment payload length masked masking-key
                           fin optype)
  fragment?
  (payload fragment-payload)
  (length fragment-length)
  (masked fragment-masked? set-fragment-masked!)
  (masking-key fragment-masking-key)
  (fin fragment-last?)
  (optype fragment-optype))

(define (hex-string->string hexstr)
  ;; convert a string like "a745ff12" to a string
  (let ((result (make-string (/ (string-length hexstr) 2))))
    (let loop ((hexs (string->list hexstr))
               (i 0))
      (if (< (length hexs) 2)
          result
          (let ((ascii (string->number (string (car hexs) (cadr hexs)) 16)))
            (string-set! result i (integer->char ascii))
            (loop (cddr hexs)
                  (+ i 1)))))))


(define (send-frame ws optype data last-frame)
  ; TODO this sucks
  (when (u8vector? data) (set! data (blob->string (u8vector->blob/shared data))))
  (let* ((len (if (string? data) (string-length data) (u8vector-length data)))
         (frame-fin (if last-frame 1 0))
         (frame-rsv1 0)
         (frame-rsv2 0)
         (frame-rsv3 0)
         (frame-opcode (optype->opcode optype))
         (octet0 (bitwise-ior (arithmetic-shift frame-fin 7)
                              (arithmetic-shift frame-rsv1 6)
                              (arithmetic-shift frame-rsv2 5)
                              (arithmetic-shift frame-rsv3 4)
                              frame-opcode))

         (frame-masked 0)
         (frame-payload-length (cond ((< len 126) len)
                                     ((< len 65536) 126)
                                     (else 127)))
         (octet1 (bitwise-ior (arithmetic-shift frame-masked 7)
                              frame-payload-length))
         (outbound-port (websocket-outbound-port ws)))

    (write-u8vector (u8vector octet0 octet1) outbound-port)

    (write-u8vector
     (cond
      ((= frame-payload-length 126)
       (u8vector
        (arithmetic-shift (bitwise-and len 65280) -8)
        (bitwise-and len 255)))
      ((= frame-payload-length 127)
       (u8vector
        0 0 0 0
        (arithmetic-shift
         (bitwise-and len 4278190080) -24)
        (arithmetic-shift
         (bitwise-and len 16711680) -16)
        (arithmetic-shift
         (bitwise-and len 65280) -8)
        (bitwise-and len 255)))
      (else (u8vector)))
     outbound-port)

    (write-string data len outbound-port)
    #t))

(define (send-message data #!optional (optype 'text) (ws (current-websocket)))
  ;; TODO break up large data into multiple frames?
  (optype->opcode optype) ; triggers error if invalid
  (dynamic-wind
      (lambda () (mutex-lock! (websocket-send-mutex ws)))
      (lambda () (send-frame ws optype data #t))
      (lambda () (mutex-unlock! (websocket-send-mutex ws)))))

(define (websocket-unmask-frame-payload payload len frame-masking-key)
  (define tmaskkey (make-u8vector 4 #f #t #t))
  (u8vector-set! tmaskkey 0 (vector-ref frame-masking-key 0))
  (u8vector-set! tmaskkey 1 (vector-ref frame-masking-key 1))
  (u8vector-set! tmaskkey 2 (vector-ref frame-masking-key 2))
  (u8vector-set! tmaskkey 3 (vector-ref frame-masking-key 3))
  (define-external wsmaskkey blob (u8vector->blob/shared tmaskkey))

  (define-external wslen int len)

  ; TODO handle -1

  (define-external wsv scheme-pointer payload)
  ((foreign-lambda* void ()
"
    if (wslen > UINT_MAX) { return -1; }

    const unsigned char* maskkey2 = wsmaskkey;
    const unsigned int kd = *(unsigned int*)maskkey2;
    const unsigned char* __restrict kb = maskkey2;


    for (int i = wslen >> 2; i != 0; --i)
    {
        *((unsigned int*)wsv) ^= kd;
        wsv += 4;
    }

    const int rem = wslen & 3;
    for (int i = 0; i < rem; ++i)
    {
        *((unsigned int*)wsv++) ^= kb[i];
    }
"
))
  payload)

(define (unmask fragment)
  (if (fragment-masked? fragment)
      (let ((r (websocket-unmask-frame-payload
                (fragment-payload fragment)
                (fragment-length fragment)
                (fragment-masking-key fragment))))
             (set-fragment-masked! fragment #f)
             r)
      (fragment-payload fragment)))

(define (read-frame-payload inbound-port frame-payload-length)
  (let ((masked-data (make-string frame-payload-length)))
    (read-string! frame-payload-length masked-data inbound-port)
    masked-data)
;;   (let* ((masked-data (make-string frame-payload-length)))
;;     (read-string! frame-payload-length masked-data inbound-port)

;;     (define tmaskkey (make-u8vector 4 #f #t #t))
;;     (u8vector-set! tmaskkey 0 (vector-ref frame-masking-key 0))
;;     (u8vector-set! tmaskkey 1 (vector-ref frame-masking-key 1))
;;     (u8vector-set! tmaskkey 2 (vector-ref frame-masking-key 2))
;;     (u8vector-set! tmaskkey 3 (vector-ref frame-masking-key 3))
;;     (define-external wsmaskkey blob (u8vector->blob/shared tmaskkey))

;;     (define-external wslen int frame-payload-length)

;;     (define-external wsv scheme-pointer masked-data)

;;     (if frame-masked
;;         (begin
;;           ((foreign-lambda* void ()
;; "
;;     const unsigned char* maskkey2 = wsmaskkey;
;;     const unsigned int kd = *(unsigned int*)maskkey2;
;;     const unsigned char* __restrict kb = maskkey2;


;;     for (int i = wslen >> 2; i != 0; --i)
;;     {
;;         *((unsigned int*)wsv) ^= kd;
;;         wsv += 4;
;;     }

;;     const int rem = wslen & 3;
;;     for (int i = 0; i < rem; ++i)
;;     {
;;         *((unsigned int*)wsv++) ^= kb[i];
;;     }
;; "
;;           ))
;;           masked-data)
;;     masked-data))
  )

(define (read-frame total-size ws)
  (let* ((inbound-port (websocket-inbound-port ws))
         (b0 (read-byte inbound-port)))
    ; we don't support reserved bits yet
    (when (or (> (bitwise-and b0 64) 0)
              (> (bitwise-and b0 32) 0)
              (> (bitwise-and b0 16) 0))
          (signal (make-websocket-exception
                   (make-property-condition 'reserved-bits-not-supported)
                   (make-property-condition 'protocol-error))))
    (cond
     ((eof-object? b0) b0)
     (else
      (let* ((frame-fin (> (bitwise-and b0 128) 0))
             (frame-opcode (bitwise-and b0 15))
             (frame-optype (opcode->optype frame-opcode))
             ;; second byte
             (b1 (read-byte inbound-port))
             ; TODO die on unmasked frame?
             (frame-masked (> (bitwise-and b1 128) 0))
             (frame-payload-length (bitwise-and b1 127)))
        (cond ((= frame-payload-length 126)
               (let ((bl0 (read-byte inbound-port))
                     (bl1 (read-byte inbound-port)))
                 (set! frame-payload-length (+ (arithmetic-shift bl0 8) bl1))))
              ((= frame-payload-length 127)
               (define (shift i r)
                 (if (< i 0)
                     r
                     (shift (- i 1) (+ (arithmetic-shift (read-byte inbound-port) (* 8 i))
                                       r))))
               (set! frame-payload-length (shift 7 0))))
        (when (or (> frame-payload-length (max-frame-size))
                  (> (+ frame-payload-length total-size) (max-message-size)))
              (signal (make-websocket-exception
                       (make-property-condition 'message-too-large))))
        (let* ((frame-masking-key
                (if frame-masked
                    (let* ((fm0 (read-byte inbound-port))
                           (fm1 (read-byte inbound-port))
                           (fm2 (read-byte inbound-port))
                           (fm3 (read-byte inbound-port)))
                      (vector fm0 fm1 fm2 fm3))
                    #f)))
          (cond
           ((or (eq? frame-optype 'text) (eq? frame-optype 'binary)
                (eq? frame-optype 'continuation) (eq? frame-optype 'ping)
                (eq? frame-optype 'pong))
            (make-fragment
             (read-frame-payload inbound-port frame-payload-length)
             frame-payload-length frame-masked
             frame-masking-key frame-fin frame-optype))
           ((eq? frame-optype 'connection-close) ; TODO, same as above?
            (make-fragment
             (read-frame-payload inbound-port frame-payload-length)
             frame-payload-length frame-masked frame-masking-key
             frame-fin frame-optype))
           (else
            (signal (make-websocket-exception
                     (make-property-condition 'unhandled-optype
                                              'optype frame-optype)))))))))))

(include "utf8-grammar.scm")

(define (valid-utf8? s)
  (or (let ((len (string-length s)))
         ; Try to validate as an ascii string first. Its essentially
         ; free, doesn't generate garbage and is many, many times
         ; faster than the general purpose validator.
         (define-external ws_utlen int len)
         (define-external ws_uts scheme-pointer s)
         (= 1
            ((foreign-lambda* int ()
"
    if (ws_utlen > UINT_MAX) { return -1; }

    for (int i = ws_utlen; i != 0; --i)
    {
        if (*((unsigned char*)ws_uts++) > 127)
        {
            C_return(0);
        }
    }

    C_return(1);
"))))
      (parse utf8-string (->parser-input s))))

(define (close-code->integer s)
  (if (string-null? s)
      1000
      (+ (arithmetic-shift (char->integer (string-ref s 0)) 8)
         (char->integer (string-ref s 1)))))

(define (close-code-string->close-reason s)
  (let ((c (close-code->integer s)))
    (case c
      ((1000) 'normal)
      ((1001) 'going-away)
      ((1002) 'protocol-error)
      ((1003) 'unknown-data-type)
      ((1007) 'invalid-data)
      ((1008) 'violated-policy)
      ((1009) 'message-too-large)
      ((1010) 'extension-negotiation-failed)
      ((1011) 'unexpected-error)
      (else
       (if (and (>= c 3000) (< c 5000))
           'unknown
           'invalid-close-code)))))

(define (valid-close-code? s)
  (neq? 'invalid-close-code (close-code-string->close-reason s)))

(define (receive-fragments #!optional (ws (current-websocket)))
  (dynamic-wind
      (lambda () (mutex-lock! (websocket-read-mutex ws)))
      (lambda ()
        (if (or (eq? (websocket-state ws) 'closing)
                (eq? (websocket-state ws) 'closed)
                (eq? (websocket-state ws) 'error))
            (values #!eof #!eof)
            (let loop ((fragments '())
                       (first #t)
                       (type 'text)
                       (total-size 0))
              (let* ((fragment (read-frame total-size ws))
                     (optype (fragment-optype fragment))
                     (len (fragment-length fragment))
                     (last-frame (fragment-last? fragment)))
                (set-websocket-last-message-timestamp! ws (current-time))
                (cond
                 ((and (control-frame? optype) (> len 125))
                  (set-websocket-state! ws 'error)
                  (signal (make-protocol-violation-exception
                           "control frame bodies must be less than 126 octets")))

                 ; connection close
                 ((and (eq? optype 'connection-close) (= len 1))
                  (set-websocket-state! ws 'error)
                  (signal (make-protocol-violation-exception
                           "close frames must not have a length of 1")))
                 ((and (eq? optype 'connection-close)
                       (not (valid-close-code? (unmask fragment))))
                  (set-websocket-state! ws 'error)
                  (signal (make-protocol-violation-exception
                           (string-append
                            "invalid close code "
                            (number->string (close-code->integer (unmask fragment)))))))
                 ((eq? optype 'connection-close)
                  (set-websocket-state! ws 'closing)
                  (values `(,fragment) optype))

                 ; immediate response
                 ((and (eq? optype 'ping) last-frame (<= len 125))
                  (unless (drop-incoming-pings)
                          (send-message (unmask fragment) 'pong))
                  (loop fragments first type total-size))

                 ; protocol violation checks
                 ((or (and first (eq? optype 'continuation))
                      (and (not first) (neq? optype 'continuation)))
                  (set-websocket-state! ws 'error)
                  (signal (make-protocol-violation-exception
                           "continuation frame out-of-order")))
                 ((and (not last-frame) (control-frame? optype))
                  (set-websocket-state! ws 'error)
                  (signal (make-protocol-violation-exception
                           "control frames can't be fragmented")))

                 ((eq? optype 'pong)
                  (loop fragments first type total-size))

                 (else
                  (if last-frame
                      (values (cons fragment fragments) (if (null? fragments) optype type))
                      (loop (cons fragment fragments) #f
                            (if first optype type)
                            (+ total-size len)))))))))
      (lambda () (mutex-unlock! (websocket-read-mutex ws)))))

(define (process-fragments fragments optype #!optional (ws (current-websocket)))
  (let ((message-body (string-concatenate/shared
                       (reverse (map unmask fragments)))))
    (when (and (or (eq? optype 'text) (eq? optype 'connection-close))
               (not (valid-utf8?
                     (if (eq? optype 'text)
                         message-body
                         (if (> (string-length message-body) 2)
                             (substring message-body 2)
                             "")))))
          (set-websocket-state! ws 'error)
          (signal (make-websocket-exception
                   (make-property-condition
                    'invalid-data 'msg "invalid UTF-8"))))
    (values message-body optype)))

(define (receive-message #!optional (ws (current-websocket)))
  (if (websocket-concurrent? ws)
      (let ((msg (mailbox-receive! (websocket-read-mailbox ws))))
        (values (car msg) (cdr msg)))
      (receive (fragments optype) (receive-fragments ws)
               (if (eof-object? fragments)
                   (values #!eof optype)
                   (process-fragments fragments optype)))))

; TODO does #!optional and #!key work together?
(define (close-websocket #!optional (ws (current-websocket))
                         #!key (close-reason 'normal) (data (make-u8vector 0)))
  (define invalid-close-reason #f)
  (define (close-reason->close-code reason)
    (case reason
      ('normal 1000)
      ('going-away 1001)
      ('protocol-error 1002)
      ('unknown-data-type 1003)
      ('invalid-data 1007)
      ('violated-policy 1008)
      ('message-too-large 1009)
      ('unexpected-error 1011)
      (else (set! invalid-close-reason reason)
            (close-reason->close-code 'unexpected-error))))

  ; Use thread timeout to handle the close-timeout
  (let ((close-thread
         (make-thread
          (lambda ()
            (if (eq? (websocket-state ws) 'open)
                (begin
                  (set-websocket-state! ws 'closed)
                  (send-frame ws 'connection-close
                                        (u8vector 3 (close-reason->close-code close-reason))
                                        #t)
                  (let loop ()
                    (receive (data type) (receive-message ws)
                             (if (eq? type 'connection-close)
                                 (unless (valid-utf8? data)
                                         (set! close-reason 'invalid-data))
                                 (loop)))))
                (begin
                  (send-frame ws 'connection-close
                                        (u8vector 3 (close-reason->close-code close-reason))
                                        #t))))
          "close timeout thread")))
    (thread-start! close-thread)
    (if (> (close-timeout) 0)
        (unless (thread-join! close-thread (close-timeout) #f)
                ; TODO actually signal error?
                ;; (thread-signal! (websocket-user-thread (current-websocket))
                ;;                 (make-websocket-exception
                ;;                  (make-property-condition 'close-timeout)))
                )
        (thread-join! close-thread))))


(define (sha1-sum in-bv)
  (hex-string->string (string->sha1sum in-bv)))

(define (websocket-compute-handshake client-key)
  (let* ((key-and-magic
          (string-append client-key "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
         (key-and-magic-sha1 (sha1-sum key-and-magic)))
    (base64-encode key-and-magic-sha1)))

(define (sec-websocket-accept-unparser header-contents)
  (map (lambda (header-content)
         (car (vector-ref header-content 0)))
       header-contents))

(header-unparsers
 (alist-update! 'sec-websocket-accept
                sec-websocket-accept-unparser
                (header-unparsers)))

(define (websocket-accept #!optional (concurrent #f))
  (let* ((user-thread (current-thread))
         (headers (request-headers (current-request)))
         (client-key (header-value 'sec-websocket-key headers))
         (ws-handshake (websocket-compute-handshake client-key))
         (ws (make-websocket
              (request-port (current-request))
              (response-port (current-response))
              user-thread
              (make-mutex "send")
              (make-mutex "read")
              (current-time)
              'open               ; websocket state
              (make-mailbox "send")
              (make-mailbox "read")
              concurrent))
         (ping-thread
          (make-thread
           (lambda ()
             (let loop ()
               (thread-sleep! (ping-interval))
               (when (eq? (websocket-state ws) 'open)
                     (send-message "" 'ping ws)
                     (loop))))
           "ping thread")))

    ; make sure the request meets the spec for websockets
    (cond ((not (and (eq? (header-value 'connection headers #f) 'upgrade)
                     (string-ci= (car (header-value 'upgrade headers '(""))) "websocket")))
           (signal (make-websocket-exception
                    (make-property-condition 'missing-upgrade-header))))
          ((not (string= (header-value 'sec-websocket-version headers "") "13"))
           (with-headers ; TODO test
            `((sec-websocket-version "13"))
            (lambda () (send-status 'upgrade-required))))
          ((not ((accept-connection) (header-value 'origin headers "")))
           ((access-denied))))

    (with-headers
     `((upgrade ("WebSocket" . #f))
       (connection (upgrade . #t))
       (sec-websocket-accept (,ws-handshake . #t)))
     (lambda ()
       (send-response status: 'switching-protocols)))
    (flush-output (response-port (current-response)))

    ; connection timeout thread
    (when (> (connection-timeout) 0)
          (thread-start!
           (lambda ()
             (let loop ()
               (let ((t (websocket-last-message-timestamp ws)))
                  ; Add one to attempt to alleviate checking the timestamp
                  ; right before when the timeout should happen.
                 (thread-sleep! (+ 1 (connection-timeout)))
                 (when (eq? (websocket-state ws) 'open)
                       (if (< (- (time->seconds (current-time))
                                 (time->seconds (websocket-last-message-timestamp ws)))
                              (connection-timeout))
                           (loop)
                           (begin (thread-signal!
                                   (websocket-user-thread ws)
                                   (make-websocket-exception
                                    (make-property-condition 'connection-timeout)))
                                  (close-websocket ws close-reason: 'going-away)))))))))

    (when (> (ping-interval) 0)
          (thread-start! ping-thread))

    ws))

(define (with-websocket proc #!optional (concurrent #f))
  (define (handle-error close-reason exn)
    (set-websocket-state! (current-websocket) 'closing)
    (close-websocket (current-websocket) close-reason: close-reason)
    (unless (port-closed? (request-port (current-request)))
            (close-input-port (request-port (current-request))))
    (unless (port-closed? (response-port (current-response)))
            (close-output-port (response-port (current-response))))
    (when (propagate-common-errors)
          (signal exn)))
  (parameterize
   ((current-websocket (websocket-accept concurrent)))
   (condition-case
    (begin (proc)
           (close-websocket)
           (close-input-port (request-port (current-request)))
           (close-output-port (response-port (current-response))))
    (exn (websocket protocol-error) (handle-error 'protocol-error exn))
    (exn (websocket invalid-data) (handle-error 'invalid-data exn))
    (exn (websocket connection-timeout) (handle-error 'going-away exn))
    (exn (websocket message-too-large) (handle-error 'message-too-large exn))
    (exn () (handle-error 'unexpected-error exn)))))

(define (with-concurrent-websocket proc)
  (let ((parent-thread (current-thread)))
    (with-websocket
     (lambda ()
       (thread-start!
        (lambda ()
          (handle-exceptions
           exn
           (thread-signal! parent-thread exn)
           (let loop ()
             (receive (fragments optype) (receive-fragments)
                      (unless (eof-object? fragments)
                              (thread-start!
                               (lambda ()
                                 (handle-exceptions
                                  exn
                                  (thread-signal! parent-thread exn)
                                  (mailbox-send!
                                   (websocket-read-mailbox (current-websocket))
                                   (receive (msg-body optype)
                                            (process-fragments fragments optype)
                                            `(,msg-body . ,optype))))))
                              (loop)))))))
       (proc))
     #t)))

(define (upgrade-to-websocket #!optional (concurrent #f))
  (websocket-accept concurrent))

)