59 lines
1.1 KiB
Go
59 lines
1.1 KiB
Go
|
package database
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
"errors"
|
||
|
)
|
||
|
|
||
|
func Context(ctx context.Context, db *sql.DB) context.Context {
|
||
|
return context.WithValue(ctx, dbCtxKey, db)
|
||
|
}
|
||
|
|
||
|
func ForContext(ctx context.Context) (*sql.Conn, error) {
|
||
|
raw, ok := ctx.Value(dbCtxKey).(*sql.DB)
|
||
|
if !ok {
|
||
|
panic(errors.New("Invalid database context"))
|
||
|
}
|
||
|
return raw.Conn(ctx)
|
||
|
}
|
||
|
|
||
|
func DBForContext(ctx context.Context) *sql.DB {
|
||
|
raw, ok := ctx.Value(dbCtxKey).(*sql.DB)
|
||
|
if !ok {
|
||
|
panic(errors.New("Invalid database context"))
|
||
|
}
|
||
|
return raw
|
||
|
}
|
||
|
|
||
|
func WithTx(ctx context.Context, opts *sql.TxOptions, fn func(tx *sql.Tx) error) error {
|
||
|
conn, err := ForContext(ctx)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer conn.Close()
|
||
|
tx, err := conn.BeginTx(ctx, opts)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer func() {
|
||
|
if r := recover(); r != nil {
|
||
|
tx.Rollback()
|
||
|
panic(r)
|
||
|
}
|
||
|
}()
|
||
|
err = fn(tx)
|
||
|
if err != nil {
|
||
|
err := tx.Rollback()
|
||
|
if err != nil && err != sql.ErrTxDone {
|
||
|
panic(err)
|
||
|
}
|
||
|
} else {
|
||
|
err := tx.Commit()
|
||
|
if err != nil && err != sql.ErrTxDone {
|
||
|
panic(err)
|
||
|
}
|
||
|
}
|
||
|
return err
|
||
|
}
|