generated from example/golang-server-template
198 lines
5.2 KiB
Go
198 lines
5.2 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"text/template"
|
|
"time"
|
|
|
|
"github.com/go-oauth2/oauth2/v4"
|
|
"github.com/go-oauth2/oauth2/v4/manage"
|
|
"github.com/go-oauth2/oauth2/v4/models"
|
|
"github.com/go-oauth2/oauth2/v4/server"
|
|
"github.com/go-oauth2/oauth2/v4/store"
|
|
"github.com/go-session/session"
|
|
)
|
|
|
|
const (
|
|
clientHost = "http://localhost:4002"
|
|
clientID = "client_id"
|
|
clientSecret = "client_secret"
|
|
)
|
|
|
|
var manager *manage.Manager
|
|
var srv *server.Server
|
|
|
|
func (app *application) initAuth() error {
|
|
manager = manage.NewDefaultManager()
|
|
manager.MustTokenStorage(store.NewMemoryTokenStore())
|
|
client_store := store.NewClientStore()
|
|
client_store.Set(clientID, &models.Client{
|
|
ID: clientID,
|
|
Secret: clientSecret,
|
|
Domain: clientHost,
|
|
})
|
|
manager.MapClientStorage(client_store)
|
|
manager.SetValidateURIHandler(func(baseURI, redirectURI string) error {
|
|
return nil
|
|
})
|
|
srv = server.NewDefaultServer(manager)
|
|
srv.SetAllowGetAccessRequest(true)
|
|
srv.SetClientInfoHandler(server.ClientFormHandler)
|
|
srv.SetAllowedGrantType(oauth2.AuthorizationCode, oauth2.Refreshing)
|
|
srv.SetAllowedResponseType(oauth2.Code)
|
|
srv.SetUserAuthorizationHandler(app.userAuthorizationHandler)
|
|
srv.SetPasswordAuthorizationHandler(app.passwordAuthorizationHandler)
|
|
return nil
|
|
}
|
|
|
|
// 授权逻辑
|
|
// 授权成功返回user_id, 授权失败重定向到登录页面
|
|
func (app *application) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (user_id string, err error) {
|
|
store, err := session.Start(r.Context(), w, r)
|
|
if err != nil {
|
|
app.serverErrorResponse(w, r, err)
|
|
return "", err
|
|
}
|
|
uid, ok := store.Get("LoggedInUserId")
|
|
if !ok {
|
|
if r.Form == nil {
|
|
r.ParseForm()
|
|
}
|
|
store.Set("ReturnUri", r.Form)
|
|
store.Save()
|
|
w.Header().Set("Location", "/v1/oauth2/login")
|
|
w.WriteHeader(http.StatusFound)
|
|
return "", nil
|
|
}
|
|
user_id = uid.(string)
|
|
return user_id, nil
|
|
}
|
|
|
|
// 校验登录账号密码
|
|
func (app *application) passwordAuthorizationHandler(ctx context.Context, clientID, username, password string) (userID string, err error) {
|
|
if username == "admin" && password == "admin" {
|
|
return "admin", nil
|
|
}
|
|
return "", errors.New("账号或密码错误")
|
|
}
|
|
|
|
// 授权逻辑
|
|
func (app *application) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
|
store, err := session.Start(r.Context(), w, r)
|
|
if err != nil {
|
|
app.serverErrorResponse(w, r, err)
|
|
return
|
|
}
|
|
var form url.Values
|
|
if v, ok := store.Get("ReturnUri"); ok {
|
|
form = v.(url.Values)
|
|
}
|
|
r.Form = form
|
|
store.Delete("ReturnUri")
|
|
store.Save()
|
|
err = srv.HandleAuthorizeRequest(w, r)
|
|
if err != nil {
|
|
app.serverErrorResponse(w, r, err)
|
|
return
|
|
}
|
|
}
|
|
|
|
// 登录页面逻辑
|
|
// 登录成功重定向到授权页面, 否则返回登录页面
|
|
func (app *application) loginHandler(w http.ResponseWriter, r *http.Request) {
|
|
store, err := session.Start(r.Context(), w, r)
|
|
if err != nil {
|
|
app.serverErrorResponse(w, r, err)
|
|
return
|
|
}
|
|
// 表单登录
|
|
if r.Method == http.MethodPost {
|
|
r.ParseForm()
|
|
user_id, err := srv.PasswordAuthorizationHandler(r.Context(), "", r.Form.Get("username"), r.Form.Get("password"))
|
|
if err != nil {
|
|
errMsg := struct{ Message string }{Message: err.Error()}
|
|
temp, err := template.ParseFiles("internal/static/login.html")
|
|
if err != nil {
|
|
errMsg.Message = err.Error()
|
|
}
|
|
temp.Execute(w, errMsg)
|
|
return
|
|
}
|
|
store.Set("LoggedInUserId", user_id)
|
|
store.Save()
|
|
w.Header().Set("Location", "/v1/oauth2/agree-auth")
|
|
w.WriteHeader(http.StatusFound)
|
|
return
|
|
}
|
|
errMsg := struct{ Message string }{Message: ""}
|
|
temp, err := template.ParseFiles("internal/static/login.html")
|
|
if err != nil {
|
|
errMsg.Message = err.Error()
|
|
}
|
|
temp.Execute(w, errMsg)
|
|
}
|
|
|
|
// token获取
|
|
func (app *application) tokenHandler(w http.ResponseWriter, r *http.Request) {
|
|
err := srv.HandleTokenRequest(w, r)
|
|
if err != nil {
|
|
app.serverErrorResponse(w, r, err)
|
|
return
|
|
}
|
|
}
|
|
|
|
// 授权页面逻辑
|
|
func (app *application) agreeAuthHandler(w http.ResponseWriter, r *http.Request) {
|
|
store, err := session.Start(r.Context(), w, r)
|
|
if err != nil {
|
|
app.serverErrorResponse(w, r, err)
|
|
return
|
|
}
|
|
if _, ok := store.Get("LoggedInUserId"); !ok {
|
|
w.Header().Set("Location", "/v1/oauth2/login")
|
|
w.WriteHeader(http.StatusFound)
|
|
return
|
|
}
|
|
app.outputHTML(w, r, "internal/static/agree-auth.html")
|
|
|
|
}
|
|
|
|
// 获取用户信息(请求头携带access_token)
|
|
func (app *application) getUserInfoHandler(w http.ResponseWriter, r *http.Request) {
|
|
access_token, ok := srv.BearerAuth(r)
|
|
if !ok {
|
|
return
|
|
}
|
|
root_ctx := context.Background()
|
|
ctx, cancel := context.WithTimeout(root_ctx, time.Second)
|
|
defer cancel()
|
|
token_info, err := srv.Manager.LoadAccessToken(ctx, access_token)
|
|
if err != nil {
|
|
app.serverErrorResponse(w, r, err)
|
|
return
|
|
}
|
|
if token_info.GetScope() != "all" {
|
|
app.serverErrorResponse(w, r, errors.New("invalid grant scope"))
|
|
return
|
|
}
|
|
user_id := token_info.GetUserID()
|
|
// TODO 根据user_id获取用户信息
|
|
app.writeJSON(w, http.StatusOK, envelope{"user_id": user_id}, nil)
|
|
}
|
|
|
|
// html文件部署
|
|
func (app *application) outputHTML(w http.ResponseWriter, req *http.Request, filename string) {
|
|
file, err := os.Open(filename)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), 500)
|
|
return
|
|
}
|
|
defer file.Close()
|
|
fi, _ := file.Stat()
|
|
http.ServeContent(w, req, file.Name(), fi.ModTime(), file)
|
|
}
|