package chat import ( "bytes" "fmt" "io/ioutil" "net/http" "os" "regexp" "strconv" "time" "github.com/google/uuid" "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, } var DbAddress = "http://127.0.0.1:8002" func SetUpHandlers() { http.HandleFunc("/messages", MessagesHandler) http.HandleFunc("/send", SendHandler) http.HandleFunc("/connect", ConnectHandler) http.HandleFunc("/clients", ClientsHandler) http.HandleFunc("/createserver", CreateServerHandler) http.HandleFunc("/file", FileHandler) upgrader.CheckOrigin = func(r *http.Request) bool { if true { // temp return true } origin := r.Header.Get("Origin") reg := regexp.MustCompile(".*:8000") return reg.MatchString(origin) } } func ConnectHandler(w http.ResponseWriter, r *http.Request) { user, ok := checkToken(r.Header.Get("Token")) if !ok { w.WriteHeader(http.StatusUnauthorized) return } con, err := upgrader.Upgrade(w, r, nil) if err != nil { fmt.Fprintln(os.Stderr, err.Error()) return } s, err := getServer(r.URL.Query().Get("name")) if err != nil { w.WriteHeader(http.StatusBadRequest) fmt.Fprintln(w, err.Error()) return } c := Client{con, s, make(chan WrappedMessage), user, uuid.NewString()} s.clients[c.id] = &c fmt.Println(c.server.reportClients(), "clients connected to "+c.server.Name+" - latest", c.username, "from", r.RemoteAddr) go c.awaitDisconnect() go c.awaitMessage() } func MessagesHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") token := r.Header.Get("Token") if _, ok := checkToken(token); !ok { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("invalid token provided:" + token)) return } s, err := getServer(r.URL.Query().Get("name")) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(err.Error())) return } offset, err := strconv.Atoi(r.URL.Query().Get("offset")) if err != nil { offset = 0 } count, err := strconv.Atoi(r.URL.Query().Get("count")) if err != nil { count = DefaultCacheLength } if offset+count <= DefaultCacheLength { s.wg.Wait() jsonData, ok := encodeJson(s.messageCache.toMessageSlice(offset, count)) if !ok { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(`{"error":"error serialising data"}`)) return } w.Write(jsonData) } else { _, _, messages := s.loadMessages(offset, count) jsonData, ok := encodeJson(messages) if !ok { w.WriteHeader(http.StatusInternalServerError) return } w.Write(jsonData) } } func SendHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if r.Method == "OPTIONS" { return } if r.Method != "POST" { w.WriteHeader(http.StatusMethodNotAllowed) return } token := r.Header.Get("Token") user, ok := checkToken(token) if !ok { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"error":"token expired"}`)) return } s, err := getServer(r.URL.Query().Get("name")) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(err.Error())) return } var mess Message decodeJson(&r.Body, &w, &mess) mess = Message{mess.Text, time.Now().UnixMilli(), user, s.Name} s.wg.Add(1) go s.storeMessage(&mess) go s.broadcastMessage(&mess) } // todo: only works for clients connected to servers running on this instance func ClientsHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") type ClientData struct { Username string `json:"username"` } clientsMap := make(map[string]struct{}) s, err := getServer(r.URL.Query().Get("name")) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(err.Error())) return } for _, client := range s.clients { clientsMap[client.username] = struct{}{} } clientsSlice := make([]ClientData, 0) for username := range clientsMap { clientsSlice = append(clientsSlice, ClientData{username}) } jsonData, ok := encodeJson(clientsSlice) if !ok { w.WriteHeader(http.StatusInternalServerError) return } w.Write(jsonData) } func CreateServerHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") type ServerConfig struct { ServerName string `json:"name"` Owner string `json:"owner"` } if r.Method != "POST" { w.WriteHeader(http.StatusMethodNotAllowed) return } token := r.Header.Get("Token") user, ok := checkToken(token) if !ok { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"error":"token expired"}`)) return } var config ServerConfig decodeJson(&r.Body, &w, &config) if config.ServerName == "" { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(`{"error":"server name blank"}`)) return } config.Owner = user jsonData, ok := encodeJson(config) if !ok { w.WriteHeader(http.StatusInternalServerError) return } res, err := http.Post(DbAddress+"/server", "application/json", bytes.NewBuffer(jsonData)) if err != nil { fmt.Fprintln(os.Stderr, "error reaching db server for POST server", err) w.WriteHeader(http.StatusInternalServerError) return } if res.StatusCode != http.StatusCreated { fmt.Println("post server error:", res.StatusCode) } w.WriteHeader(res.StatusCode) } func FileHandler(w http.ResponseWriter, r *http.Request) { const MegaByte = 1024 * 1024 token := r.Header.Get("Token") _, ok := checkToken(token) if !ok { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"error":"token expired"}`)) return } fileName := r.URL.Query().Get("name") if fileName == "" { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(`{"error":"filename required"}`)) return } switch r.Method { case "OPTIONS": return case "POST": { r.Body = http.MaxBytesReader(w, r.Body, 5*MegaByte) body, err := ioutil.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusRequestEntityTooLarge) fmt.Fprintln(w, `{"error":"file too large"}`) return } fileType := http.DetectContentType(body) if !isAcceptedMIMEType(fileType) { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, `{"error":"filetype %v not allowed"}`, fileType) return } res, err := uploadFile(&body, fileName, fileType) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } w.WriteHeader(res.StatusCode) if res.StatusCode != http.StatusCreated { w.Write([]byte(`{"error":"expected 201"}`)) } } case "GET": { res, err := http.Get(DbAddress + "/file?name=" + fileName) if err != nil { fmt.Fprintln(os.Stderr, "could not reach db server for files", err.Error()) w.WriteHeader(http.StatusInternalServerError) return } if res.StatusCode != http.StatusOK { w.WriteHeader(res.StatusCode) return } file, err := ioutil.ReadAll(res.Body) if err != nil { fmt.Fprintln(os.Stderr, "error reading response body", err.Error()) w.WriteHeader(http.StatusInternalServerError) return } w.Header().Set("Content-Type", res.Header.Get("Content-Type")) w.Write(file) } } }