WebSockets in Go

WebSockets provide full-duplex communication channels over a single TCP connection, enabling real-time communication between clients and servers. Go’s standard library doesn’t include WebSocket support, but the gorilla/websocket package makes it easy to implement.

Basic WebSocket Server

First, install the gorilla/websocket package:

go get github.com/gorilla/websocket

Basic echo server:

package main

import (
    "log"
    "net/http"
    
    "github.com/gorilla/websocket"
)

var upgrader = websocket.Upgrader{
    CheckOrigin: func(r *http.Request) bool {
        // Allow connections from any origin in development
        return true
    },
}

func wsHandler(w http.ResponseWriter, r *http.Request) {
    // Upgrade HTTP connection to WebSocket
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Println("Failed to upgrade connection:", err)
        return
    }
    defer conn.Close()
    
    for {
        // Read message from client
        messageType, message, err := conn.ReadMessage()
        if err != nil {
            log.Println("Read error:", err)
            break
        }
        
        log.Printf("Received: %s", message)
        
        // Echo the message back
        err = conn.WriteMessage(messageType, message)
        if err != nil {
            log.Println("Write error:", err)
            break
        }
    }
}

func main() {
    http.HandleFunc("/ws", wsHandler)
    
    // Serve static files for the client
    http.Handle("/", http.FileServer(http.Dir("./static")))
    
    log.Println("Server starting on :8080")
    log.Fatal(http.ListenAndServe(":8080", nil))
}

HTML Client

Create a simple HTML client to test the WebSocket:

<!DOCTYPE html>
<html>
<head>
    <title>WebSocket Echo</title>
</head>
<body>
    <h1>WebSocket Echo Test</h1>
    <input type="text" id="messageInput" placeholder="Enter message">
    <button onclick="sendMessage()">Send</button>
    <div id="messages"></div>

    <script>
        const ws = new WebSocket('ws://localhost:8080/ws');
        
        ws.onopen = function(event) {
            console.log('Connected to WebSocket');
            addMessage('Connected to server');
        };
        
        ws.onmessage = function(event) {
            console.log('Received:', event.data);
            addMessage('Server: ' + event.data);
        };
        
        ws.onclose = function(event) {
            console.log('Disconnected from WebSocket');
            addMessage('Disconnected from server');
        };
        
        ws.onerror = function(error) {
            console.error('WebSocket error:', error);
            addMessage('Error: ' + error);
        };
        
        function sendMessage() {
            const input = document.getElementById('messageInput');
            const message = input.value;
            
            if (message) {
                ws.send(message);
                addMessage('You: ' + message);
                input.value = '';
            }
        }
        
        function addMessage(message) {
            const messages = document.getElementById('messages');
            const div = document.createElement('div');
            div.textContent = message;
            messages.appendChild(div);
        }
        
        // Send message on Enter key
        document.getElementById('messageInput').addEventListener('keypress', function(event) {
            if (event.key === 'Enter') {
                sendMessage();
            }
        });
    </script>
</body>
</html>

Chat Application

Let’s build a simple chat server that broadcasts messages to all connected clients:

package main

import (
    "log"
    "net/http"
    "sync"
    
    "github.com/gorilla/websocket"
)

type Client struct {
    conn *websocket.Conn
    send chan []byte
}

type Hub struct {
    clients    map[*Client]bool
    broadcast  chan []byte
    register   chan *Client
    unregister chan *Client
    mu         sync.RWMutex
}

func newHub() *Hub {
    return &Hub{
        clients:    make(map[*Client]bool),
        broadcast:  make(chan []byte),
        register:   make(chan *Client),
        unregister: make(chan *Client),
    }
}

func (h *Hub) run() {
    for {
        select {
        case client := <-h.register:
            h.mu.Lock()
            h.clients[client] = true
            h.mu.Unlock()
            log.Println("Client connected. Total clients:", len(h.clients))
            
        case client := <-h.unregister:
            h.mu.Lock()
            if _, ok := h.clients[client]; ok {
                delete(h.clients, client)
                close(client.send)
            }
            h.mu.Unlock()
            log.Println("Client disconnected. Total clients:", len(h.clients))
            
        case message := <-h.broadcast:
            h.mu.RLock()
            for client := range h.clients {
                select {
                case client.send <- message:
                default:
                    close(client.send)
                    delete(h.clients, client)
                }
            }
            h.mu.RUnlock()
        }
    }
}

var upgrader = websocket.Upgrader{
    CheckOrigin: func(r *http.Request) bool {
        return true
    },
}

var hub = newHub()

func wsHandler(w http.ResponseWriter, r *http.Request) {
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Println("Upgrade error:", err)
        return
    }
    
    client := &Client{
        conn: conn,
        send: make(chan []byte, 256),
    }
    
    hub.register <- client
    
    // Start goroutines for reading and writing
    go client.writePump()
    go client.readPump()
}

func (c *Client) readPump() {
    defer func() {
        hub.unregister <- c
        c.conn.Close()
    }()
    
    for {
        _, message, err := c.conn.ReadMessage()
        if err != nil {
            if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
                log.Printf("WebSocket error: %v", err)
            }
            break
        }
        
        // Broadcast the message to all clients
        hub.broadcast <- message
    }
}

func (c *Client) writePump() {
    defer c.conn.Close()
    
    for {
        select {
        case message, ok := <-c.send:
            if !ok {
                c.conn.WriteMessage(websocket.CloseMessage, []byte{})
                return
            }
            
            if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
                log.Println("Write error:", err)
                return
            }
        }
    }
}

func main() {
    go hub.run()
    
    http.HandleFunc("/ws", wsHandler)
    http.Handle("/", http.FileServer(http.Dir("./static")))
    
    log.Println("Chat server starting on :8080")
    log.Fatal(http.ListenAndServe(":8080", nil))
}

Structured Messages

For more complex applications, define message structures:

package main

import (
    "encoding/json"
    "log"
    "time"
)

type Message struct {
    Type      string    `json:"type"`
    Username  string    `json:"username,omitempty"`
    Content   string    `json:"content,omitempty"`
    Timestamp time.Time `json:"timestamp"`
}

func (c *Client) readPump() {
    defer func() {
        hub.unregister <- c
        c.conn.Close()
    }()
    
    // Set read deadline
    c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
    c.conn.SetPongHandler(func(string) error {
        c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
        return nil
    })
    
    for {
        _, data, err := c.conn.ReadMessage()
        if err != nil {
            if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
                log.Printf("WebSocket error: %v", err)
            }
            break
        }
        
        var msg Message
        if err := json.Unmarshal(data, &msg); err != nil {
            log.Printf("Invalid message format: %v", err)
            continue
        }
        
        msg.Timestamp = time.Now()
        
        // Broadcast structured message
        if messageData, err := json.Marshal(msg); err == nil {
            hub.broadcast <- messageData
        }
    }
}

Authentication

Add authentication to WebSocket connections:

func wsHandler(w http.ResponseWriter, r *http.Request) {
    // Check authentication (e.g., from cookie or token)
    username, authenticated := authenticateUser(r)
    if !authenticated {
        http.Error(w, "Unauthorized", http.StatusUnauthorized)
        return
    }
    
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Println("Upgrade error:", err)
        return
    }
    
    client := &Client{
        conn:     conn,
        send:     make(chan []byte, 256),
        username: username, // Add username to Client struct
    }
    
    hub.register <- client
    // ... rest of handler
}

func authenticateUser(r *http.Request) (string, bool) {
    // Check session cookie, JWT token, etc.
    // Return username and authentication status
    cookie, err := r.Cookie("session")
    if err != nil {
        return "", false
    }
    
    // Validate session and get username
    username, valid := validateSession(cookie.Value)
    return username, valid
}

Connection Limits and Cleanup

Prevent resource exhaustion:

const maxConnections = 1000

var connectionCount int32

func wsHandler(w http.ResponseWriter, r *http.Request) {
    // Check connection limit
    if atomic.LoadInt32(&connectionCount) >= maxConnections {
        http.Error(w, "Too many connections", http.StatusTooManyRequests)
        return
    }
    
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Println("Upgrade error:", err)
        return
    }
    
    atomic.AddInt32(&connectionCount, 1)
    
    client := &Client{
        conn: conn,
        send: make(chan []byte, 256),
    }
    
    // Decrement counter when connection closes
    conn.SetCloseHandler(func(code int, text string) error {
        atomic.AddInt32(&connectionCount, -1)
        return nil
    })
    
    // ... rest of handler
}

Binary Data

WebSockets can send binary data:

func (c *Client) readPump() {
    for {
        messageType, data, err := c.conn.ReadMessage()
        if err != nil {
            break
        }
        
        switch messageType {
        case websocket.TextMessage:
            // Handle text message
            log.Printf("Text message: %s", data)
            
        case websocket.BinaryMessage:
            // Handle binary message (e.g., file upload)
            log.Printf("Binary message: %d bytes", len(data))
            
            // Process binary data...
            
        default:
            log.Printf("Unknown message type: %d", messageType)
        }
    }
}

Error Handling and Recovery

Robust error handling:

func (c *Client) writePump() {
    ticker := time.NewTicker(54 * time.Second)
    defer func() {
        ticker.Stop()
        c.conn.Close()
    }()
    
    for {
        select {
        case message, ok := <-c.send:
            c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
            if !ok {
                c.conn.WriteMessage(websocket.CloseMessage, []byte{})
                return
            }
            
            if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
                log.Println("Write error:", err)
                return
            }
            
        case <-ticker.C:
            c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
            if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
                return
            }
        }
    }
}

Testing WebSocket Connections

Test WebSocket functionality:

package main

import (
    "net/http/httptest"
    "strings"
    "testing"
    "time"
    
    "github.com/gorilla/websocket"
)

func TestWebSocketEcho(t *testing.T) {
    // Create test server
    server := httptest.NewServer(http.HandlerFunc(wsHandler))
    defer server.Close()
    
    // Convert http to ws URL
    wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws"
    
    // Connect to WebSocket
    conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
    if err != nil {
        t.Fatalf("Failed to connect: %v", err)
    }
    defer conn.Close()
    
    // Send test message
    testMessage := "Hello, WebSocket!"
    err = conn.WriteMessage(websocket.TextMessage, []byte(testMessage))
    if err != nil {
        t.Fatalf("Failed to send message: %v", err)
    }
    
    // Read response
    conn.SetReadDeadline(time.Now().Add(1 * time.Second))
    messageType, response, err := conn.ReadMessage()
    if err != nil {
        t.Fatalf("Failed to read response: %v", err)
    }
    
    if messageType != websocket.TextMessage {
        t.Errorf("Expected text message, got %d", messageType)
    }
    
    if string(response) != testMessage {
        t.Errorf("Expected %q, got %q", testMessage, string(response))
    }
}

Best Practices

  1. Handle connection lifecycle: Properly close connections and clean up resources
  2. Implement ping/pong: Keep connections alive and detect broken connections
  3. Set timeouts: Prevent hanging connections
  4. Limit concurrent connections: Prevent resource exhaustion
  5. Authenticate connections: Don’t trust unauthenticated WebSocket connections
  6. Validate messages: Check message format and size
  7. Handle errors gracefully: Don’t crash on malformed messages
  8. Use structured messages: JSON for complex data
  9. Monitor connections: Track connection count and health
  10. Test thoroughly: WebSocket behavior can be tricky to test

WebSockets enable real-time features like chat, live updates, and collaborative editing. The gorilla/websocket package provides a solid foundation for building WebSocket applications in Go.

For more networking topics, check out our HTTP clients tutorial. If you’re building REST APIs, see the REST APIs tutorial.

Last updated on