commit 37e94fdf578c0503c8018663afa6e9948363bef4 Author: george Date: Wed May 18 19:43:34 2022 +0100 init diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..410e46f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM golang:latest +WORKDIR /app +COPY go.mod go.sum ./ +RUN go mod download +COPY ./cmd ./cmd +COPY ./internal ./internal +RUN go build -o /login ./cmd + +FROM gcr.io/distroless/base-debian10 +WORKDIR / +COPY --from=0 /login /login +ENTRYPOINT ["/login"] +#docker run --network="host" -e login #127.0.0.1:8001 diff --git a/cmd/main.go b/cmd/main.go new file mode 100755 index 0000000..51ccd51 --- /dev/null +++ b/cmd/main.go @@ -0,0 +1,49 @@ +package main + +import ( + "flag" + "fmt" + "net/http" + "os" + "strings" + + login "git.wens.org.uk/login/internal" +) + +func main() { + parseArgs() + login.SetUpHandlers() + + fmt.Println("login server listening on", login.Address) + if err := http.ListenAndServe(login.Address, nil); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } +} + +func parseArgs() { + db := flag.String("d", login.DbAddress, "address for db server") + flag.Parse() + + if login.DbAddress != *db { + if !strings.HasPrefix(*db, "http://") { + *db = "http://" + *db + } + login.DbAddress = *db + } + + args := flag.Args() + if len(args) == 1 { + login.Address = os.Args[0] + } + + if len(args) > 1 { + fmt.Fprintln(os.Stderr, "too many arguments. usage: [options] [address]") + flag.PrintDefaults() + os.Exit(1) + } + + if login.JwtSecret == "" { + fmt.Println("Warning: JWT secret is empty string") + } +} diff --git a/go.mod b/go.mod new file mode 100755 index 0000000..3798400 --- /dev/null +++ b/go.mod @@ -0,0 +1,9 @@ +module git.wens.org.uk/login + +go 1.18 + +require ( + github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/google/uuid v1.3.0 + golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f +) diff --git a/go.sum b/go.sum new file mode 100755 index 0000000..2eda662 --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f h1:OeJjE6G4dgCY4PIXvIRQbE8+RX+uXZyGhUy/ksMGJoc= +golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= diff --git a/internal/handlers.go b/internal/handlers.go new file mode 100755 index 0000000..f8e0274 --- /dev/null +++ b/internal/handlers.go @@ -0,0 +1,197 @@ +package login + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "time" + + "golang.org/x/crypto/bcrypt" +) + +func SetUpHandlers() { + http.HandleFunc("/register", registerHandler) + http.HandleFunc("/login", loginHandler) + http.HandleFunc("/users", usersHandler) + http.HandleFunc("/auth", authHandler) +} + +func registerHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Content-Type", "application/json") + login, ok := decodeLogin(&w, &r) + if !ok { + return + } + + if message, ok := validateRegistration(login); !ok { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(message)) + return + } + + hash, err := hashPassword(login.Password) + if err != nil { + fmt.Fprintln(os.Stderr, "Error hashing password:", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + + id, sessionId, err := newSession(login.Username) + if err != nil { + return + } + + user := User{login.Username, sessionId, hash} + jsonData, err := json.Marshal(user) + if err != nil { + fmt.Println("Error marshalling data") + w.WriteHeader(http.StatusBadRequest) + return + } + + res, err := http.Post(DbAddress+"/user", "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + fmt.Fprintln(os.Stderr, "Error adding user to db server:", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + + if mess, ok := checkBody(&res); !ok { + w.WriteHeader(res.StatusCode) + w.Write(mess) + return + } + + if mess, ok := storeSessionToken(login.Username, sessionId); !ok { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(mess)) + return + } + + c := &http.Cookie{HttpOnly: true, Name: "session", Value: sessionId, Expires: time.Now().Add(time.Hour * 24 * 30)} + http.SetCookie(w, c) + + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"token":"` + id + `"}`)) +} + +func loginHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Content-Type", "application/json") + login, ok := decodeLogin(&w, &r) + if !ok { + return + } + + res, err := http.Get(DbAddress + "/user?username=" + login.Username) + if err != nil { + fmt.Fprintln(os.Stderr, "db server error:", err.Error()) + return + } + + body, ok := checkBody(&res) + if !ok { + w.WriteHeader(res.StatusCode) + w.Write(body) + return + } + + var user User + if err := json.Unmarshal(body, &user); err != nil { + fmt.Fprintln(os.Stderr, "bad user retrieved from db server:", string(body)) + w.WriteHeader(http.StatusInternalServerError) + return + } + + if err := bcrypt.CompareHashAndPassword(user.HashedPassword, []byte(login.Password)); err != nil { + fmt.Println("Failed login for user", login.Username) + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error":"login failed"}`)) + return + } + + id, sessionId, err := newSession(login.Username) + if err != nil { + return + } + + if mess, ok := storeSessionToken(login.Username, sessionId); !ok { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(mess)) + return + } + + c := &http.Cookie{HttpOnly: true, Name: "session", Value: sessionId} + http.SetCookie(w, c) + + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"token":"` + id + `"}`)) +} + +func usersHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Content-Type", "application/json") + + if r.Method != "GET" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + if res, err := http.Get(DbAddress + "/users"); err != nil { + fmt.Fprintln(os.Stderr, "db GET /users error:", err.Error()) + } else { + body, err := ioutil.ReadAll(res.Body) + if err != nil { + fmt.Fprintln(os.Stderr, "error reading response body:", err) + } + w.WriteHeader(res.StatusCode) + w.Write(body) + } +} + +func authHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Content-Type", "application/json") + + if r.Method != "GET" { + w.WriteHeader(http.StatusMethodNotAllowed) + } + + val, err := r.Cookie("session") + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + _, _, ok := retrieveSessionToken(val.Value) + if !ok { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"token invalid"}`)) + return + } + + user, ok := getUserBySession(val.Value) + if !ok { + fmt.Println("not found") + w.WriteHeader(http.StatusNotFound) + return + } + + id, sessionId, err := newSession(user.Username) + if err != nil { + return + } + + if mess, ok := storeSessionToken(user.Username, sessionId); !ok { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(mess)) + return + } + + c := &http.Cookie{HttpOnly: true, Name: "session", Value: sessionId, Expires: time.Now().Add(time.Hour * 24 * 30)} + http.SetCookie(w, c) + w.Write([]byte(`{"token":"` + id + `"}`)) +} diff --git a/internal/login.go b/internal/login.go new file mode 100755 index 0000000..1eadc37 --- /dev/null +++ b/internal/login.go @@ -0,0 +1,196 @@ +package login + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "time" + + "github.com/golang-jwt/jwt" + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" +) + +type Login struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type User struct { + Username string `json:"username"` + Session string `json:"session"` + HashedPassword []byte `json:"hashedPassword"` +} + +var JwtSecret = os.Getenv("secret") +var Address = "127.0.0.1:8001" +var DbAddress = "http://127.0.0.1:8002" + +func decodeLogin(w *http.ResponseWriter, r **http.Request) (login Login, ok bool) { + (*w).Header().Set("Access-Control-Allow-Origin", "*") + if (*r).Method == "OPTIONS" { + return + } + if (*r).Method != "POST" { + (*w).WriteHeader(http.StatusMethodNotAllowed) + return Login{}, false + } + + if err := json.NewDecoder((*r).Body).Decode(&login); err != nil { + fmt.Println("Error decoding login", err.Error()) + (*w).WriteHeader(http.StatusBadRequest) + (*w).Write([]byte(`{"error":"malformed login"}`)) + return Login{}, false + } + + return login, true +} + +func validateRegistration(login Login) (message string, ok bool) { + const ( + numLowerBound = 48 + numUpperBound = 57 + lcAlphaLowerBound = 97 + lcAlphaUpperBound = 122 + ucAlphaLowerBound = 65 + ucAlphaUpperBound = 90 + ) + + if len(login.Password) < 8 { + message = `{"error":"password too short"}` + return + } + + var lcAlphaCount, ucAlphaCount, numCount, specialCount int + for _, char := range login.Password { + intVal := int(char) + if numLowerBound <= intVal && intVal <= numUpperBound { + numCount++ + } else if lcAlphaLowerBound <= intVal && intVal <= lcAlphaUpperBound { + lcAlphaCount++ + } else if ucAlphaLowerBound <= intVal && intVal <= ucAlphaUpperBound { + ucAlphaCount++ + } else { + specialCount++ + } + } + + if lcAlphaCount == 0 || ucAlphaCount == 0 || numCount == 0 || specialCount == 0 { + message = `{"error":"password failed criteria"}` + return + } + return "", true +} + +func hashPassword(password string) (hash []byte, err error) { + hash, err = bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + fmt.Fprintln(os.Stderr, "Error hashing password", err.Error()) + } + return +} + +func checkBody(res **http.Response) (mess []byte, ok bool) { + body, err := ioutil.ReadAll((*res).Body) + if err != nil { + fmt.Fprintln(os.Stderr, "error reading response body:", err) + return nil, false + } + if (*res).StatusCode != http.StatusOK { + fmt.Println(string(body)) + return body, false + } + return body, true +} + +func encodeJson(data interface{}) ([]byte, bool) { + if jsonData, err := json.Marshal(data); err != nil { + fmt.Fprintln(os.Stderr, "Could not marshal json data", data) + return nil, false + } else { + return jsonData, true + } +} + +func createJWT(username string) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "exp": time.Now().Add(5 * time.Minute).Unix(), + "username": username, + }) + + tokenString, err := token.SignedString([]byte(JwtSecret)) + return tokenString, err +} + +func newSession(username string) (jwt, sessionId string, err error) { + jwt, err = createJWT(username) + return jwt, uuid.NewString(), err +} + +func storeSessionToken(username, token string) (message string, ok bool) { + user := User{username, token, nil} + jsonData, ok := encodeJson(user) + if !ok { + return "error encoding json", false + } + res, err := http.Post(DbAddress+"/session", "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + fmt.Fprintln(os.Stderr, "could not reach db server:", err) + return "could not reach db server", false + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return "could not decode response body", false + } + + if err != nil || res.StatusCode != http.StatusOK { + return string(body), false + } + + return "", true +} + +func retrieveSessionToken(username string) (token, message string, ok bool) { + res, err := http.Get(DbAddress + "/session?username=" + username) + if err != nil { + fmt.Fprintln(os.Stderr, "could not reach db server:", err) + return "", "could not reach db server", false + } + body, err := ioutil.ReadAll(res.Body) + + if err != nil || res.StatusCode != http.StatusOK { + return "", string(body), false + } + + var user User + if err := json.Unmarshal(body, &user); err != nil { + return "", err.Error(), false + } + + return user.Session, "", true +} + +func getUserBySession(session string) (user User, ok bool) { + res, err := http.Get(DbAddress + "/user?session=" + session) + if err != nil { + fmt.Fprintln(os.Stderr, "error reaching db server", err.Error()) + return User{}, false + } + if res.StatusCode != http.StatusOK { + return User{}, false + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + fmt.Fprintln(os.Stderr, "could not read response body:", err) + return User{}, false + } + if err := json.Unmarshal(body, &user); err != nil { + fmt.Fprintln(os.Stderr, "unmarhsal user error", err.Error()) + return user, false + } + return user, true +}