package budgetchat import ( "bufio" "errors" "net" "strings" "sync" "github.com/charmbracelet/log" ) func isGoodAscii(str string) bool { // we disallow anything outside the printable ascii range // that means no control characters for _, c := range str { if c < ' ' || c > '~' { return false } } return true } type client struct { id uint name string conn net.Conn reader *bufio.Reader bad bool } func (c *client) read() (string, error) { line, err := c.reader.ReadString('\n') if err != nil { log.Error("unable to read from client", "err", err) c.conn.Close() return "", err } return line[:len(line)-1], nil } func (c *client) send(msg string) error { if c.bad { log.Warn("not sending to bad client", "remote", c.conn.RemoteAddr()) return errors.New("trying to send to a bad client") } _, err := c.conn.Write([]byte(msg)) if err != nil { c.bad = true log.Warn("marking client as bad", "err", err) c.conn.Close() return err } return nil } type room struct { inc uint lock *sync.RWMutex clients []*client } func (r *room) whoishere() string { r.lock.RLock() defer r.lock.RUnlock() builder := new(strings.Builder) builder.WriteString("* Currently online: ") if len(r.clients) == 0 { log.Info("nobody is here!") builder.WriteString("nobody!\n") return builder.String() } for i, client := range r.clients { if client == nil { continue } builder.WriteString(client.name) if i != len(r.clients) { builder.WriteString(", ") } } builder.WriteRune('\n') return builder.String() } func (r *room) leave(client *client) { r.lock.Lock() defer r.lock.Unlock() leaveMsg := "* " + client.name + " went offline!\n" // loop over the current clients, saving the index of the one we're removing // and messaging everyone else that they left removable := []int{} for i, c := range r.clients { if c.bad || c.id == client.id { removable = append(removable, i) } else { c.send(leaveMsg) } } // remove the client who left and any bad connections for _, index := range removable { r.clients[index] = r.clients[len(r.clients)-1] r.clients = r.clients[:len(r.clients)-1] } } func (r *room) addClient(client *client) { r.lock.Lock() defer r.lock.Unlock() client.id = r.inc r.inc += 1 r.clients = append(r.clients, client) } func (r *room) join(conn net.Conn) { log := log.Default().WithPrefix(conn.RemoteAddr().String()) log.Info("client connecting") client := new(client) client.conn = conn client.reader = bufio.NewReader(conn) client.bad = false // start by asking for a name err := client.send("* Welcome, what's your name?\n") if err != nil { log.Warn("kicking client: unable to write welcome", "err", err) conn.Close() return } // get the client's name name, err := client.read() if err != nil { log.Warn("kicking client: unable to read name", "err", err) return } // check to make sure the name is ok if len(name) > 64 { log.Warn("kicking client: name too long") client.send("* You've been kicked: name too long\n") conn.Close() return } if !isGoodAscii(name) { log.Warn("kicking client: name contains illegal characters") client.send("* You've been kicked: name contains illegal characters\n") conn.Close() return } // establish the name client.name = name log.Info("name established", "name", client.name) log.Info("sending room membership") err = client.send(r.whoishere()) if err != nil { } // announce the new client joinMsg := "* " + client.name + " is now online\n" for _, c := range r.clients { c.send(joinMsg) } // add the client to the room r.addClient(client) log.Info("assigned id", "id", client.id) // start message loop for { line, err := client.read() if err != nil { break } err = r.message(client, line) // if they send a bad message we tell them and kick them if err != nil { log.Warn("kicking client: bad message", "err", err) client.send("* You've been kicked: " + err.Error() + "\n") break } } r.leave(client) return } func (r *room) message(from *client, msg string) error { r.lock.RLock() defer r.lock.RUnlock() if len(msg) > 1000 { return errors.New("message too long") } if !isGoodAscii(msg) { return errors.New("message contains illegal characters") } out := "[" + from.name + "] " + msg + "\n" for _, client := range r.clients { if client == nil { continue } if client.id == from.id { continue } client.send(out) } return nil } func New(listener net.Listener) error { room := new(room) room.lock = new(sync.RWMutex) room.clients = []*client{} for { conn, err := listener.Accept() if err != nil { return err } go room.join(conn) } }