package worker import ( "context" "errors" "log/slog" "sync" ) type ( // A simple task to deal with list. ChanProcessor[T any] interface { Query(context.Context) (<-chan T, error) Process(context.Context, T) error } OnFail[T any] interface { OnFail(context.Context, T, error) } ListProcessor[T any] interface { Query(context.Context) ([]T, error) Process(context.Context, T) error } chanProcessorTask[T any] struct { chanProcessor ChanProcessor[T] logger *slog.Logger scheduler *Scheduler } batchProcessorTask[T any] struct { batchProcessor ListProcessor[T] logger *slog.Logger scheduler *Scheduler } serialProcessorTask[T any] struct { batchProcessor ListProcessor[T] logger *slog.Logger scheduler *Scheduler } ) func NewTaskFromBatchProcessor[T any]( batchProcessor ListProcessor[T], scheduler *Scheduler, logger *slog.Logger, ) Task { return &batchProcessorTask[T]{ batchProcessor: batchProcessor, scheduler: scheduler, logger: logger, } } func NewTaskFromSerialProcessor[T any]( batchProcessor ListProcessor[T], scheduler *Scheduler, logger *slog.Logger, ) Task { return &serialProcessorTask[T]{ batchProcessor: batchProcessor, scheduler: scheduler, logger: logger, } } func NewTaskFromChanProcessor[T any]( chanProcessor ChanProcessor[T], scheduler *Scheduler, logger *slog.Logger, ) Task { return &chanProcessorTask[T]{ chanProcessor: chanProcessor, scheduler: scheduler, logger: logger, } } func (l *batchProcessorTask[T]) Start(ctx context.Context) error { for { values, err := l.batchProcessor.Query(ctx) if err != nil { return err } select { case <-ctx.Done(): return ctx.Err() default: } if len(values) == 0 { return nil } var wg sync.WaitGroup for _, v := range values { select { case <-ctx.Done(): return ctx.Err() default: } wg.Add(1) l.scheduler.Take() go func(v T) { defer l.scheduler.Return() defer wg.Done() if err := l.batchProcessor.Process(ctx, v); err != nil && !errors.Is(err, context.Canceled) { l.logger.Error( "Error processing batch", slog.String("error", err.Error()), ) if failure, ok := l.batchProcessor.(OnFail[T]); ok { failure.OnFail(ctx, v, err) } } }(v) } wg.Wait() } } func (l *serialProcessorTask[T]) Start(ctx context.Context) error { for { values, err := l.batchProcessor.Query(ctx) if err != nil { return err } select { case <-ctx.Done(): return ctx.Err() default: } if len(values) == 0 { return nil } for _, v := range values { select { case <-ctx.Done(): return ctx.Err() default: } l.scheduler.Take() if err := l.batchProcessor.Process(ctx, v); err != nil && !errors.Is(err, context.Canceled) { l.logger.Error( "Error processing batch", slog.String("error", err.Error()), ) if failure, ok := l.batchProcessor.(OnFail[T]); ok { failure.OnFail(ctx, v, err) } } l.scheduler.Return() } } } func (l *chanProcessorTask[T]) Start(ctx context.Context) error { c, err := l.chanProcessor.Query(ctx) if err != nil { return err } for { select { case <-ctx.Done(): return ctx.Err() case v, ok := <-c: if !ok { return nil } l.scheduler.Take() go func(v T) { defer l.scheduler.Return() if err := l.chanProcessor.Process(ctx, v); err != nil { l.logger.Error( "Error processing batch", slog.String("error", err.Error()), ) if failure, ok := l.chanProcessor.(OnFail[T]); ok { failure.OnFail(ctx, v, err) } } }(v) } } }