package main import ( "context" "crypto/sha256" "encoding/base64" "errors" "fmt" "net/http" "time" "golang.org/x/oauth2" ) const ( h5Host = "http://localhost:5173" clientID = "client_id" clientSecret = "client_secret" clientHost = "http://localhost:4002" clientState = "client_state" authHost = "http://localhost:4001" code_challenge = "code_challenge" ) var ( oauth2Config = oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, Scopes: []string{"all"}, RedirectURL: clientHost + "/v1/oauth2", Endpoint: oauth2.Endpoint{ AuthURL: authHost + "/v1/oauth2/authorize", TokenURL: authHost + "/v1/oauth2/token", }, } ) // 登录 重定向到认证服务器 func (app *application) loginHandler(w http.ResponseWriter, r *http.Request) { u := oauth2Config.AuthCodeURL(clientState, oauth2.SetAuthURLParam("code_challenge", genCodeChallengeS256(code_challenge)), oauth2.SetAuthURLParam("code_challenge_method", "S256")) http.Redirect(w, r, u, http.StatusFound) } // 授权回调 func (app *application) oauth2Handler(w http.ResponseWriter, r *http.Request) { r.ParseForm() state := r.Form.Get("state") if state != clientState { app.serverErrorResponse(w, r, errors.New("state is not valid")) return } code := r.Form.Get("code") if code == "" { app.serverErrorResponse(w, r, errors.New("code is not found")) return } token, err := oauth2Config.Exchange(r.Context(), code, oauth2.SetAuthURLParam("code_verifier", code_challenge)) if err != nil { app.serverErrorResponse(w, r, err) return } // TODO 检查资源服务器上是否有该用户信息, 没有则通过authHost/v1/oauth2/get-user-info获取用户信息并保存到资源服务器 http.Redirect(w, r, fmt.Sprintf("%s/authorize?access_token=%s&refresh_token=%s&expiry=%d", h5Host, token.AccessToken, token.RefreshToken, token.Expiry.Unix()), http.StatusFound) } // 刷新token // POST /v1/refresh-token func (app *application) refreshTokenHandler(w http.ResponseWriter, r *http.Request) { var input struct { RefreshToken string `json:"refresh_token"` } err := app.readJSON(w, r, &input) if err != nil { app.serverErrorResponse(w, r, err) return } token, err := oauth2Config.TokenSource(context.Background(), &oauth2.Token{ RefreshToken: input.RefreshToken, Expiry: time.Now(), }).Token() if err != nil { app.serverErrorResponse(w, r, err) return } err = app.writeJSON(w, http.StatusOK, envelope{"token": token}, nil) if err != nil { app.serverErrorResponse(w, r, err) return } } func genCodeChallengeS256(s string) string { s256 := sha256.Sum256([]byte(s)) return base64.URLEncoding.EncodeToString(s256[:]) }