init
						commit
						da2d17d81c
					
				| @ -0,0 +1,13 @@ | |||||||
|  | FROM golang:latest | ||||||
|  | WORKDIR /app | ||||||
|  | COPY go.mod go.sum ./ | ||||||
|  | RUN go mod download | ||||||
|  | COPY ./cmd ./cmd | ||||||
|  | COPY ./internal ./internal | ||||||
|  | RUN go build -o /chat ./cmd | ||||||
|  | 
 | ||||||
|  | FROM gcr.io/distroless/base-debian10 | ||||||
|  | WORKDIR / | ||||||
|  | COPY --from=0 /chat /chat | ||||||
|  | ENTRYPOINT ["/chat"] | ||||||
|  | #docker run --network="host" -e <SECRET> chat #127.0.0.1:8000 | ||||||
| @ -0,0 +1,51 @@ | |||||||
|  | package main | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"flag" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 	"os" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	chat "git.wens.org.uk/chat/internal" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func main() { | ||||||
|  |     parseArgs() | ||||||
|  | 
 | ||||||
|  | 	chat.Servers = make(map[string]*chat.Server) | ||||||
|  | 	chat.StartServer("root") | ||||||
|  | 	chat.SetUpHandlers() | ||||||
|  | 
 | ||||||
|  | 	fmt.Println("chat server listening on", chat.Address) | ||||||
|  | 	if err := http.ListenAndServe(chat.Address, nil); err != nil { | ||||||
|  | 		fmt.Fprintln(os.Stderr, err.Error()) | ||||||
|  | 		os.Exit(1) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func parseArgs() { | ||||||
|  |     db := flag.String("d", chat.DbAddress, "address for db server") | ||||||
|  |     flag.Parse() | ||||||
|  |     if chat.DbAddress != *db { | ||||||
|  |         if !strings.HasPrefix(*db, "http://") { | ||||||
|  |             *db = "http://" + *db | ||||||
|  |         } | ||||||
|  |        chat.DbAddress = *db  | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     args := flag.Args() | ||||||
|  |     if len(args) == 1 { | ||||||
|  |         chat.Address = args[0] | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if len(args) > 1 { | ||||||
|  |         fmt.Fprintln(os.Stderr, "too many arguments. usage: <command> [options] [address]") | ||||||
|  |         flag.PrintDefaults() | ||||||
|  |         os.Exit(1) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if chat.JwtSecret == "" { | ||||||
|  |         fmt.Println("Warning: JWT secret is empty string") | ||||||
|  |     } | ||||||
|  | } | ||||||
| @ -0,0 +1,9 @@ | |||||||
|  | module git.wens.org.uk/chat | ||||||
|  | 
 | ||||||
|  | go 1.18 | ||||||
|  | 
 | ||||||
|  | require ( | ||||||
|  | 	github.com/golang-jwt/jwt v3.2.2+incompatible | ||||||
|  | 	github.com/google/uuid v1.3.0 | ||||||
|  | 	github.com/gorilla/websocket v1.5.0 | ||||||
|  | ) | ||||||
| @ -0,0 +1,6 @@ | |||||||
|  | github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= | ||||||
|  | github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= | ||||||
|  | github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= | ||||||
|  | github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= | ||||||
|  | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= | ||||||
|  | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= | ||||||
| @ -0,0 +1,107 @@ | |||||||
|  | package chat | ||||||
|  | 
 | ||||||
|  | // chat server funcs which aren't specific to the server struct
 | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"crypto/md5" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"os" | ||||||
|  | 
 | ||||||
|  | 	"github.com/golang-jwt/jwt" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var Address = "127.0.0.1:8000" | ||||||
|  | var JwtSecret = os.Getenv("secret") | ||||||
|  | 
 | ||||||
|  | func checkToken(token string) (username string, ok bool) { | ||||||
|  | 	type ParsedToken struct { | ||||||
|  | 		Username string | ||||||
|  | 		jwt.StandardClaims | ||||||
|  | 	} | ||||||
|  | 	var t ParsedToken | ||||||
|  | 
 | ||||||
|  | 	jwt, err := jwt.ParseWithClaims(token, &t, func(t *jwt.Token) (interface{}, error) { | ||||||
|  | 		return []byte(JwtSecret), nil | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if userToken, ok := jwt.Claims.(*ParsedToken); ok && jwt.Valid { | ||||||
|  | 		return userToken.Username, true | ||||||
|  | 	} else { | ||||||
|  | 		fmt.Println(err) | ||||||
|  | 		return "", false | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func encodeJson(data interface{}) ([]byte, bool) { | ||||||
|  | 	if jsonData, err := json.Marshal(data); err != nil { | ||||||
|  | 		fmt.Fprintln(os.Stderr, "Could not marshal json data", data) | ||||||
|  | 		return nil, false | ||||||
|  | 	} else { | ||||||
|  | 		return jsonData, true | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func decodeJson(body *io.ReadCloser, w *http.ResponseWriter, message interface{}) { | ||||||
|  | 	if err := json.NewDecoder(*body).Decode(&message); err != nil { | ||||||
|  | 		fmt.Fprintln(os.Stderr, "Error decoding", err.Error()) | ||||||
|  | 		if w != nil && (*w) != nil { | ||||||
|  | 			(*w).WriteHeader(http.StatusBadRequest) | ||||||
|  | 			(*w).Write([]byte(`{"error":"bad message"}`)) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func dbHandleError(body *io.ReadCloser, w *http.ResponseWriter) (message string) { | ||||||
|  | 	type ErrorMessage struct { | ||||||
|  | 		Error string `json:"error"` | ||||||
|  | 	} | ||||||
|  | 	var error ErrorMessage | ||||||
|  | 	decodeJson(body, w, &error) | ||||||
|  | 	return error.Error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func isAcceptedMIMEType(mime string) bool { | ||||||
|  | 	switch mime { | ||||||
|  | 	case "image/bmp", | ||||||
|  | 		"image/gif", | ||||||
|  | 		"image/jpeg", | ||||||
|  | 		"image/png", | ||||||
|  | 		"audio/mpeg", | ||||||
|  | 		"video/webm": | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func uploadFile(fileContents *[]byte, fileName, mimeType string) (res *http.Response, err error) { | ||||||
|  | 	type FileObject struct { | ||||||
|  | 		FileName string   `json:"filename"` | ||||||
|  | 		FileType string   `json:"filetype"` | ||||||
|  | 		Contents []byte   `json:"contents"` | ||||||
|  | 		Sum      [16]byte `json:"sum"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	file := FileObject{ | ||||||
|  | 		fileName, mimeType, *fileContents, md5.Sum(*fileContents), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	jsonData, ok := encodeJson(&file) | ||||||
|  | 	if !ok { | ||||||
|  | 		fmt.Fprintln(os.Stderr, "error encoding file data") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	res, err = http.Post(DbAddress+"/file", "application/json", bytes.NewBuffer(jsonData)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Fprintln(os.Stderr, "file POST error:", err) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
| @ -0,0 +1,94 @@ | |||||||
|  | package chat | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"os" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/gorilla/websocket" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type User struct { | ||||||
|  | 	Username string `json:"username"` | ||||||
|  | 	Token    string `json:"token"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type Client struct { | ||||||
|  | 	conn             *websocket.Conn | ||||||
|  | 	server           *Server | ||||||
|  | 	receivedMessages chan (WrappedMessage) | ||||||
|  | 	username         string | ||||||
|  | 	id               string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type WrappedMessage struct { | ||||||
|  | 	DataType string  `json:"datatype"` | ||||||
|  | 	Data     Message `json:"data"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var deadline = time.Second * 60 | ||||||
|  | 
 | ||||||
|  | func (c *Client) disconnect() { | ||||||
|  | 	dcString := fmt.Sprint("server "+c.server.Name+" ended connection with ", c.username) | ||||||
|  | 	delete(c.server.clients, c.id) | ||||||
|  | 	c.server.sendMessage(dcString) | ||||||
|  | 	fmt.Println(dcString) | ||||||
|  | 	fmt.Println(c.server.reportClients(), "clients connected to "+c.server.Name) | ||||||
|  | 	c.conn.WriteControl(websocket.CloseGoingAway, []byte("reconnected"), time.Now().Add(deadline)) | ||||||
|  | 	c.conn.Close() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *Client) pongHandler(string) error { | ||||||
|  | 	return c.conn.SetReadDeadline(time.Now().Add(deadline)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *Client) awaitDisconnect() { | ||||||
|  | 	c.conn.SetReadDeadline(time.Now().Add(deadline)) | ||||||
|  | 	c.conn.SetPongHandler(c.pongHandler) | ||||||
|  | 	for { | ||||||
|  | 		_, _, err := c.conn.ReadMessage() | ||||||
|  | 		if err != nil { | ||||||
|  | 			if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { | ||||||
|  | 				fmt.Fprintln(os.Stderr, "Unexpected socket close error:", err.Error()) | ||||||
|  | 			} | ||||||
|  | 			c.disconnect() | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *Client) awaitMessage() { | ||||||
|  | 	ticker := time.NewTicker(deadline / 2) | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case mess := <-c.receivedMessages: | ||||||
|  | 			{ | ||||||
|  | 				jsonData, ok := encodeJson(mess) | ||||||
|  | 				if !ok { | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				if !c.sendMessage(websocket.TextMessage, jsonData) { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		case <-ticker.C: | ||||||
|  | 			{ | ||||||
|  | 				if !c.sendMessage(websocket.PingMessage, nil) { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *Client) sendMessage(messageType int, data []byte) bool { | ||||||
|  | 	if err := c.conn.WriteMessage(messageType, data); err != nil { | ||||||
|  | 		if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { | ||||||
|  | 			fmt.Fprintln(os.Stderr, "Unexpected socket close error:", err.Error()) | ||||||
|  | 			c.disconnect() | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
|  | } | ||||||
| @ -0,0 +1,305 @@ | |||||||
|  | package chat | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"net/http" | ||||||
|  | 	"os" | ||||||
|  | 	"regexp" | ||||||
|  | 	"strconv" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/google/uuid" | ||||||
|  | 	"github.com/gorilla/websocket" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var upgrader = websocket.Upgrader{ | ||||||
|  | 	ReadBufferSize:  1024, | ||||||
|  | 	WriteBufferSize: 1024, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var DbAddress = "http://127.0.0.1:8002" | ||||||
|  | 
 | ||||||
|  | func SetUpHandlers() { | ||||||
|  | 	http.HandleFunc("/messages", MessagesHandler) | ||||||
|  | 	http.HandleFunc("/send", SendHandler) | ||||||
|  | 	http.HandleFunc("/connect", ConnectHandler) | ||||||
|  | 	http.HandleFunc("/clients", ClientsHandler) | ||||||
|  | 	http.HandleFunc("/createserver", CreateServerHandler) | ||||||
|  | 	http.HandleFunc("/file", FileHandler) | ||||||
|  | 
 | ||||||
|  | 	upgrader.CheckOrigin = func(r *http.Request) bool { | ||||||
|  | 		if true { // temp
 | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 		origin := r.Header.Get("Origin") | ||||||
|  | 		reg := regexp.MustCompile(".*:8000") | ||||||
|  | 		return reg.MatchString(origin) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func ConnectHandler(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	user, ok := checkToken(r.Header.Get("Token")) | ||||||
|  | 	if !ok { | ||||||
|  | 		w.WriteHeader(http.StatusUnauthorized) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	con, err := upgrader.Upgrade(w, r, nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Fprintln(os.Stderr, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s, err := getServer(r.URL.Query().Get("name")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.WriteHeader(http.StatusBadRequest) | ||||||
|  | 		fmt.Fprintln(w, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	c := Client{con, s, make(chan WrappedMessage), user, uuid.NewString()} | ||||||
|  | 	s.clients[c.id] = &c | ||||||
|  | 
 | ||||||
|  | 	fmt.Println(c.server.reportClients(), "clients connected to "+c.server.Name+" - latest", c.username, "from", r.RemoteAddr) | ||||||
|  | 
 | ||||||
|  | 	go c.awaitDisconnect() | ||||||
|  | 	go c.awaitMessage() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func MessagesHandler(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	w.Header().Set("Content-Type", "application/json") | ||||||
|  | 
 | ||||||
|  | 	token := r.Header.Get("Token") | ||||||
|  | 	if _, ok := checkToken(token); !ok { | ||||||
|  | 		w.WriteHeader(http.StatusUnauthorized) | ||||||
|  | 		w.Write([]byte("invalid token provided:" + token)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s, err := getServer(r.URL.Query().Get("name")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.WriteHeader(http.StatusBadRequest) | ||||||
|  | 		w.Write([]byte(err.Error())) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	offset, err := strconv.Atoi(r.URL.Query().Get("offset")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		offset = 0 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	count, err := strconv.Atoi(r.URL.Query().Get("count")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		count = DefaultCacheLength | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if offset+count <= DefaultCacheLength { | ||||||
|  | 		s.wg.Wait() | ||||||
|  | 		jsonData, ok := encodeJson(s.messageCache.toMessageSlice(offset, count)) | ||||||
|  | 		if !ok { | ||||||
|  | 			w.WriteHeader(http.StatusInternalServerError) | ||||||
|  | 			w.Write([]byte(`{"error":"error serialising data"}`)) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		w.Write(jsonData) | ||||||
|  | 	} else { | ||||||
|  | 		_, _, messages := s.loadMessages(offset, count) | ||||||
|  | 		jsonData, ok := encodeJson(messages) | ||||||
|  | 		if !ok { | ||||||
|  | 			w.WriteHeader(http.StatusInternalServerError) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		w.Write(jsonData) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func SendHandler(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	w.Header().Set("Content-Type", "application/json") | ||||||
|  | 
 | ||||||
|  | 	if r.Method == "OPTIONS" { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if r.Method != "POST" { | ||||||
|  | 		w.WriteHeader(http.StatusMethodNotAllowed) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	token := r.Header.Get("Token") | ||||||
|  | 	user, ok := checkToken(token) | ||||||
|  | 	if !ok { | ||||||
|  | 		w.WriteHeader(http.StatusUnauthorized) | ||||||
|  | 		w.Write([]byte(`{"error":"token expired"}`)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s, err := getServer(r.URL.Query().Get("name")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.WriteHeader(http.StatusBadRequest) | ||||||
|  | 		w.Write([]byte(err.Error())) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var mess Message | ||||||
|  | 	decodeJson(&r.Body, &w, &mess) | ||||||
|  | 
 | ||||||
|  | 	mess = Message{mess.Text, time.Now().UnixMilli(), user, s.Name} | ||||||
|  | 	s.wg.Add(1) | ||||||
|  | 	go s.storeMessage(&mess) | ||||||
|  | 	go s.broadcastMessage(&mess) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // todo: only works for clients connected to servers running on this instance
 | ||||||
|  | func ClientsHandler(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	w.Header().Set("Content-Type", "application/json") | ||||||
|  | 
 | ||||||
|  | 	type ClientData struct { | ||||||
|  | 		Username string `json:"username"` | ||||||
|  | 	} | ||||||
|  | 	clientsMap := make(map[string]struct{}) | ||||||
|  | 
 | ||||||
|  | 	s, err := getServer(r.URL.Query().Get("name")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.WriteHeader(http.StatusBadRequest) | ||||||
|  | 		w.Write([]byte(err.Error())) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, client := range s.clients { | ||||||
|  | 		clientsMap[client.username] = struct{}{} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	clientsSlice := make([]ClientData, 0) | ||||||
|  | 	for username := range clientsMap { | ||||||
|  | 		clientsSlice = append(clientsSlice, ClientData{username}) | ||||||
|  | 	} | ||||||
|  | 	jsonData, ok := encodeJson(clientsSlice) | ||||||
|  | 	if !ok { | ||||||
|  | 		w.WriteHeader(http.StatusInternalServerError) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	w.Write(jsonData) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func CreateServerHandler(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	w.Header().Set("Content-Type", "application/json") | ||||||
|  | 
 | ||||||
|  | 	type ServerConfig struct { | ||||||
|  | 		ServerName string `json:"name"` | ||||||
|  | 		Owner      string `json:"owner"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if r.Method != "POST" { | ||||||
|  | 		w.WriteHeader(http.StatusMethodNotAllowed) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	token := r.Header.Get("Token") | ||||||
|  | 	user, ok := checkToken(token) | ||||||
|  | 	if !ok { | ||||||
|  | 		w.WriteHeader(http.StatusUnauthorized) | ||||||
|  | 		w.Write([]byte(`{"error":"token expired"}`)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var config ServerConfig | ||||||
|  | 	decodeJson(&r.Body, &w, &config) | ||||||
|  | 
 | ||||||
|  | 	if config.ServerName == "" { | ||||||
|  | 		w.WriteHeader(http.StatusBadRequest) | ||||||
|  | 		w.Write([]byte(`{"error":"server name blank"}`)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	config.Owner = user | ||||||
|  | 	jsonData, ok := encodeJson(config) | ||||||
|  | 	if !ok { | ||||||
|  | 		w.WriteHeader(http.StatusInternalServerError) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	res, err := http.Post(DbAddress+"/server", "application/json", bytes.NewBuffer(jsonData)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Fprintln(os.Stderr, "error reaching db server for POST server", err) | ||||||
|  | 		w.WriteHeader(http.StatusInternalServerError) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if res.StatusCode != http.StatusCreated { | ||||||
|  | 		fmt.Println("post server error:", res.StatusCode) | ||||||
|  | 	} | ||||||
|  | 	w.WriteHeader(res.StatusCode) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func FileHandler(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	const MegaByte = 1024 * 1024 | ||||||
|  | 
 | ||||||
|  | 	token := r.Header.Get("Token") | ||||||
|  | 	_, ok := checkToken(token) | ||||||
|  | 	if !ok { | ||||||
|  | 		w.WriteHeader(http.StatusUnauthorized) | ||||||
|  | 		w.Write([]byte(`{"error":"token expired"}`)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	fileName := r.URL.Query().Get("name") | ||||||
|  | 	if fileName == "" { | ||||||
|  | 		w.WriteHeader(http.StatusBadRequest) | ||||||
|  | 		w.Write([]byte(`{"error":"filename required"}`)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	switch r.Method { | ||||||
|  | 	case "OPTIONS": | ||||||
|  | 		return | ||||||
|  | 	case "POST": | ||||||
|  | 		{ | ||||||
|  | 			r.Body = http.MaxBytesReader(w, r.Body, 5*MegaByte) | ||||||
|  | 			body, err := ioutil.ReadAll(r.Body) | ||||||
|  | 			if err != nil { | ||||||
|  | 				w.WriteHeader(http.StatusRequestEntityTooLarge) | ||||||
|  | 				fmt.Fprintln(w, `{"error":"file too large"}`) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			fileType := http.DetectContentType(body) | ||||||
|  | 			if !isAcceptedMIMEType(fileType) { | ||||||
|  | 				w.WriteHeader(http.StatusBadRequest) | ||||||
|  | 				fmt.Fprintf(w, `{"error":"filetype %v not allowed"}`, fileType) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			res, err := uploadFile(&body, fileName, fileType) | ||||||
|  | 
 | ||||||
|  | 			if err != nil { | ||||||
|  | 				w.WriteHeader(http.StatusInternalServerError) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			w.WriteHeader(res.StatusCode) | ||||||
|  | 			if res.StatusCode != http.StatusCreated { | ||||||
|  | 				w.Write([]byte(`{"error":"expected 201"}`)) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	case "GET": | ||||||
|  | 		{ | ||||||
|  | 			res, err := http.Get(DbAddress + "/file?name=" + fileName) | ||||||
|  | 			if err != nil { | ||||||
|  | 				fmt.Fprintln(os.Stderr, "could not reach db server for files", err.Error()) | ||||||
|  | 				w.WriteHeader(http.StatusInternalServerError) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			if res.StatusCode != http.StatusOK { | ||||||
|  | 				w.WriteHeader(res.StatusCode) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			file, err := ioutil.ReadAll(res.Body) | ||||||
|  | 			if err != nil { | ||||||
|  | 				fmt.Fprintln(os.Stderr, "error reading response body", err.Error()) | ||||||
|  | 				w.WriteHeader(http.StatusInternalServerError) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			w.Header().Set("Content-Type", res.Header.Get("Content-Type")) | ||||||
|  | 			w.Write(file) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -0,0 +1,59 @@ | |||||||
|  | package chat | ||||||
|  | 
 | ||||||
|  | type List struct { | ||||||
|  | 	root      *ListNode | ||||||
|  | 	length    int | ||||||
|  | 	maxLength int | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type ListNode struct { | ||||||
|  | 	next *ListNode | ||||||
|  | 	data Message | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newList(maxLength int) List { | ||||||
|  | 	return List{nil, 0, maxLength} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *List) add(node *ListNode) { | ||||||
|  | 	if l.root == nil { | ||||||
|  | 		l.root = node | ||||||
|  | 	} else { | ||||||
|  | 		var n *ListNode | ||||||
|  | 		for n = l.root; n.next != nil; n = n.next { | ||||||
|  | 		} | ||||||
|  | 		n.next = node | ||||||
|  | 		l.length++ | ||||||
|  | 	} | ||||||
|  | 	if l.length >= l.maxLength { | ||||||
|  | 		l.trimFirst() | ||||||
|  | 		l.length = l.maxLength | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *List) trimFirst() { | ||||||
|  | 	l.root = l.root.next | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *List) toMessageSlice(offset, count int) (messages []Message) { | ||||||
|  | 	messages = make([]Message, 0) | ||||||
|  | 	if l.root == nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	for node, i := l.root, 0; node != nil; node, i = node.next, i+1 { | ||||||
|  | 		if i < offset { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		if i-offset >= count { | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 		messages = append(messages, node.data) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *List) fromMessageSlice(messages []Message) { | ||||||
|  | 	for i := len(messages) - 1; i >= 0; i-- { // expected order is in reverse i.e. latest first
 | ||||||
|  | 		l.add(&ListNode{nil, messages[i]}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -0,0 +1,176 @@ | |||||||
|  | package chat | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"net/http" | ||||||
|  | 	"os" | ||||||
|  | 	"strconv" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const DefaultCacheLength = 20 | ||||||
|  | 
 | ||||||
|  | type Server struct { | ||||||
|  | 	clients      map[string]*Client | ||||||
|  | 	messageCache List | ||||||
|  | 	Name         string `json:"name"` | ||||||
|  | 	wg           sync.WaitGroup | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type Message struct { | ||||||
|  | 	Text     string `json:"text"` | ||||||
|  | 	Time     int64  `json:"time"` | ||||||
|  | 	User     string `json:"user"` | ||||||
|  | 	ServerId string `json:"serverId"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var Servers map[string]*Server | ||||||
|  | 
 | ||||||
|  | func StartServer(name string) (s Server) { | ||||||
|  | 	s = Server{make(map[string]*Client), newList(DefaultCacheLength), name, sync.WaitGroup{}} | ||||||
|  | 	Servers[name] = &s | ||||||
|  | 	s.putServer() | ||||||
|  | 	s.loadInitMessages() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func getServer(name string) (*Server, error) { | ||||||
|  | 	if name == "" { | ||||||
|  | 		name = "root" | ||||||
|  | 	} | ||||||
|  | 	if s, ok := Servers[name]; ok { | ||||||
|  | 		return s, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	_, exists := retrieveServerFromDb(name) | ||||||
|  | 	if !exists { | ||||||
|  | 		return nil, fmt.Errorf("server does not exist") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	server := StartServer(name) | ||||||
|  | 	return &server, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func retrieveServerFromDb(name string) (s *Server, found bool) { | ||||||
|  | 	res, err := http.Get(DbAddress + "/server?name=" + name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Fprintln(os.Stderr, "could not reach db server", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if res.StatusCode != http.StatusOK { | ||||||
|  | 		fmt.Println("server", name, "not found on db") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if err := json.NewDecoder(res.Body).Decode(&s); err != nil { | ||||||
|  | 		fmt.Fprintln(os.Stderr, "error retrieving server details from db", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	return s, true | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) reportClients() (count int) { | ||||||
|  | 	for range s.clients { | ||||||
|  | 		count++ | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) sendMessage(message string) { | ||||||
|  | 	mess := Message{message, time.Now().UnixMilli(), s.Name, ""} | ||||||
|  | 	wm := WrappedMessage{"server", mess} | ||||||
|  | 	for _, client := range s.clients { | ||||||
|  | 		client.receivedMessages <- wm | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) broadcastMessage(mess *Message) { | ||||||
|  | 	wm := WrappedMessage{"", *mess} | ||||||
|  | 	for _, client := range s.clients { | ||||||
|  | 		client.receivedMessages <- wm | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) storeMessage(mess *Message) { | ||||||
|  | 	s.messageCache.add(&ListNode{nil, *mess}) | ||||||
|  | 	s.wg.Done() | ||||||
|  | 	jsonData, ok := encodeJson(mess) | ||||||
|  | 	if !ok { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if res, err := http.Post(DbAddress+"/message?name="+s.Name, "application/json", bytes.NewBuffer(jsonData)); err != nil { | ||||||
|  | 		fmt.Fprintln(os.Stderr, "db server post error", err.Error()) | ||||||
|  | 	} else { | ||||||
|  | 		body, err := ioutil.ReadAll(res.Body) | ||||||
|  | 		if err != nil { | ||||||
|  | 			fmt.Fprintln(os.Stderr, "error reading response body:", err) | ||||||
|  | 		} | ||||||
|  | 		if res.StatusCode != http.StatusOK { | ||||||
|  | 			fmt.Println("could not store message", string(body)) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) loadInitMessages() { | ||||||
|  | 	status, _, messages := s.loadMessages(0, DefaultCacheLength) | ||||||
|  | 	if status != http.StatusOK { | ||||||
|  | 		fmt.Fprintln(os.Stderr, "Could not load initial messages from db") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	s.messageCache.fromMessageSlice(messages) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) loadMessages(offset, count int) (status int, message string, messages []Message) { | ||||||
|  | 	if res, err := http.Get(DbAddress + "/message?name=" + s.Name + "&offset=" + strconv.Itoa(offset) + "&count=" + strconv.Itoa(count)); err != nil { | ||||||
|  | 		fmt.Fprintln(os.Stderr, "db server get messages error", err.Error()) | ||||||
|  | 	} else { | ||||||
|  | 		body, err := ioutil.ReadAll(res.Body) | ||||||
|  | 		if err != nil { | ||||||
|  | 			fmt.Fprintln(os.Stderr, "error reading response body:", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if res.StatusCode != http.StatusOK { | ||||||
|  | 			fmt.Println("load messages error", string(body)) | ||||||
|  | 			return res.StatusCode, string(body), nil | ||||||
|  | 		} else { | ||||||
|  | 			messages = make([]Message, 0) | ||||||
|  | 			if err := json.Unmarshal(body, &messages); err != nil { | ||||||
|  | 				fmt.Fprintln(os.Stderr, "Error unmarshalling messages from db", err.Error()) | ||||||
|  | 				return http.StatusInternalServerError, "Error unmarshalling messages from db " + err.Error(), nil | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return http.StatusOK, "", messages | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) putServer() { | ||||||
|  | 	jsonData, ok := encodeJson(&s) | ||||||
|  | 	if !ok { | ||||||
|  | 		fmt.Fprintln(os.Stderr, "error encoding server") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	res, err := http.Post(DbAddress+"/server", "application/json", bytes.NewBuffer(jsonData)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Fprintln(os.Stderr, "POST server to db error:", err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if res.StatusCode == http.StatusOK { | ||||||
|  | 		fmt.Println("server", s.Name, "already in db") | ||||||
|  | 		var temp Server | ||||||
|  | 		if err := json.NewDecoder(res.Body).Decode(&temp); err != nil { | ||||||
|  | 			fmt.Println("error updating local server details with those from db") | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		//TODO: update details which get pulled from db rather than on creation
 | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if res.StatusCode == http.StatusCreated { | ||||||
|  | 		fmt.Println("server", s.Name, "added to db") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | } | ||||||
| @ -0,0 +1,45 @@ | |||||||
|  | package testing | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/cookiejar" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	chat "git.wens.org.uk/chat/internal" | ||||||
|  | 	"github.com/golang-jwt/jwt" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func createJWT(username string) (string, error) { | ||||||
|  | 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ | ||||||
|  | 		"exp":      time.Now().Add(5 * time.Minute).Unix(), | ||||||
|  | 		"username": username, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	tokenString, err := token.SignedString([]byte(chat.JwtSecret)) | ||||||
|  | 	return tokenString, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestAddServer(t *testing.T) { | ||||||
|  | 	token, err := createJWT("testersson") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("could not create jwt: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	client := &http.Client{Jar: &cookiejar.Jar{}} | ||||||
|  | 
 | ||||||
|  | 	req, err := http.NewRequest("POST", "http://"+chat.Address+"/createserver", bytes.NewBuffer([]byte(`{"name":"newserver"}`))) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("could not create POST request: %v", err) | ||||||
|  | 	} | ||||||
|  | 	req.Header.Set("Token", token) | ||||||
|  | 	res, err := client.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("could not reach chat server: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if res.StatusCode != http.StatusCreated { | ||||||
|  | 		t.Errorf("unexpected return code from add server. Got: %v and expected %v\n", res.StatusCode, http.StatusCreated) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | } | ||||||
					Loading…
					
					
				
		Reference in New Issue