gomog/internal/engine/stream_aggregate.go

707 lines
17 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package engine
import (
"context"
"fmt"
"math/rand"
"runtime"
"sort"
"strings"
"time"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// StreamAggregationEngine 流式聚合引擎
type StreamAggregationEngine struct {
store *MemoryStore
}
// NewStreamAggregationEngine 创建流式聚合引擎
func NewStreamAggregationEngine(store *MemoryStore) *StreamAggregationEngine {
return &StreamAggregationEngine{store: store}
}
// StreamExecute 流式执行聚合管道
func (e *StreamAggregationEngine) StreamExecute(
ctx context.Context,
collection string,
pipeline []types.AggregateStage,
opts StreamAggregationOptions,
) (<-chan []types.Document, <-chan error) {
if opts.BufferSize <= 0 {
opts.BufferSize = 100
}
if opts.MaxConcurrency <= 0 {
opts.MaxConcurrency = runtime.NumCPU()
}
resultChan := make(chan []types.Document, opts.BufferSize)
errChan := make(chan error, 1)
go func() {
defer close(resultChan)
defer close(errChan)
// 获取文档迭代器
docIter, err := e.store.GetDocumentIterator(collection, opts.BufferSize)
if err != nil {
errChan <- err
return
}
defer docIter.Close()
// 分批处理文档
for docIter.HasNext() {
select {
case <-ctx.Done():
errChan <- ctx.Err()
return
default:
}
batch, err := docIter.NextBatch()
if err != nil {
errChan <- err
return
}
if len(batch) == 0 {
continue
}
// 执行管道处理
processed, err := e.processBatch(ctx, batch, pipeline, opts)
if err != nil {
errChan <- err
return
}
if len(processed) > 0 {
resultChan <- processed
}
}
}()
return resultChan, errChan
}
// processBatch 处理单个批次的文档
func (e *StreamAggregationEngine) processBatch(
ctx context.Context,
batch []types.Document,
pipeline []types.AggregateStage,
opts StreamAggregationOptions,
) ([]types.Document, error) {
var result []types.Document = batch
for _, stage := range pipeline {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var err error
result, err = e.executeStageStreaming(stage, result, opts)
if err != nil {
return nil, err
}
// 如果结果为空,提前终止
if len(result) == 0 {
break
}
}
return result, nil
}
// executeStageStreaming 执行单个阶段的流式处理
func (e *StreamAggregationEngine) executeStageStreaming(
stage types.AggregateStage,
docs []types.Document,
opts StreamAggregationOptions,
) ([]types.Document, error) {
// 对于某些操作,我们仍需完整数据集,所以需要特殊处理
switch stage.Stage {
case "$match":
return e.executeMatch(stage.Spec, docs)
case "$project":
return e.executeProject(stage.Spec, docs)
case "$limit":
return e.executeLimit(stage.Spec, docs)
case "$skip":
return e.executeSkip(stage.Spec, docs)
case "$sort":
// $sort 需要完整的数据集,所以不能完全流式处理
// 但在批处理中是可以处理的
return e.executeSort(stage.Spec, docs)
case "$unwind":
return e.executeUnwind(stage.Spec, docs)
case "$addFields", "$set":
return e.executeAddFields(stage.Spec, docs)
case "$unset":
return e.executeUnset(stage.Spec, docs)
case "$sample":
return e.executeSample(stage.Spec, docs)
case "$replaceRoot":
return e.executeReplaceRoot(stage.Spec, docs)
case "$replaceWith":
return e.executeReplaceWith(stage.Spec, docs)
// 对于需要全局数据的操作,如 $group, $lookup, $graphLookup 等
// 我们需要特殊的处理方式
case "$group":
// $group 需要完整的数据集,不能流式处理
// 这里我们返回错误,提示用户使用传统聚合
return nil, fmt.Errorf("$group stage cannot be processed in streaming mode, use regular aggregation instead")
case "$lookup":
// $lookup 需要另一个集合的完整数据,不能流式处理
return nil, fmt.Errorf("$lookup stage cannot be processed in streaming mode, use regular aggregation instead")
case "$graphLookup":
// $graphLookup 需要完整数据,不能流式处理
return nil, fmt.Errorf("$graphLookup stage cannot be processed in streaming mode, use regular aggregation instead")
// Batch 5 新增阶段
case "$unionWith":
// $unionWith 需要另一个集合的完整数据
return nil, fmt.Errorf("$unionWith stage cannot be processed in streaming mode, use regular aggregation instead")
case "$redact":
return e.executeRedact(stage.Spec, docs)
case "$indexStats", "$collStats":
// 这些统计操作需要完整数据
return nil, fmt.Errorf("$indexStats and $collStats stages cannot be processed in streaming mode, use regular aggregation instead")
case "$out", "$merge":
// 输出操作可以处理,但需要在最后阶段
return e.executeOutputStages(stage, docs)
default:
return docs, nil // 未知阶段,跳过
}
}
// executeOutputStages 处理输出阶段
func (e *StreamAggregationEngine) executeOutputStages(
stage types.AggregateStage,
docs []types.Document,
) ([]types.Document, error) {
switch stage.Stage {
case "$out":
return docs, fmt.Errorf("$out not supported in streaming mode")
case "$merge":
return docs, fmt.Errorf("$merge not supported in streaming mode")
default:
return docs, nil
}
}
// executeAddFields 执行 $addFields 阶段
func (e *StreamAggregationEngine) executeAddFields(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
addFieldsSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
// 深拷贝文档
newData := deepCopyMap(doc.Data)
// 添加字段
for field, expr := range addFieldsSpec {
newData[field] = e.evaluateExpression(newData, expr)
}
results = append(results, types.Document{
ID: doc.ID,
Data: newData,
})
}
return results, nil
}
// executeUnset 执行 $unset 阶段
func (e *StreamAggregationEngine) executeUnset(spec interface{}, docs []types.Document) ([]types.Document, error) {
unsetSpec, ok := spec.([]interface{})
if !ok {
// 如果是字符串,转换为数组
if str, isStr := spec.(string); isStr {
unsetSpec = []interface{}{str}
} else {
return docs, nil
}
}
var results []types.Document
for _, doc := range docs {
// 深拷贝文档
newData := deepCopyMap(doc.Data)
// 移除字段
for _, field := range unsetSpec {
if fieldName, isStr := field.(string); isStr {
delete(newData, fieldName)
}
}
results = append(results, types.Document{
ID: doc.ID,
Data: newData,
})
}
return results, nil
}
// executeSample 执行 $sample 阶段
func (e *StreamAggregationEngine) executeSample(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
sampleSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
size, ok := sampleSpec["size"].(float64)
if !ok {
return docs, nil
}
count := int(size)
if count >= len(docs) {
return docs, nil
}
if count <= 0 {
return []types.Document{}, nil
}
// 使用洗牌算法随机选择
shuffled := make([]types.Document, len(docs))
copy(shuffled, docs)
// Fisher-Yates 洗牌算法的变种,只取前 count 个
source := rand.NewSource(time.Now().UnixNano())
rng := rand.New(source)
for i := 0; i < count; i++ {
j := len(shuffled) - 1 - i
r := i + rng.Intn(j-i+1)
shuffled[r], shuffled[i] = shuffled[i], shuffled[r]
}
return shuffled[:count], nil
}
// executeReplaceRoot 执行 $replaceRoot 阶段
func (e *StreamAggregationEngine) executeReplaceRoot(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
replaceRootSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
newRootField, ok := replaceRootSpec["newRoot"].(string)
if !ok {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
// 获取新的根对象
newRoot := getNestedValue(doc.Data, newRootField)
if newRootMap, ok := newRoot.(map[string]interface{}); ok {
results = append(results, types.Document{
ID: doc.ID,
Data: newRootMap,
})
} else {
// 如果不是对象,创建一个包含该值的对象
results = append(results, types.Document{
ID: doc.ID,
Data: map[string]interface{}{newRootField: newRoot},
})
}
}
return results, nil
}
// executeReplaceWith 执行 $replaceWith 阶段
func (e *StreamAggregationEngine) executeReplaceWith(spec interface{}, docs []types.Document) ([]types.Document, error) {
var results []types.Document
for _, doc := range docs {
// 使用 evaluateExpression 获取新的文档数据
newData := e.evaluateExpression(doc.Data, spec)
if newDataMap, ok := newData.(map[string]interface{}); ok {
results = append(results, types.Document{
ID: doc.ID,
Data: newDataMap,
})
} else {
// 如果不是对象,创建一个包含该值的对象
results = append(results, types.Document{
ID: doc.ID,
Data: map[string]interface{}{"value": newData},
})
}
}
return results, nil
}
// executeRedact 执行 $redact 阶段
func (e *StreamAggregationEngine) executeRedact(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 这里需要复制 aggregate.go 中的实现
// 为简洁起见,暂时返回错误
return nil, fmt.Errorf("$redact stage not yet implemented in streaming mode")
}
// evaluateExpression 评估表达式(复制自 aggregate.go
func (e *StreamAggregationEngine) evaluateExpression(data map[string]interface{}, expr interface{}) interface{} {
// 复制自 aggregate.go 中的实现
// 处理字段引用(以 $ 开头的字符串)
if fieldStr, ok := expr.(string); ok && len(fieldStr) > 0 && fieldStr[0] == '$' {
fieldName := fieldStr[1:] // 移除 $ 前缀
return getNestedValue(data, fieldName)
}
if exprMap, ok := expr.(map[string]interface{}); ok {
for op, operand := range exprMap {
switch op {
case "$concat":
return e.concat(operand, data)
case "$toUpper":
return strings.ToUpper(fmt.Sprintf("%v", e.getFieldValueStr(types.Document{Data: data}, operand)))
case "$toLower":
return strings.ToLower(fmt.Sprintf("%v", e.getFieldValueStr(types.Document{Data: data}, operand)))
case "$add":
return e.add(operand, data)
case "$multiply":
return e.multiply(operand, data)
case "$ifNull":
return e.ifNull(operand, data)
case "$cond":
return e.cond(operand, data)
// 可以根据需要添加更多操作
}
}
}
return expr
}
// 以下是一些辅助函数的占位实现
func (e *StreamAggregationEngine) concat(operand interface{}, data map[string]interface{}) interface{} {
// 简单实现
if arr, ok := operand.([]interface{}); ok {
result := ""
for _, item := range arr {
evaluated := e.evaluateExpression(data, item)
result += fmt.Sprintf("%v", evaluated)
}
return result
}
return ""
}
func (e *StreamAggregationEngine) getFieldValueStr(doc types.Document, field interface{}) string {
// 简单实现
if str, ok := e.getFieldValue(doc, field).(string); ok {
return str
}
return ""
}
func (e *StreamAggregationEngine) getFieldValue(doc types.Document, field interface{}) interface{} {
switch f := field.(type) {
case string:
if len(f) > 0 && f[0] == '$' {
return getNestedValue(doc.Data, f[1:])
}
return f
default:
return field
}
}
func (e *StreamAggregationEngine) add(operand interface{}, data map[string]interface{}) interface{} {
if arr, ok := operand.([]interface{}); ok {
sum := 0.0
for _, item := range arr {
evaluated := e.evaluateExpression(data, item)
sum += toFloat64(evaluated)
}
return sum
}
return 0
}
func (e *StreamAggregationEngine) multiply(operand interface{}, data map[string]interface{}) interface{} {
if arr, ok := operand.([]interface{}); ok {
result := 1.0
for _, item := range arr {
evaluated := e.evaluateExpression(data, item)
result *= toFloat64(evaluated)
}
return result
}
return 0
}
func (e *StreamAggregationEngine) ifNull(operand interface{}, data map[string]interface{}) interface{} {
if arr, ok := operand.([]interface{}); ok && len(arr) == 2 {
evaluatedFirst := e.evaluateExpression(data, arr[0])
if evaluatedFirst != nil {
return evaluatedFirst
}
return e.evaluateExpression(data, arr[1])
}
return nil
}
func (e *StreamAggregationEngine) cond(operand interface{}, data map[string]interface{}) interface{} {
if condMap, ok := operand.(map[string]interface{}); ok {
ifCond, hasIf := condMap["if"]
thenVal, hasThen := condMap["then"]
elseVal, hasElse := condMap["else"]
if hasIf && hasThen && hasElse {
ifVal := e.evaluateExpression(data, ifCond)
if isTrue(ifVal) {
return e.evaluateExpression(data, thenVal)
}
return e.evaluateExpression(data, elseVal)
}
}
return nil
}
// executeMatch 执行 $match 阶段
func (e *StreamAggregationEngine) executeMatch(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
var filter map[string]interface{}
if f, ok := spec.(types.Filter); ok {
filter = f
} else if f, ok := spec.(map[string]interface{}); ok {
filter = f
} else {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
if MatchFilter(doc.Data, filter) {
results = append(results, doc)
}
}
return results, nil
}
// executeProject 执行 $project 阶段
func (e *StreamAggregationEngine) executeProject(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
projectSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
projected := e.projectDocument(doc.Data, projectSpec)
results = append(results, types.Document{
ID: doc.ID,
Data: projected,
})
}
return results, nil
}
// projectDocument 投影文档
func (e *StreamAggregationEngine) projectDocument(data map[string]interface{}, spec map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for field, include := range spec {
if field == "_id" {
// 特殊处理 _id
if isFalse(include) {
// 排除 _id
} else {
result["_id"] = data["_id"]
}
continue
}
if isTrue(include) {
// 包含字段
result[field] = getNestedValue(data, field)
} else if isFalse(include) {
// 排除字段(在包含模式下不处理)
continue
} else {
// 表达式
result[field] = e.evaluateExpression(data, include)
}
}
return result
}
// executeLimit 执行 $limit 阶段
func (e *StreamAggregationEngine) executeLimit(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
limit := 0
switch l := spec.(type) {
case int:
limit = l
case int64:
limit = int(l)
case float64:
limit = int(l)
}
if limit <= 0 || limit >= len(docs) {
return docs, nil
}
return docs[:limit], nil
}
// executeSkip 执行 $skip 阶段
func (e *StreamAggregationEngine) executeSkip(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
skip := 0
switch s := spec.(type) {
case int:
skip = s
case int64:
skip = int(s)
case float64:
skip = int(s)
}
if skip <= 0 {
return docs, nil
}
if skip >= len(docs) {
return []types.Document{}, nil
}
return docs[skip:], nil
}
// executeUnwind 执行 $unwind 阶段
func (e *StreamAggregationEngine) executeUnwind(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
var path string
var preserveNull bool
switch s := spec.(type) {
case string:
path = s
case map[string]interface{}:
if p, ok := s["path"].(string); ok {
path = p
}
if pn, ok := s["preserveNullAndEmptyArrays"].(bool); ok {
preserveNull = pn
}
}
if path == "" || path[0] != '$' {
return docs, nil
}
fieldPath := path[1:]
var results []types.Document
for _, doc := range docs {
arr := getNestedValue(doc.Data, fieldPath)
if arr == nil {
if preserveNull {
results = append(results, doc)
}
continue
}
array, ok := arr.([]interface{})
if !ok || len(array) == 0 {
if preserveNull {
results = append(results, doc)
}
continue
}
for _, item := range array {
newData := deepCopyMap(doc.Data)
setNestedValue(newData, fieldPath, item)
results = append(results, types.Document{
ID: doc.ID,
Data: newData,
})
}
}
return results, nil
}
// executeSort 执行 $sort 阶段
func (e *StreamAggregationEngine) executeSort(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
sortSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
// 转换为排序字段映射
sortFields := make(map[string]int)
for field, direction := range sortSpec {
dir := 1
switch d := direction.(type) {
case int:
dir = d
case int64:
dir = int(d)
case float64:
dir = int(d)
}
sortFields[field] = dir
}
// 创建可排序的副本
sorted := make([]types.Document, len(docs))
copy(sorted, docs)
sort.Slice(sorted, func(i, j int) bool {
return e.compareDocs(sorted[i], sorted[j], sortFields)
})
return sorted, nil
}
// compareDocs 比较两个文档
func (e *StreamAggregationEngine) compareDocs(a, b types.Document, sortFields map[string]int) bool {
for field, dir := range sortFields {
valA := getNestedValue(a.Data, field)
valB := getNestedValue(b.Data, field)
cmp := compareValues(valA, valB)
if cmp != 0 {
if dir < 0 {
return cmp > 0
}
return cmp < 0
}
}
return false
}