Add search page
[bloat] / service / service.go
1 package service
2
3 import (
4         "bytes"
5         "context"
6         "encoding/json"
7         "errors"
8         "fmt"
9         "io"
10         "mime/multipart"
11         "net/http"
12         "net/url"
13         "strings"
14
15         "mastodon"
16         "web/model"
17         "web/renderer"
18         "web/util"
19 )
20
21 var (
22         ErrInvalidArgument = errors.New("invalid argument")
23         ErrInvalidToken    = errors.New("invalid token")
24         ErrInvalidClient   = errors.New("invalid client")
25         ErrInvalidTimeline = errors.New("invalid timeline")
26 )
27
28 type Service interface {
29         ServeHomePage(ctx context.Context, client io.Writer) (err error)
30         GetAuthUrl(ctx context.Context, instance string) (url string, sessionID string, err error)
31         GetUserToken(ctx context.Context, sessionID string, c *model.Client, token string) (accessToken string, err error)
32         ServeErrorPage(ctx context.Context, client io.Writer, err error)
33         ServeSigninPage(ctx context.Context, client io.Writer) (err error)
34         ServeTimelinePage(ctx context.Context, client io.Writer, c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error)
35         ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error)
36         ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error)
37         ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error)
38         ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error)
39         ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error)
40         ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
41         ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
42         ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error)
43         Like(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
44         UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
45         Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
46         UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
47         PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error)
48         Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
49         UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
50 }
51
52 type service struct {
53         clientName    string
54         clientScope   string
55         clientWebsite string
56         customCSS     string
57         postFormats   []model.PostFormat
58         renderer      renderer.Renderer
59         sessionRepo   model.SessionRepository
60         appRepo       model.AppRepository
61 }
62
63 func NewService(clientName string, clientScope string, clientWebsite string,
64         customCSS string, postFormats []model.PostFormat, renderer renderer.Renderer,
65         sessionRepo model.SessionRepository, appRepo model.AppRepository) Service {
66         return &service{
67                 clientName:    clientName,
68                 clientScope:   clientScope,
69                 clientWebsite: clientWebsite,
70                 customCSS:     customCSS,
71                 postFormats:   postFormats,
72                 renderer:      renderer,
73                 sessionRepo:   sessionRepo,
74                 appRepo:       appRepo,
75         }
76 }
77
78 func (svc *service) GetAuthUrl(ctx context.Context, instance string) (
79         redirectUrl string, sessionID string, err error) {
80         var instanceURL string
81         if strings.HasPrefix(instance, "https://") {
82                 instanceURL = instance
83                 instance = strings.TrimPrefix(instance, "https://")
84         } else {
85                 instanceURL = "https://" + instance
86         }
87
88         sessionID = util.NewSessionId()
89         err = svc.sessionRepo.Add(model.Session{
90                 ID:             sessionID,
91                 InstanceDomain: instance,
92         })
93         if err != nil {
94                 return
95         }
96
97         app, err := svc.appRepo.Get(instance)
98         if err != nil {
99                 if err != model.ErrAppNotFound {
100                         return
101                 }
102
103                 var mastoApp *mastodon.Application
104                 mastoApp, err = mastodon.RegisterApp(ctx, &mastodon.AppConfig{
105                         Server:       instanceURL,
106                         ClientName:   svc.clientName,
107                         Scopes:       svc.clientScope,
108                         Website:      svc.clientWebsite,
109                         RedirectURIs: svc.clientWebsite + "/oauth_callback",
110                 })
111                 if err != nil {
112                         return
113                 }
114
115                 app = model.App{
116                         InstanceDomain: instance,
117                         InstanceURL:    instanceURL,
118                         ClientID:       mastoApp.ClientID,
119                         ClientSecret:   mastoApp.ClientSecret,
120                 }
121
122                 err = svc.appRepo.Add(app)
123                 if err != nil {
124                         return
125                 }
126         }
127
128         u, err := url.Parse("/oauth/authorize")
129         if err != nil {
130                 return
131         }
132
133         q := make(url.Values)
134         q.Set("scope", "read write follow")
135         q.Set("client_id", app.ClientID)
136         q.Set("response_type", "code")
137         q.Set("redirect_uri", svc.clientWebsite+"/oauth_callback")
138         u.RawQuery = q.Encode()
139
140         redirectUrl = instanceURL + u.String()
141
142         return
143 }
144
145 func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *model.Client,
146         code string) (token string, err error) {
147         if len(code) < 1 {
148                 err = ErrInvalidArgument
149                 return
150         }
151
152         session, err := svc.sessionRepo.Get(sessionID)
153         if err != nil {
154                 return
155         }
156
157         app, err := svc.appRepo.Get(session.InstanceDomain)
158         if err != nil {
159                 return
160         }
161
162         data := &bytes.Buffer{}
163         err = json.NewEncoder(data).Encode(map[string]string{
164                 "client_id":     app.ClientID,
165                 "client_secret": app.ClientSecret,
166                 "grant_type":    "authorization_code",
167                 "code":          code,
168                 "redirect_uri":  svc.clientWebsite + "/oauth_callback",
169         })
170         if err != nil {
171                 return
172         }
173
174         resp, err := http.Post(app.InstanceURL+"/oauth/token", "application/json", data)
175         if err != nil {
176                 return
177         }
178         defer resp.Body.Close()
179
180         var res struct {
181                 AccessToken string `json:"access_token"`
182         }
183
184         err = json.NewDecoder(resp.Body).Decode(&res)
185         if err != nil {
186                 return
187         }
188         /*
189                 err = c.AuthenticateToken(ctx, code, svc.clientWebsite+"/oauth_callback")
190                 if err != nil {
191                         return
192                 }
193                 err = svc.sessionRepo.Update(sessionID, c.GetAccessToken(ctx))
194         */
195
196         return res.AccessToken, nil
197 }
198
199 func (svc *service) ServeHomePage(ctx context.Context, client io.Writer) (err error) {
200         commonData, err := svc.getCommonData(ctx, client, nil)
201         if err != nil {
202                 return
203         }
204
205         data := &renderer.HomePageData{
206                 CommonData: commonData,
207         }
208
209         return svc.renderer.RenderHomePage(ctx, client, data)
210 }
211
212 func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, err error) {
213         var errStr string
214         if err != nil {
215                 errStr = err.Error()
216         }
217
218         commonData, err := svc.getCommonData(ctx, client, nil)
219         if err != nil {
220                 return
221         }
222
223         data := &renderer.ErrorData{
224                 CommonData: commonData,
225                 Error:      errStr,
226         }
227
228         svc.renderer.RenderErrorPage(ctx, client, data)
229 }
230
231 func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err error) {
232         commonData, err := svc.getCommonData(ctx, client, nil)
233         if err != nil {
234                 return
235         }
236
237         data := &renderer.SigninData{
238                 CommonData: commonData,
239         }
240
241         return svc.renderer.RenderSigninPage(ctx, client, data)
242 }
243
244 func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer,
245         c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) {
246
247         var hasNext, hasPrev bool
248         var nextLink, prevLink string
249
250         var pg = mastodon.Pagination{
251                 MaxID: maxID,
252                 MinID: minID,
253                 Limit: 20,
254         }
255
256         var statuses []*mastodon.Status
257         var title string
258         switch timelineType {
259         default:
260                 return ErrInvalidTimeline
261         case "home":
262                 statuses, err = c.GetTimelineHome(ctx, &pg)
263                 title = "Timeline"
264         case "local":
265                 statuses, err = c.GetTimelinePublic(ctx, true, &pg)
266                 title = "Local Timeline"
267         case "twkn":
268                 statuses, err = c.GetTimelinePublic(ctx, false, &pg)
269                 title = "The Whole Known Network"
270         }
271         if err != nil {
272                 return err
273         }
274
275         if len(maxID) > 0 && len(statuses) > 0 {
276                 hasPrev = true
277                 prevLink = fmt.Sprintf("/timeline/$s?min_id=%s", timelineType, statuses[0].ID)
278         }
279         if len(minID) > 0 && len(pg.MinID) > 0 {
280                 newStatuses, err := c.GetTimelineHome(ctx, &mastodon.Pagination{MinID: pg.MinID, Limit: 20})
281                 if err != nil {
282                         return err
283                 }
284                 newStatusesLen := len(newStatuses)
285                 if newStatusesLen == 20 {
286                         hasPrev = true
287                         prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", timelineType, pg.MinID)
288                 } else {
289                         i := 20 - newStatusesLen - 1
290                         if len(statuses) > i {
291                                 hasPrev = true
292                                 prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", timelineType, statuses[i].ID)
293                         }
294                 }
295         }
296         if len(pg.MaxID) > 0 {
297                 hasNext = true
298                 nextLink = fmt.Sprintf("/timeline/%s?max_id=%s", timelineType, pg.MaxID)
299         }
300
301         postContext := model.PostContext{
302                 DefaultVisibility: c.Session.Settings.DefaultVisibility,
303                 Formats:           svc.postFormats,
304         }
305
306         commonData, err := svc.getCommonData(ctx, client, c)
307         if err != nil {
308                 return
309         }
310
311         data := &renderer.TimelineData{
312                 Title:       title,
313                 Statuses:    statuses,
314                 HasNext:     hasNext,
315                 NextLink:    nextLink,
316                 HasPrev:     hasPrev,
317                 PrevLink:    prevLink,
318                 PostContext: postContext,
319                 CommonData:  commonData,
320         }
321
322         err = svc.renderer.RenderTimelinePage(ctx, client, data)
323         if err != nil {
324                 return
325         }
326
327         return
328 }
329
330 func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) {
331         status, err := c.GetStatus(ctx, id)
332         if err != nil {
333                 return
334         }
335
336         u, err := c.GetAccountCurrentUser(ctx)
337         if err != nil {
338                 return
339         }
340
341         var postContext model.PostContext
342         if reply {
343                 var content string
344                 if u.ID != status.Account.ID {
345                         content += "@" + status.Account.Acct + " "
346                 }
347                 for i := range status.Mentions {
348                         if status.Mentions[i].ID != u.ID && status.Mentions[i].ID != status.Account.ID {
349                                 content += "@" + status.Mentions[i].Acct + " "
350                         }
351                 }
352
353                 s, err := c.GetStatus(ctx, id)
354                 if err != nil {
355                         return err
356                 }
357
358                 postContext = model.PostContext{
359                         DefaultVisibility: s.Visibility,
360                         Formats:           svc.postFormats,
361                         ReplyContext: &model.ReplyContext{
362                                 InReplyToID:   id,
363                                 InReplyToName: status.Account.Acct,
364                                 ReplyContent:  content,
365                         },
366                 }
367         }
368
369         context, err := c.GetStatusContext(ctx, id)
370         if err != nil {
371                 return
372         }
373
374         statuses := append(append(context.Ancestors, status), context.Descendants...)
375
376         replyMap := make(map[string][]mastodon.ReplyInfo)
377
378         for i := range statuses {
379                 statuses[i].ShowReplies = true
380                 statuses[i].ReplyMap = replyMap
381                 addToReplyMap(replyMap, statuses[i].InReplyToID, statuses[i].ID, i+1)
382         }
383
384         commonData, err := svc.getCommonData(ctx, client, c)
385         if err != nil {
386                 return
387         }
388
389         data := &renderer.ThreadData{
390                 Statuses:    statuses,
391                 PostContext: postContext,
392                 ReplyMap:    replyMap,
393                 CommonData:  commonData,
394         }
395
396         err = svc.renderer.RenderThreadPage(ctx, client, data)
397         if err != nil {
398                 return
399         }
400
401         return
402 }
403
404 func (svc *service) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) {
405         var hasNext bool
406         var nextLink string
407
408         var pg = mastodon.Pagination{
409                 MaxID: maxID,
410                 MinID: minID,
411                 Limit: 20,
412         }
413
414         notifications, err := c.GetNotifications(ctx, &pg)
415         if err != nil {
416                 return
417         }
418
419         var unreadCount int
420         for i := range notifications {
421                 switch notifications[i].Type {
422                 case "reblog", "favourite":
423                         if notifications[i].Status != nil {
424                                 notifications[i].Status.HideAccountInfo = true
425                         }
426                 }
427                 if notifications[i].Pleroma != nil && notifications[i].Pleroma.IsSeen {
428                         unreadCount++
429                 }
430         }
431
432         if unreadCount > 0 {
433                 err := c.ReadNotifications(ctx, notifications[0].ID)
434                 if err != nil {
435                         return err
436                 }
437         }
438
439         if len(pg.MaxID) > 0 {
440                 hasNext = true
441                 nextLink = "/notifications?max_id=" + pg.MaxID
442         }
443
444         commonData, err := svc.getCommonData(ctx, client, c)
445         if err != nil {
446                 return
447         }
448
449         data := &renderer.NotificationData{
450                 Notifications: notifications,
451                 HasNext:       hasNext,
452                 NextLink:      nextLink,
453                 CommonData:    commonData,
454         }
455         err = svc.renderer.RenderNotificationPage(ctx, client, data)
456         if err != nil {
457                 return
458         }
459
460         return
461 }
462
463 func (svc *service) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) {
464         user, err := c.GetAccount(ctx, id)
465         if err != nil {
466                 return
467         }
468
469         var hasNext bool
470         var nextLink string
471
472         var pg = mastodon.Pagination{
473                 MaxID: maxID,
474                 MinID: minID,
475                 Limit: 20,
476         }
477
478         statuses, err := c.GetAccountStatuses(ctx, id, &pg)
479         if err != nil {
480                 return
481         }
482
483         if len(pg.MaxID) > 0 {
484                 hasNext = true
485                 nextLink = "/user/" + id + "?max_id=" + pg.MaxID
486         }
487
488         commonData, err := svc.getCommonData(ctx, client, c)
489         if err != nil {
490                 return
491         }
492
493         data := &renderer.UserData{
494                 User:       user,
495                 Statuses:   statuses,
496                 HasNext:    hasNext,
497                 NextLink:   nextLink,
498                 CommonData: commonData,
499         }
500
501         err = svc.renderer.RenderUserPage(ctx, client, data)
502         if err != nil {
503                 return
504         }
505
506         return
507 }
508
509 func (svc *service) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) {
510         commonData, err := svc.getCommonData(ctx, client, c)
511         if err != nil {
512                 return
513         }
514
515         data := &renderer.AboutData{
516                 CommonData: commonData,
517         }
518         err = svc.renderer.RenderAboutPage(ctx, client, data)
519         if err != nil {
520                 return
521         }
522
523         return
524 }
525
526 func (svc *service) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) {
527         commonData, err := svc.getCommonData(ctx, client, c)
528         if err != nil {
529                 return
530         }
531
532         emojis, err := c.GetInstanceEmojis(ctx)
533         if err != nil {
534                 return
535         }
536
537         data := &renderer.EmojiData{
538                 Emojis:     emojis,
539                 CommonData: commonData,
540         }
541
542         err = svc.renderer.RenderEmojiPage(ctx, client, data)
543         if err != nil {
544                 return
545         }
546
547         return
548 }
549
550 func (svc *service) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
551         likers, err := c.GetFavouritedBy(ctx, id, nil)
552         if err != nil {
553                 return
554         }
555
556         commonData, err := svc.getCommonData(ctx, client, c)
557         if err != nil {
558                 return
559         }
560
561         data := &renderer.LikedByData{
562                 CommonData: commonData,
563                 Users:      likers,
564         }
565
566         err = svc.renderer.RenderLikedByPage(ctx, client, data)
567         if err != nil {
568                 return
569         }
570
571         return
572 }
573
574 func (svc *service) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
575         retweeters, err := c.GetRebloggedBy(ctx, id, nil)
576         if err != nil {
577                 return
578         }
579
580         commonData, err := svc.getCommonData(ctx, client, c)
581         if err != nil {
582                 return
583         }
584
585         data := &renderer.RetweetedByData{
586                 CommonData: commonData,
587                 Users:      retweeters,
588         }
589
590         err = svc.renderer.RenderRetweetedByPage(ctx, client, data)
591         if err != nil {
592                 return
593         }
594
595         return
596 }
597
598 func (svc *service) ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) {
599         var hasNext bool
600         var nextLink string
601
602         results, err := c.Search(ctx, q, qType, 20, true, offset)
603         if err != nil {
604                 return
605         }
606
607         switch qType {
608         case "accounts":
609                 hasNext = len(results.Accounts) == 20
610         case "statuses":
611                 hasNext = len(results.Statuses) == 20
612         }
613
614         if hasNext {
615                 offset += 20
616                 nextLink = fmt.Sprintf("/search?q=%s&type=%s&offset=%d", q, qType, offset)
617         }
618
619         commonData, err := svc.getCommonData(ctx, client, c)
620         if err != nil {
621                 return
622         }
623
624         data := &renderer.SearchData{
625                 CommonData: commonData,
626                 Q:          q,
627                 Type:       qType,
628                 Users:      results.Accounts,
629                 Statuses:   results.Statuses,
630                 HasNext:    hasNext,
631                 NextLink:   nextLink,
632         }
633
634         err = svc.renderer.RenderSearchPage(ctx, client, data)
635         if err != nil {
636                 return
637         }
638
639         return
640 }
641
642 func (svc *service) getCommonData(ctx context.Context, client io.Writer, c *model.Client) (data *renderer.CommonData, err error) {
643         data = new(renderer.CommonData)
644
645         data.HeaderData = &renderer.HeaderData{
646                 Title:             "Web",
647                 NotificationCount: 0,
648                 CustomCSS:         svc.customCSS,
649         }
650
651         if c != nil && c.Session.IsLoggedIn() {
652                 notifications, err := c.GetNotifications(ctx, nil)
653                 if err != nil {
654                         return nil, err
655                 }
656
657                 var notificationCount int
658                 for i := range notifications {
659                         if notifications[i].Pleroma != nil && !notifications[i].Pleroma.IsSeen {
660                                 notificationCount++
661                         }
662                 }
663
664                 u, err := c.GetAccountCurrentUser(ctx)
665                 if err != nil {
666                         return nil, err
667                 }
668
669                 data.NavbarData = &renderer.NavbarData{
670                         User:              u,
671                         NotificationCount: notificationCount,
672                 }
673
674                 data.HeaderData.NotificationCount = notificationCount
675         }
676
677         return
678 }
679
680 func (svc *service) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
681         _, err = c.Favourite(ctx, id)
682         return
683 }
684
685 func (svc *service) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
686         _, err = c.Unfavourite(ctx, id)
687         return
688 }
689
690 func (svc *service) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
691         _, err = c.Reblog(ctx, id)
692         return
693 }
694
695 func (svc *service) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
696         _, err = c.Unreblog(ctx, id)
697         return
698 }
699
700 func (svc *service) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) {
701         var mediaIds []string
702         for _, f := range files {
703                 a, err := c.UploadMediaFromMultipartFileHeader(ctx, f)
704                 if err != nil {
705                         return "", err
706                 }
707                 mediaIds = append(mediaIds, a.ID)
708         }
709
710         // save visibility if it's a non-reply post
711         if len(replyToID) < 1 && visibility != c.Session.Settings.DefaultVisibility {
712                 c.Session.Settings.DefaultVisibility = visibility
713                 svc.sessionRepo.Add(c.Session)
714         }
715
716         tweet := &mastodon.Toot{
717                 Status:      content,
718                 InReplyToID: replyToID,
719                 MediaIDs:    mediaIds,
720                 ContentType: format,
721                 Visibility:  visibility,
722                 Sensitive:   isNSFW,
723         }
724
725         s, err := c.PostStatus(ctx, tweet)
726         if err != nil {
727                 return
728         }
729
730         return s.ID, nil
731 }
732
733 func (svc *service) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
734         _, err = c.AccountFollow(ctx, id)
735         return
736 }
737
738 func (svc *service) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
739         _, err = c.AccountUnfollow(ctx, id)
740         return
741 }
742
743 func addToReplyMap(m map[string][]mastodon.ReplyInfo, key interface{}, val string, number int) {
744         if key == nil {
745                 return
746         }
747
748         keyStr, ok := key.(string)
749         if !ok {
750                 return
751         }
752         _, ok = m[keyStr]
753         if !ok {
754                 m[keyStr] = []mastodon.ReplyInfo{}
755         }
756
757         m[keyStr] = append(m[keyStr], mastodon.ReplyInfo{val, number})
758 }