package db

import (
	"context"
	"errors"
	"fmt"
	"log"
	"log/slog"
	"time"

	"github.com/google/uuid"
	"github.com/jackc/pglogrepl"
	"github.com/jackc/pgx/v5/pgconn"
	"github.com/jackc/pgx/v5/pgproto3"
	"github.com/jackc/pgx/v5/pgtype"

	"code.icb4dc0.de/prskr/pg_v_man/core/domain"
	"code.icb4dc0.de/prskr/pg_v_man/core/ports"
	"code.icb4dc0.de/prskr/pg_v_man/infrastructure/config"
)

const (
	outputPlugin           = "pgoutput"
	defaultConsumerTimeout = 200 * time.Millisecond
)

func NewReplicationClient(ctx context.Context, cfg config.DB, consumer ports.ReplicationEventConsumer) (client *ReplicationClient, err error) {
	conn, err := pgconn.Connect(ctx, cfg.ConnectionString)
	if err != nil {
		return nil, fmt.Errorf("could not connect to database: %w", err)
	}

	defer func() {
		if err != nil {
			err = errors.Join(err, conn.Close(context.Background()))
		}
	}()

	client = &ReplicationClient{
		Conn:            conn,
		Consumer:        consumer,
		ConsumerTimeout: defaultConsumerTimeout,
		typeMap:         pgtype.NewMap(),
		relations:       make(map[uint32]*pglogrepl.RelationMessageV2),
	}

	client.sysident, err = pglogrepl.IdentifySystem(ctx, conn)
	if err != nil {
		return nil, fmt.Errorf("could not identify system: %w", err)
	}

	slog.Info(
		"System identity",
		slog.String("systemId", client.sysident.SystemID),
		slog.Int("timeline", int(client.sysident.Timeline)),
		slog.String("xlogpos", client.sysident.XLogPos.String()),
		slog.String("dbname", client.sysident.DBName),
	)

	_, err = pglogrepl.CreateReplicationSlot(ctx, conn, cfg.SlotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: true})
	if err != nil {
		return nil, fmt.Errorf("could not create replication slot: %w", err)
	}

	slog.Info("Replication slot created", slog.String("slotName", cfg.SlotName))

	pluginArguments := []string{
		"proto_version '2'",
		fmt.Sprintf("publication_names '%s'", cfg.Publication),
		"messages 'true'",
		"streaming 'true'",
	}

	err = pglogrepl.StartReplication(ctx, conn, cfg.SlotName, client.sysident.XLogPos, pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments})
	if err != nil {
		return nil, fmt.Errorf("could not start replication: %w", err)
	}

	slog.Info("Replication started", slog.String("slotName", cfg.SlotName))

	return client, nil
}

type ReplicationClient struct {
	ConsumerTimeout time.Duration
	Conn            *pgconn.PgConn
	Consumer        ports.ReplicationEventConsumer
	sysident        pglogrepl.IdentifySystemResult
	typeMap         *pgtype.Map
	relations       map[uint32]*pglogrepl.RelationMessageV2
	latestXid       uint32
	inStream        bool
}

func (c *ReplicationClient) Receive(ctx context.Context) (err error) {
	clientXLogPos := c.sysident.XLogPos
	standbyMessageTimeout := time.Second * 10
	nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout)

	for ctx.Err() == nil {
		if time.Now().After(nextStandbyMessageDeadline) {
			err = pglogrepl.SendStandbyStatusUpdate(context.Background(), c.Conn, pglogrepl.StandbyStatusUpdate{WALWritePosition: clientXLogPos})
			if err != nil {
				return fmt.Errorf("SendStandbyStatusUpdate failed: %w", err)
			}

			slog.Debug("Sent Standby status message", slog.String("xlogpos", clientXLogPos.String()))
			nextStandbyMessageDeadline = time.Now().Add(standbyMessageTimeout)
		}

		receiveCtx, stop := context.WithDeadline(ctx, nextStandbyMessageDeadline)
		rawMsg, err := c.Conn.ReceiveMessage(receiveCtx)
		stop()

		if err != nil {
			if pgconn.Timeout(err) {
				continue
			}

			return fmt.Errorf("could not receive message: %w", err)
		}

		if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok {
			slog.Error("Error response", slog.String("message", errMsg.Message))
			continue
		}

		msg, ok := rawMsg.(*pgproto3.CopyData)
		if !ok {
			slog.Warn("Received unexpected message", slog.String("message", fmt.Sprintf("%T", rawMsg)))
			continue
		}

		switch msg.Data[0] {
		case pglogrepl.PrimaryKeepaliveMessageByteID:
			pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:])
			if err != nil {
				return fmt.Errorf("ParsePrimaryKeepaliveMessage failed: %w", err)
			}

			slog.Debug(
				"Primary Keepalive Message",
				slog.String("ServerWALEnd", pkm.ServerWALEnd.String()),
				slog.Time("ServerTime", pkm.ServerTime),
				slog.Bool("ReplyRequested", pkm.ReplyRequested),
			)

			if pkm.ServerWALEnd > clientXLogPos {
				clientXLogPos = pkm.ServerWALEnd
			}
			if pkm.ReplyRequested {
				nextStandbyMessageDeadline = time.Time{}
			}

		case pglogrepl.XLogDataByteID:
			xld, err := pglogrepl.ParseXLogData(msg.Data[1:])
			if err != nil {
				return fmt.Errorf("ParseXLogData failed: %w", err)
			}

			slog.Debug(
				"Received log data",
				slog.String("WALStart", xld.WALStart.String()),
				slog.String("ServerWALEnd", xld.ServerWALEnd.String()),
				slog.Time("ServerTime", xld.ServerTime),
			)

			processCtx, stop := context.WithTimeout(ctx, c.ConsumerTimeout)
			err = c.handleReceivedLog(processCtx, xld.WALData)
			stop()

			if err != nil {
				return err
			}

			if xld.WALStart > clientXLogPos {
				clientXLogPos = xld.WALStart
			}
		}

	}

	return nil
}

func (c *ReplicationClient) Close(ctx context.Context) error {
	return c.Conn.Close(ctx)
}

func (c *ReplicationClient) handleReceivedLog(ctx context.Context, walData []byte) error {
	logicalMsg, err := pglogrepl.ParseV2(walData, c.inStream)
	if err != nil {
		return fmt.Errorf("parse logical replication message: %w", err)
	}

	slog.Debug("Receive a logical replication message", slog.String("type", logicalMsg.Type().String()))

	switch logicalMsg := logicalMsg.(type) {
	case *pglogrepl.RelationMessageV2:
		c.relations[logicalMsg.RelationID] = logicalMsg

	case *pglogrepl.BeginMessage:
		// Indicates the beginning of a group of changes in a transaction. This is only sent for committed transactions. You won't get any events from rolled back transactions.
		slog.Debug("Begin message", slog.Int("xid", int(logicalMsg.Xid)))
		c.latestXid = logicalMsg.Xid
	case *pglogrepl.CommitMessage:
		slog.Debug("Commit message", slog.String("commit_lsn", logicalMsg.CommitLSN.String()))
	case *pglogrepl.InsertMessageV2:
		rel, ok := c.relations[logicalMsg.RelationID]
		if !ok {
			return fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID)
		}

		slog.Debug("Received insert message",
			slog.Uint64("xid", uint64(c.latestXid)),
			slog.String("namespace", rel.Namespace),
			slog.String("relation", rel.RelationName),
			slog.Bool("in_stream", c.inStream),
		)

		ev := domain.ReplicationEvent{
			EventType:     domain.EventTypeInsert,
			TransactionId: c.latestXid,
			DBName:        c.sysident.DBName,
			Namespace:     rel.Namespace,
			Relation:      rel.RelationName,
			NewValues:     c.tupleToMap(rel, logicalMsg.Tuple, false),
		}

		return c.Consumer.OnDataChange(ctx, ev)

	case *pglogrepl.UpdateMessageV2:
		rel, ok := c.relations[logicalMsg.RelationID]
		if !ok {
			return fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID)
		}

		slog.Debug(
			"Received update message",
			slog.Uint64("xid", uint64(c.latestXid)),
			slog.String("namespace", rel.Namespace),
			slog.String("relation", rel.RelationName),
			slog.Bool("in_stream", c.inStream),
		)

		ev := domain.ReplicationEvent{
			EventType:     domain.EventTypeUpdate,
			TransactionId: c.latestXid,
			DBName:        c.sysident.DBName,
			Namespace:     rel.Namespace,
			Relation:      rel.RelationName,
			NewValues:     c.tupleToMap(rel, logicalMsg.NewTuple, false),
			OldValues:     c.tupleToMap(rel, logicalMsg.OldTuple, false),
		}

		return c.Consumer.OnDataChange(ctx, ev)
	case *pglogrepl.DeleteMessageV2:
		rel, ok := c.relations[logicalMsg.RelationID]
		if !ok {
			return fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID)
		}

		slog.Debug(
			"Received deletion message",
			slog.Uint64("xid", uint64(c.latestXid)),
			slog.String("namespace", rel.Namespace),
			slog.String("relation", rel.RelationName),
			slog.Bool("in_stream", c.inStream),
		)

		ev := domain.ReplicationEvent{
			EventType:     domain.EventTypeDelete,
			TransactionId: c.latestXid,
			DBName:        c.sysident.DBName,
			Namespace:     rel.Namespace,
			Relation:      rel.RelationName,
			OldValues:     c.tupleToMap(rel, logicalMsg.OldTuple, true),
		}

		return c.Consumer.OnDataChange(ctx, ev)
	case *pglogrepl.TruncateMessageV2:
		for _, relId := range logicalMsg.RelationIDs {
			rel, ok := c.relations[relId]
			if !ok {
				return fmt.Errorf("unknown relation ID %d", relId)
			}

			slog.Debug(
				"Received truncation message",
				slog.Uint64("xid", uint64(c.latestXid)),
				slog.String("namespace", rel.Namespace),
				slog.String("relation", rel.RelationName), slog.Int("xid", int(logicalMsg.Xid)),
				slog.Bool("in_stream", c.inStream),
			)

			ev := domain.ReplicationEvent{
				EventType:     domain.EventTypeTruncate,
				TransactionId: c.latestXid,
				DBName:        c.sysident.DBName,
				Namespace:     rel.Namespace,
				Relation:      rel.RelationName,
			}

			if err := c.Consumer.OnDataChange(ctx, ev); err != nil {
				return err
			}
		}
	case *pglogrepl.TypeMessageV2:
	case *pglogrepl.OriginMessage:
	case *pglogrepl.LogicalDecodingMessageV2:
		slog.Debug(
			"Logical decoding message",
			slog.String("prefix", logicalMsg.Prefix),
			slog.Any("content", logicalMsg.Content),
			slog.Int("xid", int(logicalMsg.Xid)),
		)
	case *pglogrepl.StreamStartMessageV2:
		c.inStream = true
		slog.Debug("Stream start message",
			slog.Int("xid", int(logicalMsg.Xid)),
			slog.Uint64("first_segment", uint64(logicalMsg.FirstSegment)),
		)

	case *pglogrepl.StreamStopMessageV2:
		c.inStream = false
		slog.Debug("Stream stop message")
	case *pglogrepl.StreamCommitMessageV2:
		slog.Debug("Stream commit message", slog.Int("xid", int(logicalMsg.Xid)))
	case *pglogrepl.StreamAbortMessageV2:
		slog.Debug("Stream abort message", slog.Int("xid", int(logicalMsg.Xid)))
	default:
		slog.Warn("Unknown message type", slog.String("type", fmt.Sprintf("%T", logicalMsg)))
	}

	return nil
}

func (c *ReplicationClient) tupleToMap(relation *pglogrepl.RelationMessageV2, data *pglogrepl.TupleData, onlyKey bool) *domain.Values {
	if data == nil {
		return nil
	}
	values := domain.NewValues()
	for idx, col := range data.Columns {
		isKey := relation.Columns[idx].Flags&1 > 0
		if onlyKey && !isKey {
			continue
		}

		colName := relation.Columns[idx].Name
		switch col.DataType {
		case pglogrepl.TupleDataTypeNull:
			values.AddValue(isKey, colName, nil)
		case pglogrepl.TupleDataTypeToast: // unchanged toast
			// This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you.
		case pglogrepl.TupleDataTypeBinary:
			values.AddValue(isKey, colName, col.Data)
		case pglogrepl.TupleDataTypeText:
			val, err := c.decodeTextColumnData(col.Data, relation.Columns[idx].DataType)
			if err != nil {
				log.Fatalln("error decoding column data:", err)
			}
			values.AddValue(isKey, colName, val)
		}
	}

	return values
}

func (c *ReplicationClient) decodeTextColumnData(data []byte, dataType uint32) (any, error) {
	if dt, ok := c.typeMap.TypeForOID(dataType); ok {
		decoded, err := dt.Codec.DecodeValue(c.typeMap, dataType, pgtype.TextFormatCode, data)
		if err != nil {
			return nil, err
		}

		switch dataType {
		case pgtype.UUIDOID:
			raw := decoded.([16]byte)
			return uuid.FromBytes(raw[:])
		default:
			return decoded, nil
		}
	}
	return string(data), nil
}