You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

306 lines
7.0 KiB
Go

package chat
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"os"
"regexp"
"strconv"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
var DbAddress = "http://127.0.0.1:8002"
func SetUpHandlers() {
http.HandleFunc("/messages", MessagesHandler)
http.HandleFunc("/send", SendHandler)
http.HandleFunc("/connect", ConnectHandler)
http.HandleFunc("/clients", ClientsHandler)
http.HandleFunc("/createserver", CreateServerHandler)
http.HandleFunc("/file", FileHandler)
upgrader.CheckOrigin = func(r *http.Request) bool {
if true { // temp
return true
}
origin := r.Header.Get("Origin")
reg := regexp.MustCompile(".*:8000")
return reg.MatchString(origin)
}
}
func ConnectHandler(w http.ResponseWriter, r *http.Request) {
user, ok := checkToken(r.Header.Get("Token"))
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}
con, err := upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Fprintln(os.Stderr, err.Error())
return
}
s, err := getServer(r.URL.Query().Get("name"))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintln(w, err.Error())
return
}
c := Client{con, s, make(chan WrappedMessage), user, uuid.NewString()}
s.clients[c.id] = &c
fmt.Println(c.server.reportClients(), "clients connected to "+c.server.Name+" - latest", c.username, "from", r.RemoteAddr)
go c.awaitDisconnect()
go c.awaitMessage()
}
func MessagesHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
token := r.Header.Get("Token")
if _, ok := checkToken(token); !ok {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("invalid token provided:" + token))
return
}
s, err := getServer(r.URL.Query().Get("name"))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
offset, err := strconv.Atoi(r.URL.Query().Get("offset"))
if err != nil {
offset = 0
}
count, err := strconv.Atoi(r.URL.Query().Get("count"))
if err != nil {
count = DefaultCacheLength
}
if offset+count <= DefaultCacheLength {
s.wg.Wait()
jsonData, ok := encodeJson(s.messageCache.toMessageSlice(offset, count))
if !ok {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"error serialising data"}`))
return
}
w.Write(jsonData)
} else {
_, _, messages := s.loadMessages(offset, count)
jsonData, ok := encodeJson(messages)
if !ok {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Write(jsonData)
}
}
func SendHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == "OPTIONS" {
return
}
if r.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
token := r.Header.Get("Token")
user, ok := checkToken(token)
if !ok {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"token expired"}`))
return
}
s, err := getServer(r.URL.Query().Get("name"))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
var mess Message
decodeJson(&r.Body, &w, &mess)
mess = Message{mess.Text, time.Now().UnixMilli(), user, s.Name}
s.wg.Add(1)
go s.storeMessage(&mess)
go s.broadcastMessage(&mess)
}
// todo: only works for clients connected to servers running on this instance
func ClientsHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
type ClientData struct {
Username string `json:"username"`
}
clientsMap := make(map[string]struct{})
s, err := getServer(r.URL.Query().Get("name"))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
return
}
for _, client := range s.clients {
clientsMap[client.username] = struct{}{}
}
clientsSlice := make([]ClientData, 0)
for username := range clientsMap {
clientsSlice = append(clientsSlice, ClientData{username})
}
jsonData, ok := encodeJson(clientsSlice)
if !ok {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Write(jsonData)
}
func CreateServerHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
type ServerConfig struct {
ServerName string `json:"name"`
Owner string `json:"owner"`
}
if r.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
token := r.Header.Get("Token")
user, ok := checkToken(token)
if !ok {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"token expired"}`))
return
}
var config ServerConfig
decodeJson(&r.Body, &w, &config)
if config.ServerName == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"server name blank"}`))
return
}
config.Owner = user
jsonData, ok := encodeJson(config)
if !ok {
w.WriteHeader(http.StatusInternalServerError)
return
}
res, err := http.Post(DbAddress+"/server", "application/json", bytes.NewBuffer(jsonData))
if err != nil {
fmt.Fprintln(os.Stderr, "error reaching db server for POST server", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if res.StatusCode != http.StatusCreated {
fmt.Println("post server error:", res.StatusCode)
}
w.WriteHeader(res.StatusCode)
}
func FileHandler(w http.ResponseWriter, r *http.Request) {
const MegaByte = 1024 * 1024
token := r.Header.Get("Token")
_, ok := checkToken(token)
if !ok {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"token expired"}`))
return
}
fileName := r.URL.Query().Get("name")
if fileName == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"filename required"}`))
return
}
switch r.Method {
case "OPTIONS":
return
case "POST":
{
r.Body = http.MaxBytesReader(w, r.Body, 5*MegaByte)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusRequestEntityTooLarge)
fmt.Fprintln(w, `{"error":"file too large"}`)
return
}
fileType := http.DetectContentType(body)
if !isAcceptedMIMEType(fileType) {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, `{"error":"filetype %v not allowed"}`, fileType)
return
}
res, err := uploadFile(&body, fileName, fileType)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(res.StatusCode)
if res.StatusCode != http.StatusCreated {
w.Write([]byte(`{"error":"expected 201"}`))
}
}
case "GET":
{
res, err := http.Get(DbAddress + "/file?name=" + fileName)
if err != nil {
fmt.Fprintln(os.Stderr, "could not reach db server for files", err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
if res.StatusCode != http.StatusOK {
w.WriteHeader(res.StatusCode)
return
}
file, err := ioutil.ReadAll(res.Body)
if err != nil {
fmt.Fprintln(os.Stderr, "error reading response body", err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", res.Header.Get("Content-Type"))
w.Write(file)
}
}
}