|
- package main
-
- import (
- "log"
- "net/http"
- "time"
-
- "github.com/gorilla/websocket"
- )
-
- func registerWSHandlers(mux *http.ServeMux, h *hub) {
- upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
- origin := r.Header.Get("Origin")
- if origin == "" || origin == "null" {
- return true
- }
- return true
- }}
-
- mux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
- conn, err := upgrader.Upgrade(w, r, nil)
- if err != nil {
- log.Printf("ws upgrade failed: %v (origin: %s)", err, r.Header.Get("Origin"))
- return
- }
- c := &client{conn: conn, send: make(chan []byte, 32), done: make(chan struct{})}
- h.add(c)
- defer func() {
- h.remove(c)
- _ = conn.Close()
- }()
- conn.SetReadDeadline(time.Now().Add(60 * time.Second))
- conn.SetPongHandler(func(string) error {
- conn.SetReadDeadline(time.Now().Add(60 * time.Second))
- return nil
- })
- go func() {
- ping := time.NewTicker(30 * time.Second)
- defer ping.Stop()
- for {
- select {
- case msg, ok := <-c.send:
- if !ok {
- return
- }
- _ = conn.SetWriteDeadline(time.Now().Add(200 * time.Millisecond))
- if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil {
- return
- }
- case <-ping.C:
- _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
- if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
- log.Printf("ws ping error: %v", err)
- return
- }
- }
- }
- }()
- for {
- _, _, err := conn.ReadMessage()
- if err != nil {
- return
- }
- }
- })
- }
|