Browse Source

Implement subscriptions

master
Kenneth Bruen 1 year ago
parent
commit
fa20c381a3
Signed by: kbruen
GPG Key ID: C1980A470C3EE5B1
  1. 88
      main.go
  2. 115
      pkg/handlers/findTrain.go
  3. 74
      pkg/subscriptions/subscriptions.go

88
main.go

@ -71,7 +71,11 @@ func main() {
} }
database.SetDatabase(db) database.SetDatabase(db)
subs, err := subscriptions.LoadSubscriptions() subBot, err := tgBot.New(botToken)
if err != nil {
panic(err)
}
subs, err := subscriptions.LoadSubscriptions(subBot)
if err != nil { if err != nil {
subs = nil subs = nil
fmt.Printf("WARN : Could not load subscriptions: %s\n", err.Error()) fmt.Printf("WARN : Could not load subscriptions: %s\n", err.Error())
@ -153,7 +157,7 @@ func handler(ctx context.Context, b *tgBot.Bot, update *models.Update, subs *sub
switch { switch {
case strings.HasPrefix(update.Message.Text, trainInfoCommand): case strings.HasPrefix(update.Message.Text, trainInfoCommand):
response = handleFindTrainStages(ctx, b, update, subs) response = handleFindTrainStages(ctx, b, update)
case strings.HasPrefix(update.Message.Text, cancelCommand): case strings.HasPrefix(update.Message.Text, cancelCommand):
handlers.SetChatFlow(chatFlow, handlers.InitialFlowType, handlers.InitialFlowType, "") handlers.SetChatFlow(chatFlow, handlers.InitialFlowType, handlers.InitialFlowType, "")
response = &handlers.HandlerResponse{ response = &handlers.HandlerResponse{
@ -170,7 +174,7 @@ func handler(ctx context.Context, b *tgBot.Bot, update *models.Update, subs *sub
}) })
case handlers.TrainInfoFlowType: case handlers.TrainInfoFlowType:
log.Printf("DEBUG: trainInfoFlowType with stage %s\n", chatFlow.Stage) log.Printf("DEBUG: trainInfoFlowType with stage %s\n", chatFlow.Stage)
response = handleFindTrainStages(ctx, b, update, subs) response = handleFindTrainStages(ctx, b, update)
} }
} }
} }
@ -208,18 +212,18 @@ func handler(ctx context.Context, b *tgBot.Bot, update *models.Update, subs *sub
ChatID: update.CallbackQuery.Message.Chat.ID, ChatID: update.CallbackQuery.Message.Chat.ID,
Text: pleaseWaitMessage, Text: pleaseWaitMessage,
}) })
response = handlers.HandleTrainNumberCommand(ctx, trainNumber, date, -1) response, _ = handlers.HandleTrainNumberCommand(ctx, trainNumber, date, -1, false)
if err == nil { if err == nil {
response.ProgressMessageToEditId = message.ID response.ProgressMessageToEditId = message.ID
} }
handlers.SetChatFlow(chatFlow, handlers.InitialFlowType, handlers.InitialFlowType, "") handlers.SetChatFlow(chatFlow, handlers.InitialFlowType, handlers.InitialFlowType, "")
case handlers.TrainInfoChooseGroupCallbackQuery: case handlers.TrainInfoChooseGroupCallbackQuery:
trainNumber := splitted[1]
dateInt, _ := strconv.ParseInt(splitted[2], 10, 64) dateInt, _ := strconv.ParseInt(splitted[2], 10, 64)
date := time.Unix(dateInt, 0) date := time.Unix(dateInt, 0)
groupIndex, _ := strconv.ParseInt(splitted[3], 10, 31) groupIndex, _ := strconv.ParseInt(splitted[3], 10, 31)
log.Printf("%s, %v, %d", update.CallbackQuery.Data, splitted, groupIndex) originalResponse, _ := handlers.HandleTrainNumberCommand(ctx, trainNumber, date, int(groupIndex), false)
originalResponse := handlers.HandleTrainNumberCommand(ctx, splitted[1], date, int(groupIndex))
response = &handlers.HandlerResponse{ response = &handlers.HandlerResponse{
MessageEdits: []*tgBot.EditMessageTextParams{ MessageEdits: []*tgBot.EditMessageTextParams{
{ {
@ -231,12 +235,78 @@ func handler(ctx context.Context, b *tgBot.Bot, update *models.Update, subs *sub
}, },
}, },
} }
case handlers.TrainInfoSubscribeCallbackQuery:
trainNumber := splitted[1]
dateInt, _ := strconv.ParseInt(splitted[2], 10, 64)
date := time.Unix(dateInt, 0)
groupIndex, _ := strconv.ParseInt(splitted[3], 10, 31)
err := subs.InsertSubscription(subscriptions.SubData{
ChatId: update.CallbackQuery.Message.Chat.ID,
MessageId: update.CallbackQuery.Message.ID,
TrainNumber: trainNumber,
Date: date,
GroupIndex: int(groupIndex),
})
if err != nil {
log.Printf("ERROR: Subscribe error: %s", err.Error())
response = &handlers.HandlerResponse{
CallbackAnswer: &tgBot.AnswerCallbackQueryParams{
Text: fmt.Sprintf("Error when subscribing."),
ShowAlert: true,
},
}
} else {
// TODO: Update message to contain unsubscribe button
response = &handlers.HandlerResponse{
CallbackAnswer: &tgBot.AnswerCallbackQueryParams{
Text: fmt.Sprintf("Subscribed successfully!"),
},
MessageMarkupEdits: []*tgBot.EditMessageReplyMarkupParams{
{
ChatID: update.CallbackQuery.Message.Chat.ID,
MessageID: update.CallbackQuery.Message.ID,
ReplyMarkup: handlers.GetTrainNumberCommandResponseButtons(trainNumber, date, int(groupIndex), handlers.TrainInfoResponseButtonIncludeUnsub),
},
},
}
}
case handlers.TrainInfoUnsubscribeCallbackQuery:
trainNumber := splitted[1]
dateInt, _ := strconv.ParseInt(splitted[2], 10, 64)
date := time.Unix(dateInt, 0)
groupIndex, _ := strconv.ParseInt(splitted[3], 10, 31)
_, err := subs.DeleteSubscription(update.CallbackQuery.Message.Chat.ID, update.CallbackQuery.Message.ID)
if err != nil {
log.Printf("ERROR: Unsubscribe error: %s", err.Error())
response = &handlers.HandlerResponse{
CallbackAnswer: &tgBot.AnswerCallbackQueryParams{
Text: fmt.Sprintf("Error when unsubscribing."),
ShowAlert: true,
},
}
} else {
// TODO: Update message to contain unsubscribe button
response = &handlers.HandlerResponse{
CallbackAnswer: &tgBot.AnswerCallbackQueryParams{
Text: fmt.Sprintf("Unsubscribed successfully!"),
},
MessageMarkupEdits: []*tgBot.EditMessageReplyMarkupParams{
{
ChatID: update.CallbackQuery.Message.Chat.ID,
MessageID: update.CallbackQuery.Message.ID,
ReplyMarkup: handlers.GetTrainNumberCommandResponseButtons(trainNumber, date, int(groupIndex), handlers.TrainInfoResponseButtonIncludeSub),
},
},
}
}
} }
} }
} }
} }
func handleFindTrainStages(ctx context.Context, b *tgBot.Bot, update *models.Update, subs *subscriptions.Subscriptions) *handlers.HandlerResponse { func handleFindTrainStages(ctx context.Context, b *tgBot.Bot, update *models.Update) *handlers.HandlerResponse {
log.Println("DEBUG: handleFindTrainStages") log.Println("DEBUG: handleFindTrainStages")
var response *handlers.HandlerResponse var response *handlers.HandlerResponse
@ -270,7 +340,7 @@ func handleFindTrainStages(ctx context.Context, b *tgBot.Bot, update *models.Upd
groupIndex, _ = strconv.Atoi(commandParams[2]) groupIndex, _ = strconv.Atoi(commandParams[2])
} }
response = handlers.HandleTrainNumberCommand(ctx, trainNumber, date, groupIndex) response, _ = handlers.HandleTrainNumberCommand(ctx, trainNumber, date, groupIndex, false)
if err == nil { if err == nil {
response.ProgressMessageToEditId = message.ID response.ProgressMessageToEditId = message.ID
} }
@ -306,7 +376,7 @@ func handleFindTrainStages(ctx context.Context, b *tgBot.Bot, update *models.Upd
ChatID: update.Message.Chat.ID, ChatID: update.Message.Chat.ID,
Text: pleaseWaitMessage, Text: pleaseWaitMessage,
}) })
response = handlers.HandleTrainNumberCommand(ctx, chatFlow.Extra, date, -1) response, _ = handlers.HandleTrainNumberCommand(ctx, chatFlow.Extra, date, -1, false)
if err == nil { if err == nil {
response.ProgressMessageToEditId = message.ID response.ProgressMessageToEditId = message.ID
} }

115
pkg/handlers/findTrain.go

@ -18,11 +18,23 @@ import (
const ( const (
TrainInfoChooseDateCallbackQuery = "TI_CHOOSE_DATE" TrainInfoChooseDateCallbackQuery = "TI_CHOOSE_DATE"
TrainInfoChooseGroupCallbackQuery = "TI_CHOOSE_GROUP" TrainInfoChooseGroupCallbackQuery = "TI_CHOOSE_GROUP"
TrainInfoSubscribeCallbackQuery = "TI_SUB"
TrainInfoUnsubscribeCallbackQuery = "TI_UNSUB"
viewInKaiBaseUrl = "https://kai.infotren.dcdev.ro/view-train.html" viewInKaiBaseUrl = "https://kai.infotren.dcdev.ro/view-train.html"
subscribeButton = "Subscribe to updates"
unsubscribeButton = "Unsubscribe from updates"
openInWebAppButton = "Open in WebApp"
) )
func HandleTrainNumberCommand(ctx context.Context, trainNumber string, date time.Time, groupIndex int) *HandlerResponse { const (
TrainInfoResponseButtonExcludeSub = iota
TrainInfoResponseButtonIncludeSub
TrainInfoResponseButtonIncludeUnsub
)
func HandleTrainNumberCommand(ctx context.Context, trainNumber string, date time.Time, groupIndex int, isSubscribed bool) (*HandlerResponse, bool) {
trainData, err := api.GetTrain(ctx, trainNumber, date) trainData, err := api.GetTrain(ctx, trainNumber, date)
switch { switch {
@ -34,50 +46,46 @@ func HandleTrainNumberCommand(ctx context.Context, trainNumber string, date time
Message: &bot.SendMessageParams{ Message: &bot.SendMessageParams{
Text: fmt.Sprintf("The train %s was not found.", trainNumber), Text: fmt.Sprintf("The train %s was not found.", trainNumber),
}, },
} }, false
case errors.Is(err, api.ServerError): case errors.Is(err, api.ServerError):
log.Printf("ERROR: In handle train number: %s", err.Error()) log.Printf("ERROR: In handle train number: %s", err.Error())
return &HandlerResponse{ return &HandlerResponse{
Message: &bot.SendMessageParams{ Message: &bot.SendMessageParams{
Text: fmt.Sprintf("Unknown server error when searching for train %s.", trainNumber), Text: fmt.Sprintf("Unknown server error when searching for train %s.", trainNumber),
}, },
} }, false
default: default:
log.Printf("ERROR: In handle train number: %s", err.Error()) log.Printf("ERROR: In handle train number: %s", err.Error())
return nil return nil, false
} }
if len(trainData.Groups) == 1 { if len(trainData.Groups) == 1 {
groupIndex = 0 groupIndex = 0
} }
message := bot.SendMessageParams{}
if groupIndex == -1 {
message.Text = fmt.Sprintf("Train %s %s contains multiple groups. Please choose one.", trainData.Rank, trainData.Number)
replyButtons := make([][]models.InlineKeyboardButton, 0, len(trainData.Groups)+1)
for i, group := range trainData.Groups {
replyButtons = append(replyButtons, []models.InlineKeyboardButton{
{
Text: fmt.Sprintf("%s ➔ %s", group.Route.From, group.Route.To),
CallbackData: fmt.Sprintf(TrainInfoChooseGroupCallbackQuery+"\x1b%s\x1b%d\x1b%d", trainNumber, date.Unix(), i),
},
})
}
kaiUrl, _ := url.Parse(viewInKaiBaseUrl) kaiUrl, _ := url.Parse(viewInKaiBaseUrl)
kaiUrlQuery := kaiUrl.Query() kaiUrlQuery := kaiUrl.Query()
kaiUrlQuery.Add("train", trainData.Number) kaiUrlQuery.Add("train", trainData.Number)
kaiUrlQuery.Add("date", trainData.Groups[0].Stations[0].Departure.ScheduleTime.Format(time.RFC3339)) kaiUrlQuery.Add("date", trainData.Groups[0].Stations[0].Departure.ScheduleTime.Format(time.RFC3339))
if groupIndex != -1 {
kaiUrlQuery.Add("groupIndex", strconv.Itoa(groupIndex))
}
kaiUrl.RawQuery = kaiUrlQuery.Encode() kaiUrl.RawQuery = kaiUrlQuery.Encode()
replyButtons = append(replyButtons, []models.InlineKeyboardButton{
message := bot.SendMessageParams{} {
if groupIndex == -1 {
message.Text = fmt.Sprintf("Train %s %s contains multiple groups. Please choose one.", trainData.Rank, trainData.Number)
replyButtons := make([][]models.InlineKeyboardButton, len(trainData.Groups)+1)
for i := range replyButtons {
if i == len(trainData.Groups) {
replyButtons[i] = append(replyButtons[i], models.InlineKeyboardButton{
Text: "Open in WebApp", Text: "Open in WebApp",
URL: kaiUrl.String(), URL: kaiUrl.String(),
},
}) })
} else {
group := &trainData.Groups[i]
replyButtons[i] = append(replyButtons[i], models.InlineKeyboardButton{
Text: fmt.Sprintf("%s ➔ %s", group.Route.From, group.Route.To),
CallbackData: fmt.Sprintf(TrainInfoChooseGroupCallbackQuery+"\x1b%s\x1b%d\x1b%d", trainNumber, date.Unix(), i),
})
}
}
message.ReplyMarkup = models.InlineKeyboardMarkup{ message.ReplyMarkup = models.InlineKeyboardMarkup{
InlineKeyboard: replyButtons, InlineKeyboard: replyButtons,
} }
@ -127,16 +135,11 @@ func HandleTrainNumberCommand(ctx context.Context, trainNumber string, date time
Length: len(fmt.Sprintf("%s %s", trainData.Rank, trainData.Number)), Length: len(fmt.Sprintf("%s %s", trainData.Rank, trainData.Number)),
}, },
} }
message.ReplyMarkup = models.InlineKeyboardMarkup{ buttonKind := TrainInfoResponseButtonIncludeSub
InlineKeyboard: [][]models.InlineKeyboardButton{ if isSubscribed {
{ buttonKind = TrainInfoResponseButtonIncludeUnsub
models.InlineKeyboardButton{
Text: "Open in WebApp",
URL: kaiUrl.String(),
},
},
},
} }
message.ReplyMarkup = GetTrainNumberCommandResponseButtons(trainData.Number, group.Stations[0].Departure.ScheduleTime, groupIndex, buttonKind)
} else { } else {
message.Text = fmt.Sprintf("The status of the train %s %s is unknown.", trainData.Rank, trainData.Number) message.Text = fmt.Sprintf("The status of the train %s %s is unknown.", trainData.Rank, trainData.Number)
message.Entities = []models.MessageEntity{ message.Entities = []models.MessageEntity{
@ -146,19 +149,47 @@ func HandleTrainNumberCommand(ctx context.Context, trainNumber string, date time
Length: len(fmt.Sprintf("%s %s", trainData.Rank, trainData.Number)), Length: len(fmt.Sprintf("%s %s", trainData.Rank, trainData.Number)),
}, },
} }
message.ReplyMarkup = models.InlineKeyboardMarkup{ message.ReplyMarkup = GetTrainNumberCommandResponseButtons(trainData.Number, trainData.Groups[0].Stations[0].Departure.ScheduleTime, groupIndex, TrainInfoResponseButtonExcludeSub)
InlineKeyboard: [][]models.InlineKeyboardButton{
{
models.InlineKeyboardButton{
Text: "Open in WebApp",
URL: kaiUrl.String(),
},
},
},
}
} }
return &HandlerResponse{ return &HandlerResponse{
Message: &message, Message: &message,
}, true
}
func GetTrainNumberCommandResponseButtons(trainNumber string, date time.Time, groupIndex int, responseButton int) models.ReplyMarkup {
kaiUrl, _ := url.Parse(viewInKaiBaseUrl)
kaiUrlQuery := kaiUrl.Query()
kaiUrlQuery.Add("train", trainNumber)
kaiUrlQuery.Add("date", date.Format(time.RFC3339))
if groupIndex != -1 {
kaiUrlQuery.Add("groupIndex", strconv.Itoa(groupIndex))
}
kaiUrl.RawQuery = kaiUrlQuery.Encode()
result := make([][]models.InlineKeyboardButton, 0)
if responseButton == TrainInfoResponseButtonIncludeSub {
result = append(result, []models.InlineKeyboardButton{
{
Text: subscribeButton,
CallbackData: fmt.Sprintf(TrainInfoSubscribeCallbackQuery+"\x1b%s\x1b%d\x1b%d", trainNumber, date.Unix(), groupIndex),
},
})
} else if responseButton == TrainInfoResponseButtonIncludeUnsub {
result = append(result, []models.InlineKeyboardButton{
{
Text: unsubscribeButton,
CallbackData: fmt.Sprintf(TrainInfoUnsubscribeCallbackQuery+"\x1b%s\x1b%d\x1b%d", trainNumber, date.Unix(), groupIndex),
},
})
}
result = append(result, []models.InlineKeyboardButton{
{
Text: openInWebAppButton,
URL: kaiUrl.String(),
},
})
return models.InlineKeyboardMarkup{
InlineKeyboard: result,
} }
} }

74
pkg/subscriptions/subscriptions.go

@ -2,7 +2,9 @@ package subscriptions
import ( import (
"context" "context"
"dcdev.ro/CfrTrainInfoTelegramBot/pkg/handlers"
"fmt" "fmt"
"github.com/go-telegram/bot"
"log" "log"
"sync" "sync"
"time" "time"
@ -17,14 +19,16 @@ type SubData struct {
MessageId int MessageId int
TrainNumber string TrainNumber string
Date time.Time Date time.Time
GroupIndex int
} }
type Subscriptions struct { type Subscriptions struct {
mutex sync.RWMutex mutex sync.RWMutex
data map[int64][]SubData data map[int64][]SubData
tgBot *bot.Bot
} }
func LoadSubscriptions() (*Subscriptions, error) { func LoadSubscriptions(tgBot *bot.Bot) (*Subscriptions, error) {
subs := make([]SubData, 0) subs := make([]SubData, 0)
_, err := database.ReadDB(func(db *gorm.DB) (*gorm.DB, error) { _, err := database.ReadDB(func(db *gorm.DB) (*gorm.DB, error) {
result := db.Find(&subs) result := db.Find(&subs)
@ -37,6 +41,7 @@ func LoadSubscriptions() (*Subscriptions, error) {
return &Subscriptions{ return &Subscriptions{
mutex: sync.RWMutex{}, mutex: sync.RWMutex{},
data: result, data: result,
tgBot: tgBot,
}, err }, err
} }
@ -58,12 +63,12 @@ func (sub *Subscriptions) Replace(chatId int64, data []SubData) error {
return err return err
} }
func (sub *Subscriptions) InsertSubscription(chatId int64, data SubData) error { func (sub *Subscriptions) InsertSubscription(data SubData) error {
sub.mutex.Lock() sub.mutex.Lock()
defer sub.mutex.Unlock() defer sub.mutex.Unlock()
datas := sub.data[chatId] datas := sub.data[data.ChatId]
datas = append(datas, data) datas = append(datas, data)
sub.data[chatId] = datas sub.data[data.ChatId] = datas
_, err := database.WriteDB(func(db *gorm.DB) (*gorm.DB, error) { _, err := database.WriteDB(func(db *gorm.DB) (*gorm.DB, error) {
db.Create(&data) db.Create(&data)
return db, db.Error return db, db.Error
@ -121,23 +126,70 @@ func (sub *Subscriptions) DeleteSubscription(chatId int64, messageId int) (*SubD
func (sub *Subscriptions) CheckSubscriptions(ctx context.Context) { func (sub *Subscriptions) CheckSubscriptions(ctx context.Context) {
ticker := time.NewTicker(time.Second * 90) ticker := time.NewTicker(time.Second * 90)
sub.executeChecks(ctx)
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
func() { sub.executeChecks(ctx)
case <-ctx.Done():
return
}
}
}
type workerData struct {
tgBot *bot.Bot
data SubData
}
func (sub *Subscriptions) executeChecks(ctx context.Context) {
sub.mutex.RLock() sub.mutex.RLock()
defer sub.mutex.RUnlock() defer sub.mutex.RUnlock()
for chatId, datas := range sub.data { // Only allow 8 concurrent requests
// TODO: Check for updates // TODO: Make configurable instead of hardcoded
workerCount := 8
workerChan := make(chan workerData, workerCount)
wg := &sync.WaitGroup{}
for i := 0; i < workerCount; i++ {
wg.Add(1)
go checkWorker(ctx, workerChan, wg)
}
for _, datas := range sub.data {
for i := range datas { for i := range datas {
data := &datas[i] workerChan <- workerData{
log.Printf("DEBUG: Timer tick, update for chat %d, train %s", chatId, data.TrainNumber) tgBot: sub.tgBot,
data: datas[i],
} }
} }
}() }
case <-ctx.Done(): close(workerChan)
wg.Wait()
}
func checkWorker(ctx context.Context, workerChan <-chan workerData, wg *sync.WaitGroup) {
defer wg.Done()
for wData := range workerChan {
data := wData.data
log.Printf("DEBUG: Timer tick, update for chat %d, train %s, date %s, group %d", data.ChatId, data.TrainNumber, data.Date.Format("2006-01-02"), data.GroupIndex)
resp, ok := handlers.HandleTrainNumberCommand(ctx, data.TrainNumber, data.Date, data.GroupIndex, true)
if !ok || resp == nil || resp.Message == nil {
// Silently discard update errors
log.Printf("DEBUG: Error when updating chat %d, train %s, date %s, group %d", data.ChatId, data.TrainNumber, data.Date.Format("2006-01-02"), data.GroupIndex)
return return
} }
_, _ = wData.tgBot.EditMessageText(ctx, &bot.EditMessageTextParams{
ChatID: data.ChatId,
MessageID: data.MessageId,
Text: resp.Message.Text,
ParseMode: resp.Message.ParseMode,
Entities: resp.Message.Entities,
DisableWebPagePreview: resp.Message.DisableWebPagePreview,
ReplyMarkup: resp.Message.ReplyMarkup,
})
} }
} }

Loading…
Cancel
Save