Add DNS details
This commit is contained in:
parent
0468c93671
commit
6c448fd318
18 changed files with 383 additions and 150 deletions
42
api/proto/pkg/audit/details/dns_details.proto
Normal file
42
api/proto/pkg/audit/details/dns_details.proto
Normal file
|
@ -0,0 +1,42 @@
|
|||
syntax = "proto3";
|
||||
|
||||
option go_package = "gitlab.com/inetmock/inetmock/pkg/audit/details";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "com.github.baez90.inetmock.audit.details";
|
||||
option java_outer_classname = "HandlerEventProto";
|
||||
|
||||
package inetmock.audit;
|
||||
|
||||
enum DNSOpCode {
|
||||
Query = 0;
|
||||
Status = 2;
|
||||
Notify = 4;
|
||||
Update = 5;
|
||||
}
|
||||
|
||||
enum ResourceRecordType {
|
||||
UnknownRR = 0;
|
||||
A = 1;
|
||||
NS = 2;
|
||||
CNAME = 5;
|
||||
SOA = 6;
|
||||
PTR = 12;
|
||||
HINFO = 13;
|
||||
MINFO = 14;
|
||||
MX = 15;
|
||||
TXT = 16;
|
||||
RP = 17;
|
||||
AAAA = 28;
|
||||
SRV = 33;
|
||||
NAPTR = 35;
|
||||
}
|
||||
|
||||
message DNSQuestionEntity {
|
||||
ResourceRecordType type = 1;
|
||||
string name = 2;
|
||||
}
|
||||
|
||||
message DNSDetailsEntity {
|
||||
DNSOpCode opcode = 1;
|
||||
repeated DNSQuestionEntity questions = 2;
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
syntax = "proto3";
|
||||
|
||||
option go_package = "gitlab.com/inetmock/inetmock/pkg/audit";
|
||||
option go_package = "gitlab.com/inetmock/inetmock/pkg/audit/details";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "com.github.baez90.inetmock.audit";
|
||||
option java_package = "com.github.baez90.inetmock.audit.details";
|
||||
option java_outer_classname = "HandlerEventProto";
|
||||
|
||||
package inetmock.audit;
|
47
pkg/audit/details/dns_details.go
Normal file
47
pkg/audit/details/dns_details.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package details
|
||||
|
||||
import (
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
func NewDNSFromWireFormat(entity *DNSDetailsEntity) DNS {
|
||||
d := DNS{
|
||||
OPCode: entity.Opcode,
|
||||
}
|
||||
|
||||
for _, q := range entity.Questions {
|
||||
d.Questions = append(d.Questions, DNSQuestion{
|
||||
RRType: q.Type,
|
||||
Name: q.Name,
|
||||
})
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
type DNSQuestion struct {
|
||||
RRType ResourceRecordType
|
||||
Name string
|
||||
}
|
||||
|
||||
type DNS struct {
|
||||
OPCode DNSOpCode
|
||||
Questions []DNSQuestion
|
||||
}
|
||||
|
||||
func (d DNS) MarshalToWireFormat() (any *anypb.Any, err error) {
|
||||
detailsEntity := &DNSDetailsEntity{
|
||||
Opcode: d.OPCode,
|
||||
}
|
||||
|
||||
for _, q := range d.Questions {
|
||||
detailsEntity.Questions = append(detailsEntity.Questions, &DNSQuestionEntity{
|
||||
Type: q.RRType,
|
||||
Name: q.Name,
|
||||
})
|
||||
}
|
||||
|
||||
any, err = anypb.New(detailsEntity)
|
||||
|
||||
return
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package audit
|
||||
package details
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
@ -7,7 +7,7 @@ import (
|
|||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
type HTTPDetails struct {
|
||||
type HTTP struct {
|
||||
Method string
|
||||
Host string
|
||||
URI string
|
||||
|
@ -15,7 +15,24 @@ type HTTPDetails struct {
|
|||
Headers http.Header
|
||||
}
|
||||
|
||||
func (d HTTPDetails) MarshalToWireFormat() (any *anypb.Any, err error) {
|
||||
func NewHTTPFromWireFormat(entity *HTTPDetailsEntity) HTTP {
|
||||
headers := http.Header{}
|
||||
for name, values := range entity.Headers {
|
||||
for idx := range values.Values {
|
||||
headers.Add(name, values.Values[idx])
|
||||
}
|
||||
}
|
||||
|
||||
return HTTP{
|
||||
Method: entity.Method.String(),
|
||||
Host: entity.Host,
|
||||
URI: entity.Uri,
|
||||
Proto: entity.Proto,
|
||||
Headers: headers,
|
||||
}
|
||||
}
|
||||
|
||||
func (d HTTP) MarshalToWireFormat() (any *anypb.Any, err error) {
|
||||
var method = HTTPMethod_GET
|
||||
if methodValue, known := HTTPMethod_value[strings.ToUpper(d.Method)]; known {
|
||||
method = HTTPMethod(methodValue)
|
|
@ -1,11 +1,12 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/big"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit/details"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
@ -24,35 +25,23 @@ type Event struct {
|
|||
DestinationIP net.IP
|
||||
SourcePort uint16
|
||||
DestinationPort uint16
|
||||
ProtocolDetails *anypb.Any
|
||||
ProtocolDetails Details
|
||||
TLS *TLSDetails
|
||||
}
|
||||
|
||||
func (e *Event) ProtoMessage() proto.Message {
|
||||
var sourceIP isEventEntity_SourceIP
|
||||
if ipv4 := e.SourceIP.To4(); ipv4 != nil {
|
||||
if len(ipv4) == 16 {
|
||||
sourceIP = &EventEntity_SourceIPv4{SourceIPv4: binary.BigEndian.Uint32(ipv4[12:16])}
|
||||
} else {
|
||||
sourceIP = &EventEntity_SourceIPv4{SourceIPv4: binary.BigEndian.Uint32(ipv4)}
|
||||
}
|
||||
sourceIP = &EventEntity_SourceIPv4{SourceIPv4: ipv4ToUint32(ipv4)}
|
||||
} else {
|
||||
ipv6 := big.NewInt(0)
|
||||
ipv6.SetBytes(e.SourceIP)
|
||||
sourceIP = &EventEntity_SourceIPv6{SourceIPv6: ipv6.Uint64()}
|
||||
sourceIP = &EventEntity_SourceIPv6{SourceIPv6: ipv6ToBytes(e.SourceIP)}
|
||||
}
|
||||
|
||||
var destinationIP isEventEntity_DestinationIP
|
||||
if ipv4 := e.DestinationIP.To4(); ipv4 != nil {
|
||||
if len(ipv4) == 16 {
|
||||
destinationIP = &EventEntity_DestinationIPv4{DestinationIPv4: binary.BigEndian.Uint32(ipv4[12:16])}
|
||||
} else {
|
||||
destinationIP = &EventEntity_DestinationIPv4{DestinationIPv4: binary.BigEndian.Uint32(ipv4)}
|
||||
}
|
||||
destinationIP = &EventEntity_DestinationIPv4{DestinationIPv4: ipv4ToUint32(ipv4)}
|
||||
} else {
|
||||
ipv6 := big.NewInt(0)
|
||||
ipv6.SetBytes(e.SourceIP)
|
||||
destinationIP = &EventEntity_DestinationIPv6{DestinationIPv6: ipv6.Uint64()}
|
||||
destinationIP = &EventEntity_DestinationIPv6{DestinationIPv6: ipv6ToBytes(e.DestinationIP)}
|
||||
}
|
||||
|
||||
var tlsDetails *TLSDetailsEntity = nil
|
||||
|
@ -60,6 +49,13 @@ func (e *Event) ProtoMessage() proto.Message {
|
|||
tlsDetails = e.TLS.ProtoMessage()
|
||||
}
|
||||
|
||||
var detailsEntity *anypb.Any = nil
|
||||
if e.ProtocolDetails != nil {
|
||||
if any, err := e.ProtocolDetails.MarshalToWireFormat(); err == nil {
|
||||
detailsEntity = any
|
||||
}
|
||||
}
|
||||
|
||||
return &EventEntity{
|
||||
Id: e.ID,
|
||||
Timestamp: timestamppb.New(e.Timestamp),
|
||||
|
@ -70,7 +66,7 @@ func (e *Event) ProtoMessage() proto.Message {
|
|||
SourcePort: uint32(e.SourcePort),
|
||||
DestinationPort: uint32(e.DestinationPort),
|
||||
Tls: tlsDetails,
|
||||
ProtocolDetails: e.ProtocolDetails,
|
||||
ProtocolDetails: detailsEntity,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -82,27 +78,33 @@ func (e *Event) ApplyDefaults(id int64) {
|
|||
}
|
||||
}
|
||||
|
||||
func (e *Event) SetSourceIPFromAddr(remoteAddr net.Addr) {
|
||||
ip, port := parseIPPortFromAddr(remoteAddr)
|
||||
e.SourceIP = ip
|
||||
e.SourcePort = port
|
||||
}
|
||||
|
||||
func (e *Event) SetDestinationIPFromAddr(localAddr net.Addr) {
|
||||
ip, port := parseIPPortFromAddr(localAddr)
|
||||
e.DestinationIP = ip
|
||||
e.DestinationPort = port
|
||||
}
|
||||
|
||||
func NewEventFromProto(msg *EventEntity) (ev Event) {
|
||||
var sourceIP net.IP
|
||||
switch ip := msg.GetSourceIP().(type) {
|
||||
case *EventEntity_SourceIPv4:
|
||||
buf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(buf, ip.SourceIPv4)
|
||||
sourceIP = buf
|
||||
sourceIP = sourceIP.To4()
|
||||
sourceIP = uint32ToIP(ip.SourceIPv4)
|
||||
case *EventEntity_SourceIPv6:
|
||||
sourceIP = big.NewInt(int64(ip.SourceIPv6)).Bytes()
|
||||
sourceIP = uint64ToIP(ip.SourceIPv6)
|
||||
}
|
||||
|
||||
var destinationIP net.IP
|
||||
switch ip := msg.GetDestinationIP().(type) {
|
||||
case *EventEntity_DestinationIPv4:
|
||||
buf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(buf, ip.DestinationIPv4)
|
||||
destinationIP = buf
|
||||
destinationIP = destinationIP.To4()
|
||||
destinationIP = uint32ToIP(ip.DestinationIPv4)
|
||||
case *EventEntity_DestinationIPv6:
|
||||
destinationIP = big.NewInt(int64(ip.DestinationIPv6)).Bytes()
|
||||
destinationIP = uint64ToIP(ip.DestinationIPv6)
|
||||
}
|
||||
|
||||
ev = Event{
|
||||
|
@ -114,8 +116,46 @@ func NewEventFromProto(msg *EventEntity) (ev Event) {
|
|||
DestinationIP: destinationIP,
|
||||
SourcePort: uint16(msg.GetSourcePort()),
|
||||
DestinationPort: uint16(msg.GetDestinationPort()),
|
||||
ProtocolDetails: msg.GetProtocolDetails(),
|
||||
ProtocolDetails: guessDetailsFromApp(msg.GetProtocolDetails()),
|
||||
TLS: NewTLSDetailsFromProto(msg.GetTls()),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseIPPortFromAddr(addr net.Addr) (ip net.IP, port uint16) {
|
||||
switch a := addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
return a.IP, uint16(a.Port)
|
||||
case *net.UDPAddr:
|
||||
return a.IP, uint16(a.Port)
|
||||
case *net.UnixAddr:
|
||||
return
|
||||
default:
|
||||
ipPortSplit := strings.Split(addr.String(), ":")
|
||||
if len(ipPortSplit) != 2 {
|
||||
return
|
||||
}
|
||||
|
||||
ip = net.ParseIP(ipPortSplit[0])
|
||||
if p, err := strconv.Atoi(ipPortSplit[1]); err == nil {
|
||||
port = uint16(p)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func guessDetailsFromApp(any *anypb.Any) Details {
|
||||
var detailsProto proto.Message
|
||||
var err error
|
||||
if detailsProto, err = any.UnmarshalNew(); err != nil {
|
||||
return nil
|
||||
}
|
||||
switch any.TypeUrl {
|
||||
case "type.googleapis.com/inetmock.audit.HTTPDetailsEntity":
|
||||
return details.NewHTTPFromWireFormat(detailsProto.(*details.HTTPDetailsEntity))
|
||||
case "type.googleapis.com/inetmock.audit.DNSDetailsEntity":
|
||||
return details.NewDNSFromWireFormat(detailsProto.(*details.DNSDetailsEntity))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,20 +9,20 @@ import (
|
|||
"time"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit/details"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultSink = &testSink{
|
||||
name: "test defaultSink",
|
||||
}
|
||||
testEvents = []audit.Event{
|
||||
testEvents = []*audit.Event{
|
||||
{
|
||||
Transport: audit.TransportProtocol_TCP,
|
||||
Application: audit.AppProtocol_HTTP,
|
||||
SourceIP: net.ParseIP("127.0.0.1"),
|
||||
DestinationIP: net.ParseIP("127.0.0.1"),
|
||||
SourceIP: net.ParseIP("127.0.0.1").To4(),
|
||||
DestinationIP: net.ParseIP("127.0.0.1").To4(),
|
||||
SourcePort: 32344,
|
||||
DestinationPort: 80,
|
||||
TLS: &audit.TLSDetails{
|
||||
|
@ -30,7 +30,7 @@ var (
|
|||
CipherSuite: tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
ServerName: "localhost",
|
||||
},
|
||||
ProtocolDetails: mustMarshalToWireformat(audit.HTTPDetails{
|
||||
ProtocolDetails: details.HTTP{
|
||||
Method: "GET",
|
||||
Host: "localhost",
|
||||
URI: "http://localhost/asdf",
|
||||
|
@ -38,13 +38,13 @@ var (
|
|||
Headers: http.Header{
|
||||
"Accept": []string{"application/json"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Transport: audit.TransportProtocol_TCP,
|
||||
Application: audit.AppProtocol_DNS,
|
||||
SourceIP: net.ParseIP("::1"),
|
||||
DestinationIP: net.ParseIP("::1"),
|
||||
SourceIP: net.ParseIP("::1").To16(),
|
||||
DestinationIP: net.ParseIP("::1").To16(),
|
||||
SourcePort: 32344,
|
||||
DestinationPort: 80,
|
||||
},
|
||||
|
@ -80,14 +80,6 @@ func wgMockSink(t testing.TB, wg *sync.WaitGroup) audit.Sink {
|
|||
}
|
||||
}
|
||||
|
||||
func mustMarshalToWireformat(d audit.Details) *anypb.Any {
|
||||
any, err := d.MarshalToWireFormat()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return any
|
||||
}
|
||||
|
||||
func Test_eventStream_RegisterSink(t *testing.T) {
|
||||
type args struct {
|
||||
s audit.Sink
|
||||
|
@ -156,7 +148,7 @@ func Test_eventStream_RegisterSink(t *testing.T) {
|
|||
|
||||
func Test_eventStream_Emit(t *testing.T) {
|
||||
type args struct {
|
||||
evs []audit.Event
|
||||
evs []*audit.Event
|
||||
opts []audit.EventStreamOption
|
||||
}
|
||||
type testCase struct {
|
||||
|
@ -215,9 +207,9 @@ func Test_eventStream_Emit(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
go func(evs []audit.Event, wg *sync.WaitGroup) {
|
||||
go func(evs []*audit.Event, wg *sync.WaitGroup) {
|
||||
for _, ev := range evs {
|
||||
e.Emit(ev)
|
||||
e.Emit(*ev)
|
||||
wg.Done()
|
||||
}
|
||||
}(tt.args.evs, emittedWaitGroup)
|
||||
|
|
61
pkg/audit/event_test.go
Normal file
61
pkg/audit/event_test.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit/details"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
func Test_guessDetailsFromApp(t *testing.T) {
|
||||
|
||||
mustAny := func(msg proto.Message) *anypb.Any {
|
||||
a, err := anypb.New(msg)
|
||||
if err != nil {
|
||||
panic(a)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
type args struct {
|
||||
any *anypb.Any
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
args args
|
||||
want Details
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "HTTP etails",
|
||||
args: args{
|
||||
any: mustAny(&details.HTTPDetailsEntity{
|
||||
Method: details.HTTPMethod_GET,
|
||||
Host: "localhost",
|
||||
Uri: "http://localhost/asdf",
|
||||
Proto: "HTTP",
|
||||
}),
|
||||
},
|
||||
want: details.HTTP{
|
||||
Method: "GET",
|
||||
Host: "localhost",
|
||||
URI: "http://localhost/asdf",
|
||||
Proto: "HTTP",
|
||||
Headers: http.Header{},
|
||||
},
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
if got := guessDetailsFromApp(tt.args.any); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("guessDetailsFromApp() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
}
|
||||
}
|
33
pkg/audit/ip_conversion.go
Normal file
33
pkg/audit/ip_conversion.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/big"
|
||||
"net"
|
||||
)
|
||||
|
||||
func ipv4ToUint32(ip net.IP) uint32 {
|
||||
if len(ip) == 16 {
|
||||
return binary.BigEndian.Uint32(ip[12:16])
|
||||
}
|
||||
return binary.BigEndian.Uint32(ip)
|
||||
}
|
||||
|
||||
func ipv6ToBytes(ip net.IP) uint64 {
|
||||
ipv6 := big.NewInt(0)
|
||||
ipv6.SetBytes(ip)
|
||||
return ipv6.Uint64()
|
||||
}
|
||||
|
||||
func uint32ToIP(i uint32) (ip net.IP) {
|
||||
buf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(buf, i)
|
||||
ip = buf
|
||||
ip = ip.To4()
|
||||
return
|
||||
}
|
||||
|
||||
func uint64ToIP(i uint64) (ip net.IP) {
|
||||
ip = big.NewInt(int64(i)).FillBytes(make([]byte, 16))
|
||||
return
|
||||
}
|
|
@ -46,6 +46,7 @@ func (l logSink) OnSubscribe(evs <-chan Event) {
|
|||
zap.Uint16("source_port", ev.SourcePort),
|
||||
zap.String("destination_ip", ev.DestinationIP.String()),
|
||||
zap.Uint16("destination_port", ev.DestinationPort),
|
||||
zap.Any("details", ev.ProtocolDetails),
|
||||
)
|
||||
}
|
||||
}(l.logger, evs)
|
||||
|
|
|
@ -20,7 +20,7 @@ func Test_logSink_OnSubscribe(t *testing.T) {
|
|||
type testCase struct {
|
||||
name string
|
||||
fields fields
|
||||
events []audit.Event
|
||||
events []*audit.Event
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
|
@ -93,7 +93,7 @@ func Test_logSink_OnSubscribe(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, ev := range tt.events {
|
||||
evs.Emit(ev)
|
||||
evs.Emit(*ev)
|
||||
}
|
||||
|
||||
select {
|
||||
|
|
|
@ -5,10 +5,10 @@ import (
|
|||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -36,7 +36,7 @@ func Test_eventReader_Read(t *testing.T) {
|
|||
type testCase struct {
|
||||
name string
|
||||
fields fields
|
||||
wantEv audit.Event
|
||||
wantEv *audit.Event
|
||||
wantErr bool
|
||||
}
|
||||
tests := []testCase{
|
||||
|
@ -86,7 +86,7 @@ func Test_eventReader_Read(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
if err == nil && !proto.Equal(gotEv.ProtoMessage(), tt.wantEv.ProtoMessage()) {
|
||||
if err == nil && !reflect.DeepEqual(gotEv, *tt.wantEv) {
|
||||
t.Errorf("Read() gotEv = %v, want %v", gotEv, tt.wantEv)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
func Test_writerCloserSink_OnSubscribe(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
events []audit.Event
|
||||
events []*audit.Event
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
|
@ -55,7 +55,7 @@ func Test_writerCloserSink_OnSubscribe(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, ev := range tt.events {
|
||||
evs.Emit(ev)
|
||||
evs.Emit(*ev)
|
||||
}
|
||||
|
||||
select {
|
||||
|
|
|
@ -23,7 +23,7 @@ func Test_eventWriter_Write(t *testing.T) {
|
|||
order binary.ByteOrder
|
||||
}
|
||||
type args struct {
|
||||
evs []audit.Event
|
||||
evs []*audit.Event
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
|
@ -104,7 +104,7 @@ func Test_eventWriter_Write(t *testing.T) {
|
|||
e := audit.NewEventWriter(writerMock, audit.WithWriterByteOrder(tt.fields.order))
|
||||
|
||||
for _, ev := range tt.args.evs {
|
||||
if err := e.Write(&ev); (err != nil) != tt.wantErr {
|
||||
if err := e.Write(ev); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Write() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,22 +15,23 @@ type dnsHandler struct {
|
|||
dnsServer []*dns.Server
|
||||
}
|
||||
|
||||
func (d *dnsHandler) Start(_ api.PluginContext, config config.HandlerConfig) (err error) {
|
||||
func (d *dnsHandler) Start(pluginCtx api.PluginContext, config config.HandlerConfig) (err error) {
|
||||
var options dnsOptions
|
||||
if options, err = loadFromConfig(config.Options); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
listenAddr := config.ListenAddr()
|
||||
d.logger = d.logger.With(
|
||||
d.logger = pluginCtx.Logger().With(
|
||||
zap.String("handler_name", config.HandlerName),
|
||||
zap.String("address", listenAddr),
|
||||
)
|
||||
|
||||
handler := ®exHandler{
|
||||
handlerName: config.HandlerName,
|
||||
fallback: options.Fallback,
|
||||
logger: d.logger,
|
||||
handlerName: config.HandlerName,
|
||||
fallback: options.Fallback,
|
||||
logger: pluginCtx.Logger(),
|
||||
auditEmitter: pluginCtx.Audit(),
|
||||
}
|
||||
|
||||
for _, rule := range options.Rules {
|
||||
|
|
|
@ -1,35 +1,41 @@
|
|||
package dns_mock
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit/details"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type regexHandler struct {
|
||||
handlerName string
|
||||
routes []resolverRule
|
||||
fallback ResolverFallback
|
||||
logger logging.Logger
|
||||
handlerName string
|
||||
routes []resolverRule
|
||||
fallback ResolverFallback
|
||||
auditEmitter audit.Emitter
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
func (rh *regexHandler) AddRule(rule resolverRule) {
|
||||
rh.routes = append(rh.routes, rule)
|
||||
}
|
||||
|
||||
func (rh regexHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
func (rh *regexHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
timer := prometheus.NewTimer(requestDurationHistogram.WithLabelValues(rh.handlerName))
|
||||
defer func() {
|
||||
timer.ObserveDuration()
|
||||
}()
|
||||
|
||||
rh.recordRequest(r, w.LocalAddr(), w.RemoteAddr())
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.Compress = false
|
||||
m.SetReply(r)
|
||||
|
||||
switch r.Opcode {
|
||||
case dns.OpcodeQuery:
|
||||
if r.Opcode == dns.OpcodeQuery {
|
||||
rh.handleQuery(m)
|
||||
}
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
|
@ -40,35 +46,32 @@ func (rh regexHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||
}
|
||||
}
|
||||
|
||||
func (rh regexHandler) handleQuery(m *dns.Msg) {
|
||||
func (rh *regexHandler) handleQuery(m *dns.Msg) {
|
||||
for _, q := range m.Question {
|
||||
rh.logger.Info(
|
||||
"handling question",
|
||||
zap.String("question", q.Name),
|
||||
)
|
||||
switch q.Qtype {
|
||||
case dns.TypeA:
|
||||
totalHandledRequestsCounter.WithLabelValues(rh.handlerName).Inc()
|
||||
for _, rule := range rh.routes {
|
||||
if rule.pattern.MatchString(q.Name) {
|
||||
m.Authoritative = true
|
||||
answer := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
A: rule.response,
|
||||
}
|
||||
m.Answer = append(m.Answer, answer)
|
||||
rh.logger.Info(
|
||||
"matched DNS rule",
|
||||
zap.String("pattern", rule.pattern.String()),
|
||||
zap.String("response", rule.response.String()),
|
||||
)
|
||||
return
|
||||
if !rule.pattern.MatchString(q.Name) {
|
||||
continue
|
||||
}
|
||||
m.Authoritative = true
|
||||
answer := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
A: rule.response,
|
||||
}
|
||||
m.Answer = append(m.Answer, answer)
|
||||
rh.logger.Info(
|
||||
"matched DNS rule",
|
||||
zap.String("pattern", rule.pattern.String()),
|
||||
zap.String("response", rule.response.String()),
|
||||
)
|
||||
return
|
||||
}
|
||||
rh.handleFallbackForMessage(m, q)
|
||||
default:
|
||||
|
@ -81,7 +84,7 @@ func (rh regexHandler) handleQuery(m *dns.Msg) {
|
|||
}
|
||||
}
|
||||
|
||||
func (rh regexHandler) handleFallbackForMessage(m *dns.Msg, q dns.Question) {
|
||||
func (rh *regexHandler) handleFallbackForMessage(m *dns.Msg, q dns.Question) {
|
||||
fallbackIP := rh.fallback.GetIP()
|
||||
answer := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
|
@ -99,3 +102,38 @@ func (rh regexHandler) handleFallbackForMessage(m *dns.Msg, q dns.Question) {
|
|||
m.Authoritative = true
|
||||
m.Answer = append(m.Answer, answer)
|
||||
}
|
||||
|
||||
func (rh *regexHandler) recordRequest(m *dns.Msg, localAddr, remoteAddr net.Addr) {
|
||||
dnsDetails := &details.DNS{
|
||||
OPCode: details.DNSOpCode(m.Opcode),
|
||||
}
|
||||
|
||||
for _, q := range m.Question {
|
||||
dnsDetails.Questions = append(dnsDetails.Questions, details.DNSQuestion{
|
||||
RRType: details.ResourceRecordType(q.Qtype),
|
||||
Name: q.Name,
|
||||
})
|
||||
}
|
||||
|
||||
ev := audit.Event{
|
||||
Transport: guessTransportFromAddr(localAddr),
|
||||
Application: audit.AppProtocol_DNS,
|
||||
ProtocolDetails: dnsDetails,
|
||||
}
|
||||
|
||||
ev.SetSourceIPFromAddr(remoteAddr)
|
||||
ev.SetDestinationIPFromAddr(localAddr)
|
||||
|
||||
rh.auditEmitter.Emit(ev)
|
||||
}
|
||||
|
||||
func guessTransportFromAddr(addr net.Addr) audit.TransportProtocol {
|
||||
switch addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
return audit.TransportProtocol_TCP
|
||||
case *net.UDPAddr:
|
||||
return audit.TransportProtocol_UDP
|
||||
default:
|
||||
return audit.TransportProtocol_UNKNOWN_TRANSPORT
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,9 +3,7 @@ package dns_mock
|
|||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"gitlab.com/inetmock/inetmock/pkg/api"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
"gitlab.com/inetmock/inetmock/pkg/metrics"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -20,14 +18,6 @@ var (
|
|||
)
|
||||
|
||||
func AddDNSMock(registry api.HandlerRegistry) (err error) {
|
||||
var logger logging.Logger
|
||||
if logger, err = logging.CreateLogger(); err != nil {
|
||||
return
|
||||
}
|
||||
logger = logger.With(
|
||||
zap.String("protocol_handler", name),
|
||||
)
|
||||
|
||||
if totalHandledRequestsCounter, err = metrics.Counter(
|
||||
name,
|
||||
"handled_requests_total",
|
||||
|
@ -57,9 +47,7 @@ func AddDNSMock(registry api.HandlerRegistry) (err error) {
|
|||
}
|
||||
|
||||
registry.RegisterHandler(name, func() api.ProtocolHandler {
|
||||
return &dnsHandler{
|
||||
logger: logger,
|
||||
}
|
||||
return &dnsHandler{}
|
||||
})
|
||||
|
||||
return
|
||||
|
|
|
@ -82,7 +82,7 @@ func loadFromConfig(config *viper.Viper) (options httpOptions, err error) {
|
|||
}
|
||||
|
||||
if absoluteResponsePath, parseErr = filepath.Abs(i.Response); parseErr != nil {
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
options.Rules = append(options.Rules, targetRule{
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
package http_mock
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
details "gitlab.com/inetmock/inetmock/pkg/audit/details"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
type route struct {
|
||||
|
@ -70,47 +68,22 @@ func (f emittingFileHandler) ServeHTTP(writer http.ResponseWriter, request *http
|
|||
}
|
||||
|
||||
func eventFromRequest(request *http.Request) audit.Event {
|
||||
details := audit.HTTPDetails{
|
||||
httpDetails := details.HTTP{
|
||||
Method: request.Method,
|
||||
Host: request.Host,
|
||||
URI: request.RequestURI,
|
||||
Proto: request.Proto,
|
||||
Headers: request.Header,
|
||||
}
|
||||
var wire *anypb.Any
|
||||
if w, err := details.MarshalToWireFormat(); err == nil {
|
||||
wire = w
|
||||
}
|
||||
|
||||
localIP, localPort := ipPortFromAddr(LocalAddr(request.Context()))
|
||||
remoteIP, remotePort := ipPortFromAddr(RemoteAddr(request.Context()))
|
||||
|
||||
return audit.Event{
|
||||
ev := audit.Event{
|
||||
Transport: audit.TransportProtocol_TCP,
|
||||
Application: audit.AppProtocol_HTTP,
|
||||
SourceIP: remoteIP,
|
||||
DestinationIP: localIP,
|
||||
SourcePort: remotePort,
|
||||
DestinationPort: localPort,
|
||||
ProtocolDetails: wire,
|
||||
ProtocolDetails: httpDetails,
|
||||
}
|
||||
}
|
||||
|
||||
func ipPortFromAddr(addr net.Addr) (ip net.IP, port uint16) {
|
||||
if tcpAddr, isTCPAddr := addr.(*net.TCPAddr); isTCPAddr {
|
||||
return tcpAddr.IP, uint16(tcpAddr.Port)
|
||||
}
|
||||
|
||||
ipPortSplit := strings.Split(addr.String(), ":")
|
||||
if len(ipPortSplit) != 2 {
|
||||
return
|
||||
}
|
||||
|
||||
ip = net.ParseIP(ipPortSplit[0])
|
||||
|
||||
if p, err := strconv.Atoi(ipPortSplit[1]); err == nil {
|
||||
port = uint16(p)
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
ev.SetDestinationIPFromAddr(LocalAddr(request.Context()))
|
||||
ev.SetSourceIPFromAddr(RemoteAddr(request.Context()))
|
||||
|
||||
return ev
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue