protohacking/x/budgetchat/server.go
2025-06-17 00:50:55 -04:00

249 lines
4.6 KiB
Go

package budgetchat
import (
"bufio"
"errors"
"net"
"strings"
"sync"
"github.com/charmbracelet/log"
)
func isGoodAscii(str string) bool {
// we disallow anything outside the printable ascii range
// that means no control characters
for _, c := range str {
if c < ' ' || c > '~' {
return false
}
}
return true
}
type client struct {
id uint
name string
conn net.Conn
reader *bufio.Reader
bad bool
}
func (c *client) read() (string, error) {
line, err := c.reader.ReadString('\n')
if err != nil {
log.Error("unable to read from client", "err", err)
c.conn.Close()
return "", err
}
return line[:len(line)-1], nil
}
func (c *client) send(msg string) error {
if c.bad {
log.Warn("not sending to bad client", "remote", c.conn.RemoteAddr())
return errors.New("trying to send to a bad client")
}
_, err := c.conn.Write([]byte(msg))
if err != nil {
c.bad = true
log.Warn("marking client as bad", "err", err)
c.conn.Close()
return err
}
return nil
}
type room struct {
inc uint
lock *sync.RWMutex
clients []*client
}
func (r *room) whoishere() string {
r.lock.RLock()
defer r.lock.RUnlock()
builder := new(strings.Builder)
builder.WriteString("* Currently online: ")
if len(r.clients) == 0 {
log.Info("nobody is here!")
builder.WriteString("nobody!\n")
return builder.String()
}
for i, client := range r.clients {
if client == nil {
continue
}
builder.WriteString(client.name)
if i != len(r.clients) {
builder.WriteString(", ")
}
}
builder.WriteRune('\n')
return builder.String()
}
func (r *room) leave(client *client) {
r.lock.Lock()
defer r.lock.Unlock()
leaveMsg := "* " + client.name + " went offline!\n"
// loop over the current clients, saving the index of the one we're removing
// and messaging everyone else that they left
removable := []int{}
for i, c := range r.clients {
if c.bad || c.id == client.id {
removable = append(removable, i)
} else {
c.send(leaveMsg)
}
}
// remove the client who left and any bad connections
for _, index := range removable {
r.clients[index] = r.clients[len(r.clients)-1]
r.clients = r.clients[:len(r.clients)-1]
}
}
func (r *room) addClient(client *client) {
r.lock.Lock()
defer r.lock.Unlock()
client.id = r.inc
r.inc += 1
r.clients = append(r.clients, client)
}
func (r *room) join(conn net.Conn) {
log := log.Default().WithPrefix(conn.RemoteAddr().String())
log.Info("client connecting")
client := new(client)
client.conn = conn
client.reader = bufio.NewReader(conn)
client.bad = false
// start by asking for a name
err := client.send("* Welcome, what's your name?\n")
if err != nil {
log.Warn("kicking client: unable to write welcome", "err", err)
conn.Close()
return
}
// get the client's name
name, err := client.read()
if err != nil {
log.Warn("kicking client: unable to read name", "err", err)
return
}
// check to make sure the name is ok
if len(name) > 64 {
log.Warn("kicking client: name too long")
client.send("* You've been kicked: name too long\n")
conn.Close()
return
}
if !isGoodAscii(name) {
log.Warn("kicking client: name contains illegal characters")
client.send("* You've been kicked: name contains illegal characters\n")
conn.Close()
return
}
// establish the name
client.name = name
log.Info("name established", "name", client.name)
log.Info("sending room membership")
err = client.send(r.whoishere())
if err != nil {
}
// announce the new client
joinMsg := "* " + client.name + " is now online\n"
for _, c := range r.clients {
c.send(joinMsg)
}
// add the client to the room
r.addClient(client)
log.Info("assigned id", "id", client.id)
// start message loop
for {
line, err := client.read()
if err != nil {
break
}
err = r.message(client, line)
// if they send a bad message we tell them and kick them
if err != nil {
log.Warn("kicking client: bad message", "err", err)
client.send("* You've been kicked: " + err.Error() + "\n")
break
}
}
r.leave(client)
return
}
func (r *room) message(from *client, msg string) error {
r.lock.RLock()
defer r.lock.RUnlock()
if len(msg) > 1000 {
return errors.New("message too long")
}
if !isGoodAscii(msg) {
return errors.New("message contains illegal characters")
}
out := "[" + from.name + "] " + msg + "\n"
for _, client := range r.clients {
if client == nil {
continue
}
if client.id == from.id {
continue
}
client.send(out)
}
return nil
}
func New(listener net.Listener) error {
room := new(room)
room.lock = new(sync.RWMutex)
room.clients = []*client{}
for {
conn, err := listener.Accept()
if err != nil {
return err
}
go room.join(conn)
}
}