From fa20c381a3d94ad825a5b1d454cabdfe74377be8 Mon Sep 17 00:00:00 2001 From: Dan Cojocaru Date: Wed, 27 Sep 2023 21:45:58 +0200 Subject: [PATCH] Implement subscriptions --- main.go | 88 ++++++++++++++++++--- pkg/handlers/findTrain.go | 119 ++++++++++++++++++----------- pkg/subscriptions/subscriptions.go | 84 ++++++++++++++++---- 3 files changed, 222 insertions(+), 69 deletions(-) diff --git a/main.go b/main.go index 944cf01..0a34fff 100644 --- a/main.go +++ b/main.go @@ -71,7 +71,11 @@ func main() { } database.SetDatabase(db) - subs, err := subscriptions.LoadSubscriptions() + subBot, err := tgBot.New(botToken) + if err != nil { + panic(err) + } + subs, err := subscriptions.LoadSubscriptions(subBot) if err != nil { subs = nil fmt.Printf("WARN : Could not load subscriptions: %s\n", err.Error()) @@ -153,7 +157,7 @@ func handler(ctx context.Context, b *tgBot.Bot, update *models.Update, subs *sub switch { case strings.HasPrefix(update.Message.Text, trainInfoCommand): - response = handleFindTrainStages(ctx, b, update, subs) + response = handleFindTrainStages(ctx, b, update) case strings.HasPrefix(update.Message.Text, cancelCommand): handlers.SetChatFlow(chatFlow, handlers.InitialFlowType, handlers.InitialFlowType, "") response = &handlers.HandlerResponse{ @@ -170,7 +174,7 @@ func handler(ctx context.Context, b *tgBot.Bot, update *models.Update, subs *sub }) case handlers.TrainInfoFlowType: log.Printf("DEBUG: trainInfoFlowType with stage %s\n", chatFlow.Stage) - response = handleFindTrainStages(ctx, b, update, subs) + response = handleFindTrainStages(ctx, b, update) } } } @@ -208,18 +212,18 @@ func handler(ctx context.Context, b *tgBot.Bot, update *models.Update, subs *sub ChatID: update.CallbackQuery.Message.Chat.ID, Text: pleaseWaitMessage, }) - response = handlers.HandleTrainNumberCommand(ctx, trainNumber, date, -1) + response, _ = handlers.HandleTrainNumberCommand(ctx, trainNumber, date, -1, false) if err == nil { response.ProgressMessageToEditId = message.ID } handlers.SetChatFlow(chatFlow, handlers.InitialFlowType, handlers.InitialFlowType, "") case handlers.TrainInfoChooseGroupCallbackQuery: + trainNumber := splitted[1] dateInt, _ := strconv.ParseInt(splitted[2], 10, 64) date := time.Unix(dateInt, 0) groupIndex, _ := strconv.ParseInt(splitted[3], 10, 31) - log.Printf("%s, %v, %d", update.CallbackQuery.Data, splitted, groupIndex) - originalResponse := handlers.HandleTrainNumberCommand(ctx, splitted[1], date, int(groupIndex)) + originalResponse, _ := handlers.HandleTrainNumberCommand(ctx, trainNumber, date, int(groupIndex), false) response = &handlers.HandlerResponse{ MessageEdits: []*tgBot.EditMessageTextParams{ { @@ -231,12 +235,78 @@ func handler(ctx context.Context, b *tgBot.Bot, update *models.Update, subs *sub }, }, } + + case handlers.TrainInfoSubscribeCallbackQuery: + trainNumber := splitted[1] + dateInt, _ := strconv.ParseInt(splitted[2], 10, 64) + date := time.Unix(dateInt, 0) + groupIndex, _ := strconv.ParseInt(splitted[3], 10, 31) + err := subs.InsertSubscription(subscriptions.SubData{ + ChatId: update.CallbackQuery.Message.Chat.ID, + MessageId: update.CallbackQuery.Message.ID, + TrainNumber: trainNumber, + Date: date, + GroupIndex: int(groupIndex), + }) + if err != nil { + log.Printf("ERROR: Subscribe error: %s", err.Error()) + response = &handlers.HandlerResponse{ + CallbackAnswer: &tgBot.AnswerCallbackQueryParams{ + Text: fmt.Sprintf("Error when subscribing."), + ShowAlert: true, + }, + } + } else { + // TODO: Update message to contain unsubscribe button + response = &handlers.HandlerResponse{ + CallbackAnswer: &tgBot.AnswerCallbackQueryParams{ + Text: fmt.Sprintf("Subscribed successfully!"), + }, + MessageMarkupEdits: []*tgBot.EditMessageReplyMarkupParams{ + { + ChatID: update.CallbackQuery.Message.Chat.ID, + MessageID: update.CallbackQuery.Message.ID, + ReplyMarkup: handlers.GetTrainNumberCommandResponseButtons(trainNumber, date, int(groupIndex), handlers.TrainInfoResponseButtonIncludeUnsub), + }, + }, + } + } + + case handlers.TrainInfoUnsubscribeCallbackQuery: + trainNumber := splitted[1] + dateInt, _ := strconv.ParseInt(splitted[2], 10, 64) + date := time.Unix(dateInt, 0) + groupIndex, _ := strconv.ParseInt(splitted[3], 10, 31) + _, err := subs.DeleteSubscription(update.CallbackQuery.Message.Chat.ID, update.CallbackQuery.Message.ID) + if err != nil { + log.Printf("ERROR: Unsubscribe error: %s", err.Error()) + response = &handlers.HandlerResponse{ + CallbackAnswer: &tgBot.AnswerCallbackQueryParams{ + Text: fmt.Sprintf("Error when unsubscribing."), + ShowAlert: true, + }, + } + } else { + // TODO: Update message to contain unsubscribe button + response = &handlers.HandlerResponse{ + CallbackAnswer: &tgBot.AnswerCallbackQueryParams{ + Text: fmt.Sprintf("Unsubscribed successfully!"), + }, + MessageMarkupEdits: []*tgBot.EditMessageReplyMarkupParams{ + { + ChatID: update.CallbackQuery.Message.Chat.ID, + MessageID: update.CallbackQuery.Message.ID, + ReplyMarkup: handlers.GetTrainNumberCommandResponseButtons(trainNumber, date, int(groupIndex), handlers.TrainInfoResponseButtonIncludeSub), + }, + }, + } + } } } } } -func handleFindTrainStages(ctx context.Context, b *tgBot.Bot, update *models.Update, subs *subscriptions.Subscriptions) *handlers.HandlerResponse { +func handleFindTrainStages(ctx context.Context, b *tgBot.Bot, update *models.Update) *handlers.HandlerResponse { log.Println("DEBUG: handleFindTrainStages") var response *handlers.HandlerResponse @@ -270,7 +340,7 @@ func handleFindTrainStages(ctx context.Context, b *tgBot.Bot, update *models.Upd groupIndex, _ = strconv.Atoi(commandParams[2]) } - response = handlers.HandleTrainNumberCommand(ctx, trainNumber, date, groupIndex) + response, _ = handlers.HandleTrainNumberCommand(ctx, trainNumber, date, groupIndex, false) if err == nil { response.ProgressMessageToEditId = message.ID } @@ -306,7 +376,7 @@ func handleFindTrainStages(ctx context.Context, b *tgBot.Bot, update *models.Upd ChatID: update.Message.Chat.ID, Text: pleaseWaitMessage, }) - response = handlers.HandleTrainNumberCommand(ctx, chatFlow.Extra, date, -1) + response, _ = handlers.HandleTrainNumberCommand(ctx, chatFlow.Extra, date, -1, false) if err == nil { response.ProgressMessageToEditId = message.ID } diff --git a/pkg/handlers/findTrain.go b/pkg/handlers/findTrain.go index 519e39b..165a392 100644 --- a/pkg/handlers/findTrain.go +++ b/pkg/handlers/findTrain.go @@ -18,11 +18,23 @@ import ( const ( TrainInfoChooseDateCallbackQuery = "TI_CHOOSE_DATE" TrainInfoChooseGroupCallbackQuery = "TI_CHOOSE_GROUP" + TrainInfoSubscribeCallbackQuery = "TI_SUB" + TrainInfoUnsubscribeCallbackQuery = "TI_UNSUB" viewInKaiBaseUrl = "https://kai.infotren.dcdev.ro/view-train.html" + + subscribeButton = "Subscribe to updates" + unsubscribeButton = "Unsubscribe from updates" + openInWebAppButton = "Open in WebApp" +) + +const ( + TrainInfoResponseButtonExcludeSub = iota + TrainInfoResponseButtonIncludeSub + TrainInfoResponseButtonIncludeUnsub ) -func HandleTrainNumberCommand(ctx context.Context, trainNumber string, date time.Time, groupIndex int) *HandlerResponse { +func HandleTrainNumberCommand(ctx context.Context, trainNumber string, date time.Time, groupIndex int, isSubscribed bool) (*HandlerResponse, bool) { trainData, err := api.GetTrain(ctx, trainNumber, date) switch { @@ -34,50 +46,46 @@ func HandleTrainNumberCommand(ctx context.Context, trainNumber string, date time Message: &bot.SendMessageParams{ Text: fmt.Sprintf("The train %s was not found.", trainNumber), }, - } + }, false case errors.Is(err, api.ServerError): log.Printf("ERROR: In handle train number: %s", err.Error()) return &HandlerResponse{ Message: &bot.SendMessageParams{ Text: fmt.Sprintf("Unknown server error when searching for train %s.", trainNumber), }, - } + }, false default: log.Printf("ERROR: In handle train number: %s", err.Error()) - return nil + return nil, false } if len(trainData.Groups) == 1 { groupIndex = 0 } - kaiUrl, _ := url.Parse(viewInKaiBaseUrl) - kaiUrlQuery := kaiUrl.Query() - kaiUrlQuery.Add("train", trainData.Number) - kaiUrlQuery.Add("date", trainData.Groups[0].Stations[0].Departure.ScheduleTime.Format(time.RFC3339)) - if groupIndex != -1 { - kaiUrlQuery.Add("groupIndex", strconv.Itoa(groupIndex)) - } - kaiUrl.RawQuery = kaiUrlQuery.Encode() - message := bot.SendMessageParams{} if groupIndex == -1 { message.Text = fmt.Sprintf("Train %s %s contains multiple groups. Please choose one.", trainData.Rank, trainData.Number) - replyButtons := make([][]models.InlineKeyboardButton, len(trainData.Groups)+1) - for i := range replyButtons { - if i == len(trainData.Groups) { - replyButtons[i] = append(replyButtons[i], models.InlineKeyboardButton{ - Text: "Open in WebApp", - URL: kaiUrl.String(), - }) - } else { - group := &trainData.Groups[i] - replyButtons[i] = append(replyButtons[i], models.InlineKeyboardButton{ + replyButtons := make([][]models.InlineKeyboardButton, 0, len(trainData.Groups)+1) + for i, group := range trainData.Groups { + replyButtons = append(replyButtons, []models.InlineKeyboardButton{ + { Text: fmt.Sprintf("%s ➔ %s", group.Route.From, group.Route.To), CallbackData: fmt.Sprintf(TrainInfoChooseGroupCallbackQuery+"\x1b%s\x1b%d\x1b%d", trainNumber, date.Unix(), i), - }) - } + }, + }) } + kaiUrl, _ := url.Parse(viewInKaiBaseUrl) + kaiUrlQuery := kaiUrl.Query() + kaiUrlQuery.Add("train", trainData.Number) + kaiUrlQuery.Add("date", trainData.Groups[0].Stations[0].Departure.ScheduleTime.Format(time.RFC3339)) + kaiUrl.RawQuery = kaiUrlQuery.Encode() + replyButtons = append(replyButtons, []models.InlineKeyboardButton{ + { + Text: "Open in WebApp", + URL: kaiUrl.String(), + }, + }) message.ReplyMarkup = models.InlineKeyboardMarkup{ InlineKeyboard: replyButtons, } @@ -127,16 +135,11 @@ func HandleTrainNumberCommand(ctx context.Context, trainNumber string, date time Length: len(fmt.Sprintf("%s %s", trainData.Rank, trainData.Number)), }, } - message.ReplyMarkup = models.InlineKeyboardMarkup{ - InlineKeyboard: [][]models.InlineKeyboardButton{ - { - models.InlineKeyboardButton{ - Text: "Open in WebApp", - URL: kaiUrl.String(), - }, - }, - }, + buttonKind := TrainInfoResponseButtonIncludeSub + if isSubscribed { + buttonKind = TrainInfoResponseButtonIncludeUnsub } + message.ReplyMarkup = GetTrainNumberCommandResponseButtons(trainData.Number, group.Stations[0].Departure.ScheduleTime, groupIndex, buttonKind) } else { message.Text = fmt.Sprintf("The status of the train %s %s is unknown.", trainData.Rank, trainData.Number) message.Entities = []models.MessageEntity{ @@ -146,19 +149,47 @@ func HandleTrainNumberCommand(ctx context.Context, trainNumber string, date time Length: len(fmt.Sprintf("%s %s", trainData.Rank, trainData.Number)), }, } - message.ReplyMarkup = models.InlineKeyboardMarkup{ - InlineKeyboard: [][]models.InlineKeyboardButton{ - { - models.InlineKeyboardButton{ - Text: "Open in WebApp", - URL: kaiUrl.String(), - }, - }, - }, - } + message.ReplyMarkup = GetTrainNumberCommandResponseButtons(trainData.Number, trainData.Groups[0].Stations[0].Departure.ScheduleTime, groupIndex, TrainInfoResponseButtonExcludeSub) } return &HandlerResponse{ Message: &message, + }, true +} + +func GetTrainNumberCommandResponseButtons(trainNumber string, date time.Time, groupIndex int, responseButton int) models.ReplyMarkup { + kaiUrl, _ := url.Parse(viewInKaiBaseUrl) + kaiUrlQuery := kaiUrl.Query() + kaiUrlQuery.Add("train", trainNumber) + kaiUrlQuery.Add("date", date.Format(time.RFC3339)) + if groupIndex != -1 { + kaiUrlQuery.Add("groupIndex", strconv.Itoa(groupIndex)) + } + kaiUrl.RawQuery = kaiUrlQuery.Encode() + + result := make([][]models.InlineKeyboardButton, 0) + if responseButton == TrainInfoResponseButtonIncludeSub { + result = append(result, []models.InlineKeyboardButton{ + { + Text: subscribeButton, + CallbackData: fmt.Sprintf(TrainInfoSubscribeCallbackQuery+"\x1b%s\x1b%d\x1b%d", trainNumber, date.Unix(), groupIndex), + }, + }) + } else if responseButton == TrainInfoResponseButtonIncludeUnsub { + result = append(result, []models.InlineKeyboardButton{ + { + Text: unsubscribeButton, + CallbackData: fmt.Sprintf(TrainInfoUnsubscribeCallbackQuery+"\x1b%s\x1b%d\x1b%d", trainNumber, date.Unix(), groupIndex), + }, + }) + } + result = append(result, []models.InlineKeyboardButton{ + { + Text: openInWebAppButton, + URL: kaiUrl.String(), + }, + }) + return models.InlineKeyboardMarkup{ + InlineKeyboard: result, } } diff --git a/pkg/subscriptions/subscriptions.go b/pkg/subscriptions/subscriptions.go index 8e5635d..249f44c 100644 --- a/pkg/subscriptions/subscriptions.go +++ b/pkg/subscriptions/subscriptions.go @@ -2,7 +2,9 @@ package subscriptions import ( "context" + "dcdev.ro/CfrTrainInfoTelegramBot/pkg/handlers" "fmt" + "github.com/go-telegram/bot" "log" "sync" "time" @@ -17,14 +19,16 @@ type SubData struct { MessageId int TrainNumber string Date time.Time + GroupIndex int } type Subscriptions struct { mutex sync.RWMutex data map[int64][]SubData + tgBot *bot.Bot } -func LoadSubscriptions() (*Subscriptions, error) { +func LoadSubscriptions(tgBot *bot.Bot) (*Subscriptions, error) { subs := make([]SubData, 0) _, err := database.ReadDB(func(db *gorm.DB) (*gorm.DB, error) { result := db.Find(&subs) @@ -37,6 +41,7 @@ func LoadSubscriptions() (*Subscriptions, error) { return &Subscriptions{ mutex: sync.RWMutex{}, data: result, + tgBot: tgBot, }, err } @@ -58,12 +63,12 @@ func (sub *Subscriptions) Replace(chatId int64, data []SubData) error { return err } -func (sub *Subscriptions) InsertSubscription(chatId int64, data SubData) error { +func (sub *Subscriptions) InsertSubscription(data SubData) error { sub.mutex.Lock() defer sub.mutex.Unlock() - datas := sub.data[chatId] + datas := sub.data[data.ChatId] datas = append(datas, data) - sub.data[chatId] = datas + sub.data[data.ChatId] = datas _, err := database.WriteDB(func(db *gorm.DB) (*gorm.DB, error) { db.Create(&data) return db, db.Error @@ -121,23 +126,70 @@ func (sub *Subscriptions) DeleteSubscription(chatId int64, messageId int) (*SubD func (sub *Subscriptions) CheckSubscriptions(ctx context.Context) { ticker := time.NewTicker(time.Second * 90) + sub.executeChecks(ctx) for { select { case <-ticker.C: - func() { - sub.mutex.RLock() - defer sub.mutex.RUnlock() - - for chatId, datas := range sub.data { - // TODO: Check for updates - for i := range datas { - data := &datas[i] - log.Printf("DEBUG: Timer tick, update for chat %d, train %s", chatId, data.TrainNumber) - } - } - }() + sub.executeChecks(ctx) case <-ctx.Done(): return } } } + +type workerData struct { + tgBot *bot.Bot + data SubData +} + +func (sub *Subscriptions) executeChecks(ctx context.Context) { + sub.mutex.RLock() + defer sub.mutex.RUnlock() + + // Only allow 8 concurrent requests + // TODO: Make configurable instead of hardcoded + workerCount := 8 + workerChan := make(chan workerData, workerCount) + wg := &sync.WaitGroup{} + for i := 0; i < workerCount; i++ { + wg.Add(1) + go checkWorker(ctx, workerChan, wg) + } + + for _, datas := range sub.data { + for i := range datas { + workerChan <- workerData{ + tgBot: sub.tgBot, + data: datas[i], + } + } + } + close(workerChan) + wg.Wait() +} + +func checkWorker(ctx context.Context, workerChan <-chan workerData, wg *sync.WaitGroup) { + defer wg.Done() + for wData := range workerChan { + data := wData.data + log.Printf("DEBUG: Timer tick, update for chat %d, train %s, date %s, group %d", data.ChatId, data.TrainNumber, data.Date.Format("2006-01-02"), data.GroupIndex) + + resp, ok := handlers.HandleTrainNumberCommand(ctx, data.TrainNumber, data.Date, data.GroupIndex, true) + + if !ok || resp == nil || resp.Message == nil { + // Silently discard update errors + log.Printf("DEBUG: Error when updating chat %d, train %s, date %s, group %d", data.ChatId, data.TrainNumber, data.Date.Format("2006-01-02"), data.GroupIndex) + return + } + + _, _ = wData.tgBot.EditMessageText(ctx, &bot.EditMessageTextParams{ + ChatID: data.ChatId, + MessageID: data.MessageId, + Text: resp.Message.Text, + ParseMode: resp.Message.ParseMode, + Entities: resp.Message.Entities, + DisableWebPagePreview: resp.Message.DisableWebPagePreview, + ReplyMarkup: resp.Message.ReplyMarkup, + }) + } +}