master
george 3 years ago
commit da2d17d81c

@ -0,0 +1,13 @@
FROM golang:latest
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY ./cmd ./cmd
COPY ./internal ./internal
RUN go build -o /chat ./cmd
FROM gcr.io/distroless/base-debian10
WORKDIR /
COPY --from=0 /chat /chat
ENTRYPOINT ["/chat"]
#docker run --network="host" -e <SECRET> chat #127.0.0.1:8000

@ -0,0 +1,51 @@
package main
import (
"flag"
"fmt"
"net/http"
"os"
"strings"
chat "git.wens.org.uk/chat/internal"
)
func main() {
parseArgs()
chat.Servers = make(map[string]*chat.Server)
chat.StartServer("root")
chat.SetUpHandlers()
fmt.Println("chat server listening on", chat.Address)
if err := http.ListenAndServe(chat.Address, nil); err != nil {
fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
}
func parseArgs() {
db := flag.String("d", chat.DbAddress, "address for db server")
flag.Parse()
if chat.DbAddress != *db {
if !strings.HasPrefix(*db, "http://") {
*db = "http://" + *db
}
chat.DbAddress = *db
}
args := flag.Args()
if len(args) == 1 {
chat.Address = args[0]
}
if len(args) > 1 {
fmt.Fprintln(os.Stderr, "too many arguments. usage: <command> [options] [address]")
flag.PrintDefaults()
os.Exit(1)
}
if chat.JwtSecret == "" {
fmt.Println("Warning: JWT secret is empty string")
}
}

@ -0,0 +1,9 @@
module git.wens.org.uk/chat
go 1.18
require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
)

@ -0,0 +1,6 @@
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=

@ -0,0 +1,107 @@
package chat
// chat server funcs which aren't specific to the server struct
import (
"bytes"
"crypto/md5"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"github.com/golang-jwt/jwt"
)
var Address = "127.0.0.1:8000"
var JwtSecret = os.Getenv("secret")
func checkToken(token string) (username string, ok bool) {
type ParsedToken struct {
Username string
jwt.StandardClaims
}
var t ParsedToken
jwt, err := jwt.ParseWithClaims(token, &t, func(t *jwt.Token) (interface{}, error) {
return []byte(JwtSecret), nil
})
if err != nil {
return "", false
}
if userToken, ok := jwt.Claims.(*ParsedToken); ok && jwt.Valid {
return userToken.Username, true
} else {
fmt.Println(err)
return "", false
}
}
func encodeJson(data interface{}) ([]byte, bool) {
if jsonData, err := json.Marshal(data); err != nil {
fmt.Fprintln(os.Stderr, "Could not marshal json data", data)
return nil, false
} else {
return jsonData, true
}
}
func decodeJson(body *io.ReadCloser, w *http.ResponseWriter, message interface{}) {
if err := json.NewDecoder(*body).Decode(&message); err != nil {
fmt.Fprintln(os.Stderr, "Error decoding", err.Error())
if w != nil && (*w) != nil {
(*w).WriteHeader(http.StatusBadRequest)
(*w).Write([]byte(`{"error":"bad message"}`))
}
}
}
func dbHandleError(body *io.ReadCloser, w *http.ResponseWriter) (message string) {
type ErrorMessage struct {
Error string `json:"error"`
}
var error ErrorMessage
decodeJson(body, w, &error)
return error.Error
}
func isAcceptedMIMEType(mime string) bool {
switch mime {
case "image/bmp",
"image/gif",
"image/jpeg",
"image/png",
"audio/mpeg",
"video/webm":
return true
}
return false
}
func uploadFile(fileContents *[]byte, fileName, mimeType string) (res *http.Response, err error) {
type FileObject struct {
FileName string `json:"filename"`
FileType string `json:"filetype"`
Contents []byte `json:"contents"`
Sum [16]byte `json:"sum"`
}
file := FileObject{
fileName, mimeType, *fileContents, md5.Sum(*fileContents),
}
jsonData, ok := encodeJson(&file)
if !ok {
fmt.Fprintln(os.Stderr, "error encoding file data")
return
}
res, err = http.Post(DbAddress+"/file", "application/json", bytes.NewBuffer(jsonData))
if err != nil {
fmt.Fprintln(os.Stderr, "file POST error:", err)
}
return
}

@ -0,0 +1,94 @@
package chat
import (
"fmt"
"os"
"time"
"github.com/gorilla/websocket"
)
type User struct {
Username string `json:"username"`
Token string `json:"token"`
}
type Client struct {
conn *websocket.Conn
server *Server
receivedMessages chan (WrappedMessage)
username string
id string
}
type WrappedMessage struct {
DataType string `json:"datatype"`
Data Message `json:"data"`
}
var deadline = time.Second * 60
func (c *Client) disconnect() {
dcString := fmt.Sprint("server "+c.server.Name+" ended connection with ", c.username)
delete(c.server.clients, c.id)
c.server.sendMessage(dcString)
fmt.Println(dcString)
fmt.Println(c.server.reportClients(), "clients connected to "+c.server.Name)
c.conn.WriteControl(websocket.CloseGoingAway, []byte("reconnected"), time.Now().Add(deadline))
c.conn.Close()
}
func (c *Client) pongHandler(string) error {
return c.conn.SetReadDeadline(time.Now().Add(deadline))
}
func (c *Client) awaitDisconnect() {
c.conn.SetReadDeadline(time.Now().Add(deadline))
c.conn.SetPongHandler(c.pongHandler)
for {
_, _, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
fmt.Fprintln(os.Stderr, "Unexpected socket close error:", err.Error())
}
c.disconnect()
return
}
}
}
func (c *Client) awaitMessage() {
ticker := time.NewTicker(deadline / 2)
for {
select {
case mess := <-c.receivedMessages:
{
jsonData, ok := encodeJson(mess)
if !ok {
continue
}
if !c.sendMessage(websocket.TextMessage, jsonData) {
return
}
}
case <-ticker.C:
{
if !c.sendMessage(websocket.PingMessage, nil) {
return
}
}
}
}
}
func (c *Client) sendMessage(messageType int, data []byte) bool {
if err := c.conn.WriteMessage(messageType, data); err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
fmt.Fprintln(os.Stderr, "Unexpected socket close error:", err.Error())
c.disconnect()
}
return false
}
return true
}

@ -0,0 +1,305 @@
package chat
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"os"
"regexp"
"strconv"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
var DbAddress = "http://127.0.0.1:8002"
func SetUpHandlers() {
http.HandleFunc("/messages", MessagesHandler)
http.HandleFunc("/send", SendHandler)
http.HandleFunc("/connect", ConnectHandler)
http.HandleFunc("/clients", ClientsHandler)
http.HandleFunc("/createserver", CreateServerHandler)
http.HandleFunc("/file", FileHandler)
upgrader.CheckOrigin = func(r *http.Request) bool {
if true { // temp
return true
}
origin := r.Header.Get("Origin")
reg := regexp.MustCompile(".*:8000")
return reg.MatchString(origin)
}
}
func ConnectHandler(w http.ResponseWriter, r *http.Request) {
user, ok := checkToken(r.Header.Get("Token"))
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}
con, err := upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Fprintln(os.Stderr, err.Error())
return
}
s, err := getServer(r.URL.Query().Get("name"))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintln(w, err.Error())
return
}
c := Client{con, s, make(chan WrappedMessage), user, uuid.NewString()}
s.clients[c.id] = &c
fmt.Println(c.server.reportClients(), "clients connected to "+c.server.Name+" - latest", c.username, "from", r.RemoteAddr)
go c.awaitDisconnect()
go c.awaitMessage()
}
func MessagesHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
token := r.Header.Get("Token")
if _, ok := checkToken(token); !ok {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("invalid token provided:" + token))
return
}
s, err := getServer(r.URL.Query().Get("name"))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
offset, err := strconv.Atoi(r.URL.Query().Get("offset"))
if err != nil {
offset = 0
}
count, err := strconv.Atoi(r.URL.Query().Get("count"))
if err != nil {
count = DefaultCacheLength
}
if offset+count <= DefaultCacheLength {
s.wg.Wait()
jsonData, ok := encodeJson(s.messageCache.toMessageSlice(offset, count))
if !ok {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"error serialising data"}`))
return
}
w.Write(jsonData)
} else {
_, _, messages := s.loadMessages(offset, count)
jsonData, ok := encodeJson(messages)
if !ok {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Write(jsonData)
}
}
func SendHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == "OPTIONS" {
return
}
if r.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
token := r.Header.Get("Token")
user, ok := checkToken(token)
if !ok {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"token expired"}`))
return
}
s, err := getServer(r.URL.Query().Get("name"))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
var mess Message
decodeJson(&r.Body, &w, &mess)
mess = Message{mess.Text, time.Now().UnixMilli(), user, s.Name}
s.wg.Add(1)
go s.storeMessage(&mess)
go s.broadcastMessage(&mess)
}
// todo: only works for clients connected to servers running on this instance
func ClientsHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
type ClientData struct {
Username string `json:"username"`
}
clientsMap := make(map[string]struct{})
s, err := getServer(r.URL.Query().Get("name"))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
for _, client := range s.clients {
clientsMap[client.username] = struct{}{}
}
clientsSlice := make([]ClientData, 0)
for username := range clientsMap {
clientsSlice = append(clientsSlice, ClientData{username})
}
jsonData, ok := encodeJson(clientsSlice)
if !ok {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Write(jsonData)
}
func CreateServerHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
type ServerConfig struct {
ServerName string `json:"name"`
Owner string `json:"owner"`
}
if r.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
token := r.Header.Get("Token")
user, ok := checkToken(token)
if !ok {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"token expired"}`))
return
}
var config ServerConfig
decodeJson(&r.Body, &w, &config)
if config.ServerName == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"server name blank"}`))
return
}
config.Owner = user
jsonData, ok := encodeJson(config)
if !ok {
w.WriteHeader(http.StatusInternalServerError)
return
}
res, err := http.Post(DbAddress+"/server", "application/json", bytes.NewBuffer(jsonData))
if err != nil {
fmt.Fprintln(os.Stderr, "error reaching db server for POST server", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if res.StatusCode != http.StatusCreated {
fmt.Println("post server error:", res.StatusCode)
}
w.WriteHeader(res.StatusCode)
}
func FileHandler(w http.ResponseWriter, r *http.Request) {
const MegaByte = 1024 * 1024
token := r.Header.Get("Token")
_, ok := checkToken(token)
if !ok {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"token expired"}`))
return
}
fileName := r.URL.Query().Get("name")
if fileName == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"filename required"}`))
return
}
switch r.Method {
case "OPTIONS":
return
case "POST":
{
r.Body = http.MaxBytesReader(w, r.Body, 5*MegaByte)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusRequestEntityTooLarge)
fmt.Fprintln(w, `{"error":"file too large"}`)
return
}
fileType := http.DetectContentType(body)
if !isAcceptedMIMEType(fileType) {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, `{"error":"filetype %v not allowed"}`, fileType)
return
}
res, err := uploadFile(&body, fileName, fileType)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(res.StatusCode)
if res.StatusCode != http.StatusCreated {
w.Write([]byte(`{"error":"expected 201"}`))
}
}
case "GET":
{
res, err := http.Get(DbAddress + "/file?name=" + fileName)
if err != nil {
fmt.Fprintln(os.Stderr, "could not reach db server for files", err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
if res.StatusCode != http.StatusOK {
w.WriteHeader(res.StatusCode)
return
}
file, err := ioutil.ReadAll(res.Body)
if err != nil {
fmt.Fprintln(os.Stderr, "error reading response body", err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", res.Header.Get("Content-Type"))
w.Write(file)
}
}
}

@ -0,0 +1,59 @@
package chat
type List struct {
root *ListNode
length int
maxLength int
}
type ListNode struct {
next *ListNode
data Message
}
func newList(maxLength int) List {
return List{nil, 0, maxLength}
}
func (l *List) add(node *ListNode) {
if l.root == nil {
l.root = node
} else {
var n *ListNode
for n = l.root; n.next != nil; n = n.next {
}
n.next = node
l.length++
}
if l.length >= l.maxLength {
l.trimFirst()
l.length = l.maxLength
}
}
func (l *List) trimFirst() {
l.root = l.root.next
}
func (l *List) toMessageSlice(offset, count int) (messages []Message) {
messages = make([]Message, 0)
if l.root == nil {
return
}
for node, i := l.root, 0; node != nil; node, i = node.next, i+1 {
if i < offset {
continue
}
if i-offset >= count {
break
}
messages = append(messages, node.data)
}
return
}
func (l *List) fromMessageSlice(messages []Message) {
for i := len(messages) - 1; i >= 0; i-- { // expected order is in reverse i.e. latest first
l.add(&ListNode{nil, messages[i]})
}
}

@ -0,0 +1,176 @@
package chat
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"strconv"
"sync"
"time"
)
const DefaultCacheLength = 20
type Server struct {
clients map[string]*Client
messageCache List
Name string `json:"name"`
wg sync.WaitGroup
}
type Message struct {
Text string `json:"text"`
Time int64 `json:"time"`
User string `json:"user"`
ServerId string `json:"serverId"`
}
var Servers map[string]*Server
func StartServer(name string) (s Server) {
s = Server{make(map[string]*Client), newList(DefaultCacheLength), name, sync.WaitGroup{}}
Servers[name] = &s
s.putServer()
s.loadInitMessages()
return
}
func getServer(name string) (*Server, error) {
if name == "" {
name = "root"
}
if s, ok := Servers[name]; ok {
return s, nil
}
_, exists := retrieveServerFromDb(name)
if !exists {
return nil, fmt.Errorf("server does not exist")
}
server := StartServer(name)
return &server, nil
}
func retrieveServerFromDb(name string) (s *Server, found bool) {
res, err := http.Get(DbAddress + "/server?name=" + name)
if err != nil {
fmt.Fprintln(os.Stderr, "could not reach db server", err)
return
}
if res.StatusCode != http.StatusOK {
fmt.Println("server", name, "not found on db")
return
}
if err := json.NewDecoder(res.Body).Decode(&s); err != nil {
fmt.Fprintln(os.Stderr, "error retrieving server details from db", err)
return
}
return s, true
}
func (s *Server) reportClients() (count int) {
for range s.clients {
count++
}
return
}
func (s *Server) sendMessage(message string) {
mess := Message{message, time.Now().UnixMilli(), s.Name, ""}
wm := WrappedMessage{"server", mess}
for _, client := range s.clients {
client.receivedMessages <- wm
}
}
func (s *Server) broadcastMessage(mess *Message) {
wm := WrappedMessage{"", *mess}
for _, client := range s.clients {
client.receivedMessages <- wm
}
}
func (s *Server) storeMessage(mess *Message) {
s.messageCache.add(&ListNode{nil, *mess})
s.wg.Done()
jsonData, ok := encodeJson(mess)
if !ok {
return
}
if res, err := http.Post(DbAddress+"/message?name="+s.Name, "application/json", bytes.NewBuffer(jsonData)); err != nil {
fmt.Fprintln(os.Stderr, "db server post error", err.Error())
} else {
body, err := ioutil.ReadAll(res.Body)
if err != nil {
fmt.Fprintln(os.Stderr, "error reading response body:", err)
}
if res.StatusCode != http.StatusOK {
fmt.Println("could not store message", string(body))
}
}
}
func (s *Server) loadInitMessages() {
status, _, messages := s.loadMessages(0, DefaultCacheLength)
if status != http.StatusOK {
fmt.Fprintln(os.Stderr, "Could not load initial messages from db")
return
}
s.messageCache.fromMessageSlice(messages)
}
func (s *Server) loadMessages(offset, count int) (status int, message string, messages []Message) {
if res, err := http.Get(DbAddress + "/message?name=" + s.Name + "&offset=" + strconv.Itoa(offset) + "&count=" + strconv.Itoa(count)); err != nil {
fmt.Fprintln(os.Stderr, "db server get messages error", err.Error())
} else {
body, err := ioutil.ReadAll(res.Body)
if err != nil {
fmt.Fprintln(os.Stderr, "error reading response body:", err)
}
if res.StatusCode != http.StatusOK {
fmt.Println("load messages error", string(body))
return res.StatusCode, string(body), nil
} else {
messages = make([]Message, 0)
if err := json.Unmarshal(body, &messages); err != nil {
fmt.Fprintln(os.Stderr, "Error unmarshalling messages from db", err.Error())
return http.StatusInternalServerError, "Error unmarshalling messages from db " + err.Error(), nil
}
}
}
return http.StatusOK, "", messages
}
func (s *Server) putServer() {
jsonData, ok := encodeJson(&s)
if !ok {
fmt.Fprintln(os.Stderr, "error encoding server")
return
}
res, err := http.Post(DbAddress+"/server", "application/json", bytes.NewBuffer(jsonData))
if err != nil {
fmt.Fprintln(os.Stderr, "POST server to db error:", err.Error())
return
}
if res.StatusCode == http.StatusOK {
fmt.Println("server", s.Name, "already in db")
var temp Server
if err := json.NewDecoder(res.Body).Decode(&temp); err != nil {
fmt.Println("error updating local server details with those from db")
return
}
//TODO: update details which get pulled from db rather than on creation
return
}
if res.StatusCode == http.StatusCreated {
fmt.Println("server", s.Name, "added to db")
return
}
}

@ -0,0 +1,45 @@
package testing
import (
"bytes"
"net/http"
"net/http/cookiejar"
"testing"
"time"
chat "git.wens.org.uk/chat/internal"
"github.com/golang-jwt/jwt"
)
func createJWT(username string) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"exp": time.Now().Add(5 * time.Minute).Unix(),
"username": username,
})
tokenString, err := token.SignedString([]byte(chat.JwtSecret))
return tokenString, err
}
func TestAddServer(t *testing.T) {
token, err := createJWT("testersson")
if err != nil {
t.Fatalf("could not create jwt: %v", err)
}
client := &http.Client{Jar: &cookiejar.Jar{}}
req, err := http.NewRequest("POST", "http://"+chat.Address+"/createserver", bytes.NewBuffer([]byte(`{"name":"newserver"}`)))
if err != nil {
t.Fatalf("could not create POST request: %v", err)
}
req.Header.Set("Token", token)
res, err := client.Do(req)
if err != nil {
t.Fatalf("could not reach chat server: %v", err)
}
if res.StatusCode != http.StatusCreated {
t.Errorf("unexpected return code from add server. Got: %v and expected %v\n", res.StatusCode, http.StatusCreated)
}
}
Loading…
Cancel
Save