aboutsummaryrefslogtreecommitdiff
path: root/internal/github/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/github/client.go')
-rw-r--r--internal/github/client.go345
1 files changed, 345 insertions, 0 deletions
diff --git a/internal/github/client.go b/internal/github/client.go
new file mode 100644
index 0000000..c864a8c
--- /dev/null
+++ b/internal/github/client.go
@@ -0,0 +1,345 @@
+package github
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+)
+
+const (
+ maxRetries = 3
+ retryDelay = 2 * time.Second
+)
+
+// Client implements the GitHub API client.
+type Client struct {
+ httpClient *http.Client
+ token string
+ apiURL string
+}
+
+// NewClient creates a new GitHub API client.
+func NewClient(token, apiURL string) *Client {
+ apiURL = strings.TrimRight(apiURL, "/")
+ if !strings.HasSuffix(apiURL, "/api/v3") {
+ apiURL = strings.TrimRight(apiURL, "/api")
+ apiURL = apiURL + "/api/v3"
+ }
+
+ return &Client{
+ httpClient: &http.Client{
+ Timeout: 10 * time.Second,
+ },
+ token: token,
+ apiURL: apiURL,
+ }
+}
+
+// GetPullRequests fetches pull requests created within the specified date range.
+func (c *Client) GetPullRequests(
+ ctx context.Context,
+ owner, repo string,
+ since, until time.Time,
+) ([]PullRequest, error) {
+ url := fmt.Sprintf("%s/repos/%s/%s/pulls", c.apiURL, owner, repo)
+
+ params := []string{
+ "state=all",
+ "sort=created",
+ "direction=desc",
+ "per_page=100", // Maximum allowed by GitHub
+ }
+
+ log.Printf("Fetching PRs from: %s with params: %v", url, params)
+
+ req, err := http.NewRequestWithContext(ctx, "GET", url+"?"+strings.Join(params, "&"), nil)
+ if err != nil {
+ return nil, fmt.Errorf("creating request: %w", err)
+ }
+
+ var allPRs []PullRequest
+ for {
+ prs, nextPage, err := c.doGetPullRequests(req)
+ if err != nil {
+ return nil, fmt.Errorf("fetching PRs: %w", err)
+ }
+ allPRs = append(allPRs, prs...)
+
+ // Check if we've reached PRs outside our date range
+ if len(prs) > 0 && prs[len(prs)-1].CreatedAt.Before(since) {
+ break
+ }
+
+ if nextPage == "" {
+ break
+ }
+
+ req.URL.RawQuery = nextPage
+ }
+
+ var filteredPRs []PullRequest
+ for _, pr := range allPRs {
+ if pr.CreatedAt.After(since) && pr.CreatedAt.Before(until) {
+ filteredPRs = append(filteredPRs, pr)
+ }
+ }
+
+ log.Printf("Found %d PRs in date range", len(filteredPRs))
+ return filteredPRs, nil
+}
+
+// GetPullRequestReviews fetches all reviews for a specific pull request.
+func (c *Client) GetPullRequestReviews(
+ ctx context.Context,
+ owner, repo string,
+ prNumber int,
+) ([]Review, error) {
+ url := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews", c.apiURL, owner, repo, prNumber)
+
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+ if err != nil {
+ return nil, fmt.Errorf("creating request: %w", err)
+ }
+
+ var reviews []Review
+ err = c.doRequest(req, &reviews)
+ if err != nil {
+ return nil, err
+ }
+
+ return reviews, nil
+}
+
+// GetPullRequestReviewComments fetches all review comments for a specific pull request.
+func (c *Client) GetPullRequestReviewComments(
+ ctx context.Context,
+ owner, repo string,
+ prNumber int,
+) ([]ReviewComment, error) {
+ url := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews", c.apiURL, owner, repo, prNumber)
+
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+ if err != nil {
+ return nil, fmt.Errorf("creating request: %w", err)
+ }
+
+ var reviews []Review
+ err = c.doRequest(req, &reviews)
+ if err != nil {
+ return nil, err
+ }
+
+ var allComments []ReviewComment
+ for _, review := range reviews {
+ comments, err := c.getReviewComments(ctx, owner, repo, prNumber, review.ID)
+ if err != nil {
+ return nil, err
+ }
+ allComments = append(allComments, comments...)
+ }
+
+ return allComments, nil
+}
+
+// GetRateLimit returns the current rate limit information.
+func (c *Client) GetRateLimit(ctx context.Context) (*RateLimit, error) {
+ url := fmt.Sprintf("%s/rate_limit", c.apiURL)
+
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+ if err != nil {
+ return nil, fmt.Errorf("creating request: %w", err)
+ }
+
+ var response struct {
+ Resources struct {
+ Core struct {
+ Limit int `json:"limit"`
+ Remaining int `json:"remaining"`
+ Reset time.Time `json:"reset"`
+ } `json:"core"`
+ } `json:"resources"`
+ }
+
+ err = c.doRequest(req, &response)
+ if err != nil {
+ return nil, err
+ }
+
+ return &RateLimit{
+ Limit: response.Resources.Core.Limit,
+ Remaining: response.Resources.Core.Remaining,
+ Reset: response.Resources.Core.Reset,
+ }, nil
+}
+
+func (c *Client) doGetPullRequests(req *http.Request) ([]PullRequest, string, error) {
+ var prs []PullRequest
+ nextPage, err := c.doRequestWithPagination(req, &prs)
+ if err != nil {
+ return nil, "", err
+ }
+ return prs, nextPage, nil
+}
+
+func (c *Client) getReviewComments(
+ ctx context.Context,
+ owner, repo string,
+ prNumber int,
+ reviewID int64,
+) ([]ReviewComment, error) {
+ url := fmt.Sprintf(
+ "%s/repos/%s/%s/pulls/%d/reviews/%d/comments",
+ c.apiURL,
+ owner,
+ repo,
+ prNumber,
+ reviewID,
+ )
+
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+ if err != nil {
+ return nil, fmt.Errorf("creating request: %w", err)
+ }
+
+ var comments []ReviewComment
+ err = c.doRequest(req, &comments)
+ if err != nil {
+ return nil, err
+ }
+
+ return comments, nil
+}
+
+func (c *Client) doRequest(req *http.Request, v any) error {
+ req.Header.Set("Authorization", "token "+c.token)
+ req.Header.Set("Accept", "application/vnd.github.v3+json")
+
+ var lastErr error
+ for i := range maxRetries {
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ lastErr = err
+ log.Printf("Request failed (attempt %d/%d): %v", i+1, maxRetries, err)
+ time.Sleep(retryDelay * time.Duration(i+1))
+ continue
+ }
+ defer func() {
+ err := resp.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+ }()
+
+ // Check rate limit
+ if resp.StatusCode == http.StatusForbidden {
+ resetTime := resp.Header.Get("X-RateLimit-Reset")
+ if resetTime != "" {
+ reset, err := strconv.ParseInt(resetTime, 10, 64)
+ if err == nil {
+ waitTime := time.Until(time.Unix(reset, 0))
+ if waitTime > 0 {
+ log.Printf("Rate limit exceeded. Waiting %v before retry", waitTime)
+ time.Sleep(waitTime)
+ continue
+ }
+ }
+ }
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ var apiErr APIError
+ if err := json.Unmarshal(body, &apiErr); err == nil {
+ log.Printf("API Error Response: %+v", apiErr)
+ return fmt.Errorf("API error: %s (Status: %d)", apiErr.Message, resp.StatusCode)
+ }
+ log.Printf("Unexpected response body: %s", string(body))
+ return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ }
+
+ if err := json.NewDecoder(resp.Body).Decode(v); err != nil {
+ return fmt.Errorf("decoding response: %w", err)
+ }
+
+ return nil
+ }
+
+ return fmt.Errorf("max retries exceeded: %w", lastErr)
+}
+
+func (c *Client) doRequestWithPagination(req *http.Request, v any) (string, error) {
+ req.Header.Set("Authorization", "token "+c.token)
+ req.Header.Set("Accept", "application/vnd.github.v3+json")
+
+ var lastErr error
+ for i := range maxRetries {
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ lastErr = err
+ log.Printf("Request failed (attempt %d/%d): %v", i+1, maxRetries, err)
+ time.Sleep(retryDelay * time.Duration(i+1))
+ continue
+ }
+ defer func() {
+ err := resp.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+ }()
+
+ if resp.StatusCode == http.StatusForbidden {
+ resetTime := resp.Header.Get("X-RateLimit-Reset")
+ if resetTime != "" {
+ reset, err := strconv.ParseInt(resetTime, 10, 64)
+ if err == nil {
+ waitTime := time.Until(time.Unix(reset, 0))
+ if waitTime > 0 {
+ log.Printf("Rate limit exceeded. Waiting %v before retry", waitTime)
+ time.Sleep(waitTime)
+ continue
+ }
+ }
+ }
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ var apiErr APIError
+ if err := json.Unmarshal(body, &apiErr); err == nil {
+ log.Printf("API Error Response: %+v", apiErr)
+ return "", fmt.Errorf("API error: %s (Status: %d)", apiErr.Message, resp.StatusCode)
+ }
+ log.Printf("Unexpected response body: %s", string(body))
+ return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ }
+
+ if err := json.NewDecoder(resp.Body).Decode(v); err != nil {
+ return "", fmt.Errorf("decoding response: %w", err)
+ }
+
+ nextPage := ""
+ if link := resp.Header.Get("Link"); link != "" {
+ parts := strings.SplitSeq(link, ",")
+ for part := range parts {
+ if strings.Contains(part, "rel=\"next\"") {
+ start := strings.Index(part, "<")
+ end := strings.Index(part, ">")
+ if start >= 0 && end > start {
+ nextPage = part[start+1 : end]
+ break
+ }
+ }
+ }
+ }
+
+ return nextPage, nil
+ }
+
+ return "", fmt.Errorf("max retries exceeded: %w", lastErr)
+}