package vault import ( "encoding/base64" "encoding/json" "errors" "fmt" "os" "strings" "sync" "code.icb4dc0.de/buildr/buildr/internal/maps" ) var ( ErrVaultNotInitialized = errors.New("vault is not properly initialized") _ json.Marshaler = (*Vault)(nil) _ json.Unmarshaler = (*Vault)(nil) ) type Entry struct { Value []byte Salt []byte Nonce []byte } func EntryFromText(txt string) (e Entry, err error) { const expectedNumberOfComponents = 3 split := strings.Split(txt, ":") if len(split) != expectedNumberOfComponents { return Entry{}, fmt.Errorf("unexpected number of components for entry: %d", len(split)) } encoding := base64.StdEncoding if e.Nonce, err = encoding.DecodeString(split[0]); err != nil { return Entry{}, err } if e.Salt, err = encoding.DecodeString(split[1]); err != nil { return Entry{}, err } if e.Value, err = encoding.DecodeString(split[2]); err != nil { return Entry{}, err } return e, nil } func (e Entry) MarshalText() string { encoding := base64.StdEncoding val := fmt.Sprintf( "%s:%s:%s", encoding.EncodeToString(e.Nonce), encoding.EncodeToString(e.Salt), encoding.EncodeToString(e.Value), ) return val } func (e Entry) Decrypt(decrypter Decrypter, passphrase string) ([]byte, error) { return decrypter.Decrypt(e.Value, passphrase, e.Salt, e.Nonce) } func NewVault(opts ...InitOption) (*Vault, error) { v := &Vault{ encryption: NewAesGcmEncryption(Pbkdf2Deriver()), entries: make(map[string]Entry), } for _, opt := range opts { if err := opt.applyToVault(v); err != nil { if os.IsNotExist(err) { //nolint:nilnil // correct from a domain level return nil, nil } return nil, err } } return v, nil } type Vault struct { encryption Encryption entries map[string]Entry passphrase string lock sync.RWMutex } func (v *Vault) Keys() []string { v.lock.RLock() defer v.lock.RUnlock() return maps.Keys(v.entries) } func (v *Vault) GetValue(entryKey string) ([]byte, error) { if v == nil { return nil, ErrVaultNotInitialized } v.lock.RLock() defer v.lock.RUnlock() return v.doGetValue(entryKey) } func (v *Vault) SetValue(entryKey string, plainText []byte) error { if v == nil { return ErrVaultNotInitialized } v.lock.Lock() defer v.lock.Unlock() return v.doSetValue(entryKey, plainText) } func (v *Vault) RemoveValue(entryKey string) { v.lock.Lock() defer v.lock.Unlock() delete(v.entries, entryKey) } func (v *Vault) doGetValue(entryKey string) ([]byte, error) { entry, ok := v.entries[entryKey] if !ok { return nil, nil } return entry.Decrypt(v.encryption, v.passphrase) } func (v *Vault) doSetValue(entryKey string, plainText []byte) error { cipher, salt, nonce, err := v.encryption.Encrypt(plainText, v.passphrase) if err != nil { return err } v.entries[entryKey] = Entry{ Value: cipher, Salt: salt, Nonce: nonce, } return nil } func (v *Vault) UnmarshalJSON(data []byte) error { v.lock.Lock() defer v.lock.Unlock() rawEntries := make(map[string]string) v.entries = make(map[string]Entry) if err := json.Unmarshal(data, &rawEntries); err != nil { return err } for key, val := range rawEntries { var err error if v.entries[key], err = EntryFromText(val); err != nil { return err } if _, err := v.entries[key].Decrypt(v.encryption, v.passphrase); err != nil { return fmt.Errorf("failed to decrypt entry %w", err) } } return nil } func (v *Vault) MarshalJSON() ([]byte, error) { toSerialize := make(map[string]string, len(v.entries)) v.lock.RLock() defer v.lock.RUnlock() for k, v := range v.entries { toSerialize[k] = v.MarshalText() } return json.Marshal(toSerialize) }