package main import ( "context" "errors" "net/http" "net/url" "os" "text/template" "time" "github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4/manage" "github.com/go-oauth2/oauth2/v4/models" "github.com/go-oauth2/oauth2/v4/server" "github.com/go-oauth2/oauth2/v4/store" "github.com/go-session/session" ) const ( clientHost = "http://localhost:4002" clientID = "client_id" clientSecret = "client_secret" ) var manager *manage.Manager var srv *server.Server func (app *application) initAuth() error { manager = manage.NewDefaultManager() manager.MustTokenStorage(store.NewMemoryTokenStore()) client_store := store.NewClientStore() client_store.Set(clientID, &models.Client{ ID: clientID, Secret: clientSecret, Domain: clientHost, }) manager.MapClientStorage(client_store) manager.SetValidateURIHandler(func(baseURI, redirectURI string) error { return nil }) srv = server.NewDefaultServer(manager) srv.SetAllowGetAccessRequest(true) srv.SetClientInfoHandler(server.ClientFormHandler) srv.SetAllowedGrantType(oauth2.AuthorizationCode, oauth2.Refreshing) srv.SetAllowedResponseType(oauth2.Code) srv.SetUserAuthorizationHandler(app.userAuthorizationHandler) srv.SetPasswordAuthorizationHandler(app.passwordAuthorizationHandler) return nil } // 授权逻辑 // 授权成功返回user_id, 授权失败重定向到登录页面 func (app *application) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (user_id string, err error) { store, err := session.Start(r.Context(), w, r) if err != nil { app.serverErrorResponse(w, r, err) return "", err } uid, ok := store.Get("LoggedInUserId") if !ok { if r.Form == nil { r.ParseForm() } store.Set("ReturnUri", r.Form) store.Save() w.Header().Set("Location", "/v1/oauth2/login") w.WriteHeader(http.StatusFound) return "", nil } user_id = uid.(string) return user_id, nil } // 校验登录账号密码 func (app *application) passwordAuthorizationHandler(ctx context.Context, clientID, username, password string) (userID string, err error) { if username == "admin" && password == "admin" { return "admin", nil } return "", errors.New("账号或密码错误") } // 授权逻辑 func (app *application) authorizeHandler(w http.ResponseWriter, r *http.Request) { store, err := session.Start(r.Context(), w, r) if err != nil { app.serverErrorResponse(w, r, err) return } var form url.Values if v, ok := store.Get("ReturnUri"); ok { form = v.(url.Values) } r.Form = form store.Delete("ReturnUri") store.Save() err = srv.HandleAuthorizeRequest(w, r) if err != nil { app.serverErrorResponse(w, r, err) return } } // 登录页面逻辑 // 登录成功重定向到授权页面, 否则返回登录页面 func (app *application) loginHandler(w http.ResponseWriter, r *http.Request) { store, err := session.Start(r.Context(), w, r) if err != nil { app.serverErrorResponse(w, r, err) return } // 表单登录 if r.Method == http.MethodPost { r.ParseForm() user_id, err := srv.PasswordAuthorizationHandler(r.Context(), "", r.Form.Get("username"), r.Form.Get("password")) if err != nil { errMsg := struct{ Message string }{Message: err.Error()} temp, err := template.ParseFiles("internal/static/login.html") if err != nil { errMsg.Message = err.Error() } temp.Execute(w, errMsg) return } store.Set("LoggedInUserId", user_id) store.Save() w.Header().Set("Location", "/v1/oauth2/agree-auth") w.WriteHeader(http.StatusFound) return } errMsg := struct{ Message string }{Message: ""} temp, err := template.ParseFiles("internal/static/login.html") if err != nil { errMsg.Message = err.Error() } temp.Execute(w, errMsg) } // token获取 func (app *application) tokenHandler(w http.ResponseWriter, r *http.Request) { err := srv.HandleTokenRequest(w, r) if err != nil { app.serverErrorResponse(w, r, err) return } } // 授权页面逻辑 func (app *application) agreeAuthHandler(w http.ResponseWriter, r *http.Request) { store, err := session.Start(r.Context(), w, r) if err != nil { app.serverErrorResponse(w, r, err) return } if _, ok := store.Get("LoggedInUserId"); !ok { w.Header().Set("Location", "/v1/oauth2/login") w.WriteHeader(http.StatusFound) return } app.outputHTML(w, r, "internal/static/agree-auth.html") } // 获取用户信息(请求头携带access_token) func (app *application) getUserInfoHandler(w http.ResponseWriter, r *http.Request) { access_token, ok := srv.BearerAuth(r) if !ok { return } root_ctx := context.Background() ctx, cancel := context.WithTimeout(root_ctx, time.Second) defer cancel() token_info, err := srv.Manager.LoadAccessToken(ctx, access_token) if err != nil { app.serverErrorResponse(w, r, err) return } if token_info.GetScope() != "all" { app.serverErrorResponse(w, r, errors.New("invalid grant scope")) return } user_id := token_info.GetUserID() // TODO 根据user_id获取用户信息 app.writeJSON(w, http.StatusOK, envelope{"user_id": user_id}, nil) } // html文件部署 func (app *application) outputHTML(w http.ResponseWriter, req *http.Request, filename string) { file, err := os.Open(filename) if err != nil { http.Error(w, err.Error(), 500) return } defer file.Close() fi, _ := file.Stat() http.ServeContent(w, req, file.Name(), fi.ModTime(), file) }