From b44f1c92d25ab5f0e88fa10f15ce8e991a701bba Mon Sep 17 00:00:00 2001 From: Dan Cojocaru Date: Wed, 27 Sep 2023 04:31:02 +0200 Subject: [PATCH] Initial commit --- .gitignore | 26 +++ .idea/.gitignore | 8 + .idea/CfrTrainInfoTelegramBot.iml | 9 + .idea/modules.xml | 8 + .idea/vcs.xml | 15 ++ go.mod | 15 ++ go.sum | 12 + main.go | 349 +++++++++++++++++++++++++++++ pkg/api/trains.go | 111 +++++++++ pkg/database/database.go | 28 +++ pkg/handlers/chatFlow.go | 57 +++++ pkg/handlers/findTrain.go | 164 ++++++++++++++ pkg/handlers/response.go | 15 ++ pkg/subscriptions/subscriptions.go | 143 ++++++++++++ pkg/utils/location.go | 7 + pkg/utils/parseDate.go | 56 +++++ 16 files changed, 1023 insertions(+) create mode 100644 .gitignore create mode 100644 .idea/.gitignore create mode 100644 .idea/CfrTrainInfoTelegramBot.iml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 pkg/api/trains.go create mode 100644 pkg/database/database.go create mode 100644 pkg/handlers/chatFlow.go create mode 100644 pkg/handlers/findTrain.go create mode 100644 pkg/handlers/response.go create mode 100644 pkg/subscriptions/subscriptions.go create mode 100644 pkg/utils/location.go create mode 100644 pkg/utils/parseDate.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..415e529 --- /dev/null +++ b/.gitignore @@ -0,0 +1,26 @@ +debug +bot_db.sqlite + +## Go Template from https://github.com/github/gitignore/blob/main/Go.gitignore + +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/CfrTrainInfoTelegramBot.iml b/.idea/CfrTrainInfoTelegramBot.iml new file mode 100644 index 0000000..5e764c4 --- /dev/null +++ b/.idea/CfrTrainInfoTelegramBot.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..6df9409 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..0b0005f --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a8a03ff --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module dcdev.ro/CfrTrainInfoTelegramBot + +go 1.20 + +require ( + github.com/go-telegram/bot v0.7.15 + gorm.io/driver/sqlite v1.5.3 + gorm.io/gorm v1.25.4 +) + +require ( + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/mattn/go-sqlite3 v1.14.17 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8916b78 --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/go-telegram/bot v0.7.15 h1:Xi1PGEUjcJvZ4qG0EssFPUkcxlDbEIx1VWStMeG6GvE= +github.com/go-telegram/bot v0.7.15/go.mod h1:i2TRs7fXWIeaceF3z7KzsMt/he0TwkVC680mvdTFYeM= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +gorm.io/driver/sqlite v1.5.3 h1:7/0dUgX28KAcopdfbRWWl68Rflh6osa4rDh+m51KL2g= +gorm.io/driver/sqlite v1.5.3/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= +gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw= +gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= diff --git a/main.go b/main.go new file mode 100644 index 0000000..8612c8c --- /dev/null +++ b/main.go @@ -0,0 +1,349 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + "time" + + "dcdev.ro/CfrTrainInfoTelegramBot/pkg/database" + "dcdev.ro/CfrTrainInfoTelegramBot/pkg/handlers" + "dcdev.ro/CfrTrainInfoTelegramBot/pkg/subscriptions" + "dcdev.ro/CfrTrainInfoTelegramBot/pkg/utils" + tgBot "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +const ( + trainInfoCommand = "/train_info" + stationInfoCommand = "/station_info" + routeCommand = "/route" + cancelCommand = "/cancel" + + initialMessage = `Hello. 😄 + +You can send the following commands: + +` + trainInfoCommand + ` - Find information about a certain train. +` + stationInfoCommand + ` - Find departures or arrivals at a certain station. +` + routeCommand + ` - Find trains for a certain route. + +You may use ` + cancelCommand + ` to cancel any ongoing command.` + waitingForTrainNumberMessage = "Please send the number of the train you want information for." + pleaseWaitMessage = "Please wait..." + cancelResponseMessage = "Command cancelled." + chooseDateMessage = `Please choose the date of departure from the first station for this train. + +You may also send the date as a message in the following formats: dd.mm.yyyy, m/d/yyyy, yyyy-mm-dd, UNIX timestamp. + +Keep in mind that, for night trains, this date might be yesterday.` + invalidDateMessage = "Invalid date. Please try again or us " + cancelCommand + " to cancel." +) + +func main() { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + log.SetOutput(os.Stderr) + + botToken := os.Getenv("CFR_BOT.TOKEN") + if len(botToken) == 0 { + log.Fatal("ERROR: No bot token supplied; supply with CFR_BOT.TOKEN") + } + + db, err := gorm.Open(sqlite.Open("bot_db.sqlite"), &gorm.Config{}) + if err != nil { + panic(err) + } + if err := db.AutoMigrate(&handlers.ChatFlow{}); err != nil { + panic(err) + } + if err := db.AutoMigrate(&subscriptions.SubData{}); err != nil { + panic(err) + } + database.SetDatabase(db) + + subs, err := subscriptions.LoadSubscriptions() + if err != nil { + subs = nil + fmt.Printf("WARN : Could not load subscriptions: %s\n", err.Error()) + } + + go subs.CheckSubscriptions(ctx) + + bot, err := tgBot.New(botToken, tgBot.WithDefaultHandler(handlerBuilder(subs))) + if err != nil { + panic(err) + } + + log.Print("INFO : Starting...") + bot.Start(ctx) +} + +func handlerBuilder(subs *subscriptions.Subscriptions) func(context.Context, *tgBot.Bot, *models.Update) { + return func(ctx context.Context, b *tgBot.Bot, update *models.Update) { + handler(ctx, b, update, subs) + } +} + +func handler(ctx context.Context, b *tgBot.Bot, update *models.Update, subs *subscriptions.Subscriptions) { + var response *handlers.HandlerResponse + var toEditId int + defer func() { + if response == nil { + return + } + if response.ProgressMessageToEditId != 0 { + toEditId = response.ProgressMessageToEditId + } + if response.Message != nil { + response.Message.ChatID = response.Injected.ChatId + if toEditId != 0 { + b.EditMessageText(ctx, &tgBot.EditMessageTextParams{ + ChatID: response.Message.ChatID, + MessageID: toEditId, + Text: response.Message.Text, + ParseMode: response.Message.ParseMode, + Entities: response.Message.Entities, + DisableWebPagePreview: response.Message.DisableWebPagePreview, + ReplyMarkup: response.Message.ReplyMarkup, + }) + } else { + b.SendMessage(ctx, response.Message) + } + } + if response.CallbackAnswer != nil { + b.AnswerCallbackQuery(ctx, response.CallbackAnswer) + } + for _, edit := range response.MessageEdits { + if (edit.ChatID == nil || edit.MessageID == 0) && edit.InlineMessageID == "" { + edit.ChatID = response.Injected.ChatId + edit.MessageID = response.Injected.MessageId + } + b.EditMessageText(ctx, edit) + } + for _, edit := range response.MessageMarkupEdits { + if (edit.ChatID == nil || edit.MessageID == 0) && edit.InlineMessageID == "" { + edit.ChatID = response.Injected.ChatId + edit.MessageID = response.Injected.MessageId + } + b.EditMessageReplyMarkup(ctx, edit) + } + }() + + if update.Message != nil { + defer func() { + if response == nil { + response = &handlers.HandlerResponse{} + } + response.Injected.ChatId = update.Message.Chat.ID + response.Injected.MessageId = update.Message.ID + }() + log.Printf("DEBUG: Got message: %s\n", update.Message.Text) + + chatFlow := handlers.GetChatFlow(update.Message.Chat.ID) + + switch { + case strings.HasPrefix(update.Message.Text, trainInfoCommand): + response = handleFindTrainStages(ctx, b, update, subs) + case strings.HasPrefix(update.Message.Text, cancelCommand): + handlers.SetChatFlow(chatFlow, handlers.InitialFlowType, handlers.InitialFlowType, "") + response = &handlers.HandlerResponse{ + Message: &tgBot.SendMessageParams{ + Text: cancelResponseMessage, + }, + } + default: + switch chatFlow.Type { + case handlers.InitialFlowType: + b.SendMessage(ctx, &tgBot.SendMessageParams{ + ChatID: update.Message.Chat.ID, + Text: initialMessage, + }) + case handlers.TrainInfoFlowType: + log.Printf("DEBUG: trainInfoFlowType with stage %s\n", chatFlow.Stage) + response = handleFindTrainStages(ctx, b, update, subs) + } + } + } + if update.CallbackQuery != nil { + defer func() { + if response == nil { + response = &handlers.HandlerResponse{ + CallbackAnswer: &tgBot.AnswerCallbackQueryParams{ + CallbackQueryID: update.CallbackQuery.ID, + }, + } + } + response.Injected.ChatId = update.CallbackQuery.Message.Chat.ID + response.Injected.MessageId = update.CallbackQuery.Message.ID + if response.CallbackAnswer == nil { + response.CallbackAnswer = &tgBot.AnswerCallbackQueryParams{ + CallbackQueryID: update.CallbackQuery.ID, + } + } + if response.CallbackAnswer.CallbackQueryID == "" { + response.CallbackAnswer.CallbackQueryID = update.CallbackQuery.ID + } + }() + + chatFlow := handlers.GetChatFlow(update.CallbackQuery.Message.Chat.ID) + + if len(update.CallbackQuery.Data) != 0 { + splitted := strings.Split(update.CallbackQuery.Data, "\x1b") + switch splitted[0] { + case handlers.TrainInfoChooseDateCallbackQuery: + trainNumber := splitted[1] + dateInt, _ := strconv.ParseInt(splitted[2], 10, 64) + date := time.Unix(dateInt, 0) + message, err := b.SendMessage(ctx, &tgBot.SendMessageParams{ + ChatID: update.CallbackQuery.Message.Chat.ID, + Text: pleaseWaitMessage, + }) + response = handlers.HandleTrainNumberCommand(ctx, trainNumber, date, -1) + if err == nil { + response.ProgressMessageToEditId = message.ID + } + handlers.SetChatFlow(chatFlow, handlers.InitialFlowType, handlers.InitialFlowType, "") + + case handlers.TrainInfoChooseGroupCallbackQuery: + dateInt, _ := strconv.ParseInt(splitted[2], 10, 64) + date := time.Unix(dateInt, 0) + groupIndex, _ := strconv.ParseInt(splitted[3], 10, 31) + log.Printf("%s, %v, %d", update.CallbackQuery.Data, splitted, groupIndex) + originalResponse := handlers.HandleTrainNumberCommand(ctx, splitted[1], date, int(groupIndex)) + response = &handlers.HandlerResponse{ + MessageEdits: []*tgBot.EditMessageTextParams{ + { + Text: originalResponse.Message.Text, + ParseMode: originalResponse.Message.ParseMode, + Entities: originalResponse.Message.Entities, + DisableWebPagePreview: originalResponse.Message.DisableWebPagePreview, + ReplyMarkup: originalResponse.Message.ReplyMarkup, + }, + }, + } + } + } + } +} + +func handleFindTrainStages(ctx context.Context, b *tgBot.Bot, update *models.Update, subs *subscriptions.Subscriptions) *handlers.HandlerResponse { + log.Println("DEBUG: handleFindTrainStages") + var response *handlers.HandlerResponse + + var chatId int64 + if update.Message != nil { + chatId = update.Message.Chat.ID + } + if update.CallbackQuery != nil { + chatId = update.CallbackQuery.Message.Chat.ID + } + chatFlow := handlers.GetChatFlow(chatId) + switch chatFlow.Type { + case handlers.InitialFlowType: + // Only command is possible here + commandParamsString := strings.TrimPrefix(update.Message.Text, trainInfoCommand) + commandParamsString = strings.TrimSpace(commandParamsString) + commandParams := strings.Split(commandParamsString, " ") + if len(commandParams) > 1 { + message, err := b.SendMessage(ctx, &tgBot.SendMessageParams{ + ChatID: update.Message.Chat.ID, + Text: pleaseWaitMessage, + }) + trainNumber := commandParams[0] + date := time.Now() + groupIndex := -1 + + if len(commandParams) > 1 { + date, _ = time.Parse(time.RFC3339, commandParams[1]) + } + if len(commandParams) > 2 { + groupIndex, _ = strconv.Atoi(commandParams[2]) + } + + response = handlers.HandleTrainNumberCommand(ctx, trainNumber, date, groupIndex) + if err == nil { + response.ProgressMessageToEditId = message.ID + } + } else if len(commandParams) > 0 && len(commandParams[0]) != 0 { + // Got only train number + trainNumber := commandParams[0] + response = getTrainInfoChooseDateResponse(trainNumber) + handlers.SetChatFlow(chatFlow, handlers.TrainInfoFlowType, handlers.WaitingForDateStage, trainNumber) + } else { + response = &handlers.HandlerResponse{ + Message: &tgBot.SendMessageParams{ + Text: waitingForTrainNumberMessage, + }, + } + handlers.SetChatFlow(chatFlow, handlers.TrainInfoFlowType, handlers.WaitingForTrainNumberStage, "") + } + case handlers.TrainInfoFlowType: + switch chatFlow.Stage { + case handlers.WaitingForTrainNumberStage: + trainNumber := update.Message.Text + response = getTrainInfoChooseDateResponse(trainNumber) + handlers.SetChatFlow(chatFlow, handlers.TrainInfoFlowType, handlers.WaitingForDateStage, trainNumber) + case handlers.WaitingForDateStage: + date, err := utils.ParseDate(update.Message.Text) + if err != nil { + response = &handlers.HandlerResponse{ + Message: &tgBot.SendMessageParams{ + Text: invalidDateMessage, + }, + } + } else { + message, err := b.SendMessage(ctx, &tgBot.SendMessageParams{ + ChatID: update.Message.Chat.ID, + Text: pleaseWaitMessage, + }) + response = handlers.HandleTrainNumberCommand(ctx, chatFlow.Extra, date, -1) + if err == nil { + response.ProgressMessageToEditId = message.ID + } + handlers.SetChatFlow(chatFlow, handlers.InitialFlowType, handlers.InitialFlowType, "") + } + } + } + return response +} + +func getTrainInfoChooseDateResponse(trainNumber string) *handlers.HandlerResponse { + replyButtons := make([][]models.InlineKeyboardButton, 0, 4) + replyButtons = append(replyButtons, []models.InlineKeyboardButton{ + { + Text: fmt.Sprintf("Yesterday (%s)", time.Now().Add(time.Hour*-24).In(utils.Location).Format("02.01.2006")), + CallbackData: fmt.Sprintf(handlers.TrainInfoChooseDateCallbackQuery+"\x1b%s\x1b%d", trainNumber, time.Now().Add(time.Hour*-24).Unix()), + }, { + Text: fmt.Sprintf("Today (%s)", time.Now().In(utils.Location).Format("02.01.2006")), + CallbackData: fmt.Sprintf(handlers.TrainInfoChooseDateCallbackQuery+"\x1b%s\x1b%d", trainNumber, time.Now().Unix()), + }, + }) + for i := 1; i < 4; i++ { + arr := make([]models.InlineKeyboardButton, 0, 7) + for j := 0; j < 7; j++ { + ts := time.Now().Add(time.Hour * time.Duration(24*(j+(i-1)*7+1))).In(utils.Location) + arr = append(arr, models.InlineKeyboardButton{ + Text: ts.Format("02.01"), + CallbackData: fmt.Sprintf(handlers.TrainInfoChooseDateCallbackQuery+"\x1b%s\x1b%d", trainNumber, ts.Unix()), + }) + } + replyButtons = append(replyButtons, arr) + } + return &handlers.HandlerResponse{ + Message: &tgBot.SendMessageParams{ + Text: chooseDateMessage, + ReplyMarkup: models.InlineKeyboardMarkup{ + InlineKeyboard: replyButtons, + }, + }, + } +} diff --git a/pkg/api/trains.go b/pkg/api/trains.go new file mode 100644 index 0000000..cbd830c --- /dev/null +++ b/pkg/api/trains.go @@ -0,0 +1,111 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +type TrainResponse struct { + Rank string `json:"rank"` + Number string `json:"number"` + Date string `json:"date"` + Operator string `json:"operator"` + Groups []TrainGroup `json:"groups"` +} + +type TrainGroup struct { + Route struct { + From string `json:"from"` + To string `json:"to"` + } `json:"route"` + Status *struct { + Delay int `json:"delay"` + Station string `json:"station"` + State string `json:"state"` + } `json:"status"` + Stations []TrainStation `json:"stations"` +} + +type TrainStation struct { + Name string `json:"name"` + LinkName string `json:"linkName"` + Km int `json:"km"` + StoppingTime *int `json:"stoppingTime"` + Platform *string `json:"platform"` + Arrival *TrainArrDep `json:"arrival"` + Departure *TrainArrDep `json:"departure"` + Notes []any `json:"notes"` +} + +type TrainArrDep struct { + ScheduleTime time.Time `json:"scheduleTime"` + Status *struct { + Delay int `json:"delay"` + Real bool `json:"real"` + Cancelled bool `json:"cancelled"` + } `json:"status"` +} + +const ( + trainApiEndpoint = "https://scraper.infotren.dcdev.ro/v3" +) + +var ( + TrainNotFound = fmt.Errorf("train not found") + ServerError = fmt.Errorf("server error") +) + +func GetTrain(ctx context.Context, trainNumber string, date time.Time) (*TrainResponse, error) { + u, _ := url.Parse(trainApiEndpoint) + u.Path, _ = url.JoinPath(u.Path, "trains", trainNumber) + query := u.Query() + query.Add("date", date.Format(time.RFC3339)) + u.RawQuery = query.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, fmt.Errorf("error getting train %s: %w", trainNumber, err) + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error getting train %s: %w", trainNumber, err) + } + defer func() { + _ = res.Body.Close() + }() + + switch { + case res.StatusCode == http.StatusNotFound: + return nil, fmt.Errorf("error getting train %s: %w", trainNumber, TrainNotFound) + case res.StatusCode/100 != 2: + return nil, fmt.Errorf("error getting train %s: status code %d: %w", trainNumber, res.StatusCode, ServerError) + } + + var body []byte + if res.ContentLength > 0 { + body = make([]byte, res.ContentLength) + n, err := io.ReadFull(res.Body, body) + if err != nil && err != io.EOF { + return nil, fmt.Errorf("error getting train %s: %w", trainNumber, err) + } else if n != int(res.ContentLength) { + body = body[0:n] + } + } else { + body, err = io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("error getting train %s: %w", trainNumber, err) + } + } + + var trainData TrainResponse + if err := json.Unmarshal(body, &trainData); err != nil { + return nil, fmt.Errorf("error getting train %s: %w", trainNumber, err) + } + + return &trainData, nil +} diff --git a/pkg/database/database.go b/pkg/database/database.go new file mode 100644 index 0000000..d0aa170 --- /dev/null +++ b/pkg/database/database.go @@ -0,0 +1,28 @@ +package database + +import ( + "sync" + + "gorm.io/gorm" +) + +var ( + db *gorm.DB + mutex sync.RWMutex +) + +func SetDatabase(d *gorm.DB) { + db = d +} + +func ReadDB[T any](callback func(*gorm.DB) (T, error)) (T, error) { + mutex.RLock() + defer mutex.RUnlock() + return callback(db) +} + +func WriteDB[T any](callback func(*gorm.DB) (T, error)) (T, error) { + mutex.Lock() + defer mutex.Unlock() + return callback(db) +} diff --git a/pkg/handlers/chatFlow.go b/pkg/handlers/chatFlow.go new file mode 100644 index 0000000..9649aa6 --- /dev/null +++ b/pkg/handlers/chatFlow.go @@ -0,0 +1,57 @@ +package handlers + +import ( + "log" + + "dcdev.ro/CfrTrainInfoTelegramBot/pkg/database" + "gorm.io/gorm" +) + +const ( + InitialFlowType = "initial" + TrainInfoFlowType = "trainInfo" + StationInfoFlowType = "stationInfo" + RouteFlowType = "route" + + WaitingForTrainNumberStage = "waitingForTrainNumber" + WaitingForDateStage = "waitingForDate" +) + +type ChatFlow struct { + gorm.Model + ChatId int64 + Type string + Stage string + Extra string +} + +func GetChatFlow(chatId int64) *ChatFlow { + chatFlow := &ChatFlow{} + result, _ := database.ReadDB(func(db *gorm.DB) (*gorm.DB, error) { + return db.First(chatFlow, "chat_id = ?", chatId), nil + }) + if result.RowsAffected == 0 { + log.Printf("DEBUG: Chat not found in DB: %d\n", chatId) + chatFlow = &ChatFlow{ + ChatId: chatId, + Type: InitialFlowType, + } + _, _ = database.WriteDB(func(db *gorm.DB) (*gorm.DB, error) { + return db.Create(chatFlow), nil + }) + } else { + log.Printf("DEBUG: Chat found in DB: %d, type %s, stage %s\n", chatId, chatFlow.Type, chatFlow.Stage) + } + return chatFlow +} + +func SetChatFlow(chatFlow *ChatFlow, flowType string, stage string, extra string) { + _, _ = database.WriteDB(func(db *gorm.DB) (*gorm.DB, error) { + return db.Model(chatFlow).Updates(ChatFlow{ + Type: flowType, + Stage: stage, + Extra: extra, + }), nil + }) + log.Printf("DEBUG: setChatFlow type %s, stage %s", flowType, stage) +} diff --git a/pkg/handlers/findTrain.go b/pkg/handlers/findTrain.go new file mode 100644 index 0000000..519e39b --- /dev/null +++ b/pkg/handlers/findTrain.go @@ -0,0 +1,164 @@ +package handlers + +import ( + "context" + "errors" + "fmt" + "log" + "net/url" + "strconv" + "strings" + "time" + + "dcdev.ro/CfrTrainInfoTelegramBot/pkg/api" + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" +) + +const ( + TrainInfoChooseDateCallbackQuery = "TI_CHOOSE_DATE" + TrainInfoChooseGroupCallbackQuery = "TI_CHOOSE_GROUP" + + viewInKaiBaseUrl = "https://kai.infotren.dcdev.ro/view-train.html" +) + +func HandleTrainNumberCommand(ctx context.Context, trainNumber string, date time.Time, groupIndex int) *HandlerResponse { + trainData, err := api.GetTrain(ctx, trainNumber, date) + + switch { + case err == nil: + break + case errors.Is(err, api.TrainNotFound): + log.Printf("ERROR: In handle train number: %s", err.Error()) + return &HandlerResponse{ + Message: &bot.SendMessageParams{ + Text: fmt.Sprintf("The train %s was not found.", trainNumber), + }, + } + case errors.Is(err, api.ServerError): + log.Printf("ERROR: In handle train number: %s", err.Error()) + return &HandlerResponse{ + Message: &bot.SendMessageParams{ + Text: fmt.Sprintf("Unknown server error when searching for train %s.", trainNumber), + }, + } + default: + log.Printf("ERROR: In handle train number: %s", err.Error()) + return nil + } + + if len(trainData.Groups) == 1 { + groupIndex = 0 + } + + kaiUrl, _ := url.Parse(viewInKaiBaseUrl) + kaiUrlQuery := kaiUrl.Query() + kaiUrlQuery.Add("train", trainData.Number) + kaiUrlQuery.Add("date", trainData.Groups[0].Stations[0].Departure.ScheduleTime.Format(time.RFC3339)) + if groupIndex != -1 { + kaiUrlQuery.Add("groupIndex", strconv.Itoa(groupIndex)) + } + kaiUrl.RawQuery = kaiUrlQuery.Encode() + + message := bot.SendMessageParams{} + if groupIndex == -1 { + message.Text = fmt.Sprintf("Train %s %s contains multiple groups. Please choose one.", trainData.Rank, trainData.Number) + replyButtons := make([][]models.InlineKeyboardButton, len(trainData.Groups)+1) + for i := range replyButtons { + if i == len(trainData.Groups) { + replyButtons[i] = append(replyButtons[i], models.InlineKeyboardButton{ + Text: "Open in WebApp", + URL: kaiUrl.String(), + }) + } else { + group := &trainData.Groups[i] + replyButtons[i] = append(replyButtons[i], models.InlineKeyboardButton{ + Text: fmt.Sprintf("%s ➔ %s", group.Route.From, group.Route.To), + CallbackData: fmt.Sprintf(TrainInfoChooseGroupCallbackQuery+"\x1b%s\x1b%d\x1b%d", trainNumber, date.Unix(), i), + }) + } + } + message.ReplyMarkup = models.InlineKeyboardMarkup{ + InlineKeyboard: replyButtons, + } + } else if len(trainData.Groups) > groupIndex { + group := &trainData.Groups[groupIndex] + + messageText := strings.Builder{} + messageText.WriteString(fmt.Sprintf("Train %s %s\n%s ➔ %s\n\n", trainData.Rank, trainData.Number, group.Route.From, group.Route.To)) + + messageText.WriteString(fmt.Sprintf("Date: %s\n", trainData.Date)) + messageText.WriteString(fmt.Sprintf("Operator: %s\n", trainData.Operator)) + if group.Status != nil { + messageText.WriteString("Status: ") + if group.Status.Delay == 0 { + messageText.WriteString("on time when ") + } else { + messageText.WriteString(fmt.Sprintf("%d min ", func(x int) int { + if x < 0 { + return -x + } else { + return x + } + }(group.Status.Delay))) + if group.Status.Delay < 0 { + messageText.WriteString("early when ") + } else { + messageText.WriteString("late when ") + } + } + switch group.Status.State { + case "arrival": + messageText.WriteString("arriving at ") + case "departure": + messageText.WriteString("departing from ") + case "passing": + messageText.WriteString("passing through ") + } + messageText.WriteString(group.Status.Station) + messageText.WriteString("\n") + } + + message.Text = messageText.String() + message.Entities = []models.MessageEntity{ + { + Type: models.MessageEntityTypeBold, + Offset: 6, + Length: len(fmt.Sprintf("%s %s", trainData.Rank, trainData.Number)), + }, + } + message.ReplyMarkup = models.InlineKeyboardMarkup{ + InlineKeyboard: [][]models.InlineKeyboardButton{ + { + models.InlineKeyboardButton{ + Text: "Open in WebApp", + URL: kaiUrl.String(), + }, + }, + }, + } + } else { + message.Text = fmt.Sprintf("The status of the train %s %s is unknown.", trainData.Rank, trainData.Number) + message.Entities = []models.MessageEntity{ + { + Type: models.MessageEntityTypeBold, + Offset: 24, + Length: len(fmt.Sprintf("%s %s", trainData.Rank, trainData.Number)), + }, + } + message.ReplyMarkup = models.InlineKeyboardMarkup{ + InlineKeyboard: [][]models.InlineKeyboardButton{ + { + models.InlineKeyboardButton{ + Text: "Open in WebApp", + URL: kaiUrl.String(), + }, + }, + }, + } + } + + return &HandlerResponse{ + Message: &message, + } +} diff --git a/pkg/handlers/response.go b/pkg/handlers/response.go new file mode 100644 index 0000000..4554ae9 --- /dev/null +++ b/pkg/handlers/response.go @@ -0,0 +1,15 @@ +package handlers + +import "github.com/go-telegram/bot" + +type HandlerResponse struct { + Message *bot.SendMessageParams + ProgressMessageToEditId int + CallbackAnswer *bot.AnswerCallbackQueryParams + MessageEdits []*bot.EditMessageTextParams + MessageMarkupEdits []*bot.EditMessageReplyMarkupParams + Injected struct { + ChatId int64 + MessageId int + } +} diff --git a/pkg/subscriptions/subscriptions.go b/pkg/subscriptions/subscriptions.go new file mode 100644 index 0000000..8e5635d --- /dev/null +++ b/pkg/subscriptions/subscriptions.go @@ -0,0 +1,143 @@ +package subscriptions + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "dcdev.ro/CfrTrainInfoTelegramBot/pkg/database" + "gorm.io/gorm" +) + +type SubData struct { + gorm.Model + ChatId int64 + MessageId int + TrainNumber string + Date time.Time +} + +type Subscriptions struct { + mutex sync.RWMutex + data map[int64][]SubData +} + +func LoadSubscriptions() (*Subscriptions, error) { + subs := make([]SubData, 0) + _, err := database.ReadDB(func(db *gorm.DB) (*gorm.DB, error) { + result := db.Find(&subs) + return result, result.Error + }) + result := map[int64][]SubData{} + for _, sub := range subs { + result[sub.ChatId] = append(result[sub.ChatId], sub) + } + return &Subscriptions{ + mutex: sync.RWMutex{}, + data: result, + }, err +} + +func (sub *Subscriptions) Replace(chatId int64, data []SubData) error { + // Only allow replacing if all records use same chatId + for _, d := range data { + if d.ChatId != chatId { + return fmt.Errorf("data contains item whose ChatId (%d) doesn't match chatId (%d)", d.ChatId, chatId) + } + } + sub.mutex.Lock() + defer sub.mutex.Unlock() + sub.data[chatId] = data + _, err := database.WriteDB(func(db *gorm.DB) (*gorm.DB, error) { + db.Delete(&SubData{}, "chat_id = ?", chatId) + db.Create(&data) + return db, db.Error + }) + return err +} + +func (sub *Subscriptions) InsertSubscription(chatId int64, data SubData) error { + sub.mutex.Lock() + defer sub.mutex.Unlock() + datas := sub.data[chatId] + datas = append(datas, data) + sub.data[chatId] = datas + _, err := database.WriteDB(func(db *gorm.DB) (*gorm.DB, error) { + db.Create(&data) + return db, db.Error + }) + return err +} + +func (sub *Subscriptions) DeleteChat(chatId int64) error { + sub.mutex.Lock() + defer sub.mutex.Unlock() + delete(sub.data, chatId) + _, err := database.WriteDB(func(db *gorm.DB) (*gorm.DB, error) { + db.Delete(&SubData{}, "chat_id = ?", chatId) + return db, db.Error + }) + return err +} + +func (sub *Subscriptions) DeleteSubscription(chatId int64, messageId int) (*SubData, error) { + sub.mutex.Lock() + defer sub.mutex.Unlock() + datas := sub.data[chatId] + deleteIndex := -1 + for i := range datas { + if datas[i].MessageId == messageId { + deleteIndex = i + break + } + } + var result *SubData + if deleteIndex != -1 { + result = &SubData{} + *result = datas[deleteIndex] + datas[deleteIndex] = datas[len(datas)-1] + datas = datas[:len(datas)-1] + + _, err := database.WriteDB(func(db *gorm.DB) (*gorm.DB, error) { + db.Delete(result) + return db, db.Error + }) + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("subscription chatId %d messageId %d not found", chatId, messageId) + } + if len(datas) == 0 { + delete(sub.data, chatId) + } else { + sub.data[chatId] = datas + } + return result, nil +} + +func (sub *Subscriptions) CheckSubscriptions(ctx context.Context) { + ticker := time.NewTicker(time.Second * 90) + + for { + select { + case <-ticker.C: + func() { + sub.mutex.RLock() + defer sub.mutex.RUnlock() + + for chatId, datas := range sub.data { + // TODO: Check for updates + for i := range datas { + data := &datas[i] + log.Printf("DEBUG: Timer tick, update for chat %d, train %s", chatId, data.TrainNumber) + } + } + }() + case <-ctx.Done(): + return + } + } +} diff --git a/pkg/utils/location.go b/pkg/utils/location.go new file mode 100644 index 0000000..da4759c --- /dev/null +++ b/pkg/utils/location.go @@ -0,0 +1,7 @@ +package utils + +import "time" + +var ( + Location, _ = time.LoadLocation("Europe/Bucharest") +) diff --git a/pkg/utils/parseDate.go b/pkg/utils/parseDate.go new file mode 100644 index 0000000..9766bb8 --- /dev/null +++ b/pkg/utils/parseDate.go @@ -0,0 +1,56 @@ +package utils + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +var ( + InvalidDateFormat = fmt.Errorf("invalid date format") +) + +func ParseDate(input string) (time.Time, error) { + if strings.Contains(input, "-") { + return parse3Part(input, "-", 0, 1, 2) + } else if strings.Contains(input, "/") { + return parse3Part(input, "/", 2, 0, 1) + } else if strings.Contains(input, ".") { + return parse3Part(input, ".", 2, 1, 0) + } else { + parsed, err := strconv.ParseInt(input, 10, 63) + if err != nil { + return time.Time{}, err + } + return time.Unix(parsed, 0), nil + } +} + +func parse3Part(input string, sep string, yearIndex int, monthIndex int, dayIndex int) (time.Time, error) { + splitted := strings.Split(input, sep) + if len(splitted) == 2 && yearIndex == 2 { + // If the year is the last part of the format, allow omitting it + splitted = append(splitted, fmt.Sprintf("%d", time.Now().Year())) + } + if len(splitted) != 3 { + return time.Time{}, InvalidDateFormat + } + year, err := strconv.Atoi(splitted[yearIndex]) + if err != nil { + return time.Time{}, InvalidDateFormat + } + if year < 100 { + // Assume xx.xx.23 or x/x/23 => 2023 + year = (time.Now().Year() / 100 * 100) + year + } + month, err := strconv.Atoi(splitted[monthIndex]) + if err != nil { + return time.Time{}, InvalidDateFormat + } + day, err := strconv.Atoi(splitted[dayIndex]) + if err != nil { + return time.Time{}, InvalidDateFormat + } + return time.Date(year, time.Month(month), day, 12, 0, 0, 0, Location), nil +}