searchhut/database/context.go

59 lines
1.1 KiB
Go
Raw Normal View History

2022-07-08 19:46:11 +02:00
package database
import (
"context"
"database/sql"
"errors"
)
func Context(ctx context.Context, db *sql.DB) context.Context {
return context.WithValue(ctx, dbCtxKey, db)
}
func ForContext(ctx context.Context) (*sql.Conn, error) {
raw, ok := ctx.Value(dbCtxKey).(*sql.DB)
if !ok {
panic(errors.New("Invalid database context"))
}
return raw.Conn(ctx)
}
func DBForContext(ctx context.Context) *sql.DB {
raw, ok := ctx.Value(dbCtxKey).(*sql.DB)
if !ok {
panic(errors.New("Invalid database context"))
}
return raw
}
func WithTx(ctx context.Context, opts *sql.TxOptions, fn func(tx *sql.Tx) error) error {
conn, err := ForContext(ctx)
if err != nil {
return err
}
defer conn.Close()
tx, err := conn.BeginTx(ctx, opts)
if err != nil {
return err
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
err = fn(tx)
if err != nil {
err := tx.Rollback()
if err != nil && err != sql.ErrTxDone {
panic(err)
}
} else {
err := tx.Commit()
if err != nil && err != sql.ErrTxDone {
panic(err)
}
}
return err
}