init
commit
da2d17d81c
@ -0,0 +1,13 @@
|
||||
FROM golang:latest
|
||||
WORKDIR /app
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY ./cmd ./cmd
|
||||
COPY ./internal ./internal
|
||||
RUN go build -o /chat ./cmd
|
||||
|
||||
FROM gcr.io/distroless/base-debian10
|
||||
WORKDIR /
|
||||
COPY --from=0 /chat /chat
|
||||
ENTRYPOINT ["/chat"]
|
||||
#docker run --network="host" -e <SECRET> chat #127.0.0.1:8000
|
@ -0,0 +1,51 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
chat "git.wens.org.uk/chat/internal"
|
||||
)
|
||||
|
||||
func main() {
|
||||
parseArgs()
|
||||
|
||||
chat.Servers = make(map[string]*chat.Server)
|
||||
chat.StartServer("root")
|
||||
chat.SetUpHandlers()
|
||||
|
||||
fmt.Println("chat server listening on", chat.Address)
|
||||
if err := http.ListenAndServe(chat.Address, nil); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func parseArgs() {
|
||||
db := flag.String("d", chat.DbAddress, "address for db server")
|
||||
flag.Parse()
|
||||
if chat.DbAddress != *db {
|
||||
if !strings.HasPrefix(*db, "http://") {
|
||||
*db = "http://" + *db
|
||||
}
|
||||
chat.DbAddress = *db
|
||||
}
|
||||
|
||||
args := flag.Args()
|
||||
if len(args) == 1 {
|
||||
chat.Address = args[0]
|
||||
}
|
||||
|
||||
if len(args) > 1 {
|
||||
fmt.Fprintln(os.Stderr, "too many arguments. usage: <command> [options] [address]")
|
||||
flag.PrintDefaults()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if chat.JwtSecret == "" {
|
||||
fmt.Println("Warning: JWT secret is empty string")
|
||||
}
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
module git.wens.org.uk/chat
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
)
|
@ -0,0 +1,6 @@
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
@ -0,0 +1,107 @@
|
||||
package chat
|
||||
|
||||
// chat server funcs which aren't specific to the server struct
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
var Address = "127.0.0.1:8000"
|
||||
var JwtSecret = os.Getenv("secret")
|
||||
|
||||
func checkToken(token string) (username string, ok bool) {
|
||||
type ParsedToken struct {
|
||||
Username string
|
||||
jwt.StandardClaims
|
||||
}
|
||||
var t ParsedToken
|
||||
|
||||
jwt, err := jwt.ParseWithClaims(token, &t, func(t *jwt.Token) (interface{}, error) {
|
||||
return []byte(JwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if userToken, ok := jwt.Claims.(*ParsedToken); ok && jwt.Valid {
|
||||
return userToken.Username, true
|
||||
} else {
|
||||
fmt.Println(err)
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func encodeJson(data interface{}) ([]byte, bool) {
|
||||
if jsonData, err := json.Marshal(data); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Could not marshal json data", data)
|
||||
return nil, false
|
||||
} else {
|
||||
return jsonData, true
|
||||
}
|
||||
}
|
||||
|
||||
func decodeJson(body *io.ReadCloser, w *http.ResponseWriter, message interface{}) {
|
||||
if err := json.NewDecoder(*body).Decode(&message); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error decoding", err.Error())
|
||||
if w != nil && (*w) != nil {
|
||||
(*w).WriteHeader(http.StatusBadRequest)
|
||||
(*w).Write([]byte(`{"error":"bad message"}`))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dbHandleError(body *io.ReadCloser, w *http.ResponseWriter) (message string) {
|
||||
type ErrorMessage struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
var error ErrorMessage
|
||||
decodeJson(body, w, &error)
|
||||
return error.Error
|
||||
}
|
||||
|
||||
func isAcceptedMIMEType(mime string) bool {
|
||||
switch mime {
|
||||
case "image/bmp",
|
||||
"image/gif",
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"audio/mpeg",
|
||||
"video/webm":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func uploadFile(fileContents *[]byte, fileName, mimeType string) (res *http.Response, err error) {
|
||||
type FileObject struct {
|
||||
FileName string `json:"filename"`
|
||||
FileType string `json:"filetype"`
|
||||
Contents []byte `json:"contents"`
|
||||
Sum [16]byte `json:"sum"`
|
||||
}
|
||||
|
||||
file := FileObject{
|
||||
fileName, mimeType, *fileContents, md5.Sum(*fileContents),
|
||||
}
|
||||
|
||||
jsonData, ok := encodeJson(&file)
|
||||
if !ok {
|
||||
fmt.Fprintln(os.Stderr, "error encoding file data")
|
||||
return
|
||||
}
|
||||
|
||||
res, err = http.Post(DbAddress+"/file", "application/json", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "file POST error:", err)
|
||||
}
|
||||
return
|
||||
}
|
@ -0,0 +1,94 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Username string `json:"username"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
conn *websocket.Conn
|
||||
server *Server
|
||||
receivedMessages chan (WrappedMessage)
|
||||
username string
|
||||
id string
|
||||
}
|
||||
|
||||
type WrappedMessage struct {
|
||||
DataType string `json:"datatype"`
|
||||
Data Message `json:"data"`
|
||||
}
|
||||
|
||||
var deadline = time.Second * 60
|
||||
|
||||
func (c *Client) disconnect() {
|
||||
dcString := fmt.Sprint("server "+c.server.Name+" ended connection with ", c.username)
|
||||
delete(c.server.clients, c.id)
|
||||
c.server.sendMessage(dcString)
|
||||
fmt.Println(dcString)
|
||||
fmt.Println(c.server.reportClients(), "clients connected to "+c.server.Name)
|
||||
c.conn.WriteControl(websocket.CloseGoingAway, []byte("reconnected"), time.Now().Add(deadline))
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Client) pongHandler(string) error {
|
||||
return c.conn.SetReadDeadline(time.Now().Add(deadline))
|
||||
}
|
||||
|
||||
func (c *Client) awaitDisconnect() {
|
||||
c.conn.SetReadDeadline(time.Now().Add(deadline))
|
||||
c.conn.SetPongHandler(c.pongHandler)
|
||||
for {
|
||||
_, _, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
fmt.Fprintln(os.Stderr, "Unexpected socket close error:", err.Error())
|
||||
}
|
||||
c.disconnect()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) awaitMessage() {
|
||||
ticker := time.NewTicker(deadline / 2)
|
||||
for {
|
||||
select {
|
||||
case mess := <-c.receivedMessages:
|
||||
{
|
||||
jsonData, ok := encodeJson(mess)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if !c.sendMessage(websocket.TextMessage, jsonData) {
|
||||
return
|
||||
}
|
||||
}
|
||||
case <-ticker.C:
|
||||
{
|
||||
if !c.sendMessage(websocket.PingMessage, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) sendMessage(messageType int, data []byte) bool {
|
||||
if err := c.conn.WriteMessage(messageType, data); err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
fmt.Fprintln(os.Stderr, "Unexpected socket close error:", err.Error())
|
||||
c.disconnect()
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
@ -0,0 +1,305 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
var DbAddress = "http://127.0.0.1:8002"
|
||||
|
||||
func SetUpHandlers() {
|
||||
http.HandleFunc("/messages", MessagesHandler)
|
||||
http.HandleFunc("/send", SendHandler)
|
||||
http.HandleFunc("/connect", ConnectHandler)
|
||||
http.HandleFunc("/clients", ClientsHandler)
|
||||
http.HandleFunc("/createserver", CreateServerHandler)
|
||||
http.HandleFunc("/file", FileHandler)
|
||||
|
||||
upgrader.CheckOrigin = func(r *http.Request) bool {
|
||||
if true { // temp
|
||||
return true
|
||||
}
|
||||
origin := r.Header.Get("Origin")
|
||||
reg := regexp.MustCompile(".*:8000")
|
||||
return reg.MatchString(origin)
|
||||
}
|
||||
}
|
||||
|
||||
func ConnectHandler(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := checkToken(r.Header.Get("Token"))
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
con, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s, err := getServer(r.URL.Query().Get("name"))
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintln(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c := Client{con, s, make(chan WrappedMessage), user, uuid.NewString()}
|
||||
s.clients[c.id] = &c
|
||||
|
||||
fmt.Println(c.server.reportClients(), "clients connected to "+c.server.Name+" - latest", c.username, "from", r.RemoteAddr)
|
||||
|
||||
go c.awaitDisconnect()
|
||||
go c.awaitMessage()
|
||||
}
|
||||
|
||||
func MessagesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
token := r.Header.Get("Token")
|
||||
if _, ok := checkToken(token); !ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("invalid token provided:" + token))
|
||||
return
|
||||
}
|
||||
|
||||
s, err := getServer(r.URL.Query().Get("name"))
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
offset, err := strconv.Atoi(r.URL.Query().Get("offset"))
|
||||
if err != nil {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
count, err := strconv.Atoi(r.URL.Query().Get("count"))
|
||||
if err != nil {
|
||||
count = DefaultCacheLength
|
||||
}
|
||||
|
||||
if offset+count <= DefaultCacheLength {
|
||||
s.wg.Wait()
|
||||
jsonData, ok := encodeJson(s.messageCache.toMessageSlice(offset, count))
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(`{"error":"error serialising data"}`))
|
||||
return
|
||||
}
|
||||
w.Write(jsonData)
|
||||
} else {
|
||||
_, _, messages := s.loadMessages(offset, count)
|
||||
jsonData, ok := encodeJson(messages)
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(jsonData)
|
||||
}
|
||||
}
|
||||
|
||||
func SendHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
return
|
||||
}
|
||||
if r.Method != "POST" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
token := r.Header.Get("Token")
|
||||
user, ok := checkToken(token)
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{"error":"token expired"}`))
|
||||
return
|
||||
}
|
||||
|
||||
s, err := getServer(r.URL.Query().Get("name"))
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var mess Message
|
||||
decodeJson(&r.Body, &w, &mess)
|
||||
|
||||
mess = Message{mess.Text, time.Now().UnixMilli(), user, s.Name}
|
||||
s.wg.Add(1)
|
||||
go s.storeMessage(&mess)
|
||||
go s.broadcastMessage(&mess)
|
||||
}
|
||||
|
||||
// todo: only works for clients connected to servers running on this instance
|
||||
func ClientsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
type ClientData struct {
|
||||
Username string `json:"username"`
|
||||
}
|
||||
clientsMap := make(map[string]struct{})
|
||||
|
||||
s, err := getServer(r.URL.Query().Get("name"))
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
for _, client := range s.clients {
|
||||
clientsMap[client.username] = struct{}{}
|
||||
}
|
||||
|
||||
clientsSlice := make([]ClientData, 0)
|
||||
for username := range clientsMap {
|
||||
clientsSlice = append(clientsSlice, ClientData{username})
|
||||
}
|
||||
jsonData, ok := encodeJson(clientsSlice)
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(jsonData)
|
||||
}
|
||||
|
||||
func CreateServerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
type ServerConfig struct {
|
||||
ServerName string `json:"name"`
|
||||
Owner string `json:"owner"`
|
||||
}
|
||||
|
||||
if r.Method != "POST" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
token := r.Header.Get("Token")
|
||||
user, ok := checkToken(token)
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{"error":"token expired"}`))
|
||||
return
|
||||
}
|
||||
|
||||
var config ServerConfig
|
||||
decodeJson(&r.Body, &w, &config)
|
||||
|
||||
if config.ServerName == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"server name blank"}`))
|
||||
return
|
||||
}
|
||||
config.Owner = user
|
||||
jsonData, ok := encodeJson(config)
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
res, err := http.Post(DbAddress+"/server", "application/json", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "error reaching db server for POST server", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
fmt.Println("post server error:", res.StatusCode)
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
}
|
||||
|
||||
func FileHandler(w http.ResponseWriter, r *http.Request) {
|
||||
const MegaByte = 1024 * 1024
|
||||
|
||||
token := r.Header.Get("Token")
|
||||
_, ok := checkToken(token)
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{"error":"token expired"}`))
|
||||
return
|
||||
}
|
||||
|
||||
fileName := r.URL.Query().Get("name")
|
||||
if fileName == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"filename required"}`))
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case "OPTIONS":
|
||||
return
|
||||
case "POST":
|
||||
{
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 5*MegaByte)
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusRequestEntityTooLarge)
|
||||
fmt.Fprintln(w, `{"error":"file too large"}`)
|
||||
return
|
||||
}
|
||||
|
||||
fileType := http.DetectContentType(body)
|
||||
if !isAcceptedMIMEType(fileType) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, `{"error":"filetype %v not allowed"}`, fileType)
|
||||
return
|
||||
}
|
||||
|
||||
res, err := uploadFile(&body, fileName, fileType)
|
||||
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
w.Write([]byte(`{"error":"expected 201"}`))
|
||||
}
|
||||
}
|
||||
case "GET":
|
||||
{
|
||||
res, err := http.Get(DbAddress + "/file?name=" + fileName)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "could not reach db server for files", err.Error())
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if res.StatusCode != http.StatusOK {
|
||||
w.WriteHeader(res.StatusCode)
|
||||
return
|
||||
}
|
||||
file, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "error reading response body", err.Error())
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", res.Header.Get("Content-Type"))
|
||||
w.Write(file)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,59 @@
|
||||
package chat
|
||||
|
||||
type List struct {
|
||||
root *ListNode
|
||||
length int
|
||||
maxLength int
|
||||
}
|
||||
|
||||
type ListNode struct {
|
||||
next *ListNode
|
||||
data Message
|
||||
}
|
||||
|
||||
func newList(maxLength int) List {
|
||||
return List{nil, 0, maxLength}
|
||||
}
|
||||
|
||||
func (l *List) add(node *ListNode) {
|
||||
if l.root == nil {
|
||||
l.root = node
|
||||
} else {
|
||||
var n *ListNode
|
||||
for n = l.root; n.next != nil; n = n.next {
|
||||
}
|
||||
n.next = node
|
||||
l.length++
|
||||
}
|
||||
if l.length >= l.maxLength {
|
||||
l.trimFirst()
|
||||
l.length = l.maxLength
|
||||
}
|
||||
}
|
||||
|
||||
func (l *List) trimFirst() {
|
||||
l.root = l.root.next
|
||||
}
|
||||
|
||||
func (l *List) toMessageSlice(offset, count int) (messages []Message) {
|
||||
messages = make([]Message, 0)
|
||||
if l.root == nil {
|
||||
return
|
||||
}
|
||||
for node, i := l.root, 0; node != nil; node, i = node.next, i+1 {
|
||||
if i < offset {
|
||||
continue
|
||||
}
|
||||
if i-offset >= count {
|
||||
break
|
||||
}
|
||||
messages = append(messages, node.data)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *List) fromMessageSlice(messages []Message) {
|
||||
for i := len(messages) - 1; i >= 0; i-- { // expected order is in reverse i.e. latest first
|
||||
l.add(&ListNode{nil, messages[i]})
|
||||
}
|
||||
}
|
@ -0,0 +1,176 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const DefaultCacheLength = 20
|
||||
|
||||
type Server struct {
|
||||
clients map[string]*Client
|
||||
messageCache List
|
||||
Name string `json:"name"`
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Text string `json:"text"`
|
||||
Time int64 `json:"time"`
|
||||
User string `json:"user"`
|
||||
ServerId string `json:"serverId"`
|
||||
}
|
||||
|
||||
var Servers map[string]*Server
|
||||
|
||||
func StartServer(name string) (s Server) {
|
||||
s = Server{make(map[string]*Client), newList(DefaultCacheLength), name, sync.WaitGroup{}}
|
||||
Servers[name] = &s
|
||||
s.putServer()
|
||||
s.loadInitMessages()
|
||||
return
|
||||
}
|
||||
|
||||
func getServer(name string) (*Server, error) {
|
||||
if name == "" {
|
||||
name = "root"
|
||||
}
|
||||
if s, ok := Servers[name]; ok {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
_, exists := retrieveServerFromDb(name)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("server does not exist")
|
||||
}
|
||||
|
||||
server := StartServer(name)
|
||||
return &server, nil
|
||||
}
|
||||
|
||||
func retrieveServerFromDb(name string) (s *Server, found bool) {
|
||||
res, err := http.Get(DbAddress + "/server?name=" + name)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "could not reach db server", err)
|
||||
return
|
||||
}
|
||||
if res.StatusCode != http.StatusOK {
|
||||
fmt.Println("server", name, "not found on db")
|
||||
return
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&s); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "error retrieving server details from db", err)
|
||||
return
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
func (s *Server) reportClients() (count int) {
|
||||
for range s.clients {
|
||||
count++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) sendMessage(message string) {
|
||||
mess := Message{message, time.Now().UnixMilli(), s.Name, ""}
|
||||
wm := WrappedMessage{"server", mess}
|
||||
for _, client := range s.clients {
|
||||
client.receivedMessages <- wm
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) broadcastMessage(mess *Message) {
|
||||
wm := WrappedMessage{"", *mess}
|
||||
for _, client := range s.clients {
|
||||
client.receivedMessages <- wm
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) storeMessage(mess *Message) {
|
||||
s.messageCache.add(&ListNode{nil, *mess})
|
||||
s.wg.Done()
|
||||
jsonData, ok := encodeJson(mess)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if res, err := http.Post(DbAddress+"/message?name="+s.Name, "application/json", bytes.NewBuffer(jsonData)); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "db server post error", err.Error())
|
||||
} else {
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "error reading response body:", err)
|
||||
}
|
||||
if res.StatusCode != http.StatusOK {
|
||||
fmt.Println("could not store message", string(body))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) loadInitMessages() {
|
||||
status, _, messages := s.loadMessages(0, DefaultCacheLength)
|
||||
if status != http.StatusOK {
|
||||
fmt.Fprintln(os.Stderr, "Could not load initial messages from db")
|
||||
return
|
||||
}
|
||||
s.messageCache.fromMessageSlice(messages)
|
||||
}
|
||||
|
||||
func (s *Server) loadMessages(offset, count int) (status int, message string, messages []Message) {
|
||||
if res, err := http.Get(DbAddress + "/message?name=" + s.Name + "&offset=" + strconv.Itoa(offset) + "&count=" + strconv.Itoa(count)); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "db server get messages error", err.Error())
|
||||
} else {
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "error reading response body:", err)
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
fmt.Println("load messages error", string(body))
|
||||
return res.StatusCode, string(body), nil
|
||||
} else {
|
||||
messages = make([]Message, 0)
|
||||
if err := json.Unmarshal(body, &messages); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error unmarshalling messages from db", err.Error())
|
||||
return http.StatusInternalServerError, "Error unmarshalling messages from db " + err.Error(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return http.StatusOK, "", messages
|
||||
}
|
||||
|
||||
func (s *Server) putServer() {
|
||||
jsonData, ok := encodeJson(&s)
|
||||
if !ok {
|
||||
fmt.Fprintln(os.Stderr, "error encoding server")
|
||||
return
|
||||
}
|
||||
res, err := http.Post(DbAddress+"/server", "application/json", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "POST server to db error:", err.Error())
|
||||
return
|
||||
}
|
||||
if res.StatusCode == http.StatusOK {
|
||||
fmt.Println("server", s.Name, "already in db")
|
||||
var temp Server
|
||||
if err := json.NewDecoder(res.Body).Decode(&temp); err != nil {
|
||||
fmt.Println("error updating local server details with those from db")
|
||||
return
|
||||
}
|
||||
//TODO: update details which get pulled from db rather than on creation
|
||||
return
|
||||
}
|
||||
if res.StatusCode == http.StatusCreated {
|
||||
fmt.Println("server", s.Name, "added to db")
|
||||
return
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
package testing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
chat "git.wens.org.uk/chat/internal"
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
func createJWT(username string) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"exp": time.Now().Add(5 * time.Minute).Unix(),
|
||||
"username": username,
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte(chat.JwtSecret))
|
||||
return tokenString, err
|
||||
}
|
||||
|
||||
func TestAddServer(t *testing.T) {
|
||||
token, err := createJWT("testersson")
|
||||
if err != nil {
|
||||
t.Fatalf("could not create jwt: %v", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Jar: &cookiejar.Jar{}}
|
||||
|
||||
req, err := http.NewRequest("POST", "http://"+chat.Address+"/createserver", bytes.NewBuffer([]byte(`{"name":"newserver"}`)))
|
||||
if err != nil {
|
||||
t.Fatalf("could not create POST request: %v", err)
|
||||
}
|
||||
req.Header.Set("Token", token)
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("could not reach chat server: %v", err)
|
||||
}
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
t.Errorf("unexpected return code from add server. Got: %v and expected %v\n", res.StatusCode, http.StatusCreated)
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue