golang-server-template/internal/data/users.go
2024-07-08 17:36:56 +08:00

200 lines
4.9 KiB
Go

package data
import (
"context"
"crypto/sha256"
"database/sql"
"errors"
"time"
"golang.org/x/crypto/bcrypt"
"greenlight.alexedwards.net/internal/validator"
)
var (
ErrDuplicateEmail = errors.New("duplicate email")
)
type User struct {
ID int64 `json:"id"`
CreateAt time.Time `json:"create_at"`
Name string `json:"name"`
Email string `json:"email"`
Password password `json:"-"`
Activated bool `json:"activated"`
Version int `json:"-"`
}
var AnonymousUser = &User{}
type password struct {
plaintext *string
hash []byte
}
func (u *User) IsAnonymous() bool {
return u == AnonymousUser
}
func (p *password) Set(plaintextPassword string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(plaintextPassword), 12)
if err != nil {
return err
}
p.plaintext = &plaintextPassword
p.hash = hash
return nil
}
func (p *password) Matches(plaintextPassword string) (bool, error) {
err := bcrypt.CompareHashAndPassword(p.hash, []byte(plaintextPassword))
if err != nil {
switch {
case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword):
return false, nil
default:
return false, err
}
}
return true, nil
}
func ValidateEmail(v *validator.Validator, email string) {
v.Check(email != "", "email", "must be provided")
v.Check(validator.Matches(email, validator.EmailRX), "email", "must be a valid email address")
}
func ValidatePasswordPlaintext(v *validator.Validator, password string) {
v.Check(password != "", "password", "must be provided")
v.Check(len(password) >= 8, "password", "must be at least 8 bytes long")
v.Check(len(password) <= 72, "password", "must not be more than 72 bytes long")
}
func ValidateUser(v *validator.Validator, user *User) {
v.Check(user.Name != "", "name", "must be provided")
v.Check(len(user.Name) <= 500, "name", "must not be more than 500 bytes long")
ValidateEmail(v, user.Email)
if user.Password.plaintext != nil {
ValidatePasswordPlaintext(v, *user.Password.plaintext)
}
if user.Password.hash == nil {
panic("missing password hash for user")
}
}
type UserModel struct {
DB *sql.DB
}
func (m UserModel) Insert(user *User) error {
query := `
INSERT INTO users (name, email, password_hash, activated)
VALUES ($1, $2, $3, $4)
RETURNING id, created_at, version`
args := []any{user.Name, user.Email, user.Password.hash, user.Activated}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.CreateAt, &user.Version)
if err != nil {
switch {
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
return ErrDuplicateEmail
default:
return err
}
}
return nil
}
func (m UserModel) GetByEmail(email string) (*User, error) {
query := `
SELECT id, created_at, name, email, password_hash, activated, version
FROM users
WHERE email = $1`
var user User
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := m.DB.QueryRowContext(ctx, query, email).Scan(
&user.ID,
&user.CreateAt,
&user.Name,
&user.Email,
&user.Password.hash,
&user.Activated,
&user.Version,
)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
return &user, nil
}
func (m UserModel) Update(user *User) error {
query := `
UPDATE users
SET name = $1, email = $2, password_hash = $3, activated = $4, version = version + 1
WHERE id = $5 AND version = $6
RETURNING version`
args := []any{
user.Name,
user.Email,
user.Password.hash,
user.Activated,
user.ID,
user.Version,
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version)
if err != nil {
switch {
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
return ErrDuplicateEmail
case errors.Is(err, sql.ErrNoRows):
return ErrEditConflict
default:
return err
}
}
return nil
}
func (m UserModel) GetForToken(tokenScope, tokenPlaintext string) (*User, error) {
tokenHash := sha256.Sum256([]byte(tokenPlaintext))
query := `
SELECT users.id, users.created_at, users.name, users.email, users.password_hash, users.activated, users.version
FROM users
INNER JOIN tokens
ON users.id = tokens.user_id
WHERE tokens.hash = $1
AND tokens.scope = $2
AND tokens.expiry > $3`
args := []any{tokenHash[:], tokenScope, time.Now()}
var user User
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := m.DB.QueryRowContext(ctx, query, args...).Scan(
&user.ID,
&user.CreateAt,
&user.Name,
&user.Email,
&user.Password.hash,
&user.Activated,
&user.Version,
)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
return &user, nil
}