init
commit
da2d17d81c
@ -0,0 +1,13 @@
|
|||||||
|
FROM golang:latest
|
||||||
|
WORKDIR /app
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
COPY ./cmd ./cmd
|
||||||
|
COPY ./internal ./internal
|
||||||
|
RUN go build -o /chat ./cmd
|
||||||
|
|
||||||
|
FROM gcr.io/distroless/base-debian10
|
||||||
|
WORKDIR /
|
||||||
|
COPY --from=0 /chat /chat
|
||||||
|
ENTRYPOINT ["/chat"]
|
||||||
|
#docker run --network="host" -e <SECRET> chat #127.0.0.1:8000
|
@ -0,0 +1,51 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
chat "git.wens.org.uk/chat/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
parseArgs()
|
||||||
|
|
||||||
|
chat.Servers = make(map[string]*chat.Server)
|
||||||
|
chat.StartServer("root")
|
||||||
|
chat.SetUpHandlers()
|
||||||
|
|
||||||
|
fmt.Println("chat server listening on", chat.Address)
|
||||||
|
if err := http.ListenAndServe(chat.Address, nil); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err.Error())
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseArgs() {
|
||||||
|
db := flag.String("d", chat.DbAddress, "address for db server")
|
||||||
|
flag.Parse()
|
||||||
|
if chat.DbAddress != *db {
|
||||||
|
if !strings.HasPrefix(*db, "http://") {
|
||||||
|
*db = "http://" + *db
|
||||||
|
}
|
||||||
|
chat.DbAddress = *db
|
||||||
|
}
|
||||||
|
|
||||||
|
args := flag.Args()
|
||||||
|
if len(args) == 1 {
|
||||||
|
chat.Address = args[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) > 1 {
|
||||||
|
fmt.Fprintln(os.Stderr, "too many arguments. usage: <command> [options] [address]")
|
||||||
|
flag.PrintDefaults()
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chat.JwtSecret == "" {
|
||||||
|
fmt.Println("Warning: JWT secret is empty string")
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,9 @@
|
|||||||
|
module git.wens.org.uk/chat
|
||||||
|
|
||||||
|
go 1.18
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||||
|
github.com/google/uuid v1.3.0
|
||||||
|
github.com/gorilla/websocket v1.5.0
|
||||||
|
)
|
@ -0,0 +1,6 @@
|
|||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||||
|
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
@ -0,0 +1,107 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
// chat server funcs which aren't specific to the server struct
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var Address = "127.0.0.1:8000"
|
||||||
|
var JwtSecret = os.Getenv("secret")
|
||||||
|
|
||||||
|
func checkToken(token string) (username string, ok bool) {
|
||||||
|
type ParsedToken struct {
|
||||||
|
Username string
|
||||||
|
jwt.StandardClaims
|
||||||
|
}
|
||||||
|
var t ParsedToken
|
||||||
|
|
||||||
|
jwt, err := jwt.ParseWithClaims(token, &t, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
return []byte(JwtSecret), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
if userToken, ok := jwt.Claims.(*ParsedToken); ok && jwt.Valid {
|
||||||
|
return userToken.Username, true
|
||||||
|
} else {
|
||||||
|
fmt.Println(err)
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeJson(data interface{}) ([]byte, bool) {
|
||||||
|
if jsonData, err := json.Marshal(data); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "Could not marshal json data", data)
|
||||||
|
return nil, false
|
||||||
|
} else {
|
||||||
|
return jsonData, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeJson(body *io.ReadCloser, w *http.ResponseWriter, message interface{}) {
|
||||||
|
if err := json.NewDecoder(*body).Decode(&message); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "Error decoding", err.Error())
|
||||||
|
if w != nil && (*w) != nil {
|
||||||
|
(*w).WriteHeader(http.StatusBadRequest)
|
||||||
|
(*w).Write([]byte(`{"error":"bad message"}`))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dbHandleError(body *io.ReadCloser, w *http.ResponseWriter) (message string) {
|
||||||
|
type ErrorMessage struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
var error ErrorMessage
|
||||||
|
decodeJson(body, w, &error)
|
||||||
|
return error.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAcceptedMIMEType(mime string) bool {
|
||||||
|
switch mime {
|
||||||
|
case "image/bmp",
|
||||||
|
"image/gif",
|
||||||
|
"image/jpeg",
|
||||||
|
"image/png",
|
||||||
|
"audio/mpeg",
|
||||||
|
"video/webm":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func uploadFile(fileContents *[]byte, fileName, mimeType string) (res *http.Response, err error) {
|
||||||
|
type FileObject struct {
|
||||||
|
FileName string `json:"filename"`
|
||||||
|
FileType string `json:"filetype"`
|
||||||
|
Contents []byte `json:"contents"`
|
||||||
|
Sum [16]byte `json:"sum"`
|
||||||
|
}
|
||||||
|
|
||||||
|
file := FileObject{
|
||||||
|
fileName, mimeType, *fileContents, md5.Sum(*fileContents),
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, ok := encodeJson(&file)
|
||||||
|
if !ok {
|
||||||
|
fmt.Fprintln(os.Stderr, "error encoding file data")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err = http.Post(DbAddress+"/file", "application/json", bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "file POST error:", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
@ -0,0 +1,94 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
conn *websocket.Conn
|
||||||
|
server *Server
|
||||||
|
receivedMessages chan (WrappedMessage)
|
||||||
|
username string
|
||||||
|
id string
|
||||||
|
}
|
||||||
|
|
||||||
|
type WrappedMessage struct {
|
||||||
|
DataType string `json:"datatype"`
|
||||||
|
Data Message `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var deadline = time.Second * 60
|
||||||
|
|
||||||
|
func (c *Client) disconnect() {
|
||||||
|
dcString := fmt.Sprint("server "+c.server.Name+" ended connection with ", c.username)
|
||||||
|
delete(c.server.clients, c.id)
|
||||||
|
c.server.sendMessage(dcString)
|
||||||
|
fmt.Println(dcString)
|
||||||
|
fmt.Println(c.server.reportClients(), "clients connected to "+c.server.Name)
|
||||||
|
c.conn.WriteControl(websocket.CloseGoingAway, []byte("reconnected"), time.Now().Add(deadline))
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) pongHandler(string) error {
|
||||||
|
return c.conn.SetReadDeadline(time.Now().Add(deadline))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) awaitDisconnect() {
|
||||||
|
c.conn.SetReadDeadline(time.Now().Add(deadline))
|
||||||
|
c.conn.SetPongHandler(c.pongHandler)
|
||||||
|
for {
|
||||||
|
_, _, err := c.conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||||
|
fmt.Fprintln(os.Stderr, "Unexpected socket close error:", err.Error())
|
||||||
|
}
|
||||||
|
c.disconnect()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) awaitMessage() {
|
||||||
|
ticker := time.NewTicker(deadline / 2)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case mess := <-c.receivedMessages:
|
||||||
|
{
|
||||||
|
jsonData, ok := encodeJson(mess)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.sendMessage(websocket.TextMessage, jsonData) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case <-ticker.C:
|
||||||
|
{
|
||||||
|
if !c.sendMessage(websocket.PingMessage, nil) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendMessage(messageType int, data []byte) bool {
|
||||||
|
if err := c.conn.WriteMessage(messageType, data); err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||||
|
fmt.Fprintln(os.Stderr, "Unexpected socket close error:", err.Error())
|
||||||
|
c.disconnect()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
@ -0,0 +1,305 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var upgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
var DbAddress = "http://127.0.0.1:8002"
|
||||||
|
|
||||||
|
func SetUpHandlers() {
|
||||||
|
http.HandleFunc("/messages", MessagesHandler)
|
||||||
|
http.HandleFunc("/send", SendHandler)
|
||||||
|
http.HandleFunc("/connect", ConnectHandler)
|
||||||
|
http.HandleFunc("/clients", ClientsHandler)
|
||||||
|
http.HandleFunc("/createserver", CreateServerHandler)
|
||||||
|
http.HandleFunc("/file", FileHandler)
|
||||||
|
|
||||||
|
upgrader.CheckOrigin = func(r *http.Request) bool {
|
||||||
|
if true { // temp
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
reg := regexp.MustCompile(".*:8000")
|
||||||
|
return reg.MatchString(origin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConnectHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user, ok := checkToken(r.Header.Get("Token"))
|
||||||
|
if !ok {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
con, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := getServer(r.URL.Query().Get("name"))
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintln(w, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c := Client{con, s, make(chan WrappedMessage), user, uuid.NewString()}
|
||||||
|
s.clients[c.id] = &c
|
||||||
|
|
||||||
|
fmt.Println(c.server.reportClients(), "clients connected to "+c.server.Name+" - latest", c.username, "from", r.RemoteAddr)
|
||||||
|
|
||||||
|
go c.awaitDisconnect()
|
||||||
|
go c.awaitMessage()
|
||||||
|
}
|
||||||
|
|
||||||
|
func MessagesHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
token := r.Header.Get("Token")
|
||||||
|
if _, ok := checkToken(token); !ok {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
w.Write([]byte("invalid token provided:" + token))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := getServer(r.URL.Query().Get("name"))
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
w.Write([]byte(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
offset, err := strconv.Atoi(r.URL.Query().Get("offset"))
|
||||||
|
if err != nil {
|
||||||
|
offset = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := strconv.Atoi(r.URL.Query().Get("count"))
|
||||||
|
if err != nil {
|
||||||
|
count = DefaultCacheLength
|
||||||
|
}
|
||||||
|
|
||||||
|
if offset+count <= DefaultCacheLength {
|
||||||
|
s.wg.Wait()
|
||||||
|
jsonData, ok := encodeJson(s.messageCache.toMessageSlice(offset, count))
|
||||||
|
if !ok {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"error":"error serialising data"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(jsonData)
|
||||||
|
} else {
|
||||||
|
_, _, messages := s.loadMessages(offset, count)
|
||||||
|
jsonData, ok := encodeJson(messages)
|
||||||
|
if !ok {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(jsonData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SendHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
if r.Method == "OPTIONS" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Method != "POST" {
|
||||||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := r.Header.Get("Token")
|
||||||
|
user, ok := checkToken(token)
|
||||||
|
if !ok {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
w.Write([]byte(`{"error":"token expired"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := getServer(r.URL.Query().Get("name"))
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
w.Write([]byte(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var mess Message
|
||||||
|
decodeJson(&r.Body, &w, &mess)
|
||||||
|
|
||||||
|
mess = Message{mess.Text, time.Now().UnixMilli(), user, s.Name}
|
||||||
|
s.wg.Add(1)
|
||||||
|
go s.storeMessage(&mess)
|
||||||
|
go s.broadcastMessage(&mess)
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo: only works for clients connected to servers running on this instance
|
||||||
|
func ClientsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
type ClientData struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
}
|
||||||
|
clientsMap := make(map[string]struct{})
|
||||||
|
|
||||||
|
s, err := getServer(r.URL.Query().Get("name"))
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
w.Write([]byte(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, client := range s.clients {
|
||||||
|
clientsMap[client.username] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
clientsSlice := make([]ClientData, 0)
|
||||||
|
for username := range clientsMap {
|
||||||
|
clientsSlice = append(clientsSlice, ClientData{username})
|
||||||
|
}
|
||||||
|
jsonData, ok := encodeJson(clientsSlice)
|
||||||
|
if !ok {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(jsonData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateServerHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
type ServerConfig struct {
|
||||||
|
ServerName string `json:"name"`
|
||||||
|
Owner string `json:"owner"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method != "POST" {
|
||||||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := r.Header.Get("Token")
|
||||||
|
user, ok := checkToken(token)
|
||||||
|
if !ok {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
w.Write([]byte(`{"error":"token expired"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var config ServerConfig
|
||||||
|
decodeJson(&r.Body, &w, &config)
|
||||||
|
|
||||||
|
if config.ServerName == "" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
w.Write([]byte(`{"error":"server name blank"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config.Owner = user
|
||||||
|
jsonData, ok := encodeJson(config)
|
||||||
|
if !ok {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := http.Post(DbAddress+"/server", "application/json", bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "error reaching db server for POST server", err)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusCreated {
|
||||||
|
fmt.Println("post server error:", res.StatusCode)
|
||||||
|
}
|
||||||
|
w.WriteHeader(res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FileHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
const MegaByte = 1024 * 1024
|
||||||
|
|
||||||
|
token := r.Header.Get("Token")
|
||||||
|
_, ok := checkToken(token)
|
||||||
|
if !ok {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
w.Write([]byte(`{"error":"token expired"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := r.URL.Query().Get("name")
|
||||||
|
if fileName == "" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
w.Write([]byte(`{"error":"filename required"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
case "OPTIONS":
|
||||||
|
return
|
||||||
|
case "POST":
|
||||||
|
{
|
||||||
|
r.Body = http.MaxBytesReader(w, r.Body, 5*MegaByte)
|
||||||
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusRequestEntityTooLarge)
|
||||||
|
fmt.Fprintln(w, `{"error":"file too large"}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fileType := http.DetectContentType(body)
|
||||||
|
if !isAcceptedMIMEType(fileType) {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintf(w, `{"error":"filetype %v not allowed"}`, fileType)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := uploadFile(&body, fileName, fileType)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(res.StatusCode)
|
||||||
|
if res.StatusCode != http.StatusCreated {
|
||||||
|
w.Write([]byte(`{"error":"expected 201"}`))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "GET":
|
||||||
|
{
|
||||||
|
res, err := http.Get(DbAddress + "/file?name=" + fileName)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "could not reach db server for files", err.Error())
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
w.WriteHeader(res.StatusCode)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file, err := ioutil.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "error reading response body", err.Error())
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", res.Header.Get("Content-Type"))
|
||||||
|
w.Write(file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,59 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
type List struct {
|
||||||
|
root *ListNode
|
||||||
|
length int
|
||||||
|
maxLength int
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListNode struct {
|
||||||
|
next *ListNode
|
||||||
|
data Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func newList(maxLength int) List {
|
||||||
|
return List{nil, 0, maxLength}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *List) add(node *ListNode) {
|
||||||
|
if l.root == nil {
|
||||||
|
l.root = node
|
||||||
|
} else {
|
||||||
|
var n *ListNode
|
||||||
|
for n = l.root; n.next != nil; n = n.next {
|
||||||
|
}
|
||||||
|
n.next = node
|
||||||
|
l.length++
|
||||||
|
}
|
||||||
|
if l.length >= l.maxLength {
|
||||||
|
l.trimFirst()
|
||||||
|
l.length = l.maxLength
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *List) trimFirst() {
|
||||||
|
l.root = l.root.next
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *List) toMessageSlice(offset, count int) (messages []Message) {
|
||||||
|
messages = make([]Message, 0)
|
||||||
|
if l.root == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for node, i := l.root, 0; node != nil; node, i = node.next, i+1 {
|
||||||
|
if i < offset {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if i-offset >= count {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
messages = append(messages, node.data)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *List) fromMessageSlice(messages []Message) {
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- { // expected order is in reverse i.e. latest first
|
||||||
|
l.add(&ListNode{nil, messages[i]})
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,176 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DefaultCacheLength = 20
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
clients map[string]*Client
|
||||||
|
messageCache List
|
||||||
|
Name string `json:"name"`
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Time int64 `json:"time"`
|
||||||
|
User string `json:"user"`
|
||||||
|
ServerId string `json:"serverId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var Servers map[string]*Server
|
||||||
|
|
||||||
|
func StartServer(name string) (s Server) {
|
||||||
|
s = Server{make(map[string]*Client), newList(DefaultCacheLength), name, sync.WaitGroup{}}
|
||||||
|
Servers[name] = &s
|
||||||
|
s.putServer()
|
||||||
|
s.loadInitMessages()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getServer(name string) (*Server, error) {
|
||||||
|
if name == "" {
|
||||||
|
name = "root"
|
||||||
|
}
|
||||||
|
if s, ok := Servers[name]; ok {
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, exists := retrieveServerFromDb(name)
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("server does not exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
server := StartServer(name)
|
||||||
|
return &server, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func retrieveServerFromDb(name string) (s *Server, found bool) {
|
||||||
|
res, err := http.Get(DbAddress + "/server?name=" + name)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "could not reach db server", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
fmt.Println("server", name, "not found on db")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(res.Body).Decode(&s); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "error retrieving server details from db", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return s, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) reportClients() (count int) {
|
||||||
|
for range s.clients {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) sendMessage(message string) {
|
||||||
|
mess := Message{message, time.Now().UnixMilli(), s.Name, ""}
|
||||||
|
wm := WrappedMessage{"server", mess}
|
||||||
|
for _, client := range s.clients {
|
||||||
|
client.receivedMessages <- wm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) broadcastMessage(mess *Message) {
|
||||||
|
wm := WrappedMessage{"", *mess}
|
||||||
|
for _, client := range s.clients {
|
||||||
|
client.receivedMessages <- wm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) storeMessage(mess *Message) {
|
||||||
|
s.messageCache.add(&ListNode{nil, *mess})
|
||||||
|
s.wg.Done()
|
||||||
|
jsonData, ok := encodeJson(mess)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if res, err := http.Post(DbAddress+"/message?name="+s.Name, "application/json", bytes.NewBuffer(jsonData)); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "db server post error", err.Error())
|
||||||
|
} else {
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "error reading response body:", err)
|
||||||
|
}
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
fmt.Println("could not store message", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) loadInitMessages() {
|
||||||
|
status, _, messages := s.loadMessages(0, DefaultCacheLength)
|
||||||
|
if status != http.StatusOK {
|
||||||
|
fmt.Fprintln(os.Stderr, "Could not load initial messages from db")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.messageCache.fromMessageSlice(messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) loadMessages(offset, count int) (status int, message string, messages []Message) {
|
||||||
|
if res, err := http.Get(DbAddress + "/message?name=" + s.Name + "&offset=" + strconv.Itoa(offset) + "&count=" + strconv.Itoa(count)); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "db server get messages error", err.Error())
|
||||||
|
} else {
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "error reading response body:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
fmt.Println("load messages error", string(body))
|
||||||
|
return res.StatusCode, string(body), nil
|
||||||
|
} else {
|
||||||
|
messages = make([]Message, 0)
|
||||||
|
if err := json.Unmarshal(body, &messages); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "Error unmarshalling messages from db", err.Error())
|
||||||
|
return http.StatusInternalServerError, "Error unmarshalling messages from db " + err.Error(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return http.StatusOK, "", messages
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) putServer() {
|
||||||
|
jsonData, ok := encodeJson(&s)
|
||||||
|
if !ok {
|
||||||
|
fmt.Fprintln(os.Stderr, "error encoding server")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res, err := http.Post(DbAddress+"/server", "application/json", bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "POST server to db error:", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if res.StatusCode == http.StatusOK {
|
||||||
|
fmt.Println("server", s.Name, "already in db")
|
||||||
|
var temp Server
|
||||||
|
if err := json.NewDecoder(res.Body).Decode(&temp); err != nil {
|
||||||
|
fmt.Println("error updating local server details with those from db")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
//TODO: update details which get pulled from db rather than on creation
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if res.StatusCode == http.StatusCreated {
|
||||||
|
fmt.Println("server", s.Name, "added to db")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,45 @@
|
|||||||
|
package testing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"net/http/cookiejar"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
chat "git.wens.org.uk/chat/internal"
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createJWT(username string) (string, error) {
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||||
|
"exp": time.Now().Add(5 * time.Minute).Unix(),
|
||||||
|
"username": username,
|
||||||
|
})
|
||||||
|
|
||||||
|
tokenString, err := token.SignedString([]byte(chat.JwtSecret))
|
||||||
|
return tokenString, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddServer(t *testing.T) {
|
||||||
|
token, err := createJWT("testersson")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("could not create jwt: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Jar: &cookiejar.Jar{}}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", "http://"+chat.Address+"/createserver", bytes.NewBuffer([]byte(`{"name":"newserver"}`)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("could not create POST request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Token", token)
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("could not reach chat server: %v", err)
|
||||||
|
}
|
||||||
|
if res.StatusCode != http.StatusCreated {
|
||||||
|
t.Errorf("unexpected return code from add server. Got: %v and expected %v\n", res.StatusCode, http.StatusCreated)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue