package plugins import ( "bytes" "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "io" "log/slog" "net/url" "os" "path/filepath" commonv1 "code.icb4dc0.de/buildr/api/generated/common/v1" wasiv1 "code.icb4dc0.de/buildr/api/generated/wasi/v1" "code.icb4dc0.de/buildr/buildr/internal/config" "code.icb4dc0.de/buildr/buildr/internal/slices" "github.com/go-git/go-git/v5/plumbing/hash" "github.com/google/uuid" "golang.org/x/sync/errgroup" "code.icb4dc0.de/buildr/buildr/internal/ioutils" "code.icb4dc0.de/buildr/buildr/modules" "code.icb4dc0.de/buildr/buildr/modules/plugin" "code.icb4dc0.de/buildr/buildr/modules/state" ) type Manager struct { Downloader Downloader plugins state.Plugins cacheDir string } func (p *Manager) Init(plugins state.Plugins, cacheDir string) { p.cacheDir = cacheDir p.plugins = plugins } func (p *Manager) Register(ctx context.Context, registry *modules.TypeRegistry) error { plugins, err := p.plugins.List(ctx) if err != nil { return err } grp, grpcCtx := errgroup.WithContext(ctx) for _, registeredPlugin := range plugins { registeredPlugin := registeredPlugin grp.Go(func() error { knownModules, err := p.plugins.ModulesForPlugin(grpcCtx, *registeredPlugin.ID) if err != nil { return err } for _, module := range knownModules { module := module registry.RegisterModule(modules.ModuleFactoryFunc(func() modules.ModuleWithMeta { var ( defaultSpec = new(commonv1.ModuleSpec) mod = plugin.Module{ PluginPayload: &plugin.PayloadFile{ Path: registeredPlugin.LocalPath, Checksum: registeredPlugin.Hash, }, PluginCategory: module.Category, PluginType: module.Type, ModuleSpec: make(map[string]any), } ) if err := defaultSpec.UnmarshalVT(module.DefaultSpec); err == nil { mod.SetModuleSpec(defaultSpec) } return &modules.Metadata[plugin.Module]{ Module: mod, } })) } return nil }) } return grp.Wait() } func (p *Manager) UpdatePlugins(ctx context.Context, refs ...config.PluginReference) error { knownPlugins, err := p.plugins.List(ctx) if err != nil { return err } pluginsByName := make(map[string]state.Plugin, len(knownPlugins)) for _, p := range knownPlugins { pluginsByName[p.Name] = p } for _, pluginReference := range refs { var requiredDownload bool if knownPlugin, ok := pluginsByName[pluginReference.Name]; !ok { requiredDownload = true } else { delete(pluginsByName, pluginReference.Name) if requiredDownload, err = p.validateKnownPlugin(knownPlugin, pluginReference); err != nil { return err } } if requiredDownload { slog.Info("Downloading plugin", slog.String("plugin", pluginReference.Name)) if downloadedPlugin, payload, err := p.downloadPlugin(ctx, pluginReference); err != nil { return err } else if pluginID, err := p.plugins.UpsertPlugin(ctx, *downloadedPlugin); err != nil { return err } else if inventory, err := plugin.DiscoverInventory(ctx, payload); err != nil { return err } else { pluginModules := slices.Map(inventory.Specs, func(in *wasiv1.PluginInventoryResponse_InventorySpec) state.PluginModule { return moduleReferenceToPluginModule(pluginID, in) }) if err := p.plugins.UpsertModules(ctx, pluginModules); err != nil { return err } } } } for _, obsoletePlugin := range pluginsByName { slog.Info("Removing obsolete plugin", slog.String("plugin", obsoletePlugin.Name)) if err := os.Remove(obsoletePlugin.LocalPath); err != nil { if !errors.Is(err, os.ErrNotExist) { return err } } if err := p.plugins.Remove(ctx, *obsoletePlugin.ID); err != nil { return err } } return nil } func (p *Manager) downloadPlugin(ctx context.Context, ref config.PluginReference) (pl *state.Plugin, payload []byte, err error) { pluginURL, err := url.Parse(ref.URL) if err != nil { return nil, nil, err } targetPath := filepath.Join(p.cacheDir, ref.Name) if err := p.Downloader.Download(ctx, targetPath, pluginURL); err != nil { return nil, nil, err } computedHash, payload, err := hashFile(targetPath, sha256.New()) if err != nil { return nil, nil, err } if ref.Checksum != nil { expectedHash, err := hex.DecodeString(*ref.Checksum) if err != nil { return nil, nil, err } if !bytes.Equal(expectedHash, computedHash) { return nil, nil, fmt.Errorf("plugin checksum mismatch - expected %s, got %s", *ref.Checksum, hex.EncodeToString(computedHash)) } } return &state.Plugin{ Name: ref.Name, URL: pluginURL, LocalPath: targetPath, Hash: computedHash, }, payload, nil } func (p *Manager) validateKnownPlugin(knownPlugin state.Plugin, ref config.PluginReference) (requireUpdate bool, err error) { // check if URL is valid _, err = url.Parse(ref.URL) if err != nil { return false, err } // expected file does not exist if _, err := os.Stat(knownPlugin.LocalPath); err != nil { return true, nil } // by default require update if URL was changed if ref.URL != knownPlugin.URL.String() { return true, nil } // check if SHA256 is valid computedHash, equal, err := validateHashForFile(knownPlugin.LocalPath, knownPlugin.Hash) if err != nil { return false, err } else if !equal { return true, nil } // if checksum differs from expected, require update if ref.Checksum != nil { expectedChecksum, err := hex.DecodeString(*ref.Checksum) if err != nil { return false, err } if !bytes.Equal(computedHash, expectedChecksum) { return true, nil } } return false, nil } func validateHashForFile(filePath string, expectedHash []byte) (computedHash []byte, equal bool, err error) { computedHash, _, err = hashFile(filePath, sha256.New()) if err != nil { return nil, false, err } if !bytes.Equal(expectedHash, computedHash) { return nil, false, nil } return computedHash, true, nil } func hashFile(filePath string, h hash.Hash) (computedHash, payload []byte, err error) { f, err := os.Open(filePath) if err != nil { return nil, nil, err } defer func() { _ = f.Close() }() buf := bytes.NewBuffer(nil) w := io.MultiWriter(h, buf) if _, err := ioutils.CopyWithPooledBuffer(w, f); err != nil { return nil, nil, err } return h.Sum(nil), buf.Bytes(), nil } func moduleReferenceToPluginModule(pluginID uuid.UUID, in *wasiv1.PluginInventoryResponse_InventorySpec) state.PluginModule { return state.PluginModule{ PluginID: pluginID, Type: in.ModuleRef.ModuleType, Category: in.ModuleRef.ModuleCategory, DefaultSpec: in.EmptySpec, } }