pg_v_man/infrastructure/db/replication_client.go
2025-02-14 08:44:34 +01:00

385 lines
11 KiB
Go

package db
import (
"context"
"errors"
"fmt"
"log"
"log/slog"
"time"
"github.com/google/uuid"
"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/jackc/pgx/v5/pgtype"
"code.icb4dc0.de/prskr/pg_v_man/core/domain"
"code.icb4dc0.de/prskr/pg_v_man/core/ports"
"code.icb4dc0.de/prskr/pg_v_man/infrastructure/config"
)
const (
outputPlugin = "pgoutput"
defaultConsumerTimeout = 200 * time.Millisecond
)
func NewReplicationClient(ctx context.Context, cfg config.DB, consumer ports.ReplicationEventConsumer) (client *ReplicationClient, err error) {
conn, err := pgconn.Connect(ctx, cfg.ConnectionString)
if err != nil {
return nil, fmt.Errorf("could not connect to database: %w", err)
}
defer func() {
if err != nil {
err = errors.Join(err, conn.Close(context.Background()))
}
}()
client = &ReplicationClient{
Conn: conn,
Consumer: consumer,
ConsumerTimeout: defaultConsumerTimeout,
typeMap: pgtype.NewMap(),
relations: make(map[uint32]*pglogrepl.RelationMessageV2),
}
client.sysident, err = pglogrepl.IdentifySystem(ctx, conn)
if err != nil {
return nil, fmt.Errorf("could not identify system: %w", err)
}
slog.Info(
"System identity",
slog.String("systemId", client.sysident.SystemID),
slog.Int("timeline", int(client.sysident.Timeline)),
slog.String("xlogpos", client.sysident.XLogPos.String()),
slog.String("dbname", client.sysident.DBName),
)
_, err = pglogrepl.CreateReplicationSlot(ctx, conn, cfg.SlotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: true})
if err != nil {
return nil, fmt.Errorf("could not create replication slot: %w", err)
}
slog.Info("Replication slot created", slog.String("slotName", cfg.SlotName))
pluginArguments := []string{
"proto_version '2'",
fmt.Sprintf("publication_names '%s'", cfg.Publication),
"messages 'true'",
"streaming 'true'",
}
err = pglogrepl.StartReplication(ctx, conn, cfg.SlotName, client.sysident.XLogPos, pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments})
if err != nil {
return nil, fmt.Errorf("could not start replication: %w", err)
}
slog.Info("Replication started", slog.String("slotName", cfg.SlotName))
return client, nil
}
type ReplicationClient struct {
ConsumerTimeout time.Duration
Conn *pgconn.PgConn
Consumer ports.ReplicationEventConsumer
sysident pglogrepl.IdentifySystemResult
typeMap *pgtype.Map
relations map[uint32]*pglogrepl.RelationMessageV2
latestXid uint32
inStream bool
}
func (c *ReplicationClient) Receive(ctx context.Context) (err error) {
clientXLogPos := c.sysident.XLogPos
standbyMessageTimeout := time.Second * 10
nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout)
for ctx.Err() == nil {
if time.Now().After(nextStandbyMessageDeadline) {
err = pglogrepl.SendStandbyStatusUpdate(context.Background(), c.Conn, pglogrepl.StandbyStatusUpdate{WALWritePosition: clientXLogPos})
if err != nil {
return fmt.Errorf("SendStandbyStatusUpdate failed: %w", err)
}
slog.Debug("Sent Standby status message", slog.String("xlogpos", clientXLogPos.String()))
nextStandbyMessageDeadline = time.Now().Add(standbyMessageTimeout)
}
receiveCtx, stop := context.WithDeadline(ctx, nextStandbyMessageDeadline)
rawMsg, err := c.Conn.ReceiveMessage(receiveCtx)
stop()
if err != nil {
if pgconn.Timeout(err) {
continue
}
return fmt.Errorf("could not receive message: %w", err)
}
if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok {
slog.Error("Error response", slog.String("message", errMsg.Message))
continue
}
msg, ok := rawMsg.(*pgproto3.CopyData)
if !ok {
slog.Warn("Received unexpected message", slog.String("message", fmt.Sprintf("%T", rawMsg)))
continue
}
switch msg.Data[0] {
case pglogrepl.PrimaryKeepaliveMessageByteID:
pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:])
if err != nil {
return fmt.Errorf("ParsePrimaryKeepaliveMessage failed: %w", err)
}
slog.Debug(
"Primary Keepalive Message",
slog.String("ServerWALEnd", pkm.ServerWALEnd.String()),
slog.Time("ServerTime", pkm.ServerTime),
slog.Bool("ReplyRequested", pkm.ReplyRequested),
)
if pkm.ServerWALEnd > clientXLogPos {
clientXLogPos = pkm.ServerWALEnd
}
if pkm.ReplyRequested {
nextStandbyMessageDeadline = time.Time{}
}
case pglogrepl.XLogDataByteID:
xld, err := pglogrepl.ParseXLogData(msg.Data[1:])
if err != nil {
return fmt.Errorf("ParseXLogData failed: %w", err)
}
slog.Debug(
"Received log data",
slog.String("WALStart", xld.WALStart.String()),
slog.String("ServerWALEnd", xld.ServerWALEnd.String()),
slog.Time("ServerTime", xld.ServerTime),
)
processCtx, stop := context.WithTimeout(ctx, c.ConsumerTimeout)
err = c.handleReceivedLog(processCtx, xld.WALData)
stop()
if err != nil {
return err
}
if xld.WALStart > clientXLogPos {
clientXLogPos = xld.WALStart
}
}
}
return nil
}
func (c *ReplicationClient) Close(ctx context.Context) error {
return c.Conn.Close(ctx)
}
func (c *ReplicationClient) handleReceivedLog(ctx context.Context, walData []byte) error {
logicalMsg, err := pglogrepl.ParseV2(walData, c.inStream)
if err != nil {
return fmt.Errorf("parse logical replication message: %w", err)
}
slog.Debug("Receive a logical replication message", slog.String("type", logicalMsg.Type().String()))
switch logicalMsg := logicalMsg.(type) {
case *pglogrepl.RelationMessageV2:
c.relations[logicalMsg.RelationID] = logicalMsg
case *pglogrepl.BeginMessage:
// Indicates the beginning of a group of changes in a transaction. This is only sent for committed transactions. You won't get any events from rolled back transactions.
slog.Debug("Begin message", slog.Int("xid", int(logicalMsg.Xid)))
c.latestXid = logicalMsg.Xid
case *pglogrepl.CommitMessage:
slog.Debug("Commit message", slog.String("commit_lsn", logicalMsg.CommitLSN.String()))
case *pglogrepl.InsertMessageV2:
rel, ok := c.relations[logicalMsg.RelationID]
if !ok {
return fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID)
}
slog.Debug("Received insert message",
slog.Uint64("xid", uint64(c.latestXid)),
slog.String("namespace", rel.Namespace),
slog.String("relation", rel.RelationName),
slog.Bool("in_stream", c.inStream),
)
ev := domain.ReplicationEvent{
EventType: domain.EventTypeInsert,
TransactionId: c.latestXid,
DBName: c.sysident.DBName,
Namespace: rel.Namespace,
Relation: rel.RelationName,
NewValues: c.tupleToMap(rel, logicalMsg.Tuple, false),
}
return c.Consumer.OnDataChange(ctx, ev)
case *pglogrepl.UpdateMessageV2:
rel, ok := c.relations[logicalMsg.RelationID]
if !ok {
return fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID)
}
slog.Debug(
"Received update message",
slog.Uint64("xid", uint64(c.latestXid)),
slog.String("namespace", rel.Namespace),
slog.String("relation", rel.RelationName),
slog.Bool("in_stream", c.inStream),
)
ev := domain.ReplicationEvent{
EventType: domain.EventTypeUpdate,
TransactionId: c.latestXid,
DBName: c.sysident.DBName,
Namespace: rel.Namespace,
Relation: rel.RelationName,
NewValues: c.tupleToMap(rel, logicalMsg.NewTuple, false),
OldValues: c.tupleToMap(rel, logicalMsg.OldTuple, false),
}
return c.Consumer.OnDataChange(ctx, ev)
case *pglogrepl.DeleteMessageV2:
rel, ok := c.relations[logicalMsg.RelationID]
if !ok {
return fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID)
}
slog.Debug(
"Received deletion message",
slog.Uint64("xid", uint64(c.latestXid)),
slog.String("namespace", rel.Namespace),
slog.String("relation", rel.RelationName),
slog.Bool("in_stream", c.inStream),
)
ev := domain.ReplicationEvent{
EventType: domain.EventTypeDelete,
TransactionId: c.latestXid,
DBName: c.sysident.DBName,
Namespace: rel.Namespace,
Relation: rel.RelationName,
OldValues: c.tupleToMap(rel, logicalMsg.OldTuple, true),
}
return c.Consumer.OnDataChange(ctx, ev)
case *pglogrepl.TruncateMessageV2:
for _, relId := range logicalMsg.RelationIDs {
rel, ok := c.relations[relId]
if !ok {
return fmt.Errorf("unknown relation ID %d", relId)
}
slog.Debug(
"Received truncation message",
slog.Uint64("xid", uint64(c.latestXid)),
slog.String("namespace", rel.Namespace),
slog.String("relation", rel.RelationName), slog.Int("xid", int(logicalMsg.Xid)),
slog.Bool("in_stream", c.inStream),
)
ev := domain.ReplicationEvent{
EventType: domain.EventTypeTruncate,
TransactionId: c.latestXid,
DBName: c.sysident.DBName,
Namespace: rel.Namespace,
Relation: rel.RelationName,
}
if err := c.Consumer.OnDataChange(ctx, ev); err != nil {
return err
}
}
case *pglogrepl.TypeMessageV2:
case *pglogrepl.OriginMessage:
case *pglogrepl.LogicalDecodingMessageV2:
slog.Debug(
"Logical decoding message",
slog.String("prefix", logicalMsg.Prefix),
slog.Any("content", logicalMsg.Content),
slog.Int("xid", int(logicalMsg.Xid)),
)
case *pglogrepl.StreamStartMessageV2:
c.inStream = true
slog.Debug("Stream start message",
slog.Int("xid", int(logicalMsg.Xid)),
slog.Uint64("first_segment", uint64(logicalMsg.FirstSegment)),
)
case *pglogrepl.StreamStopMessageV2:
c.inStream = false
slog.Debug("Stream stop message")
case *pglogrepl.StreamCommitMessageV2:
slog.Debug("Stream commit message", slog.Int("xid", int(logicalMsg.Xid)))
case *pglogrepl.StreamAbortMessageV2:
slog.Debug("Stream abort message", slog.Int("xid", int(logicalMsg.Xid)))
default:
slog.Warn("Unknown message type", slog.String("type", fmt.Sprintf("%T", logicalMsg)))
}
return nil
}
func (c *ReplicationClient) tupleToMap(relation *pglogrepl.RelationMessageV2, data *pglogrepl.TupleData, onlyKey bool) *domain.Values {
if data == nil {
return nil
}
values := domain.NewValues()
for idx, col := range data.Columns {
isKey := relation.Columns[idx].Flags&1 > 0
if onlyKey && !isKey {
continue
}
colName := relation.Columns[idx].Name
switch col.DataType {
case pglogrepl.TupleDataTypeNull:
values.AddValue(isKey, colName, nil)
case pglogrepl.TupleDataTypeToast: // unchanged toast
// This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you.
case pglogrepl.TupleDataTypeBinary:
values.AddValue(isKey, colName, col.Data)
case pglogrepl.TupleDataTypeText:
val, err := c.decodeTextColumnData(col.Data, relation.Columns[idx].DataType)
if err != nil {
log.Fatalln("error decoding column data:", err)
}
values.AddValue(isKey, colName, val)
}
}
return values
}
func (c *ReplicationClient) decodeTextColumnData(data []byte, dataType uint32) (any, error) {
if dt, ok := c.typeMap.TypeForOID(dataType); ok {
decoded, err := dt.Codec.DecodeValue(c.typeMap, dataType, pgtype.TextFormatCode, data)
if err != nil {
return nil, err
}
switch dataType {
case pgtype.UUIDOID:
raw := decoded.([16]byte)
return uuid.FromBytes(raw[:])
default:
return decoded, nil
}
}
return string(data), nil
}