Add post format selection
[bloat] / service / service.go
1 package service
2
3 import (
4         "bytes"
5         "context"
6         "encoding/json"
7         "errors"
8         "fmt"
9         "io"
10         "mime/multipart"
11         "net/http"
12         "net/url"
13         "strings"
14
15         "mastodon"
16         "web/model"
17         "web/renderer"
18         "web/util"
19 )
20
21 var (
22         ErrInvalidArgument = errors.New("invalid argument")
23         ErrInvalidToken    = errors.New("invalid token")
24         ErrInvalidClient   = errors.New("invalid client")
25         ErrInvalidTimeline = errors.New("invalid timeline")
26 )
27
28 type Service interface {
29         ServeHomePage(ctx context.Context, client io.Writer) (err error)
30         GetAuthUrl(ctx context.Context, instance string) (url string, sessionID string, err error)
31         GetUserToken(ctx context.Context, sessionID string, c *model.Client, token string) (accessToken string, err error)
32         ServeErrorPage(ctx context.Context, client io.Writer, err error)
33         ServeSigninPage(ctx context.Context, client io.Writer) (err error)
34         ServeTimelinePage(ctx context.Context, client io.Writer, c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error)
35         ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error)
36         ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error)
37         ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error)
38         ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error)
39         ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error)
40         ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
41         ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
42         Like(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
43         UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
44         Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
45         UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
46         PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error)
47         Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
48         UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error)
49 }
50
51 type service struct {
52         clientName    string
53         clientScope   string
54         clientWebsite string
55         customCSS     string
56         postFormats   []model.PostFormat
57         renderer      renderer.Renderer
58         sessionRepo   model.SessionRepository
59         appRepo       model.AppRepository
60 }
61
62 func NewService(clientName string, clientScope string, clientWebsite string,
63         customCSS string, postFormats []model.PostFormat, renderer renderer.Renderer,
64         sessionRepo model.SessionRepository, appRepo model.AppRepository) Service {
65         return &service{
66                 clientName:    clientName,
67                 clientScope:   clientScope,
68                 clientWebsite: clientWebsite,
69                 customCSS:     customCSS,
70                 postFormats:   postFormats,
71                 renderer:      renderer,
72                 sessionRepo:   sessionRepo,
73                 appRepo:       appRepo,
74         }
75 }
76
77 func (svc *service) GetAuthUrl(ctx context.Context, instance string) (
78         redirectUrl string, sessionID string, err error) {
79         var instanceURL string
80         if strings.HasPrefix(instance, "https://") {
81                 instanceURL = instance
82                 instance = strings.TrimPrefix(instance, "https://")
83         } else {
84                 instanceURL = "https://" + instance
85         }
86
87         sessionID = util.NewSessionId()
88         err = svc.sessionRepo.Add(model.Session{
89                 ID:             sessionID,
90                 InstanceDomain: instance,
91         })
92         if err != nil {
93                 return
94         }
95
96         app, err := svc.appRepo.Get(instance)
97         if err != nil {
98                 if err != model.ErrAppNotFound {
99                         return
100                 }
101
102                 var mastoApp *mastodon.Application
103                 mastoApp, err = mastodon.RegisterApp(ctx, &mastodon.AppConfig{
104                         Server:       instanceURL,
105                         ClientName:   svc.clientName,
106                         Scopes:       svc.clientScope,
107                         Website:      svc.clientWebsite,
108                         RedirectURIs: svc.clientWebsite + "/oauth_callback",
109                 })
110                 if err != nil {
111                         return
112                 }
113
114                 app = model.App{
115                         InstanceDomain: instance,
116                         InstanceURL:    instanceURL,
117                         ClientID:       mastoApp.ClientID,
118                         ClientSecret:   mastoApp.ClientSecret,
119                 }
120
121                 err = svc.appRepo.Add(app)
122                 if err != nil {
123                         return
124                 }
125         }
126
127         u, err := url.Parse("/oauth/authorize")
128         if err != nil {
129                 return
130         }
131
132         q := make(url.Values)
133         q.Set("scope", "read write follow")
134         q.Set("client_id", app.ClientID)
135         q.Set("response_type", "code")
136         q.Set("redirect_uri", svc.clientWebsite+"/oauth_callback")
137         u.RawQuery = q.Encode()
138
139         redirectUrl = instanceURL + u.String()
140
141         return
142 }
143
144 func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *model.Client,
145         code string) (token string, err error) {
146         if len(code) < 1 {
147                 err = ErrInvalidArgument
148                 return
149         }
150
151         session, err := svc.sessionRepo.Get(sessionID)
152         if err != nil {
153                 return
154         }
155
156         app, err := svc.appRepo.Get(session.InstanceDomain)
157         if err != nil {
158                 return
159         }
160
161         data := &bytes.Buffer{}
162         err = json.NewEncoder(data).Encode(map[string]string{
163                 "client_id":     app.ClientID,
164                 "client_secret": app.ClientSecret,
165                 "grant_type":    "authorization_code",
166                 "code":          code,
167                 "redirect_uri":  svc.clientWebsite + "/oauth_callback",
168         })
169         if err != nil {
170                 return
171         }
172
173         resp, err := http.Post(app.InstanceURL+"/oauth/token", "application/json", data)
174         if err != nil {
175                 return
176         }
177         defer resp.Body.Close()
178
179         var res struct {
180                 AccessToken string `json:"access_token"`
181         }
182
183         err = json.NewDecoder(resp.Body).Decode(&res)
184         if err != nil {
185                 return
186         }
187         /*
188                 err = c.AuthenticateToken(ctx, code, svc.clientWebsite+"/oauth_callback")
189                 if err != nil {
190                         return
191                 }
192                 err = svc.sessionRepo.Update(sessionID, c.GetAccessToken(ctx))
193         */
194
195         return res.AccessToken, nil
196 }
197
198 func (svc *service) ServeHomePage(ctx context.Context, client io.Writer) (err error) {
199         commonData, err := svc.getCommonData(ctx, client, nil)
200         if err != nil {
201                 return
202         }
203
204         data := &renderer.HomePageData{
205                 CommonData: commonData,
206         }
207
208         return svc.renderer.RenderHomePage(ctx, client, data)
209 }
210
211 func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, err error) {
212         var errStr string
213         if err != nil {
214                 errStr = err.Error()
215         }
216
217         commonData, err := svc.getCommonData(ctx, client, nil)
218         if err != nil {
219                 return
220         }
221
222         data := &renderer.ErrorData{
223                 CommonData: commonData,
224                 Error:      errStr,
225         }
226
227         svc.renderer.RenderErrorPage(ctx, client, data)
228 }
229
230 func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err error) {
231         commonData, err := svc.getCommonData(ctx, client, nil)
232         if err != nil {
233                 return
234         }
235
236         data := &renderer.SigninData{
237                 CommonData: commonData,
238         }
239
240         return svc.renderer.RenderSigninPage(ctx, client, data)
241 }
242
243 func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer,
244         c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) {
245
246         var hasNext, hasPrev bool
247         var nextLink, prevLink string
248
249         var pg = mastodon.Pagination{
250                 MaxID: maxID,
251                 MinID: minID,
252                 Limit: 20,
253         }
254
255         var statuses []*mastodon.Status
256         var title string
257         switch timelineType {
258         default:
259                 return ErrInvalidTimeline
260         case "home":
261                 statuses, err = c.GetTimelineHome(ctx, &pg)
262                 title = "Timeline"
263         case "local":
264                 statuses, err = c.GetTimelinePublic(ctx, true, &pg)
265                 title = "Local Timeline"
266         case "twkn":
267                 statuses, err = c.GetTimelinePublic(ctx, false, &pg)
268                 title = "The Whole Known Network"
269         }
270         if err != nil {
271                 return err
272         }
273
274         if len(maxID) > 0 && len(statuses) > 0 {
275                 hasPrev = true
276                 prevLink = fmt.Sprintf("/timeline/$s?min_id=%s", timelineType, statuses[0].ID)
277         }
278         if len(minID) > 0 && len(pg.MinID) > 0 {
279                 newStatuses, err := c.GetTimelineHome(ctx, &mastodon.Pagination{MinID: pg.MinID, Limit: 20})
280                 if err != nil {
281                         return err
282                 }
283                 newStatusesLen := len(newStatuses)
284                 if newStatusesLen == 20 {
285                         hasPrev = true
286                         prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", timelineType, pg.MinID)
287                 } else {
288                         i := 20 - newStatusesLen - 1
289                         if len(statuses) > i {
290                                 hasPrev = true
291                                 prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", timelineType, statuses[i].ID)
292                         }
293                 }
294         }
295         if len(pg.MaxID) > 0 {
296                 hasNext = true
297                 nextLink = fmt.Sprintf("/timeline/%s?max_id=%s", timelineType, pg.MaxID)
298         }
299
300         postContext := model.PostContext{
301                 DefaultVisibility: c.Session.Settings.DefaultVisibility,
302                 Formats:           svc.postFormats,
303         }
304
305         commonData, err := svc.getCommonData(ctx, client, c)
306         if err != nil {
307                 return
308         }
309
310         data := &renderer.TimelineData{
311                 Title:       title,
312                 Statuses:    statuses,
313                 HasNext:     hasNext,
314                 NextLink:    nextLink,
315                 HasPrev:     hasPrev,
316                 PrevLink:    prevLink,
317                 PostContext: postContext,
318                 CommonData:  commonData,
319         }
320
321         err = svc.renderer.RenderTimelinePage(ctx, client, data)
322         if err != nil {
323                 return
324         }
325
326         return
327 }
328
329 func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) {
330         status, err := c.GetStatus(ctx, id)
331         if err != nil {
332                 return
333         }
334
335         u, err := c.GetAccountCurrentUser(ctx)
336         if err != nil {
337                 return
338         }
339
340         var postContext model.PostContext
341         if reply {
342                 var content string
343                 if u.ID != status.Account.ID {
344                         content += "@" + status.Account.Acct + " "
345                 }
346                 for i := range status.Mentions {
347                         if status.Mentions[i].ID != u.ID && status.Mentions[i].ID != status.Account.ID {
348                                 content += "@" + status.Mentions[i].Acct + " "
349                         }
350                 }
351
352                 s, err := c.GetStatus(ctx, id)
353                 if err != nil {
354                         return err
355                 }
356
357                 postContext = model.PostContext{
358                         DefaultVisibility: s.Visibility,
359                         Formats:           svc.postFormats,
360                         ReplyContext: &model.ReplyContext{
361                                 InReplyToID:   id,
362                                 InReplyToName: status.Account.Acct,
363                                 ReplyContent:  content,
364                         },
365                 }
366         }
367
368         context, err := c.GetStatusContext(ctx, id)
369         if err != nil {
370                 return
371         }
372
373         statuses := append(append(context.Ancestors, status), context.Descendants...)
374
375         replyMap := make(map[string][]mastodon.ReplyInfo)
376
377         for i := range statuses {
378                 statuses[i].ShowReplies = true
379                 statuses[i].ReplyMap = replyMap
380                 addToReplyMap(replyMap, statuses[i].InReplyToID, statuses[i].ID, i+1)
381         }
382
383         commonData, err := svc.getCommonData(ctx, client, c)
384         if err != nil {
385                 return
386         }
387
388         data := &renderer.ThreadData{
389                 Statuses:    statuses,
390                 PostContext: postContext,
391                 ReplyMap:    replyMap,
392                 CommonData:  commonData,
393         }
394
395         err = svc.renderer.RenderThreadPage(ctx, client, data)
396         if err != nil {
397                 return
398         }
399
400         return
401 }
402
403 func (svc *service) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) {
404         var hasNext bool
405         var nextLink string
406
407         var pg = mastodon.Pagination{
408                 MaxID: maxID,
409                 MinID: minID,
410                 Limit: 20,
411         }
412
413         notifications, err := c.GetNotifications(ctx, &pg)
414         if err != nil {
415                 return
416         }
417
418         var unreadCount int
419         for i := range notifications {
420                 switch notifications[i].Type {
421                 case "reblog", "favourite":
422                         if notifications[i].Status != nil {
423                                 notifications[i].Status.HideAccountInfo = true
424                         }
425                 }
426                 if notifications[i].Pleroma != nil && notifications[i].Pleroma.IsSeen {
427                         unreadCount++
428                 }
429         }
430
431         if unreadCount > 0 {
432                 err := c.ReadNotifications(ctx, notifications[0].ID)
433                 if err != nil {
434                         return err
435                 }
436         }
437
438         if len(pg.MaxID) > 0 {
439                 hasNext = true
440                 nextLink = "/notifications?max_id=" + pg.MaxID
441         }
442
443         commonData, err := svc.getCommonData(ctx, client, c)
444         if err != nil {
445                 return
446         }
447
448         data := &renderer.NotificationData{
449                 Notifications: notifications,
450                 HasNext:       hasNext,
451                 NextLink:      nextLink,
452                 CommonData:    commonData,
453         }
454         err = svc.renderer.RenderNotificationPage(ctx, client, data)
455         if err != nil {
456                 return
457         }
458
459         return
460 }
461
462 func (svc *service) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) {
463         user, err := c.GetAccount(ctx, id)
464         if err != nil {
465                 return
466         }
467
468         var hasNext bool
469         var nextLink string
470
471         var pg = mastodon.Pagination{
472                 MaxID: maxID,
473                 MinID: minID,
474                 Limit: 20,
475         }
476
477         statuses, err := c.GetAccountStatuses(ctx, id, &pg)
478         if err != nil {
479                 return
480         }
481
482         if len(pg.MaxID) > 0 {
483                 hasNext = true
484                 nextLink = "/user/" + id + "?max_id=" + pg.MaxID
485         }
486
487         commonData, err := svc.getCommonData(ctx, client, c)
488         if err != nil {
489                 return
490         }
491
492         data := &renderer.UserData{
493                 User:       user,
494                 Statuses:   statuses,
495                 HasNext:    hasNext,
496                 NextLink:   nextLink,
497                 CommonData: commonData,
498         }
499
500         err = svc.renderer.RenderUserPage(ctx, client, data)
501         if err != nil {
502                 return
503         }
504
505         return
506 }
507
508 func (svc *service) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) {
509         commonData, err := svc.getCommonData(ctx, client, c)
510         if err != nil {
511                 return
512         }
513
514         data := &renderer.AboutData{
515                 CommonData: commonData,
516         }
517         err = svc.renderer.RenderAboutPage(ctx, client, data)
518         if err != nil {
519                 return
520         }
521
522         return
523 }
524
525 func (svc *service) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) {
526         commonData, err := svc.getCommonData(ctx, client, c)
527         if err != nil {
528                 return
529         }
530
531         emojis, err := c.GetInstanceEmojis(ctx)
532         if err != nil {
533                 return
534         }
535
536         data := &renderer.EmojiData{
537                 Emojis:     emojis,
538                 CommonData: commonData,
539         }
540
541         err = svc.renderer.RenderEmojiPage(ctx, client, data)
542         if err != nil {
543                 return
544         }
545
546         return
547 }
548
549 func (svc *service) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
550         likers, err := c.GetFavouritedBy(ctx, id, nil)
551         if err != nil {
552                 return
553         }
554
555         commonData, err := svc.getCommonData(ctx, client, c)
556         if err != nil {
557                 return
558         }
559
560         data := &renderer.LikedByData{
561                 CommonData: commonData,
562                 Users:      likers,
563         }
564
565         err = svc.renderer.RenderLikedByPage(ctx, client, data)
566         if err != nil {
567                 return
568         }
569
570         return
571 }
572
573 func (svc *service) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
574         retweeters, err := c.GetRebloggedBy(ctx, id, nil)
575         if err != nil {
576                 return
577         }
578
579         commonData, err := svc.getCommonData(ctx, client, c)
580         if err != nil {
581                 return
582         }
583
584         data := &renderer.RetweetedByData{
585                 CommonData: commonData,
586                 Users:      retweeters,
587         }
588
589         err = svc.renderer.RenderRetweetedByPage(ctx, client, data)
590         if err != nil {
591                 return
592         }
593
594         return
595 }
596 func (svc *service) getCommonData(ctx context.Context, client io.Writer, c *model.Client) (data *renderer.CommonData, err error) {
597         data = new(renderer.CommonData)
598
599         data.HeaderData = &renderer.HeaderData{
600                 Title:             "Web",
601                 NotificationCount: 0,
602                 CustomCSS:         svc.customCSS,
603         }
604
605         if c != nil && c.Session.IsLoggedIn() {
606                 notifications, err := c.GetNotifications(ctx, nil)
607                 if err != nil {
608                         return nil, err
609                 }
610
611                 var notificationCount int
612                 for i := range notifications {
613                         if notifications[i].Pleroma != nil && !notifications[i].Pleroma.IsSeen {
614                                 notificationCount++
615                         }
616                 }
617
618                 u, err := c.GetAccountCurrentUser(ctx)
619                 if err != nil {
620                         return nil, err
621                 }
622
623                 data.NavbarData = &renderer.NavbarData{
624                         User:              u,
625                         NotificationCount: notificationCount,
626                 }
627
628                 data.HeaderData.NotificationCount = notificationCount
629         }
630
631         return
632 }
633
634 func (svc *service) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
635         _, err = c.Favourite(ctx, id)
636         return
637 }
638
639 func (svc *service) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
640         _, err = c.Unfavourite(ctx, id)
641         return
642 }
643
644 func (svc *service) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
645         _, err = c.Reblog(ctx, id)
646         return
647 }
648
649 func (svc *service) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
650         _, err = c.Unreblog(ctx, id)
651         return
652 }
653
654 func (svc *service) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) {
655         var mediaIds []string
656         for _, f := range files {
657                 a, err := c.UploadMediaFromMultipartFileHeader(ctx, f)
658                 if err != nil {
659                         return "", err
660                 }
661                 mediaIds = append(mediaIds, a.ID)
662         }
663
664         // save visibility if it's a non-reply post
665         if len(replyToID) < 1 && visibility != c.Session.Settings.DefaultVisibility {
666                 c.Session.Settings.DefaultVisibility = visibility
667                 svc.sessionRepo.Add(c.Session)
668         }
669
670         tweet := &mastodon.Toot{
671                 Status:      content,
672                 InReplyToID: replyToID,
673                 MediaIDs:    mediaIds,
674                 ContentType: format,
675                 Visibility:  visibility,
676                 Sensitive:   isNSFW,
677         }
678
679         s, err := c.PostStatus(ctx, tweet)
680         if err != nil {
681                 return
682         }
683
684         return s.ID, nil
685 }
686
687 func (svc *service) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
688         _, err = c.AccountFollow(ctx, id)
689         return
690 }
691
692 func (svc *service) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
693         _, err = c.AccountUnfollow(ctx, id)
694         return
695 }
696
697 func addToReplyMap(m map[string][]mastodon.ReplyInfo, key interface{}, val string, number int) {
698         if key == nil {
699                 return
700         }
701
702         keyStr, ok := key.(string)
703         if !ok {
704                 return
705         }
706         _, ok = m[keyStr]
707         if !ok {
708                 m[keyStr] = []mastodon.ReplyInfo{}
709         }
710
711         m[keyStr] = append(m[keyStr], mastodon.ReplyInfo{val, number})
712 }