diff options
Diffstat (limited to '')
| -rw-r--r-- | internal/github/client.go | 345 |
1 files changed, 345 insertions, 0 deletions
diff --git a/internal/github/client.go b/internal/github/client.go new file mode 100644 index 0000000..c864a8c --- /dev/null +++ b/internal/github/client.go @@ -0,0 +1,345 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strconv" + "strings" + "time" +) + +const ( + maxRetries = 3 + retryDelay = 2 * time.Second +) + +// Client implements the GitHub API client. +type Client struct { + httpClient *http.Client + token string + apiURL string +} + +// NewClient creates a new GitHub API client. +func NewClient(token, apiURL string) *Client { + apiURL = strings.TrimRight(apiURL, "/") + if !strings.HasSuffix(apiURL, "/api/v3") { + apiURL = strings.TrimRight(apiURL, "/api") + apiURL = apiURL + "/api/v3" + } + + return &Client{ + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + token: token, + apiURL: apiURL, + } +} + +// GetPullRequests fetches pull requests created within the specified date range. +func (c *Client) GetPullRequests( + ctx context.Context, + owner, repo string, + since, until time.Time, +) ([]PullRequest, error) { + url := fmt.Sprintf("%s/repos/%s/%s/pulls", c.apiURL, owner, repo) + + params := []string{ + "state=all", + "sort=created", + "direction=desc", + "per_page=100", // Maximum allowed by GitHub + } + + log.Printf("Fetching PRs from: %s with params: %v", url, params) + + req, err := http.NewRequestWithContext(ctx, "GET", url+"?"+strings.Join(params, "&"), nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + + var allPRs []PullRequest + for { + prs, nextPage, err := c.doGetPullRequests(req) + if err != nil { + return nil, fmt.Errorf("fetching PRs: %w", err) + } + allPRs = append(allPRs, prs...) + + // Check if we've reached PRs outside our date range + if len(prs) > 0 && prs[len(prs)-1].CreatedAt.Before(since) { + break + } + + if nextPage == "" { + break + } + + req.URL.RawQuery = nextPage + } + + var filteredPRs []PullRequest + for _, pr := range allPRs { + if pr.CreatedAt.After(since) && pr.CreatedAt.Before(until) { + filteredPRs = append(filteredPRs, pr) + } + } + + log.Printf("Found %d PRs in date range", len(filteredPRs)) + return filteredPRs, nil +} + +// GetPullRequestReviews fetches all reviews for a specific pull request. +func (c *Client) GetPullRequestReviews( + ctx context.Context, + owner, repo string, + prNumber int, +) ([]Review, error) { + url := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews", c.apiURL, owner, repo, prNumber) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + + var reviews []Review + err = c.doRequest(req, &reviews) + if err != nil { + return nil, err + } + + return reviews, nil +} + +// GetPullRequestReviewComments fetches all review comments for a specific pull request. +func (c *Client) GetPullRequestReviewComments( + ctx context.Context, + owner, repo string, + prNumber int, +) ([]ReviewComment, error) { + url := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews", c.apiURL, owner, repo, prNumber) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + + var reviews []Review + err = c.doRequest(req, &reviews) + if err != nil { + return nil, err + } + + var allComments []ReviewComment + for _, review := range reviews { + comments, err := c.getReviewComments(ctx, owner, repo, prNumber, review.ID) + if err != nil { + return nil, err + } + allComments = append(allComments, comments...) + } + + return allComments, nil +} + +// GetRateLimit returns the current rate limit information. +func (c *Client) GetRateLimit(ctx context.Context) (*RateLimit, error) { + url := fmt.Sprintf("%s/rate_limit", c.apiURL) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + + var response struct { + Resources struct { + Core struct { + Limit int `json:"limit"` + Remaining int `json:"remaining"` + Reset time.Time `json:"reset"` + } `json:"core"` + } `json:"resources"` + } + + err = c.doRequest(req, &response) + if err != nil { + return nil, err + } + + return &RateLimit{ + Limit: response.Resources.Core.Limit, + Remaining: response.Resources.Core.Remaining, + Reset: response.Resources.Core.Reset, + }, nil +} + +func (c *Client) doGetPullRequests(req *http.Request) ([]PullRequest, string, error) { + var prs []PullRequest + nextPage, err := c.doRequestWithPagination(req, &prs) + if err != nil { + return nil, "", err + } + return prs, nextPage, nil +} + +func (c *Client) getReviewComments( + ctx context.Context, + owner, repo string, + prNumber int, + reviewID int64, +) ([]ReviewComment, error) { + url := fmt.Sprintf( + "%s/repos/%s/%s/pulls/%d/reviews/%d/comments", + c.apiURL, + owner, + repo, + prNumber, + reviewID, + ) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + + var comments []ReviewComment + err = c.doRequest(req, &comments) + if err != nil { + return nil, err + } + + return comments, nil +} + +func (c *Client) doRequest(req *http.Request, v any) error { + req.Header.Set("Authorization", "token "+c.token) + req.Header.Set("Accept", "application/vnd.github.v3+json") + + var lastErr error + for i := range maxRetries { + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = err + log.Printf("Request failed (attempt %d/%d): %v", i+1, maxRetries, err) + time.Sleep(retryDelay * time.Duration(i+1)) + continue + } + defer func() { + err := resp.Body.Close() + if err != nil { + log.Fatal(err) + } + }() + + // Check rate limit + if resp.StatusCode == http.StatusForbidden { + resetTime := resp.Header.Get("X-RateLimit-Reset") + if resetTime != "" { + reset, err := strconv.ParseInt(resetTime, 10, 64) + if err == nil { + waitTime := time.Until(time.Unix(reset, 0)) + if waitTime > 0 { + log.Printf("Rate limit exceeded. Waiting %v before retry", waitTime) + time.Sleep(waitTime) + continue + } + } + } + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + var apiErr APIError + if err := json.Unmarshal(body, &apiErr); err == nil { + log.Printf("API Error Response: %+v", apiErr) + return fmt.Errorf("API error: %s (Status: %d)", apiErr.Message, resp.StatusCode) + } + log.Printf("Unexpected response body: %s", string(body)) + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + if err := json.NewDecoder(resp.Body).Decode(v); err != nil { + return fmt.Errorf("decoding response: %w", err) + } + + return nil + } + + return fmt.Errorf("max retries exceeded: %w", lastErr) +} + +func (c *Client) doRequestWithPagination(req *http.Request, v any) (string, error) { + req.Header.Set("Authorization", "token "+c.token) + req.Header.Set("Accept", "application/vnd.github.v3+json") + + var lastErr error + for i := range maxRetries { + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = err + log.Printf("Request failed (attempt %d/%d): %v", i+1, maxRetries, err) + time.Sleep(retryDelay * time.Duration(i+1)) + continue + } + defer func() { + err := resp.Body.Close() + if err != nil { + log.Fatal(err) + } + }() + + if resp.StatusCode == http.StatusForbidden { + resetTime := resp.Header.Get("X-RateLimit-Reset") + if resetTime != "" { + reset, err := strconv.ParseInt(resetTime, 10, 64) + if err == nil { + waitTime := time.Until(time.Unix(reset, 0)) + if waitTime > 0 { + log.Printf("Rate limit exceeded. Waiting %v before retry", waitTime) + time.Sleep(waitTime) + continue + } + } + } + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + var apiErr APIError + if err := json.Unmarshal(body, &apiErr); err == nil { + log.Printf("API Error Response: %+v", apiErr) + return "", fmt.Errorf("API error: %s (Status: %d)", apiErr.Message, resp.StatusCode) + } + log.Printf("Unexpected response body: %s", string(body)) + return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + if err := json.NewDecoder(resp.Body).Decode(v); err != nil { + return "", fmt.Errorf("decoding response: %w", err) + } + + nextPage := "" + if link := resp.Header.Get("Link"); link != "" { + parts := strings.SplitSeq(link, ",") + for part := range parts { + if strings.Contains(part, "rel=\"next\"") { + start := strings.Index(part, "<") + end := strings.Index(part, ">") + if start >= 0 && end > start { + nextPage = part[start+1 : end] + break + } + } + } + } + + return nextPage, nil + } + + return "", fmt.Errorf("max retries exceeded: %w", lastErr) +} |
