oauth2-resource-server/cmd/api/auth.go
2025-01-02 15:13:41 +08:00

96 lines
2.7 KiB
Go

package main
import (
"context"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"net/http"
"time"
"golang.org/x/oauth2"
)
const (
h5Host = "http://localhost:5173"
clientID = "client_id"
clientSecret = "client_secret"
clientHost = "http://localhost:4002"
clientState = "client_state"
authHost = "http://localhost:4001"
code_challenge = "code_challenge"
)
var (
oauth2Config = oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: []string{"all"},
RedirectURL: clientHost + "/v1/oauth2",
Endpoint: oauth2.Endpoint{
AuthURL: authHost + "/v1/oauth2/authorize",
TokenURL: authHost + "/v1/oauth2/token",
},
}
)
// 登录 重定向到认证服务器
func (app *application) loginHandler(w http.ResponseWriter, r *http.Request) {
u := oauth2Config.AuthCodeURL(clientState, oauth2.SetAuthURLParam("code_challenge", genCodeChallengeS256(code_challenge)), oauth2.SetAuthURLParam("code_challenge_method", "S256"))
http.Redirect(w, r, u, http.StatusFound)
}
// 授权回调
func (app *application) oauth2Handler(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
state := r.Form.Get("state")
if state != clientState {
app.serverErrorResponse(w, r, errors.New("state is not valid"))
return
}
code := r.Form.Get("code")
if code == "" {
app.serverErrorResponse(w, r, errors.New("code is not found"))
return
}
token, err := oauth2Config.Exchange(r.Context(), code, oauth2.SetAuthURLParam("code_verifier", code_challenge))
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
// TODO 检查资源服务器上是否有该用户信息, 没有则通过authHost/v1/oauth2/get-user-info获取用户信息并保存到资源服务器
http.Redirect(w, r, fmt.Sprintf("%s/authorize?access_token=%s&refresh_token=%s&expiry=%d", h5Host, token.AccessToken, token.RefreshToken, token.Expiry.Unix()), http.StatusFound)
}
// 刷新token
// POST /v1/refresh-token
func (app *application) refreshTokenHandler(w http.ResponseWriter, r *http.Request) {
var input struct {
RefreshToken string `json:"refresh_token"`
}
err := app.readJSON(w, r, &input)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
token, err := oauth2Config.TokenSource(context.Background(), &oauth2.Token{
RefreshToken: input.RefreshToken,
Expiry: time.Now(),
}).Token()
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
err = app.writeJSON(w, http.StatusOK, envelope{"token": token}, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
}
func genCodeChallengeS256(s string) string {
s256 := sha256.Sum256([]byte(s))
return base64.URLEncoding.EncodeToString(s256[:])
}