thrift TProcessor

大家有没有想过,调用方调用一个方法,服务端是怎么找到对应方法并处理的?

很多RPC里是通过反射实现的,但是thrift里用的是另一种方式。

(1) TProcessor作用

thrift idl 生成的代码 隐藏了复杂的序列化、业务接口调用等逻辑

在调用方的程序里面,thrift idl根据调用的服务接口提前生成代理实现类,并通过依赖注入等技术注入到声明了该接口的相关业务逻辑里面。该代理实现类会拦截所有的方法调用,在提供的方法处理逻辑里面完成一整套的远程调用,并把远程调用结果返回给调用方,这样调用方在调用远程方法的时候就获得了像调用本地接口一样的体验。

// 

(2) Demo-go

func main() {
	
	addr := "localhost:9090"

	// 协议工厂
	var protocolFactory thrift.TProtocolFactory
	protocolFactory = thrift.NewTBinaryProtocolFactoryConf(nil)

	// 传输工厂
	var transportFactory thrift.TTransportFactory
	transportFactory = thrift.NewTTransportFactory()
	// 传输方式
	var transport thrift.TServerTransport
	transport, err := thrift.NewTServerSocket(addr)
	if err != nil {
		fmt.Println(err)
	}

	// (src/handler.go) CalculatorHandler 是 Calculator接口的实现
	calculatorHandler := NewCalculatorHandler()
	// 入参要求是  (gen-go/tutorial/tutorial.go)里的Calculator接口
	// 返回结果 CalculatorProcessor的父类SharedServiceProcessor (gen-go/tutorial/tutorial.go) 是 TProcessor接口 (thrift/processor_factory.go) 的实现
	processor := tutorial.NewCalculatorProcessor(calculatorHandler)
	// 根据 processor, transport, transportFactory, protocolFactory 创建server
	// 这些信息会保存在server结构里,在处理请求的时候会用到
	server := thrift.NewTSimpleServer4(processor, transport, transportFactory, protocolFactory)

	fmt.Println("Starting the simple server... on ", addr)
  return server.Serve()
}

(3) thrift idl 生成代码分析-go

// TODO 调用流程图

// idl生成的文件  /gen-go/tutorial/tutorial.go
package tutorial

type CalculatorProcessor struct {
  *shared.SharedServiceProcessor
}

// 创建计算处理类
// CalculatorProcessor 实现了 TProcessor接口
func NewCalculatorProcessor(handler Calculator) *CalculatorProcessor {
  // 
  self10 := &CalculatorProcessor{shared.NewSharedServiceProcessor(handler)}
  // 按照 接口的方法名、方法对应的处理类 放入processorMap
  self10.AddToProcessorMap("ping", &calculatorProcessorPing{handler:handler})
  self10.AddToProcessorMap("add", &calculatorProcessorAdd{handler:handler})
  self10.AddToProcessorMap("calculate", &calculatorProcessorCalculate{handler:handler})
  self10.AddToProcessorMap("zip", &calculatorProcessorZip{handler:handler})
  return self10
}

// 按照方法名、处理类放入processorMap 
func (p *SharedServiceProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) {
  p.processorMap[key] = processor
}

(3.1) server处理请求逻辑

// 简化了部分代码 
func (p *TSimpleServer) processRequests(client TTransport) (err error) {

  // 从client获取processor
	processor := p.processorFactory.GetProcessor(client)
	inputTransport, err := p.inputTransportFactory.GetTransport(client)
	if err != nil {
		return err
	}
	inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport)
	var outputTransport TTransport
  var outputProtocol TProtocol
  oTrans, err := p.outputTransportFactory.GetTransport(client)
	outputTransport = oTrans
  outputProtocol = p.outputProtocolFactory.GetProtocol(outputTransport)

  // 调server-proxy 
	ok, err := processor.Process(ctx, inputProtocol, outputProtocol)

	return nil
}

(3.2) service-processor逻辑 (调用idl的Process方法)

// file /gen-go/shared/shared.go
package shared

// 
type SharedServiceProcessor struct {
  processorMap map[string]thrift.TProcessorFunction  // 存储方法名和对应的方法处理函数
  handler SharedService                              // 
}

// 处理
func (p *SharedServiceProcessor) Process(ctx context.Context, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) {

  // 从传输协议里获取 方法名、版本号、序列Id
  name, _, seqId, err2 := iprot.ReadMessageBegin(ctx)
  if err2 != nil { return false, thrift.WrapTException(err2) }

  // 根据方法名 从方法映射 获取对应的处理函数 
  // 从map里根据方法名拿到对应processor 
  if processor, ok := p.GetProcessorFunction(name); ok {
    // 调用方法处理  返回成功 
    // 返回数据会写到oprot 
    return processor.Process(ctx, seqId, iprot, oprot)
  }
  // 
  iprot.Skip(ctx, thrift.STRUCT)
  iprot.ReadMessageEnd(ctx)
  // 写失败原因消息 
  x5 := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function " + name)
  oprot.WriteMessageBegin(ctx, name, thrift.EXCEPTION, seqId)
  x5.Write(ctx, oprot)
  oprot.WriteMessageEnd(ctx)
  oprot.Flush(ctx)
  // 返回失败
  return false, x5
}


// 根据方法名 从方法映射 获取对应的处理函数 
// 从map里根据方法名拿到对应processor 
func (p *SharedServiceProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) {
  processor, ok = p.processorMap[key]
  return processor, ok
}

(3.3) method-processor逻辑 (包含调用业务方法)

// file  /gen-go/tutorial/tutorial.go
package tutorial 

// 
func (p *calculatorProcessorAdd) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) {
  args := CalculatorAddArgs{}
  var err2 error

  // 读取入参
  if err2 = args.Read(ctx, iprot); err2 != nil {
    // 读取信息结束  
    iprot.ReadMessageEnd(ctx)
    // 
    x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())

    // 写返回值
    // 写返回消息开始  消息方法"add"  消息类型=thrift.EXCEPTION 序列Id使用入参里的seqId
    oprot.WriteMessageBegin(ctx, "add", thrift.EXCEPTION, seqId)
    // 写返回内容
    x.Write(ctx, oprot)
    // 写返回消息结束
    oprot.WriteMessageEnd(ctx)
    // flush
    oprot.Flush(ctx)
    // 返回
    return false, thrift.WrapTException(err2)
  }
  // 读取入参消息结束
  iprot.ReadMessageEnd(ctx)

  tickerCancel := func() {}
  // Start a goroutine to do server side connectivity check.
  if thrift.ServerConnectivityCheckInterval > 0 {
    var cancel context.CancelFunc
    ctx, cancel = context.WithCancel(ctx)
    defer cancel()
    var tickerCtx context.Context
    tickerCtx, tickerCancel = context.WithCancel(context.Background())
    defer tickerCancel()
    // 开一个goroute
    go func(ctx context.Context, cancel context.CancelFunc) {
      ticker := time.NewTicker(thrift.ServerConnectivityCheckInterval)
      defer ticker.Stop()
      for {
        select {
        case <-ctx.Done():
          return
        case <-ticker.C:
          if !iprot.Transport().IsOpen() {
            cancel()
            return
          }
        }
      }
    }(tickerCtx, cancel)
  }

  // 
  result := CalculatorAddResult{}
  // 声明返回值
  var retval int32
  // 调对应实现方法
  if retval, err2 = p.handler.Add(ctx, args.Num1, args.Num2); err2 != nil {
    // 如果调方法异常  
    // 
    tickerCancel()
    // 丢弃请求
    if err2 == thrift.ErrAbandonRequest {
      // 返回处理失败(不处理)  返回对应异常  
      return false, thrift.WrapTException(err2)
    }
    //
    x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing add: " + err2.Error())
    oprot.WriteMessageBegin(ctx, "add", thrift.EXCEPTION, seqId)
    x.Write(ctx, oprot)
    oprot.WriteMessageEnd(ctx)
    oprot.Flush(ctx)
    // 返回处理成功  方法调用异常
    return true, thrift.WrapTException(err2)
  } else {
    // 调用成功  设置返回值  
    result.Success = &retval
  }
  tickerCancel()
  // 返回消息写到oprot 
  // 写消息开始 写入方法名、消息类型=thrift.REPLY、消息序列=入参序列Id
  if err2 = oprot.WriteMessageBegin(ctx, "add", thrift.REPLY, seqId); err2 != nil {
    err = thrift.WrapTException(err2)
  }
  // 写消息内容
  if err2 = result.Write(ctx, oprot); err == nil && err2 != nil {
    err = thrift.WrapTException(err2)
  }
  // 写消息结束
  if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {
    err = thrift.WrapTException(err2)
  }
  // flush
  if err2 = oprot.Flush(ctx); err == nil && err2 != nil {
    err = thrift.WrapTException(err2)
  }
  if err != nil {
    return
  }
  // 返回处理成功
  return true, err
}

(3.4) 从消息读取入参逻辑

// 读取入参并设置到p(*CalculatorAddArgs)
func (p *CalculatorAddArgs) Read(ctx context.Context, iprot thrift.TProtocol) error {
  // 
  if _, err := iprot.ReadStructBegin(ctx); err != nil {
    return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err)
  }


  for {
    // 读取字段开始
    // 读取 字段名(空的,所以用_)、字段类型、字段Id(对应idl里指定的序号Id)  
    _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx)
    if err != nil {
      return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err)
    }

    // 如果字段类型是STOP,直接跳出循环
    if fieldTypeId == thrift.STOP { break; }

    // 这个例子里有2个字段,所以switch的时候有2个case   如果有10个字段,这儿就有10个case
    switch fieldId {
    case 1:
      // 读取时会按照idl里的类型做校验,如果字段类型不一致,跳过
      // 校验第一个字段是否是I32类型
      if fieldTypeId == thrift.I32 {
        // 读取第一个字段  
        if err := p.ReadField1(ctx, iprot); err != nil {
          return err
        }
      } else {
        // 跳过  
        if err := iprot.Skip(ctx, fieldTypeId); err != nil {
          return err
        }
      }
    case 2:
      if fieldTypeId == thrift.I32 {
        if err := p.ReadField2(ctx, iprot); err != nil {
          return err
        }
      } else {
        if err := iprot.Skip(ctx, fieldTypeId); err != nil {
          return err
        }
      }
    default:
      // 默认 跳过
      if err := iprot.Skip(ctx, fieldTypeId); err != nil {
        return err
      }
    }

    // 读取字段结束
    if err := iprot.ReadFieldEnd(ctx); err != nil {
      return err
    }
  }
  // 有异常时返回异常
  if err := iprot.ReadStructEnd(ctx); err != nil {
    return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err)
  }

  // 返回空
  return nil
}

(3.4.1) 读取某个字段

// 
func (p *CalculatorAddArgs)  ReadField1(ctx context.Context, iprot thrift.TProtocol) error {

  //   
  if v, err := iprot.ReadI32(ctx); err != nil {
  return thrift.PrependError("error reading field 1: ", err)
} else {
  // 把读取到的值赋到第一个参数上  
  p.Num1 = v
}
  return nil
}


func (p *CalculatorAddArgs)  ReadField2(ctx context.Context, iprot thrift.TProtocol) error {
  if v, err := iprot.ReadI32(ctx); err != nil {
  return thrift.PrependError("error reading field 2: ", err)
} else {
  // 把读取到的值赋到第二个参数上  
  p.Num2 = v
}
  return nil
}

(4) thrift TProcessor 源码分析-go

// thrift/processor_factory.go 
package thrift

// 
// 
// A processor is a generic object which operates upon an input stream and
// writes to some output stream.
type TProcessor interface {
	Process(ctx context.Context, in, out TProtocol) (bool, TException)

	// ProcessorMap returns a map of thrift method names to TProcessorFunctions.
	ProcessorMap() map[string]TProcessorFunction

	// AddToProcessorMap adds the given TProcessorFunction to the internal
	// processor map at the given key.
	//
	// If one is already set at the given key, it will be replaced with the new
	// TProcessorFunction.
	AddToProcessorMap(string, TProcessorFunction)
}

type TProcessorFunction interface {
	Process(ctx context.Context, seqId int32, in, out TProtocol) (bool, TException)
}

参考资料

[1] (四)– 方法调用模型分析
[2] thrift-processor
[3] Thrift之TProcess类体系原理及源码详细解析
[4] 再识RPC-thrift