385 lines
11 KiB
Go
385 lines
11 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"log/slog"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pglogrepl"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgproto3"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
|
|
"code.icb4dc0.de/prskr/pg_v_man/core/domain"
|
|
"code.icb4dc0.de/prskr/pg_v_man/core/ports"
|
|
"code.icb4dc0.de/prskr/pg_v_man/infrastructure/config"
|
|
)
|
|
|
|
const (
|
|
outputPlugin = "pgoutput"
|
|
defaultConsumerTimeout = 200 * time.Millisecond
|
|
)
|
|
|
|
func NewReplicationClient(ctx context.Context, cfg config.DB, consumer ports.ReplicationEventConsumer) (client *ReplicationClient, err error) {
|
|
conn, err := pgconn.Connect(ctx, cfg.ConnectionString)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not connect to database: %w", err)
|
|
}
|
|
|
|
defer func() {
|
|
if err != nil {
|
|
err = errors.Join(err, conn.Close(context.Background()))
|
|
}
|
|
}()
|
|
|
|
client = &ReplicationClient{
|
|
Conn: conn,
|
|
Consumer: consumer,
|
|
ConsumerTimeout: defaultConsumerTimeout,
|
|
typeMap: pgtype.NewMap(),
|
|
relations: make(map[uint32]*pglogrepl.RelationMessageV2),
|
|
}
|
|
|
|
client.sysident, err = pglogrepl.IdentifySystem(ctx, conn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not identify system: %w", err)
|
|
}
|
|
|
|
slog.Info(
|
|
"System identity",
|
|
slog.String("systemId", client.sysident.SystemID),
|
|
slog.Int("timeline", int(client.sysident.Timeline)),
|
|
slog.String("xlogpos", client.sysident.XLogPos.String()),
|
|
slog.String("dbname", client.sysident.DBName),
|
|
)
|
|
|
|
_, err = pglogrepl.CreateReplicationSlot(ctx, conn, cfg.SlotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: true})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not create replication slot: %w", err)
|
|
}
|
|
|
|
slog.Info("Replication slot created", slog.String("slotName", cfg.SlotName))
|
|
|
|
pluginArguments := []string{
|
|
"proto_version '2'",
|
|
fmt.Sprintf("publication_names '%s'", cfg.Publication),
|
|
"messages 'true'",
|
|
"streaming 'true'",
|
|
}
|
|
|
|
err = pglogrepl.StartReplication(ctx, conn, cfg.SlotName, client.sysident.XLogPos, pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not start replication: %w", err)
|
|
}
|
|
|
|
slog.Info("Replication started", slog.String("slotName", cfg.SlotName))
|
|
|
|
return client, nil
|
|
}
|
|
|
|
type ReplicationClient struct {
|
|
ConsumerTimeout time.Duration
|
|
Conn *pgconn.PgConn
|
|
Consumer ports.ReplicationEventConsumer
|
|
sysident pglogrepl.IdentifySystemResult
|
|
typeMap *pgtype.Map
|
|
relations map[uint32]*pglogrepl.RelationMessageV2
|
|
latestXid uint32
|
|
inStream bool
|
|
}
|
|
|
|
func (c *ReplicationClient) Receive(ctx context.Context) (err error) {
|
|
clientXLogPos := c.sysident.XLogPos
|
|
standbyMessageTimeout := time.Second * 10
|
|
nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout)
|
|
|
|
for ctx.Err() == nil {
|
|
if time.Now().After(nextStandbyMessageDeadline) {
|
|
err = pglogrepl.SendStandbyStatusUpdate(context.Background(), c.Conn, pglogrepl.StandbyStatusUpdate{WALWritePosition: clientXLogPos})
|
|
if err != nil {
|
|
return fmt.Errorf("SendStandbyStatusUpdate failed: %w", err)
|
|
}
|
|
|
|
slog.Debug("Sent Standby status message", slog.String("xlogpos", clientXLogPos.String()))
|
|
nextStandbyMessageDeadline = time.Now().Add(standbyMessageTimeout)
|
|
}
|
|
|
|
receiveCtx, stop := context.WithDeadline(ctx, nextStandbyMessageDeadline)
|
|
rawMsg, err := c.Conn.ReceiveMessage(receiveCtx)
|
|
stop()
|
|
|
|
if err != nil {
|
|
if pgconn.Timeout(err) {
|
|
continue
|
|
}
|
|
|
|
return fmt.Errorf("could not receive message: %w", err)
|
|
}
|
|
|
|
if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok {
|
|
slog.Error("Error response", slog.String("message", errMsg.Message))
|
|
continue
|
|
}
|
|
|
|
msg, ok := rawMsg.(*pgproto3.CopyData)
|
|
if !ok {
|
|
slog.Warn("Received unexpected message", slog.String("message", fmt.Sprintf("%T", rawMsg)))
|
|
continue
|
|
}
|
|
|
|
switch msg.Data[0] {
|
|
case pglogrepl.PrimaryKeepaliveMessageByteID:
|
|
pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:])
|
|
if err != nil {
|
|
return fmt.Errorf("ParsePrimaryKeepaliveMessage failed: %w", err)
|
|
}
|
|
|
|
slog.Debug(
|
|
"Primary Keepalive Message",
|
|
slog.String("ServerWALEnd", pkm.ServerWALEnd.String()),
|
|
slog.Time("ServerTime", pkm.ServerTime),
|
|
slog.Bool("ReplyRequested", pkm.ReplyRequested),
|
|
)
|
|
|
|
if pkm.ServerWALEnd > clientXLogPos {
|
|
clientXLogPos = pkm.ServerWALEnd
|
|
}
|
|
if pkm.ReplyRequested {
|
|
nextStandbyMessageDeadline = time.Time{}
|
|
}
|
|
|
|
case pglogrepl.XLogDataByteID:
|
|
xld, err := pglogrepl.ParseXLogData(msg.Data[1:])
|
|
if err != nil {
|
|
return fmt.Errorf("ParseXLogData failed: %w", err)
|
|
}
|
|
|
|
slog.Debug(
|
|
"Received log data",
|
|
slog.String("WALStart", xld.WALStart.String()),
|
|
slog.String("ServerWALEnd", xld.ServerWALEnd.String()),
|
|
slog.Time("ServerTime", xld.ServerTime),
|
|
)
|
|
|
|
processCtx, stop := context.WithTimeout(ctx, c.ConsumerTimeout)
|
|
err = c.handleReceivedLog(processCtx, xld.WALData)
|
|
stop()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if xld.WALStart > clientXLogPos {
|
|
clientXLogPos = xld.WALStart
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *ReplicationClient) Close(ctx context.Context) error {
|
|
return c.Conn.Close(ctx)
|
|
}
|
|
|
|
func (c *ReplicationClient) handleReceivedLog(ctx context.Context, walData []byte) error {
|
|
logicalMsg, err := pglogrepl.ParseV2(walData, c.inStream)
|
|
if err != nil {
|
|
return fmt.Errorf("parse logical replication message: %w", err)
|
|
}
|
|
|
|
slog.Debug("Receive a logical replication message", slog.String("type", logicalMsg.Type().String()))
|
|
|
|
switch logicalMsg := logicalMsg.(type) {
|
|
case *pglogrepl.RelationMessageV2:
|
|
c.relations[logicalMsg.RelationID] = logicalMsg
|
|
|
|
case *pglogrepl.BeginMessage:
|
|
// Indicates the beginning of a group of changes in a transaction. This is only sent for committed transactions. You won't get any events from rolled back transactions.
|
|
slog.Debug("Begin message", slog.Int("xid", int(logicalMsg.Xid)))
|
|
c.latestXid = logicalMsg.Xid
|
|
case *pglogrepl.CommitMessage:
|
|
slog.Debug("Commit message", slog.String("commit_lsn", logicalMsg.CommitLSN.String()))
|
|
case *pglogrepl.InsertMessageV2:
|
|
rel, ok := c.relations[logicalMsg.RelationID]
|
|
if !ok {
|
|
return fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID)
|
|
}
|
|
|
|
slog.Debug("Received insert message",
|
|
slog.Uint64("xid", uint64(c.latestXid)),
|
|
slog.String("namespace", rel.Namespace),
|
|
slog.String("relation", rel.RelationName),
|
|
slog.Bool("in_stream", c.inStream),
|
|
)
|
|
|
|
ev := domain.ReplicationEvent{
|
|
EventType: domain.EventTypeInsert,
|
|
TransactionId: c.latestXid,
|
|
DBName: c.sysident.DBName,
|
|
Namespace: rel.Namespace,
|
|
Relation: rel.RelationName,
|
|
NewValues: c.tupleToMap(rel, logicalMsg.Tuple, false),
|
|
}
|
|
|
|
return c.Consumer.OnDataChange(ctx, ev)
|
|
|
|
case *pglogrepl.UpdateMessageV2:
|
|
rel, ok := c.relations[logicalMsg.RelationID]
|
|
if !ok {
|
|
return fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID)
|
|
}
|
|
|
|
slog.Debug(
|
|
"Received update message",
|
|
slog.Uint64("xid", uint64(c.latestXid)),
|
|
slog.String("namespace", rel.Namespace),
|
|
slog.String("relation", rel.RelationName),
|
|
slog.Bool("in_stream", c.inStream),
|
|
)
|
|
|
|
ev := domain.ReplicationEvent{
|
|
EventType: domain.EventTypeUpdate,
|
|
TransactionId: c.latestXid,
|
|
DBName: c.sysident.DBName,
|
|
Namespace: rel.Namespace,
|
|
Relation: rel.RelationName,
|
|
NewValues: c.tupleToMap(rel, logicalMsg.NewTuple, false),
|
|
OldValues: c.tupleToMap(rel, logicalMsg.OldTuple, false),
|
|
}
|
|
|
|
return c.Consumer.OnDataChange(ctx, ev)
|
|
case *pglogrepl.DeleteMessageV2:
|
|
rel, ok := c.relations[logicalMsg.RelationID]
|
|
if !ok {
|
|
return fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID)
|
|
}
|
|
|
|
slog.Debug(
|
|
"Received deletion message",
|
|
slog.Uint64("xid", uint64(c.latestXid)),
|
|
slog.String("namespace", rel.Namespace),
|
|
slog.String("relation", rel.RelationName),
|
|
slog.Bool("in_stream", c.inStream),
|
|
)
|
|
|
|
ev := domain.ReplicationEvent{
|
|
EventType: domain.EventTypeDelete,
|
|
TransactionId: c.latestXid,
|
|
DBName: c.sysident.DBName,
|
|
Namespace: rel.Namespace,
|
|
Relation: rel.RelationName,
|
|
OldValues: c.tupleToMap(rel, logicalMsg.OldTuple, true),
|
|
}
|
|
|
|
return c.Consumer.OnDataChange(ctx, ev)
|
|
case *pglogrepl.TruncateMessageV2:
|
|
for _, relId := range logicalMsg.RelationIDs {
|
|
rel, ok := c.relations[relId]
|
|
if !ok {
|
|
return fmt.Errorf("unknown relation ID %d", relId)
|
|
}
|
|
|
|
slog.Debug(
|
|
"Received truncation message",
|
|
slog.Uint64("xid", uint64(c.latestXid)),
|
|
slog.String("namespace", rel.Namespace),
|
|
slog.String("relation", rel.RelationName), slog.Int("xid", int(logicalMsg.Xid)),
|
|
slog.Bool("in_stream", c.inStream),
|
|
)
|
|
|
|
ev := domain.ReplicationEvent{
|
|
EventType: domain.EventTypeTruncate,
|
|
TransactionId: c.latestXid,
|
|
DBName: c.sysident.DBName,
|
|
Namespace: rel.Namespace,
|
|
Relation: rel.RelationName,
|
|
}
|
|
|
|
if err := c.Consumer.OnDataChange(ctx, ev); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
case *pglogrepl.TypeMessageV2:
|
|
case *pglogrepl.OriginMessage:
|
|
case *pglogrepl.LogicalDecodingMessageV2:
|
|
slog.Debug(
|
|
"Logical decoding message",
|
|
slog.String("prefix", logicalMsg.Prefix),
|
|
slog.Any("content", logicalMsg.Content),
|
|
slog.Int("xid", int(logicalMsg.Xid)),
|
|
)
|
|
case *pglogrepl.StreamStartMessageV2:
|
|
c.inStream = true
|
|
slog.Debug("Stream start message",
|
|
slog.Int("xid", int(logicalMsg.Xid)),
|
|
slog.Uint64("first_segment", uint64(logicalMsg.FirstSegment)),
|
|
)
|
|
|
|
case *pglogrepl.StreamStopMessageV2:
|
|
c.inStream = false
|
|
slog.Debug("Stream stop message")
|
|
case *pglogrepl.StreamCommitMessageV2:
|
|
slog.Debug("Stream commit message", slog.Int("xid", int(logicalMsg.Xid)))
|
|
case *pglogrepl.StreamAbortMessageV2:
|
|
slog.Debug("Stream abort message", slog.Int("xid", int(logicalMsg.Xid)))
|
|
default:
|
|
slog.Warn("Unknown message type", slog.String("type", fmt.Sprintf("%T", logicalMsg)))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *ReplicationClient) tupleToMap(relation *pglogrepl.RelationMessageV2, data *pglogrepl.TupleData, onlyKey bool) *domain.Values {
|
|
if data == nil {
|
|
return nil
|
|
}
|
|
values := domain.NewValues()
|
|
for idx, col := range data.Columns {
|
|
isKey := relation.Columns[idx].Flags&1 > 0
|
|
if onlyKey && !isKey {
|
|
continue
|
|
}
|
|
|
|
colName := relation.Columns[idx].Name
|
|
switch col.DataType {
|
|
case pglogrepl.TupleDataTypeNull:
|
|
values.AddValue(isKey, colName, nil)
|
|
case pglogrepl.TupleDataTypeToast: // unchanged toast
|
|
// This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you.
|
|
case pglogrepl.TupleDataTypeBinary:
|
|
values.AddValue(isKey, colName, col.Data)
|
|
case pglogrepl.TupleDataTypeText:
|
|
val, err := c.decodeTextColumnData(col.Data, relation.Columns[idx].DataType)
|
|
if err != nil {
|
|
log.Fatalln("error decoding column data:", err)
|
|
}
|
|
values.AddValue(isKey, colName, val)
|
|
}
|
|
}
|
|
|
|
return values
|
|
}
|
|
|
|
func (c *ReplicationClient) decodeTextColumnData(data []byte, dataType uint32) (any, error) {
|
|
if dt, ok := c.typeMap.TypeForOID(dataType); ok {
|
|
decoded, err := dt.Codec.DecodeValue(c.typeMap, dataType, pgtype.TextFormatCode, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch dataType {
|
|
case pgtype.UUIDOID:
|
|
raw := decoded.([16]byte)
|
|
return uuid.FromBytes(raw[:])
|
|
default:
|
|
return decoded, nil
|
|
}
|
|
}
|
|
return string(data), nil
|
|
}
|