package kinesis_consumer import ( "context" "errors" "fmt" "sync" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kinesis" "github.com/aws/aws-sdk-go-v2/service/kinesis/types" "github.com/influxdata/telegraf" ) type recordHandler func(ctx context.Context, shard string, r *types.Record) type shardConsumer struct { seqnr string interval time.Duration log telegraf.Logger client *kinesis.Client params *kinesis.GetShardIteratorInput onMessage recordHandler } func (c *shardConsumer) consume(ctx context.Context, shard string) ([]types.ChildShard, error) { ticker := time.NewTicker(c.interval) defer ticker.Stop() // Get the first shard iterator iter, err := c.iterator(ctx) if err != nil { return nil, fmt.Errorf("getting first shard iterator failed: %w", err) } for { // Get new records from the shard resp, err := c.client.GetRecords(ctx, &kinesis.GetRecordsInput{ ShardIterator: iter, }) if err != nil { // Handle recoverable errors var throughputErr *types.ProvisionedThroughputExceededException var expiredIterErr *types.ExpiredIteratorException switch { case errors.As(err, &throughputErr): // Wait a second before trying again as suggested by // https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html c.log.Tracef("throughput exceeded when getting records for shard %s...", shard) time.Sleep(time.Second) continue case errors.As(err, &expiredIterErr): c.log.Tracef("iterator expired for shard %s...", shard) if iter, err = c.iterator(ctx); err != nil { return nil, fmt.Errorf("getting shard iterator failed: %w", err) } continue case errors.Is(err, context.Canceled): return nil, nil default: c.log.Tracef("get-records error is of type %T", err) return nil, fmt.Errorf("getting records failed: %w", err) } } c.log.Tracef("read %d records for shard %s...", len(resp.Records), shard) // Check if we fully read the shard if resp.NextShardIterator == nil { return resp.ChildShards, nil } iter = resp.NextShardIterator // Process the records and keep track of the last sequence number // consumed for recreating the iterator. for _, r := range resp.Records { c.onMessage(ctx, shard, &r) c.seqnr = *r.SequenceNumber if errors.Is(ctx.Err(), context.Canceled) { return nil, nil } } // Wait for the poll interval to pass or cancel select { case <-ctx.Done(): return nil, nil case <-ticker.C: continue } } } func (c *shardConsumer) iterator(ctx context.Context) (*string, error) { for { resp, err := c.client.GetShardIterator(ctx, c.params) if err != nil { var throughputErr *types.ProvisionedThroughputExceededException if errors.As(err, &throughputErr) { // We called the function too often and should wait a bit // until trying again c.log.Tracef("throughput exceeded when getting iterator for shard %s...", *c.params.ShardId) time.Sleep(time.Second) continue } return nil, err } c.log.Tracef("successfully updated iterator for shard %s (%s)...", *c.params.ShardId, c.seqnr) return resp.ShardIterator, nil } } type consumer struct { config aws.Config stream string iterType types.ShardIteratorType pollInterval time.Duration shardUpdateInterval time.Duration log telegraf.Logger onMessage recordHandler position func(shard string) string client *kinesis.Client shardsConsumed map[string]bool shardConsumers map[string]*shardConsumer wg sync.WaitGroup sync.Mutex } func (c *consumer) init() error { if c.stream == "" { return errors.New("stream cannot be empty") } if c.pollInterval <= 0 { return errors.New("invalid poll interval") } if c.onMessage == nil { return errors.New("message handler is undefined") } c.shardsConsumed = make(map[string]bool) c.shardConsumers = make(map[string]*shardConsumer) return nil } func (c *consumer) start(ctx context.Context) { // Setup the client c.client = kinesis.NewFromConfig(c.config) // Do the initial discovery of shards if err := c.updateShardConsumers(ctx); err != nil { c.log.Errorf("Initializing shards failed: %v", err) } // If the consumer has a shard-update interval, use a ticker to update // available shards on a regular basis if c.shardUpdateInterval <= 0 { return } ticker := time.NewTicker(c.shardUpdateInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: if err := c.updateShardConsumers(ctx); err != nil { c.log.Errorf("Updating shards failed: %v", err) } } } } func (c *consumer) updateShardConsumers(ctx context.Context) error { // List all shards of the given stream var availableShards []types.Shard req := &kinesis.ListShardsInput{StreamName: aws.String(c.stream)} for { resp, err := c.client.ListShards(ctx, req) if err != nil { return fmt.Errorf("listing shards failed: %w", err) } availableShards = append(availableShards, resp.Shards...) if resp.NextToken == nil { break } req = &kinesis.ListShardsInput{NextToken: resp.NextToken} } c.log.Tracef("got %d shards during update", len(availableShards)) // All following operations need to be locked to create a consistent // state of the shards and consumers c.Lock() defer c.Unlock() // Filter out all shards actively consumed already inactiveShards := make([]types.Shard, 0, len(availableShards)) for _, shard := range availableShards { id := *shard.ShardId if _, found := c.shardConsumers[id]; found { c.log.Tracef("shard %s is actively consumed...", id) continue } c.log.Tracef("shard %s is not actively consumed...", id) inactiveShards = append(inactiveShards, shard) } // Fill the shards already consumed and get the positions if the consumer // is backed by an iterator store newShards := make([]types.Shard, 0, len(inactiveShards)) seqnrs := make(map[string]string, len(inactiveShards)) for _, shard := range inactiveShards { id := *shard.ShardId if c.shardsConsumed[id] { c.log.Tracef("shard %s is already fully consumed...", id) continue } c.log.Tracef("shard %s is not fully consumed...", id) // Retrieve the shard position from the store if c.position != nil { seqnr := c.position(id) if seqnr == "" { // A truely new shard newShards = append(newShards, shard) c.log.Tracef("shard %s is new...", id) continue } seqnrs[id] = seqnr // Check if we already fully consumed for closed shards end := shard.SequenceNumberRange.EndingSequenceNumber if end != nil && *end == seqnr { c.log.Tracef("shard %s is closed and already fully consumed...", id) c.shardsConsumed[id] = true continue } c.log.Tracef("shard %s is not yet fully consumed...", id) } // The shard is not fully consumed yet so save the sequence number // and the shard as "new". newShards = append(newShards, shard) } // Filter all shards already fully consumed and create a new consumer for // every remaining new shard respecting resharding artifacts for _, shard := range newShards { id := *shard.ShardId // Handle resharding by making sure all parents are consumed already // before starting a consumer on a child shard. If parents are not // consumed fully we ignore this shard here as it will be reported // by the call to `GetRecords` as a child later. if shard.ParentShardId != nil && *shard.ParentShardId != "" { pid := *shard.ParentShardId if !c.shardsConsumed[pid] { c.log.Tracef("shard %s has parent %s which is not fully consumed yet...", id, pid) continue } } if shard.AdjacentParentShardId != nil && *shard.AdjacentParentShardId != "" { pid := *shard.AdjacentParentShardId if !c.shardsConsumed[pid] { c.log.Tracef("shard %s has adjacent parent %s which is not fully consumed yet...", id, pid) continue } } // Create a new consumer and start it c.wg.Add(1) go func(shardID string) { defer c.wg.Done() c.startShardConsumer(ctx, shardID, seqnrs[shardID]) }(id) } return nil } func (c *consumer) startShardConsumer(ctx context.Context, id, seqnr string) { c.log.Tracef("starting consumer for shard %s at sequence number %q...", id, seqnr) sc := &shardConsumer{ seqnr: seqnr, interval: c.pollInterval, log: c.log, onMessage: c.onMessage, client: c.client, params: &kinesis.GetShardIteratorInput{ ShardId: &id, ShardIteratorType: c.iterType, StreamName: &c.stream, }, } if seqnr != "" { sc.params.ShardIteratorType = types.ShardIteratorTypeAfterSequenceNumber sc.params.StartingSequenceNumber = &seqnr } c.shardConsumers[id] = sc childs, err := sc.consume(ctx, id) if err != nil { c.log.Errorf("Consuming shard %s failed: %v", id, err) return } c.log.Tracef("finished consuming shard %s", id) c.Lock() defer c.Unlock() c.shardsConsumed[id] = true delete(c.shardConsumers, id) for _, shard := range childs { cid := *shard.ShardId startable := true for _, pid := range shard.ParentShards { startable = startable && c.shardsConsumed[pid] } if !startable { c.log.Tracef("child shard %s of shard %s is not startable as parents are fully consumed yet...", cid, id) continue } c.log.Tracef("child shard %s of shard %s is startable...", cid, id) var cseqnr string if c.position != nil { cseqnr = c.position(cid) } c.wg.Add(1) go func() { defer c.wg.Done() c.startShardConsumer(ctx, cid, cseqnr) }() } } func (c *consumer) stop() { c.wg.Wait() }