From da2d17d81ce3838e4cebdbfa4b92f417fa93a891 Mon Sep 17 00:00:00 2001 From: george Date: Wed, 18 May 2022 19:35:42 +0100 Subject: [PATCH] init --- Dockerfile | 13 ++ cmd/main.go | 51 ++++++++ go.mod | 9 ++ go.sum | 6 + internal/chat.go | 107 +++++++++++++++ internal/client.go | 94 +++++++++++++ internal/handlers.go | 305 +++++++++++++++++++++++++++++++++++++++++++ internal/list.go | 59 +++++++++ internal/server.go | 176 +++++++++++++++++++++++++ test/server_test.go | 45 +++++++ 10 files changed, 865 insertions(+) create mode 100644 Dockerfile create mode 100755 cmd/main.go create mode 100755 go.mod create mode 100755 go.sum create mode 100755 internal/chat.go create mode 100755 internal/client.go create mode 100755 internal/handlers.go create mode 100755 internal/list.go create mode 100755 internal/server.go create mode 100755 test/server_test.go diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..3a182a7 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM golang:latest +WORKDIR /app +COPY go.mod go.sum ./ +RUN go mod download +COPY ./cmd ./cmd +COPY ./internal ./internal +RUN go build -o /chat ./cmd + +FROM gcr.io/distroless/base-debian10 +WORKDIR / +COPY --from=0 /chat /chat +ENTRYPOINT ["/chat"] +#docker run --network="host" -e chat #127.0.0.1:8000 diff --git a/cmd/main.go b/cmd/main.go new file mode 100755 index 0000000..80a6a9b --- /dev/null +++ b/cmd/main.go @@ -0,0 +1,51 @@ +package main + +import ( + "flag" + "fmt" + "net/http" + "os" + "strings" + + chat "git.wens.org.uk/chat/internal" +) + +func main() { + parseArgs() + + chat.Servers = make(map[string]*chat.Server) + chat.StartServer("root") + chat.SetUpHandlers() + + fmt.Println("chat server listening on", chat.Address) + if err := http.ListenAndServe(chat.Address, nil); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } +} + +func parseArgs() { + db := flag.String("d", chat.DbAddress, "address for db server") + flag.Parse() + if chat.DbAddress != *db { + if !strings.HasPrefix(*db, "http://") { + *db = "http://" + *db + } + chat.DbAddress = *db + } + + args := flag.Args() + if len(args) == 1 { + chat.Address = args[0] + } + + if len(args) > 1 { + fmt.Fprintln(os.Stderr, "too many arguments. usage: [options] [address]") + flag.PrintDefaults() + os.Exit(1) + } + + if chat.JwtSecret == "" { + fmt.Println("Warning: JWT secret is empty string") + } +} diff --git a/go.mod b/go.mod new file mode 100755 index 0000000..832ccef --- /dev/null +++ b/go.mod @@ -0,0 +1,9 @@ +module git.wens.org.uk/chat + +go 1.18 + +require ( + github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/google/uuid v1.3.0 + github.com/gorilla/websocket v1.5.0 +) diff --git a/go.sum b/go.sum new file mode 100755 index 0000000..2f4d54e --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/internal/chat.go b/internal/chat.go new file mode 100755 index 0000000..fa62010 --- /dev/null +++ b/internal/chat.go @@ -0,0 +1,107 @@ +package chat + +// chat server funcs which aren't specific to the server struct + +import ( + "bytes" + "crypto/md5" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + + "github.com/golang-jwt/jwt" +) + +var Address = "127.0.0.1:8000" +var JwtSecret = os.Getenv("secret") + +func checkToken(token string) (username string, ok bool) { + type ParsedToken struct { + Username string + jwt.StandardClaims + } + var t ParsedToken + + jwt, err := jwt.ParseWithClaims(token, &t, func(t *jwt.Token) (interface{}, error) { + return []byte(JwtSecret), nil + }) + + if err != nil { + return "", false + } + + if userToken, ok := jwt.Claims.(*ParsedToken); ok && jwt.Valid { + return userToken.Username, true + } else { + fmt.Println(err) + return "", false + } +} + +func encodeJson(data interface{}) ([]byte, bool) { + if jsonData, err := json.Marshal(data); err != nil { + fmt.Fprintln(os.Stderr, "Could not marshal json data", data) + return nil, false + } else { + return jsonData, true + } +} + +func decodeJson(body *io.ReadCloser, w *http.ResponseWriter, message interface{}) { + if err := json.NewDecoder(*body).Decode(&message); err != nil { + fmt.Fprintln(os.Stderr, "Error decoding", err.Error()) + if w != nil && (*w) != nil { + (*w).WriteHeader(http.StatusBadRequest) + (*w).Write([]byte(`{"error":"bad message"}`)) + } + } +} + +func dbHandleError(body *io.ReadCloser, w *http.ResponseWriter) (message string) { + type ErrorMessage struct { + Error string `json:"error"` + } + var error ErrorMessage + decodeJson(body, w, &error) + return error.Error +} + +func isAcceptedMIMEType(mime string) bool { + switch mime { + case "image/bmp", + "image/gif", + "image/jpeg", + "image/png", + "audio/mpeg", + "video/webm": + return true + } + return false +} + +func uploadFile(fileContents *[]byte, fileName, mimeType string) (res *http.Response, err error) { + type FileObject struct { + FileName string `json:"filename"` + FileType string `json:"filetype"` + Contents []byte `json:"contents"` + Sum [16]byte `json:"sum"` + } + + file := FileObject{ + fileName, mimeType, *fileContents, md5.Sum(*fileContents), + } + + jsonData, ok := encodeJson(&file) + if !ok { + fmt.Fprintln(os.Stderr, "error encoding file data") + return + } + + res, err = http.Post(DbAddress+"/file", "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + fmt.Fprintln(os.Stderr, "file POST error:", err) + } + return +} diff --git a/internal/client.go b/internal/client.go new file mode 100755 index 0000000..7a357a4 --- /dev/null +++ b/internal/client.go @@ -0,0 +1,94 @@ +package chat + +import ( + "fmt" + "os" + "time" + + "github.com/gorilla/websocket" +) + +type User struct { + Username string `json:"username"` + Token string `json:"token"` +} + +type Client struct { + conn *websocket.Conn + server *Server + receivedMessages chan (WrappedMessage) + username string + id string +} + +type WrappedMessage struct { + DataType string `json:"datatype"` + Data Message `json:"data"` +} + +var deadline = time.Second * 60 + +func (c *Client) disconnect() { + dcString := fmt.Sprint("server "+c.server.Name+" ended connection with ", c.username) + delete(c.server.clients, c.id) + c.server.sendMessage(dcString) + fmt.Println(dcString) + fmt.Println(c.server.reportClients(), "clients connected to "+c.server.Name) + c.conn.WriteControl(websocket.CloseGoingAway, []byte("reconnected"), time.Now().Add(deadline)) + c.conn.Close() +} + +func (c *Client) pongHandler(string) error { + return c.conn.SetReadDeadline(time.Now().Add(deadline)) +} + +func (c *Client) awaitDisconnect() { + c.conn.SetReadDeadline(time.Now().Add(deadline)) + c.conn.SetPongHandler(c.pongHandler) + for { + _, _, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + fmt.Fprintln(os.Stderr, "Unexpected socket close error:", err.Error()) + } + c.disconnect() + return + } + } +} + +func (c *Client) awaitMessage() { + ticker := time.NewTicker(deadline / 2) + for { + select { + case mess := <-c.receivedMessages: + { + jsonData, ok := encodeJson(mess) + if !ok { + continue + } + + if !c.sendMessage(websocket.TextMessage, jsonData) { + return + } + } + case <-ticker.C: + { + if !c.sendMessage(websocket.PingMessage, nil) { + return + } + } + } + } +} + +func (c *Client) sendMessage(messageType int, data []byte) bool { + if err := c.conn.WriteMessage(messageType, data); err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + fmt.Fprintln(os.Stderr, "Unexpected socket close error:", err.Error()) + c.disconnect() + } + return false + } + return true +} diff --git a/internal/handlers.go b/internal/handlers.go new file mode 100755 index 0000000..08c068d --- /dev/null +++ b/internal/handlers.go @@ -0,0 +1,305 @@ +package chat + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "os" + "regexp" + "strconv" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +var DbAddress = "http://127.0.0.1:8002" + +func SetUpHandlers() { + http.HandleFunc("/messages", MessagesHandler) + http.HandleFunc("/send", SendHandler) + http.HandleFunc("/connect", ConnectHandler) + http.HandleFunc("/clients", ClientsHandler) + http.HandleFunc("/createserver", CreateServerHandler) + http.HandleFunc("/file", FileHandler) + + upgrader.CheckOrigin = func(r *http.Request) bool { + if true { // temp + return true + } + origin := r.Header.Get("Origin") + reg := regexp.MustCompile(".*:8000") + return reg.MatchString(origin) + } +} + +func ConnectHandler(w http.ResponseWriter, r *http.Request) { + user, ok := checkToken(r.Header.Get("Token")) + if !ok { + w.WriteHeader(http.StatusUnauthorized) + return + } + + con, err := upgrader.Upgrade(w, r, nil) + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + return + } + + s, err := getServer(r.URL.Query().Get("name")) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintln(w, err.Error()) + return + } + + c := Client{con, s, make(chan WrappedMessage), user, uuid.NewString()} + s.clients[c.id] = &c + + fmt.Println(c.server.reportClients(), "clients connected to "+c.server.Name+" - latest", c.username, "from", r.RemoteAddr) + + go c.awaitDisconnect() + go c.awaitMessage() +} + +func MessagesHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + token := r.Header.Get("Token") + if _, ok := checkToken(token); !ok { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("invalid token provided:" + token)) + return + } + + s, err := getServer(r.URL.Query().Get("name")) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + return + } + + offset, err := strconv.Atoi(r.URL.Query().Get("offset")) + if err != nil { + offset = 0 + } + + count, err := strconv.Atoi(r.URL.Query().Get("count")) + if err != nil { + count = DefaultCacheLength + } + + if offset+count <= DefaultCacheLength { + s.wg.Wait() + jsonData, ok := encodeJson(s.messageCache.toMessageSlice(offset, count)) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"error serialising data"}`)) + return + } + w.Write(jsonData) + } else { + _, _, messages := s.loadMessages(offset, count) + jsonData, ok := encodeJson(messages) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Write(jsonData) + } +} + +func SendHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + if r.Method == "OPTIONS" { + return + } + if r.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + token := r.Header.Get("Token") + user, ok := checkToken(token) + if !ok { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error":"token expired"}`)) + return + } + + s, err := getServer(r.URL.Query().Get("name")) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + return + } + + var mess Message + decodeJson(&r.Body, &w, &mess) + + mess = Message{mess.Text, time.Now().UnixMilli(), user, s.Name} + s.wg.Add(1) + go s.storeMessage(&mess) + go s.broadcastMessage(&mess) +} + +// todo: only works for clients connected to servers running on this instance +func ClientsHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + type ClientData struct { + Username string `json:"username"` + } + clientsMap := make(map[string]struct{}) + + s, err := getServer(r.URL.Query().Get("name")) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + return + } + + for _, client := range s.clients { + clientsMap[client.username] = struct{}{} + } + + clientsSlice := make([]ClientData, 0) + for username := range clientsMap { + clientsSlice = append(clientsSlice, ClientData{username}) + } + jsonData, ok := encodeJson(clientsSlice) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Write(jsonData) +} + +func CreateServerHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + type ServerConfig struct { + ServerName string `json:"name"` + Owner string `json:"owner"` + } + + if r.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + token := r.Header.Get("Token") + user, ok := checkToken(token) + if !ok { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error":"token expired"}`)) + return + } + + var config ServerConfig + decodeJson(&r.Body, &w, &config) + + if config.ServerName == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"server name blank"}`)) + return + } + config.Owner = user + jsonData, ok := encodeJson(config) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + return + } + + res, err := http.Post(DbAddress+"/server", "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + fmt.Fprintln(os.Stderr, "error reaching db server for POST server", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + if res.StatusCode != http.StatusCreated { + fmt.Println("post server error:", res.StatusCode) + } + w.WriteHeader(res.StatusCode) +} + +func FileHandler(w http.ResponseWriter, r *http.Request) { + const MegaByte = 1024 * 1024 + + token := r.Header.Get("Token") + _, ok := checkToken(token) + if !ok { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error":"token expired"}`)) + return + } + + fileName := r.URL.Query().Get("name") + if fileName == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"filename required"}`)) + return + } + + switch r.Method { + case "OPTIONS": + return + case "POST": + { + r.Body = http.MaxBytesReader(w, r.Body, 5*MegaByte) + body, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusRequestEntityTooLarge) + fmt.Fprintln(w, `{"error":"file too large"}`) + return + } + + fileType := http.DetectContentType(body) + if !isAcceptedMIMEType(fileType) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, `{"error":"filetype %v not allowed"}`, fileType) + return + } + + res, err := uploadFile(&body, fileName, fileType) + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(res.StatusCode) + if res.StatusCode != http.StatusCreated { + w.Write([]byte(`{"error":"expected 201"}`)) + } + } + case "GET": + { + res, err := http.Get(DbAddress + "/file?name=" + fileName) + if err != nil { + fmt.Fprintln(os.Stderr, "could not reach db server for files", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + if res.StatusCode != http.StatusOK { + w.WriteHeader(res.StatusCode) + return + } + file, err := ioutil.ReadAll(res.Body) + if err != nil { + fmt.Fprintln(os.Stderr, "error reading response body", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", res.Header.Get("Content-Type")) + w.Write(file) + } + } +} diff --git a/internal/list.go b/internal/list.go new file mode 100755 index 0000000..bbee122 --- /dev/null +++ b/internal/list.go @@ -0,0 +1,59 @@ +package chat + +type List struct { + root *ListNode + length int + maxLength int +} + +type ListNode struct { + next *ListNode + data Message +} + +func newList(maxLength int) List { + return List{nil, 0, maxLength} +} + +func (l *List) add(node *ListNode) { + if l.root == nil { + l.root = node + } else { + var n *ListNode + for n = l.root; n.next != nil; n = n.next { + } + n.next = node + l.length++ + } + if l.length >= l.maxLength { + l.trimFirst() + l.length = l.maxLength + } +} + +func (l *List) trimFirst() { + l.root = l.root.next +} + +func (l *List) toMessageSlice(offset, count int) (messages []Message) { + messages = make([]Message, 0) + if l.root == nil { + return + } + for node, i := l.root, 0; node != nil; node, i = node.next, i+1 { + if i < offset { + continue + } + if i-offset >= count { + break + } + messages = append(messages, node.data) + } + return +} + +func (l *List) fromMessageSlice(messages []Message) { + for i := len(messages) - 1; i >= 0; i-- { // expected order is in reverse i.e. latest first + l.add(&ListNode{nil, messages[i]}) + } +} diff --git a/internal/server.go b/internal/server.go new file mode 100755 index 0000000..d6b7b1a --- /dev/null +++ b/internal/server.go @@ -0,0 +1,176 @@ +package chat + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "strconv" + "sync" + "time" +) + +const DefaultCacheLength = 20 + +type Server struct { + clients map[string]*Client + messageCache List + Name string `json:"name"` + wg sync.WaitGroup +} + +type Message struct { + Text string `json:"text"` + Time int64 `json:"time"` + User string `json:"user"` + ServerId string `json:"serverId"` +} + +var Servers map[string]*Server + +func StartServer(name string) (s Server) { + s = Server{make(map[string]*Client), newList(DefaultCacheLength), name, sync.WaitGroup{}} + Servers[name] = &s + s.putServer() + s.loadInitMessages() + return +} + +func getServer(name string) (*Server, error) { + if name == "" { + name = "root" + } + if s, ok := Servers[name]; ok { + return s, nil + } + + _, exists := retrieveServerFromDb(name) + if !exists { + return nil, fmt.Errorf("server does not exist") + } + + server := StartServer(name) + return &server, nil +} + +func retrieveServerFromDb(name string) (s *Server, found bool) { + res, err := http.Get(DbAddress + "/server?name=" + name) + if err != nil { + fmt.Fprintln(os.Stderr, "could not reach db server", err) + return + } + if res.StatusCode != http.StatusOK { + fmt.Println("server", name, "not found on db") + return + } + if err := json.NewDecoder(res.Body).Decode(&s); err != nil { + fmt.Fprintln(os.Stderr, "error retrieving server details from db", err) + return + } + return s, true +} + +func (s *Server) reportClients() (count int) { + for range s.clients { + count++ + } + return +} + +func (s *Server) sendMessage(message string) { + mess := Message{message, time.Now().UnixMilli(), s.Name, ""} + wm := WrappedMessage{"server", mess} + for _, client := range s.clients { + client.receivedMessages <- wm + } +} + +func (s *Server) broadcastMessage(mess *Message) { + wm := WrappedMessage{"", *mess} + for _, client := range s.clients { + client.receivedMessages <- wm + } +} + +func (s *Server) storeMessage(mess *Message) { + s.messageCache.add(&ListNode{nil, *mess}) + s.wg.Done() + jsonData, ok := encodeJson(mess) + if !ok { + return + } + + if res, err := http.Post(DbAddress+"/message?name="+s.Name, "application/json", bytes.NewBuffer(jsonData)); err != nil { + fmt.Fprintln(os.Stderr, "db server post error", err.Error()) + } else { + body, err := ioutil.ReadAll(res.Body) + if err != nil { + fmt.Fprintln(os.Stderr, "error reading response body:", err) + } + if res.StatusCode != http.StatusOK { + fmt.Println("could not store message", string(body)) + } + } +} + +func (s *Server) loadInitMessages() { + status, _, messages := s.loadMessages(0, DefaultCacheLength) + if status != http.StatusOK { + fmt.Fprintln(os.Stderr, "Could not load initial messages from db") + return + } + s.messageCache.fromMessageSlice(messages) +} + +func (s *Server) loadMessages(offset, count int) (status int, message string, messages []Message) { + if res, err := http.Get(DbAddress + "/message?name=" + s.Name + "&offset=" + strconv.Itoa(offset) + "&count=" + strconv.Itoa(count)); err != nil { + fmt.Fprintln(os.Stderr, "db server get messages error", err.Error()) + } else { + body, err := ioutil.ReadAll(res.Body) + if err != nil { + fmt.Fprintln(os.Stderr, "error reading response body:", err) + } + + if res.StatusCode != http.StatusOK { + fmt.Println("load messages error", string(body)) + return res.StatusCode, string(body), nil + } else { + messages = make([]Message, 0) + if err := json.Unmarshal(body, &messages); err != nil { + fmt.Fprintln(os.Stderr, "Error unmarshalling messages from db", err.Error()) + return http.StatusInternalServerError, "Error unmarshalling messages from db " + err.Error(), nil + } + } + } + return http.StatusOK, "", messages +} + +func (s *Server) putServer() { + jsonData, ok := encodeJson(&s) + if !ok { + fmt.Fprintln(os.Stderr, "error encoding server") + return + } + res, err := http.Post(DbAddress+"/server", "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + fmt.Fprintln(os.Stderr, "POST server to db error:", err.Error()) + return + } + if res.StatusCode == http.StatusOK { + fmt.Println("server", s.Name, "already in db") + var temp Server + if err := json.NewDecoder(res.Body).Decode(&temp); err != nil { + fmt.Println("error updating local server details with those from db") + return + } + //TODO: update details which get pulled from db rather than on creation + return + } + if res.StatusCode == http.StatusCreated { + fmt.Println("server", s.Name, "added to db") + return + } + +} diff --git a/test/server_test.go b/test/server_test.go new file mode 100755 index 0000000..5895b63 --- /dev/null +++ b/test/server_test.go @@ -0,0 +1,45 @@ +package testing + +import ( + "bytes" + "net/http" + "net/http/cookiejar" + "testing" + "time" + + chat "git.wens.org.uk/chat/internal" + "github.com/golang-jwt/jwt" +) + +func createJWT(username string) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "exp": time.Now().Add(5 * time.Minute).Unix(), + "username": username, + }) + + tokenString, err := token.SignedString([]byte(chat.JwtSecret)) + return tokenString, err +} + +func TestAddServer(t *testing.T) { + token, err := createJWT("testersson") + if err != nil { + t.Fatalf("could not create jwt: %v", err) + } + + client := &http.Client{Jar: &cookiejar.Jar{}} + + req, err := http.NewRequest("POST", "http://"+chat.Address+"/createserver", bytes.NewBuffer([]byte(`{"name":"newserver"}`))) + if err != nil { + t.Fatalf("could not create POST request: %v", err) + } + req.Header.Set("Token", token) + res, err := client.Do(req) + if err != nil { + t.Fatalf("could not reach chat server: %v", err) + } + if res.StatusCode != http.StatusCreated { + t.Errorf("unexpected return code from add server. Got: %v and expected %v\n", res.StatusCode, http.StatusCreated) + } + +}