package db import ( "context" "errors" "fmt" "log" "log/slog" "time" "github.com/google/uuid" "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" "code.icb4dc0.de/prskr/pg_v_man/core/domain" "code.icb4dc0.de/prskr/pg_v_man/core/ports" "code.icb4dc0.de/prskr/pg_v_man/infrastructure/config" ) const ( outputPlugin = "pgoutput" defaultConsumerTimeout = 200 * time.Millisecond ) func NewReplicationClient(ctx context.Context, cfg config.DB, consumer ports.ReplicationEventConsumer) (client *ReplicationClient, err error) { conn, err := pgconn.Connect(ctx, cfg.ConnectionString) if err != nil { return nil, fmt.Errorf("could not connect to database: %w", err) } defer func() { if err != nil { err = errors.Join(err, conn.Close(context.Background())) } }() client = &ReplicationClient{ Conn: conn, Consumer: consumer, ConsumerTimeout: defaultConsumerTimeout, typeMap: pgtype.NewMap(), relations: make(map[uint32]*pglogrepl.RelationMessageV2), } client.sysident, err = pglogrepl.IdentifySystem(ctx, conn) if err != nil { return nil, fmt.Errorf("could not identify system: %w", err) } slog.Info( "System identity", slog.String("systemId", client.sysident.SystemID), slog.Int("timeline", int(client.sysident.Timeline)), slog.String("xlogpos", client.sysident.XLogPos.String()), slog.String("dbname", client.sysident.DBName), ) _, err = pglogrepl.CreateReplicationSlot(ctx, conn, cfg.SlotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: true}) if err != nil { return nil, fmt.Errorf("could not create replication slot: %w", err) } slog.Info("Replication slot created", slog.String("slotName", cfg.SlotName)) pluginArguments := []string{ "proto_version '2'", fmt.Sprintf("publication_names '%s'", cfg.Publication), "messages 'true'", "streaming 'true'", } err = pglogrepl.StartReplication(ctx, conn, cfg.SlotName, client.sysident.XLogPos, pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments}) if err != nil { return nil, fmt.Errorf("could not start replication: %w", err) } slog.Info("Replication started", slog.String("slotName", cfg.SlotName)) return client, nil } type ReplicationClient struct { ConsumerTimeout time.Duration Conn *pgconn.PgConn Consumer ports.ReplicationEventConsumer sysident pglogrepl.IdentifySystemResult typeMap *pgtype.Map relations map[uint32]*pglogrepl.RelationMessageV2 latestXid uint32 inStream bool } func (c *ReplicationClient) Receive(ctx context.Context) (err error) { clientXLogPos := c.sysident.XLogPos standbyMessageTimeout := time.Second * 10 nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout) for ctx.Err() == nil { if time.Now().After(nextStandbyMessageDeadline) { err = pglogrepl.SendStandbyStatusUpdate(context.Background(), c.Conn, pglogrepl.StandbyStatusUpdate{WALWritePosition: clientXLogPos}) if err != nil { return fmt.Errorf("SendStandbyStatusUpdate failed: %w", err) } slog.Debug("Sent Standby status message", slog.String("xlogpos", clientXLogPos.String())) nextStandbyMessageDeadline = time.Now().Add(standbyMessageTimeout) } receiveCtx, stop := context.WithDeadline(ctx, nextStandbyMessageDeadline) rawMsg, err := c.Conn.ReceiveMessage(receiveCtx) stop() if err != nil { if pgconn.Timeout(err) { continue } return fmt.Errorf("could not receive message: %w", err) } if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { slog.Error("Error response", slog.String("message", errMsg.Message)) continue } msg, ok := rawMsg.(*pgproto3.CopyData) if !ok { slog.Warn("Received unexpected message", slog.String("message", fmt.Sprintf("%T", rawMsg))) continue } switch msg.Data[0] { case pglogrepl.PrimaryKeepaliveMessageByteID: pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:]) if err != nil { return fmt.Errorf("ParsePrimaryKeepaliveMessage failed: %w", err) } slog.Debug( "Primary Keepalive Message", slog.String("ServerWALEnd", pkm.ServerWALEnd.String()), slog.Time("ServerTime", pkm.ServerTime), slog.Bool("ReplyRequested", pkm.ReplyRequested), ) if pkm.ServerWALEnd > clientXLogPos { clientXLogPos = pkm.ServerWALEnd } if pkm.ReplyRequested { nextStandbyMessageDeadline = time.Time{} } case pglogrepl.XLogDataByteID: xld, err := pglogrepl.ParseXLogData(msg.Data[1:]) if err != nil { return fmt.Errorf("ParseXLogData failed: %w", err) } slog.Debug( "Received log data", slog.String("WALStart", xld.WALStart.String()), slog.String("ServerWALEnd", xld.ServerWALEnd.String()), slog.Time("ServerTime", xld.ServerTime), ) processCtx, stop := context.WithTimeout(ctx, c.ConsumerTimeout) err = c.handleReceivedLog(processCtx, xld.WALData) stop() if err != nil { return err } if xld.WALStart > clientXLogPos { clientXLogPos = xld.WALStart } } } return nil } func (c *ReplicationClient) Close(ctx context.Context) error { return c.Conn.Close(ctx) } func (c *ReplicationClient) handleReceivedLog(ctx context.Context, walData []byte) error { logicalMsg, err := pglogrepl.ParseV2(walData, c.inStream) if err != nil { return fmt.Errorf("parse logical replication message: %w", err) } slog.Debug("Receive a logical replication message", slog.String("type", logicalMsg.Type().String())) switch logicalMsg := logicalMsg.(type) { case *pglogrepl.RelationMessageV2: c.relations[logicalMsg.RelationID] = logicalMsg case *pglogrepl.BeginMessage: // Indicates the beginning of a group of changes in a transaction. This is only sent for committed transactions. You won't get any events from rolled back transactions. slog.Debug("Begin message", slog.Int("xid", int(logicalMsg.Xid))) c.latestXid = logicalMsg.Xid case *pglogrepl.CommitMessage: slog.Debug("Commit message", slog.String("commit_lsn", logicalMsg.CommitLSN.String())) case *pglogrepl.InsertMessageV2: rel, ok := c.relations[logicalMsg.RelationID] if !ok { return fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID) } slog.Debug("Received insert message", slog.Uint64("xid", uint64(c.latestXid)), slog.String("namespace", rel.Namespace), slog.String("relation", rel.RelationName), slog.Bool("in_stream", c.inStream), ) ev := domain.ReplicationEvent{ EventType: domain.EventTypeInsert, TransactionId: c.latestXid, DBName: c.sysident.DBName, Namespace: rel.Namespace, Relation: rel.RelationName, NewValues: c.tupleToMap(rel, logicalMsg.Tuple, false), } return c.Consumer.OnDataChange(ctx, ev) case *pglogrepl.UpdateMessageV2: rel, ok := c.relations[logicalMsg.RelationID] if !ok { return fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID) } slog.Debug( "Received update message", slog.Uint64("xid", uint64(c.latestXid)), slog.String("namespace", rel.Namespace), slog.String("relation", rel.RelationName), slog.Bool("in_stream", c.inStream), ) ev := domain.ReplicationEvent{ EventType: domain.EventTypeUpdate, TransactionId: c.latestXid, DBName: c.sysident.DBName, Namespace: rel.Namespace, Relation: rel.RelationName, NewValues: c.tupleToMap(rel, logicalMsg.NewTuple, false), OldValues: c.tupleToMap(rel, logicalMsg.OldTuple, false), } return c.Consumer.OnDataChange(ctx, ev) case *pglogrepl.DeleteMessageV2: rel, ok := c.relations[logicalMsg.RelationID] if !ok { return fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID) } slog.Debug( "Received deletion message", slog.Uint64("xid", uint64(c.latestXid)), slog.String("namespace", rel.Namespace), slog.String("relation", rel.RelationName), slog.Bool("in_stream", c.inStream), ) ev := domain.ReplicationEvent{ EventType: domain.EventTypeDelete, TransactionId: c.latestXid, DBName: c.sysident.DBName, Namespace: rel.Namespace, Relation: rel.RelationName, OldValues: c.tupleToMap(rel, logicalMsg.OldTuple, true), } return c.Consumer.OnDataChange(ctx, ev) case *pglogrepl.TruncateMessageV2: for _, relId := range logicalMsg.RelationIDs { rel, ok := c.relations[relId] if !ok { return fmt.Errorf("unknown relation ID %d", relId) } slog.Debug( "Received truncation message", slog.Uint64("xid", uint64(c.latestXid)), slog.String("namespace", rel.Namespace), slog.String("relation", rel.RelationName), slog.Int("xid", int(logicalMsg.Xid)), slog.Bool("in_stream", c.inStream), ) ev := domain.ReplicationEvent{ EventType: domain.EventTypeTruncate, TransactionId: c.latestXid, DBName: c.sysident.DBName, Namespace: rel.Namespace, Relation: rel.RelationName, } if err := c.Consumer.OnDataChange(ctx, ev); err != nil { return err } } case *pglogrepl.TypeMessageV2: case *pglogrepl.OriginMessage: case *pglogrepl.LogicalDecodingMessageV2: slog.Debug( "Logical decoding message", slog.String("prefix", logicalMsg.Prefix), slog.Any("content", logicalMsg.Content), slog.Int("xid", int(logicalMsg.Xid)), ) case *pglogrepl.StreamStartMessageV2: c.inStream = true slog.Debug("Stream start message", slog.Int("xid", int(logicalMsg.Xid)), slog.Uint64("first_segment", uint64(logicalMsg.FirstSegment)), ) case *pglogrepl.StreamStopMessageV2: c.inStream = false slog.Debug("Stream stop message") case *pglogrepl.StreamCommitMessageV2: slog.Debug("Stream commit message", slog.Int("xid", int(logicalMsg.Xid))) case *pglogrepl.StreamAbortMessageV2: slog.Debug("Stream abort message", slog.Int("xid", int(logicalMsg.Xid))) default: slog.Warn("Unknown message type", slog.String("type", fmt.Sprintf("%T", logicalMsg))) } return nil } func (c *ReplicationClient) tupleToMap(relation *pglogrepl.RelationMessageV2, data *pglogrepl.TupleData, onlyKey bool) *domain.Values { if data == nil { return nil } values := domain.NewValues() for idx, col := range data.Columns { isKey := relation.Columns[idx].Flags&1 > 0 if onlyKey && !isKey { continue } colName := relation.Columns[idx].Name switch col.DataType { case pglogrepl.TupleDataTypeNull: values.AddValue(isKey, colName, nil) case pglogrepl.TupleDataTypeToast: // unchanged toast // This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you. case pglogrepl.TupleDataTypeBinary: values.AddValue(isKey, colName, col.Data) case pglogrepl.TupleDataTypeText: val, err := c.decodeTextColumnData(col.Data, relation.Columns[idx].DataType) if err != nil { log.Fatalln("error decoding column data:", err) } values.AddValue(isKey, colName, val) } } return values } func (c *ReplicationClient) decodeTextColumnData(data []byte, dataType uint32) (any, error) { if dt, ok := c.typeMap.TypeForOID(dataType); ok { decoded, err := dt.Codec.DecodeValue(c.typeMap, dataType, pgtype.TextFormatCode, data) if err != nil { return nil, err } switch dataType { case pgtype.UUIDOID: raw := decoded.([16]byte) return uuid.FromBytes(raw[:]) default: return decoded, nil } } return string(data), nil }