commit d340bc49da1c35fac6315c28f78ebb21e8c70179 Author: Nyeogmi Date: Fri Sep 26 20:36:54 2025 -0700 Implementation 1 of a minimalistic IRC server diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f45d9d3 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.chromaticdragon.app/pyrex/minimal-irc-server/v2 + +go 1.21.5 diff --git a/src/irc/commands/all.go b/src/irc/commands/all.go new file mode 100644 index 0000000..b0b3458 --- /dev/null +++ b/src/irc/commands/all.go @@ -0,0 +1,11 @@ +package commands + +import ( + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/irc/world" +) + +func HandleCommands(msg world.WrappedMessage) { + handleAuthCommands(msg) + handleJoinPartCommands(msg) + handlePrivmsgNotifyCommands(msg) +} diff --git a/src/irc/commands/auth.go b/src/irc/commands/auth.go new file mode 100644 index 0000000..5877a46 --- /dev/null +++ b/src/irc/commands/auth.go @@ -0,0 +1,84 @@ +package commands + +import ( + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/irc/users" + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/irc/world" + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/transport" +) + +func handleAuthCommands(msg world.WrappedMessage) { + handleNickAndUser(msg) + completeHandshakeIfPossible(msg) +} + +func handleNickAndUser(msg world.WrappedMessage) { + if msg.Sender.GetHasReceivedAuthHandshakeReply() { + // TODO: Send an error reply + return + } + + if msg.Content.Command == "NICK" { + args := msg.Content.Arguments + if len(args) != 1 { + // TODO: Send an error reply + return + } + + nick := args[0] + validNick, err := users.ValidateNick(nick) + if err != nil { + // TODO: Send an error reply + return + } + + err = msg.Sender.SetNick(&validNick) + if err != nil { + msg.World.Server.TerminateClient(msg.Sender.GetClientId(), err) + } + } + + if msg.Content.Command == "USER" { + args := msg.Content.Arguments + if len(args) != 4 { + // TODO: Send an error reply + return + } + + username := args[0] + zero := args[1] + star := args[2] + realName := args[3] + + if zero != "0" || star != "*" { + // TODO: Send an error reply + return + } + + msg.Sender.SetUsername(&username) + msg.Sender.SetRealName(&realName) + + // TODO: Validation? I wonder if it matters. + } +} + +func completeHandshakeIfPossible(msg world.WrappedMessage) { + sender := msg.Sender + if msg.Sender.GetHasReceivedAuthHandshakeReply() { + return + } + + isReady := sender.GetNick() != nil && sender.GetUsername() != nil && sender.GetRealName() != nil + if !isReady { + return + } + + sender.SetHasReceivedAuthHandshakeReply(true) + msg.World.Server.SendMessage(sender.GetClientId(), transport.Content{ + Command: "NICK", + Arguments: []string{sender.GetNick().Value}, + }) + msg.World.Server.SendMessage(msg.Sender.GetClientId(), transport.Content{ + Command: "USER", + Arguments: []string{*sender.GetUsername(), "0", "*", *sender.GetRealName()}, + }) +} diff --git a/src/irc/commands/joinPart.go b/src/irc/commands/joinPart.go new file mode 100644 index 0000000..0f287bb --- /dev/null +++ b/src/irc/commands/joinPart.go @@ -0,0 +1,59 @@ +package commands + +import ( + "strings" + + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/irc/users" + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/irc/world" +) + +func handleJoinPartCommands(msg world.WrappedMessage) { + if msg.Content.Command == "JOIN" { + if len(msg.Content.Arguments) != 1 { + // TODO: Wrong number of arguments + return + } + channelsToJoin := parseChannelList(msg.Content.Arguments[0]) + + for _, channel := range channelsToJoin { + err := msg.Sender.Join(channel) + if err != nil { + msg.World.Server.TerminateClient(msg.Sender.GetClientId(), err) + return + } + msg.World.RelayToChannel(msg, channel, nil) + } + } + + if msg.Content.Command == "PART" { + n := len(msg.Content.Arguments) + if !(n == 1 || n == 2) { + return + } + channelsToPart := parseChannelList(msg.Content.Arguments[0]) + + for _, channel := range channelsToPart { + err := msg.Sender.Part(channel) + if err != nil { + msg.World.Server.TerminateClient(msg.Sender.GetClientId(), err) + return + } + msg.World.RelayToChannel(msg, channel, nil) + // the user won't see their own #part because they left, so send it + msg.World.RelayToClient(msg, msg.Sender.GetClientId(), nil) + } + } +} + +func parseChannelList(arg string) []users.ChannelName { + var channels []users.ChannelName + for _, channelName := range strings.Split(arg, ",") { + validChannel, err := users.ValidateChannelName(channelName) + if err != nil { // can't join, not a channel + continue + } + + channels = append(channels, validChannel) + } + return channels +} diff --git a/src/irc/commands/privmsgNotify.go b/src/irc/commands/privmsgNotify.go new file mode 100644 index 0000000..83d4459 --- /dev/null +++ b/src/irc/commands/privmsgNotify.go @@ -0,0 +1,21 @@ +package commands + +import ( + "log" + + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/irc/world" + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/transport" +) + +func handlePrivmsgNotifyCommands(msg world.WrappedMessage) { + if msg.Content.Command == "PRIVMSG" || msg.Content.Command == "NOTIFY" || msg.Content.Command == "CTCP" { + log.Printf("message-like command") + if len(msg.Content.Arguments) == 0 { + // TODO: Error reply + return + } + + // Was this message to a user? + msg.World.RelayToVagueDestination(msg, msg.Content.Arguments[0], []transport.ClientId{msg.Sender.GetClientId()}) + } +} diff --git a/src/irc/errors/errors.go b/src/irc/errors/errors.go new file mode 100644 index 0000000..ebf0f0e --- /dev/null +++ b/src/irc/errors/errors.go @@ -0,0 +1,7 @@ +package errors + +import "fmt" + +var ErrAlreadyInChannel = fmt.Errorf("already in channel") +var ErrNickAlreadyInUse = fmt.Errorf("nick already in use") +var ErrNotInChannel = fmt.Errorf("not in channel") diff --git a/src/irc/identifiers.go b/src/irc/identifiers.go new file mode 100644 index 0000000..d8005d0 --- /dev/null +++ b/src/irc/identifiers.go @@ -0,0 +1,61 @@ +package irc + +import ( + "fmt" + "regexp" + "strings" +) + +var ErrNotANick = fmt.Errorf("does not look like a nickname") + +var regexpNick = regexp.MustCompile("^[a-zA-Z0-9]+$") // NOTE: more constrained than real character set + +type Nick string +type CanonicalNick string + +func ValidateNick(s string) (Nick, error) { + // TODO: Fail if the string doesn't look like a nick + if !regexpNick.MatchString(s) { + return "", fmt.Errorf("%w: %s", ErrNotANick, s) + } + return Nick(s), nil +} + +func (n Nick) Canonize() CanonicalNick { + return CanonicalNick(strings.ToLower(string(n))) +} + +func (n *Nick) CanonizeNullable() *CanonicalNick { + if n == nil { + return nil + } + result := n.Canonize() + return &result +} + +var ErrNotAChannel = fmt.Errorf("does not look like a channel name") + +var regexpChannel = regexp.MustCompile("^#[a-zA-Z0-9]+$") // NOTE: more constrained than real character set + +type Channel string +type CanonicalChannel string + +func ValidateChannel(s string) (Channel, error) { + // TODO: Fail if the string doesn't look like a channel name + if !regexpChannel.MatchString(s) { + return "", fmt.Errorf("%w: %s", ErrNotAChannel, s) + } + return Channel(s), nil +} + +func (c Channel) Canonize() CanonicalChannel { + return CanonicalChannel(strings.ToLower(string(c))) +} + +func (c *Channel) CanonizeNullable() *CanonicalChannel { + if c == nil { + return nil + } + result := c.Canonize() + return &result +} diff --git a/src/irc/main.go b/src/irc/main.go new file mode 100644 index 0000000..0c4a016 --- /dev/null +++ b/src/irc/main.go @@ -0,0 +1,25 @@ +package irc + +import ( + "log" + + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/irc/commands" + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/irc/world" + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/transport" +) + +func ServeIrc(server *transport.Server) { + world := world.NewWorld(server) + + for { + rawMessage, err := server.ReceiveMessage() + if err != nil { + log.Println("failed to receive message: %w") + return + } + + wrappedMessage := world.Wrap(rawMessage) + + commands.HandleCommands(wrappedMessage) + } +} diff --git a/src/irc/users/system.go b/src/irc/users/system.go new file mode 100644 index 0000000..07fec43 --- /dev/null +++ b/src/irc/users/system.go @@ -0,0 +1,183 @@ +package users + +import ( + "fmt" + "slices" + + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/irc/errors" + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/transport" +) + +type UserId uint64 + +type UsersSystem struct { + clientIdIndex map[transport.ClientId]*User + nickIndex map[canonicalNick]*User + channelNameIndex map[canonicalChannelName](map[*User]struct{}) +} + +type User struct { + users *UsersSystem + + clientId transport.ClientId + nick *Nick + username *string + realName *string + + hasReceivedAuthHandshakeReply bool + + channels []ChannelName +} + +func NewUsersSystem() *UsersSystem { + return &UsersSystem{ + clientIdIndex: make(map[transport.ClientId]*User), + nickIndex: make(map[canonicalNick]*User), + channelNameIndex: make(map[canonicalChannelName]map[*User]struct{}), + } +} + +func (users *UsersSystem) ByClientIdOrCreate(clientId transport.ClientId) *User { + existing, ok := users.clientIdIndex[clientId] + + if ok { + return existing + } + + user := &User{ + users: users, + + clientId: clientId, + nick: nil, + username: nil, + realName: nil, + + hasReceivedAuthHandshakeReply: false, + + channels: nil, + } + users.clientIdIndex[clientId] = user + return user +} + +func (users *UsersSystem) ByNick(nick Nick) *User { + return users.nickIndex[nick.canonical] +} + +func (users *UsersSystem) ByChannel(channelName ChannelName) map[*User]struct{} { + return users.channelNameIndex[channelName.canonical] +} + +func (user *User) GetClientId() transport.ClientId { + return user.clientId +} + +func (user *User) GetNick() *Nick { + return user.nick +} + +func (user *User) SetNick(newNick *Nick) error { + users := user.users + oldNick := user.nick + + // check if already in use -- if so, refuse + _, alreadyInUse := users.nickIndex[newNick.canonical] + if alreadyInUse { + if oldNick != nil && newNick.canonical == oldNick.canonical { + // it's fine, this is the user who held that nick + // so continue as before + } else { + return fmt.Errorf("%w: %s", errors.ErrNickAlreadyInUse, newNick.Value) + } + } + + // update indexes + if oldNick != nil { + delete(users.nickIndex, oldNick.canonical) + } + if newNick != nil { + users.nickIndex[newNick.canonical] = user + } + + // update me + user.nick = newNick + return nil +} + +func (user *User) GetUsername() *string { + return user.username +} + +func (user *User) SetUsername(username *string) { + user.username = username +} + +func (user *User) GetRealName() *string { + return user.realName +} + +func (user *User) SetRealName(realName *string) { + user.realName = realName +} + +func (user *User) GetHasReceivedAuthHandshakeReply() bool { + return user.hasReceivedAuthHandshakeReply +} + +func (user *User) SetHasReceivedAuthHandshakeReply(value bool) { + user.hasReceivedAuthHandshakeReply = value +} + +func (user *User) IsInChannel(channelName ChannelName) bool { + return slices.ContainsFunc(user.channels, func(existingChannel ChannelName) bool { + return channelName.canonical == existingChannel.canonical + }) +} + +func (user *User) Join(channelName ChannelName) error { + users := user.users + + // if I'm already in this channel, don't join + if user.IsInChannel(channelName) { + return fmt.Errorf("%w: %s", errors.ErrAlreadyInChannel, channelName.Value) + } + + // update indexes + existing, ok := users.channelNameIndex[channelName.canonical] + if !ok { + existing = make(map[*User]struct{}) + users.channelNameIndex[channelName.canonical] = existing + } + _, wasInChannel := existing[user] + if wasInChannel { + panic("tried to join a channel, but I was mysteriously already in it") + } + existing[user] = struct{}{} + + // update me + user.channels = append(user.channels, channelName) + + return nil +} + +func (user *User) Part(channelName ChannelName) error { + users := user.users + + // if i'm not in this channel, don't part + if !user.IsInChannel(channelName) { + return fmt.Errorf("%w: %s", errors.ErrNotInChannel, channelName.Value) + } + + // update indexes + existing, ok := users.channelNameIndex[channelName.canonical] + if ok { + delete(existing, user) + if len(existing) == 0 { + delete(users.channelNameIndex, channelName.canonical) + } + } else { + panic("tried to part from a channel, but was mysteriously absent from it") + } + + return nil +} diff --git a/src/irc/users/types.go b/src/irc/users/types.go new file mode 100644 index 0000000..a551173 --- /dev/null +++ b/src/irc/users/types.go @@ -0,0 +1,46 @@ +package users + +import ( + "fmt" + "regexp" + "strings" +) + +type Nick struct { + Value string + canonical canonicalNick +} +type canonicalNick string + +var ErrNotANick = fmt.Errorf("does not look like a nickname") +var regexpNick = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`) // NOTE: more constrained than real character set + +func ValidateNick(input string) (Nick, error) { + if !regexpNick.MatchString(input) { + return Nick{}, fmt.Errorf("%w: %s", ErrNotANick, input) + } + + return Nick{ + Value: input, + canonical: canonicalNick(strings.ToLower(input)), + }, nil +} + +type ChannelName struct { + Value string + canonical canonicalChannelName +} +type canonicalChannelName string + +var ErrNotAChannelName = fmt.Errorf("does not look like a channel name") +var regexpChannelName = regexp.MustCompile(`^#[a-zA-Z0-9\-_]+$`) + +func ValidateChannelName(input string) (ChannelName, error) { + if !regexpChannelName.MatchString(input) { + return ChannelName{}, fmt.Errorf("%w: %s", ErrNotAChannelName, input) + } + return ChannelName{ + Value: input, + canonical: canonicalChannelName(strings.ToLower(input)), + }, nil +} diff --git a/src/irc/world/world.go b/src/irc/world/world.go new file mode 100644 index 0000000..b03dc29 --- /dev/null +++ b/src/irc/world/world.go @@ -0,0 +1,122 @@ +package world + +import ( + "fmt" + "log" + "slices" + + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/irc/users" + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/transport" +) + +type World struct { + Server *transport.Server + UsersSystem *users.UsersSystem +} + +type WrappedMessage struct { + World *World + Sender *users.User + Content transport.Content +} + +func NewWorld(server *transport.Server) *World { + usersSystem := users.NewUsersSystem() + + return &World{ + Server: server, + UsersSystem: usersSystem, + } +} + +func (world *World) Wrap(msg transport.IncomingMessage) WrappedMessage { + sender := world.UsersSystem.ByClientIdOrCreate(msg.Sender) + + return WrappedMessage{ + World: world, + Sender: sender, + Content: msg.Content, + } +} + +// transmission of messages +func (world *World) RelayToVagueDestination( + msg WrappedMessage, + name string, + exclude []transport.ClientId, +) { + nick, err := users.ValidateNick(name) + if err == nil { + // so it's a nick! + world.RelayToNick(msg, nick, exclude) + return + } + + channel, err := users.ValidateChannelName(name) + if err == nil { + // so it's a channel! + world.RelayToChannel(msg, channel, exclude) + return + } + + log.Fatalf("not sure how to send to %s", name) + // TODO: Error response: "what is this?" +} + +func (world *World) RelayToClient( + msg WrappedMessage, + client transport.ClientId, + exclude []transport.ClientId, +) { + content := createAnnotatedContent(msg) + if slices.Contains(exclude, client) { + return // don't relay + } + world.Server.SendMessage(client, content) +} + +func (world *World) RelayToNick( + msg WrappedMessage, + nick users.Nick, + exclude []transport.ClientId, +) { + content := createAnnotatedContent(msg) + + user := world.UsersSystem.ByNick(nick) + if user == nil { + // TODO: Send an error reply. The user didn't exist + return + } + + if slices.Contains(exclude, user.GetClientId()) { + return // don't relay + } + + world.Server.SendMessage(user.GetClientId(), content) +} + +func (world *World) RelayToChannel( + msg WrappedMessage, + channelName users.ChannelName, + exclude []transport.ClientId, +) { + content := createAnnotatedContent(msg) + + members := world.UsersSystem.ByChannel(channelName) + log.Printf("Members of %s: %v\n", channelName, members) + for member := range members { + if slices.Contains(exclude, member.GetClientId()) { + return // don't relay + } + world.Server.SendMessage(member.GetClientId(), content) + } +} + +func createAnnotatedContent( + msg WrappedMessage, +) transport.Content { + content := msg.Content + fullSource := fmt.Sprintf("%s!clients/%d", msg.Sender.GetNick().Value, msg.Sender.GetClientId()) + content.Source = &fullSource + return content +} diff --git a/src/main.go b/src/main.go new file mode 100644 index 0000000..f92ec19 --- /dev/null +++ b/src/main.go @@ -0,0 +1,18 @@ +package main + +import ( + "log" + + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/irc" + "git.chromaticdragon.app/pyrex/minimal-irc-server/v2/src/transport" +) + +func main() { + server, err := transport.NewServer("127.0.0.1:6667") + if err != nil { + log.Fatalln("couldn't start server: %w", err) + } + defer server.Close() + + irc.ServeIrc(server) +} diff --git a/src/transport/connectedClients.go b/src/transport/connectedClients.go new file mode 100644 index 0000000..2dc496a --- /dev/null +++ b/src/transport/connectedClients.go @@ -0,0 +1,60 @@ +package transport + +import ( + "context" + "sync" +) + +type ClientId uint64 + +type ConnectedClients struct { + mutex sync.Mutex + nextId ClientId + table map[ClientId]*ConnectedClient +} +type ConnectedClient struct { + cancel context.CancelCauseFunc + outgoingMessages chan<- OutgoingMessage +} + +func newConnectedClients() ConnectedClients { + return ConnectedClients{ + mutex: sync.Mutex{}, + nextId: 1, + table: make(map[ClientId]*ConnectedClient), + } +} + +func (cc *ConnectedClients) Enroll(callback func(ClientId) ConnectedClient) ClientId { + cc.mutex.Lock() + defer cc.mutex.Unlock() + + clientId := cc.nextId + cc.nextId += 1 + + newClient := callback(clientId) + cc.table[clientId] = &newClient + + return clientId + +} + +func (cc *ConnectedClients) Unenroll(clientId ClientId) { + cc.mutex.Lock() + defer cc.mutex.Unlock() + + delete(cc.table, clientId) +} + +func (cc *ConnectedClients) BorrowIfPresent(clientId ClientId, callback func(*ConnectedClient)) { + cc.mutex.Lock() + defer cc.mutex.Unlock() + + client, ok := cc.table[clientId] + if !ok { + return + } + + callback(client) + +} diff --git a/src/transport/messages.go b/src/transport/messages.go new file mode 100644 index 0000000..4a84afb --- /dev/null +++ b/src/transport/messages.go @@ -0,0 +1,17 @@ +package transport + +type IncomingMessage struct { + Sender ClientId + Content Content +} + +type OutgoingMessage struct { + Recipient ClientId + Content Content +} + +type Content struct { + Source *string + Command string + Arguments []string +} diff --git a/src/transport/networkingUtilities.go b/src/transport/networkingUtilities.go new file mode 100644 index 0000000..683f3e0 --- /dev/null +++ b/src/transport/networkingUtilities.go @@ -0,0 +1,30 @@ +package transport + +import ( + "bufio" + "io" +) + +type lineByLineItem struct { + Line string + Error error +} + +func readLineByLine(reader io.Reader) <-chan lineByLineItem { + bufReader := bufio.NewReader(reader) + channel := make(chan lineByLineItem) + + go (func() { + defer close(channel) + + for { + line, err := bufReader.ReadString('\n') + channel <- lineByLineItem{Line: line, Error: err} + if err != nil { + return + } + } + })() + + return channel +} diff --git a/src/transport/parsing.go b/src/transport/parsing.go new file mode 100644 index 0000000..89c5c25 --- /dev/null +++ b/src/transport/parsing.go @@ -0,0 +1,117 @@ +package transport + +import ( + "fmt" + "strings" +) + +var ErrInvalidIncomingMessage = fmt.Errorf("invalid content in ingoing message") + +func Deserialize(line string) (*Content, error) { + line, found := strings.CutSuffix(line, "\r\n") + if !found { + return nil, fmt.Errorf("%w: all IRC messages should be terminated by \\r\\n (%s)", ErrInvalidIncomingMessage, line) + } + + if line == "" { + // blank line + return nil, nil + } + + var p = &parser{basis: line, index: 0} + var source *string + if p.Pop(":") { + src := p.PopWhile(isNotWhitespace) + source = &src + p.PopWhile(isWhitespace) + + if len(src) == 0 { + return nil, p.NewError("zero-length source") + } + } + + command := p.PopWhile(isNotWhitespace) + p.PopWhile(isWhitespace) + if len(command) == 0 { + return nil, p.NewError("zero-length command") + } + + var args []string + for !p.Depleted() { + var arg string + if p.Pop(":") { + arg = p.PopWhile(isNotNewline) + } else { + arg = p.PopWhile(isNotWhitespace) + + if len(arg) == 0 { + return nil, p.NewError("zero-length arg in non-final position") + } + } + p.PopWhile(isWhitespace) + + args = append(args, arg) + } + + return &Content{ + Source: source, + Command: strings.ToUpper(command), + Arguments: args, + }, nil +} + +type parser struct { + basis string + index int +} + +func (p *parser) Depleted() bool { + return p.index >= len(p.basis) +} + +func (p *parser) Pop(s string) bool { + n := len(s) + if p.index+n > len(p.basis) { + return false + } + if p.basis[p.index:p.index+n] == s { + p.index += n + return true + } + return false +} + +func (p *parser) PopWhile(pred func(byte) bool) string { + start := p.index + end := start + for { + if end < len(p.basis) && pred(p.basis[end]) { + end = end + 1 + } else { + break + } + } + p.index = end + return p.basis[start:end] + +} + +func (p *parser) NewError(msg string) error { + return fmt.Errorf("%w: %s (%s, %d)", ErrInvalidIncomingMessage, msg, p.basis, p.index) +} + +func isNotWhitespace(b byte) bool { + return !isWhitespace(b) +} + +func isWhitespace(b byte) bool { + return b == '\n' || b == '\r' || b == ' ' +} + +func isNotNewline(b byte) bool { + return !isNewline(b) +} + +func isNewline(b byte) bool { + return b == '\n' || b == '\r' +} diff --git a/src/transport/serialization.go b/src/transport/serialization.go new file mode 100644 index 0000000..d20f216 --- /dev/null +++ b/src/transport/serialization.go @@ -0,0 +1,66 @@ +package transport + +import ( + "fmt" + "strings" +) + +var ErrInvalidContent = fmt.Errorf("invalid content in message") + +func (c Content) Serialize() (*string, error) { + var builder strings.Builder + + if c.Source != nil { + src := *c.Source + builder.WriteString(":") + err := writeDisallowingWhitespace(&builder, src, "space in source") + if err != nil { + return nil, err + } + builder.WriteByte(' ') + + } + + err := writeDisallowingWhitespace(&builder, c.Command, "space in command") + if err != nil { + return nil, err + } + + for ix, arg := range c.Arguments { + builder.WriteByte(' ') + isLast := ix == len(c.Arguments)-1 + if isLast { + builder.WriteString(":") + writeDisallowingNewlines(&builder, arg, "newline in final arg") + } else { + writeDisallowingWhitespace(&builder, arg, "space in non-final arg") + } + } + + builder.WriteString("\r\n") + result := builder.String() + return &result, nil +} +func writeDisallowingWhitespace(sb *strings.Builder, s string, msg string) error { + if containsWhitespace(s) { + return fmt.Errorf("%w: %s (%s)", ErrInvalidContent, s, msg) + } + sb.WriteString(s) + return nil +} + +func writeDisallowingNewlines(sb *strings.Builder, s string, msg string) error { + if containsNewlines(s) { + return fmt.Errorf("%w: %s (%s)", ErrInvalidContent, s, msg) + } + sb.WriteString(s) + return nil +} + +func containsWhitespace(s string) bool { + return strings.Contains(s, " ") || strings.Contains(s, "\n") || strings.Contains(s, "\r") +} + +func containsNewlines(s string) bool { + return strings.Contains(s, "\n") || strings.Contains(s, "\r") +} diff --git a/src/transport/server.go b/src/transport/server.go new file mode 100644 index 0000000..720ac68 --- /dev/null +++ b/src/transport/server.go @@ -0,0 +1,152 @@ +package transport + +import ( + "bufio" + "context" + "fmt" + "log" + "net" +) + +type Server struct { + ctx context.Context + cancel context.CancelCauseFunc + connectedClients ConnectedClients + incomingMessages chan IncomingMessage +} + +var ErrAlreadyClosed = fmt.Errorf("server already closed") + +func NewServer(address string) (*Server, error) { + ctx, cancel := context.WithCancelCause(context.Background()) + listener, err := net.Listen("tcp", address) + + if err != nil { + return nil, err + } + + server := &Server{ + ctx: ctx, + cancel: cancel, + connectedClients: newConnectedClients(), + incomingMessages: make(chan IncomingMessage), + } + + go (func() { + for { + connection, err := listener.Accept() + if err != nil { + cancel(err) + return + } + go server.handleConnection(connection) + } + })() + + return server, nil +} + +func (server *Server) Close() { + close(server.incomingMessages) + server.cancel(nil) +} + +func (server *Server) handleConnection(conn net.Conn) { + defer conn.Close() + + clientCtx, cancel := context.WithCancelCause(server.ctx) + outgoingMessages := make(chan OutgoingMessage) + + clientId := server.connectedClients.Enroll(func(id ClientId) ConnectedClient { + return ConnectedClient{ + cancel: cancel, + outgoingMessages: outgoingMessages, + } + }) + defer server.connectedClients.Unenroll(clientId) + + go (func() { + <-clientCtx.Done() + cause := context.Cause(clientCtx) + log.Printf("client %d done: %s", clientId, cause) + })() + + ingoingLines := readLineByLine(conn) + outgoingLines := bufio.NewWriter(conn) + + for { + select { + case item := <-ingoingLines: + line := item.Line + err := item.Error + + if err != nil { + cancel(err) + continue + } + + msg, err := Deserialize(line) + if err != nil { + cancel(err) + continue + } + if msg == nil { + continue + } + + log.Printf("recv: %v", msg) + server.incomingMessages <- IncomingMessage{ + Sender: clientId, + Content: *msg, + } + case outgoing := <-outgoingMessages: + log.Printf("sent: %v", outgoing.Content) + content, err := outgoing.Content.Serialize() + if err != nil { + cancel(err) + continue + } + log.Printf("content: %s", *content) + + _, err = outgoingLines.WriteString(*content) + if err != nil { + cancel(err) + continue + } + + // TODO: Don't flush on every iteration + err = outgoingLines.Flush() + if err != nil { + cancel(err) + continue + } + case <-clientCtx.Done(): + return + } + } + +} + +func (server *Server) ReceiveMessage() (IncomingMessage, error) { + message, ok := <-server.incomingMessages + if !ok { + return IncomingMessage{}, ErrAlreadyClosed + } + return message, nil +} + +func (server *Server) SendMessage(client ClientId, content Content) { + outgoing := OutgoingMessage{ + Recipient: client, + Content: content, + } + server.connectedClients.BorrowIfPresent(client, func(connectedClient *ConnectedClient) { + connectedClient.outgoingMessages <- outgoing + }) +} + +func (server *Server) TerminateClient(client ClientId, err error) { + server.connectedClients.BorrowIfPresent(client, func(connectedClient *ConnectedClient) { + connectedClient.cancel(err) + }) +}