Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions plugins/connectors/sharepoint/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package sharepoint

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"

log "github.com/cihub/seelog"
"golang.org/x/oauth2"
"golang.org/x/oauth2/microsoft"
)

type SharePointAPIClient struct {
config *SharePointConfig
httpClient *http.Client
oauthConfig *oauth2.Config
token *oauth2.Token
retryClient *RetryClient
}

func NewSharePointAPIClient(config *SharePointConfig) (*SharePointAPIClient, error) {
client := &SharePointAPIClient{
config: config,
}

// 初始化OAuth配置
client.oauthConfig = &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
Endpoint: microsoft.AzureADEndpoint(config.TenantID),
Scopes: []string{"https://graph.microsoft.com/.default"},
}

// 初始化token
if config.AccessToken != "" {
client.token = &oauth2.Token{
AccessToken: config.AccessToken,
RefreshToken: config.RefreshToken,
Expiry: config.TokenExpiry,
}
}

// 初始化HTTP客户端
ctx := context.Background()
client.httpClient = client.oauthConfig.Client(ctx, client.token)

// 初始化重试客户端
client.retryClient = NewRetryClient(config.RetryConfig)

return client, nil
}

func (c *SharePointAPIClient) GetSites(ctx context.Context) ([]SharePointSite, error) {
url := "https://graph.microsoft.com/v1.0/sites"

var allSites []SharePointSite
for {
resp, err := c.retryClient.DoWithRetry(ctx, func() (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
return c.httpClient.Do(req)
})

if err != nil {
return nil, fmt.Errorf("failed to get sites: %w", err)
}

var response struct {
Value []SharePointSite `json:"value"`
NextLink string `json:"@odata.nextLink"`
}

if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
resp.Body.Close()
return nil, fmt.Errorf("failed to decode response: %w", err)
}
resp.Body.Close()

allSites = append(allSites, response.Value...)

if response.NextLink == "" {
break
}
url = response.NextLink
}

return allSites, nil
}

func (c *SharePointAPIClient) GetDocumentLibraries(ctx context.Context, siteID string) ([]SharePointList, error) {
url := fmt.Sprintf("https://graph.microsoft.com/v1.0/sites/%s/lists?$filter=list/template eq 'documentLibrary'", siteID)

var allLists []SharePointList
for {
resp, err := c.retryClient.DoWithRetry(ctx, func() (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
return c.httpClient.Do(req)
})

if err != nil {
return nil, fmt.Errorf("failed to get document libraries: %w", err)
}

var response struct {
Value []SharePointList `json:"value"`
NextLink string `json:"@odata.nextLink"`
}

if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
resp.Body.Close()
return nil, fmt.Errorf("failed to decode response: %w", err)
}
resp.Body.Close()

allLists = append(allLists, response.Value...)

if response.NextLink == "" {
break
}
url = response.NextLink
}

return allLists, nil
}

func (c *SharePointAPIClient) GetItems(ctx context.Context, siteID, listID string, pageSize int) ([]SharePointItem, string, error) {
url := fmt.Sprintf("https://graph.microsoft.com/v1.0/sites/%s/lists/%s/items?$expand=fields,driveItem&$top=%d", siteID, listID, pageSize)

resp, err := c.retryClient.DoWithRetry(ctx, func() (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
return c.httpClient.Do(req)
})

if err != nil {
return nil, "", fmt.Errorf("failed to get items: %w", err)
}
defer resp.Body.Close()

var response struct {
Value []SharePointItem `json:"value"`
NextLink string `json:"@odata.nextLink"`
}

if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return nil, "", fmt.Errorf("failed to decode response: %w", err)
}

return response.Value, response.NextLink, nil
}

func (c *SharePointAPIClient) DownloadFile(ctx context.Context, downloadURL string) ([]byte, error) {
resp, err := c.retryClient.DoWithRetry(ctx, func() (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
if err != nil {
return nil, err
}
return c.httpClient.Do(req)
})

if err != nil {
return nil, fmt.Errorf("failed to download file: %w", err)
}
defer resp.Body.Close()

return io.ReadAll(resp.Body)
}
155 changes: 155 additions & 0 deletions plugins/connectors/sharepoint/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package sharepoint

import (
"context"
"fmt"
"net/http"

"github.com/julienschmidt/httprouter"
log "github.com/cihub/seelog"
"golang.org/x/oauth2"
"golang.org/x/oauth2/microsoft"
"infini.sh/coco/modules/common"
"infini.sh/framework/core/api"
"infini.sh/framework/core/orm"
)

func (p *Plugin) connect(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
datasourceID := req.URL.Query().Get("datasource_id")
if datasourceID == "" {
api.WriteError(w, fmt.Errorf("datasource_id is required"), http.StatusBadRequest)
return
}

// 获取数据源配置
datasource := &common.DataSource{}
datasource.ID = datasourceID
exists, err := orm.Get(datasource)
if !exists || err != nil {
api.WriteError(w, fmt.Errorf("datasource not found"), http.StatusNotFound)
return
}

config, err := parseSharePointConfig(datasource)
if err != nil {
api.WriteError(w, err, http.StatusBadRequest)
return
}

// 创建OAuth配置
oauthConfig := &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
Endpoint: microsoft.AzureADEndpoint(config.TenantID),
Scopes: []string{"https://graph.microsoft.com/.default"},
RedirectURL: fmt.Sprintf("%s/connector/sharepoint/oauth_redirect?datasource_id=%s",
getBaseURL(req), datasourceID),
}

// 生成授权URL
authURL := oauthConfig.AuthCodeURL("state", oauth2.AccessTypeOffline)

api.WriteJSON(w, map[string]interface{}{
"auth_url": authURL,
}, http.StatusOK)
}

func (p *Plugin) oAuthRedirect(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
datasourceID := req.URL.Query().Get("datasource_id")
code := req.URL.Query().Get("code")

if datasourceID == "" || code == "" {
api.WriteError(w, fmt.Errorf("missing required parameters"), http.StatusBadRequest)
return
}

// 获取数据源
datasource := &common.DataSource{}
datasource.ID = datasourceID
exists, err := orm.Get(datasource)
if !exists || err != nil {
api.WriteError(w, fmt.Errorf("datasource not found"), http.StatusNotFound)
return
}

config, err := parseSharePointConfig(datasource)
if err != nil {
api.WriteError(w, err, http.StatusBadRequest)
return
}

// 交换token
oauthConfig := &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
Endpoint: microsoft.AzureADEndpoint(config.TenantID),
Scopes: []string{"https://graph.microsoft.com/.default"},
RedirectURL: fmt.Sprintf("%s/connector/sharepoint/oauth_redirect?datasource_id=%s",
getBaseURL(req), datasourceID),
}

ctx := context.Background()
token, err := oauthConfig.Exchange(ctx, code)
if err != nil {
api.WriteError(w, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError)
return
}

// 更新数据源配置
configMap := datasource.Connector.Config.(map[string]interface{})
configMap["access_token"] = token.AccessToken
configMap["refresh_token"] = token.RefreshToken
configMap["token_expiry"] = token.Expiry

datasource.Connector.Config = configMap
err = orm.Update(datasource)
if err != nil {
api.WriteError(w, fmt.Errorf("failed to update datasource: %w", err), http.StatusInternalServerError)
return
}

// 重定向到成功页面
http.Redirect(w, req, "/datasource/edit/"+datasourceID+"?connected=true", http.StatusFound)
}

func (p *Plugin) reset(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
datasourceID := req.URL.Query().Get("datasource_id")
if datasourceID == "" {
api.WriteError(w, fmt.Errorf("datasource_id is required"), http.StatusBadRequest)
return
}

// 获取数据源
datasource := &common.DataSource{}
datasource.ID = datasourceID
exists, err := orm.Get(datasource)
if !exists || err != nil {
api.WriteError(w, fmt.Errorf("datasource not found"), http.StatusNotFound)
return
}

// 清除token
configMap := datasource.Connector.Config.(map[string]interface{})
delete(configMap, "access_token")
delete(configMap, "refresh_token")
delete(configMap, "token_expiry")

datasource.Connector.Config = configMap
err = orm.Update(datasource)
if err != nil {
api.WriteError(w, fmt.Errorf("failed to update datasource: %w", err), http.StatusInternalServerError)
return
}

api.WriteJSON(w, map[string]interface{}{
"success": true,
}, http.StatusOK)
}

func getBaseURL(req *http.Request) string {
scheme := "http"
if req.TLS != nil {
scheme = "https"
}
return fmt.Sprintf("%s://%s", scheme, req.Host)
}
59 changes: 59 additions & 0 deletions plugins/connectors/sharepoint/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package sharepoint

import (
"fmt"
"time"

"infini.sh/coco/modules/common"
"infini.sh/framework/core/config"
)

func parseSharePointConfig(datasource *common.DataSource) (*SharePointConfig, error) {
if datasource.Connector.Config == nil {
return nil, fmt.Errorf("connector config is nil")
}

cfg, err := config.NewConfigFrom(datasource.Connector.Config)
if err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err)
}

sharePointConfig := &SharePointConfig{}
err = cfg.Unpack(sharePointConfig)
if err != nil {
return nil, fmt.Errorf("failed to unpack config: %w", err)
}

// 设置默认的重试配置
if sharePointConfig.RetryConfig.MaxRetries == 0 {
sharePointConfig.RetryConfig.MaxRetries = 3
}
if sharePointConfig.RetryConfig.InitialDelay == 0 {
sharePointConfig.RetryConfig.InitialDelay = time.Second
}
if sharePointConfig.RetryConfig.MaxDelay == 0 {
sharePointConfig.RetryConfig.MaxDelay = time.Minute
}
if sharePointConfig.RetryConfig.BackoffFactor == 0 {
sharePointConfig.RetryConfig.BackoffFactor = 2.0
}

return sharePointConfig, nil
}

func validateSharePointConfig(config *SharePointConfig) error {
if config.SiteURL == "" {
return fmt.Errorf("site_url is required")
}
if config.TenantID == "" {
return fmt.Errorf("tenant_id is required")
}
if config.ClientID == "" {
return fmt.Errorf("client_id is required")
}
if config.ClientSecret == "" {
return fmt.Errorf("client_secret is required")
}

return nil
}
Loading