前面谈过gRPC的SSL/TLS安全机制,发现设置过程比较复杂:比如证书签名:需要服务端、客户端两头都设置等。想想实际上用JWT会更加便捷,而且更安全和功能强大,因为除JWT的加密签名之外还可以把私密的用户信息放在JWT里加密后在服务端和客户端之间传递。当然,最基本的是通过对JWT的验证机制可以控制客户端对某些功能的使用权限。
通过JWT实现gRPC的函数调用权限管理原理其实很简单:客户端首先从服务端通过身份验证获取JWT,然后在调用服务函数时把这个JWT同时传给服务端进行权限验证。客户端提交身份验证请求返回JWT可以用一个独立的服务函数实现,如下面.proto文件里的GetAuthToken:
- message PBPOSCredential {
- string userid = 1;
- string password = 2;
- }
- message PBPOSToken {
- string jwt = 1;
- }
- service SendCommand {
- rpc SingleResponse(PBPOSCommand) returns (PBPOSResponse) {};
- rpc GetTxnItems(PBPOSCommand) returns (stream PBTxnItem) {};
- rpc GetAuthToken(PBPOSCredential) returns (PBPOSToken) {};
- }
比较棘手的是如何把JWT从客户端传送至服务端,因为gRPC基本上骑劫了Request和Response。其中一个方法是通过Interceptor来截取Request的header即metadata。客户端将JWT写入metadata,服务端从metadata读取JWT。
我们先看看客户端的Interceptor设置和使用:
- class AuthClientInterceptor(jwt: String) extends ClientInterceptor {
- def interceptCall[ReqT, RespT](methodDescriptor: MethodDescriptor[ReqT, RespT], callOptions: CallOptions, channel: io.grpc.Channel): ClientCall[ReqT, RespT] =
- new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](channel.newCall(methodDescriptor, callOptions)) {
- override def start(responseListener: ClientCall.Listener[RespT], headers: Metadata): Unit = {
- headers.put(Key.of("jwt", Metadata.ASCII_STRING_MARSHALLER), jwt)
- super.start(responseListener, headers)
- }
- }
- }
- ...
- val unsafeChannel = NettyChannelBuilder
- .forAddress("192.168.0.189",50051)
- .negotiationType(NegotiationType.PLAINTEXT)
- .build()
- val securedChannel = ClientInterceptors.intercept(unsafeChannel, new AuthClientInterceptor(jwt))
- val securedClient = SendCommandGrpc.blockingStub(securedChannel)
- val resp = securedClient.singleResponse(PBPOSCommand())
身份验证请求即JWT获取是不需要Interceptor的,所以要用没有Interceptor的unsafeChannel:
- //build connection channel
- val unsafeChannel = NettyChannelBuilder
- .forAddress("192.168.0.189",50051)
- .negotiationType(NegotiationType.PLAINTEXT)
- .build()
- val authClient = SendCommandGrpc.blockingStub(unsafeChannel)
- val jwt = authClient.getAuthToken(PBPOSCredential(userid="johnny",password="p4ssw0rd")).jwt
- println(s"got jwt: $jwt")
-
JWT的构建和使用已经在前面的几篇博文里讨论过了:
- package com.datatech.auth
- import pdi.jwt._
- import org.json4s.native.Json
- import org.json4s._
- import org.json4s.jackson.JsonMethods._
- import pdi.jwt.algorithms._
- import scala.util._
- object AuthBase {
- type UserInfo = Map[String, Any]
- case class AuthBase(
- algorithm: JwtAlgorithm = JwtAlgorithm.HMD5,
- secret: String = "OpenSesame",
- getUserInfo: (String,String) => Option[UserInfo] = null) {
- ctx =>
- def withAlgorithm(algo: JwtAlgorithm): AuthBase = ctx.copy(algorithm = algo)
- def withSecretKey(key: String): AuthBase = ctx.copy(secret = key)
- def withUserFunc(f: (String, String) => Option[UserInfo]): AuthBase = ctx.copy(getUserInfo = f)
- def authenticateToken(token: String): Option[String] =
- algorithm match {
- case algo: JwtAsymmetricAlgorithm =>
- Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtAsymmetricAlgorithm]))) match {
- case true => Some(token)
- case _ => None
- }
- case _ =>
- Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtHmacAlgorithm]))) match {
- case true => Some(token)
- case _ => None
- }
- }
- def getUserInfo(token: String): Option[UserInfo] = {
- algorithm match {
- case algo: JwtAsymmetricAlgorithm =>
- Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtAsymmetricAlgorithm])) match {
- case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) \ "userinfo").values.asInstanceOf[UserInfo])
- case Failure(err) => None
- }
- case _ =>
- Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtHmacAlgorithm])) match {
- case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) \ "userinfo").values.asInstanceOf[UserInfo])
- case Failure(err) => None
- }
- }
- }
- def issueJwt(userinfo: UserInfo): String = {
- val claims = JwtClaim() + Json(DefaultFormats).write(("userinfo", userinfo))
- Jwt.encode(claims, secret, algorithm)
- }
- }
- }
服务端Interceptor的构建和设置如下:
- abstract class FutureListener[Q](implicit ec: ExecutionContext) extends Listener[Q] {
- protected val delegate: Future[Listener[Q]]
- private val eventually = delegate.foreach _
- override def onComplete(): Unit = eventually { _.onComplete() }
- override def onCancel(): Unit = eventually { _.onCancel() }
- override def onMessage(message: Q): Unit = eventually { _ onMessage message }
- override def onHalfClose(): Unit = eventually { _.onHalfClose() }
- override def onReady(): Unit = eventually { _.onReady() }
- }
- object Keys {
- val AUTH_META_KEY: Metadata.Key[String] = of("jwt", Metadata.ASCII_STRING_MARSHALLER)
- val AUTH_CTX_KEY: Context.Key[String] = key("jwt")
- }
- class AuthorizationInterceptor(implicit ec: ExecutionContext) extends ServerInterceptor {
- override def interceptCall[Q, R](
- call: ServerCall[Q, R],
- headers: Metadata,
- next: ServerCallHandler[Q, R]
- ): Listener[Q] = {
- val prevCtx = Context.current
- val jwt = headers.get(Keys.AUTH_META_KEY)
- println(s"!!!!!!!!!!! $jwt !!!!!!!!!!")
- new FutureListener[Q] {
- protected val delegate = Future {
- val nextCtx = prevCtx withValue (Keys.AUTH_CTX_KEY, jwt)
- Contexts.interceptCall(nextCtx, call, headers, next)
- }
- }
- }
- }
- trait gRPCServer {
- def runServer(service: ServerServiceDefinition)(implicit actorSys: ActorSystem): Unit = {
- import actorSys.dispatcher
- val server = NettyServerBuilder
- .forPort(50051)
- .addService(ServerInterceptors.intercept(service,
- new AuthorizationInterceptor))
- .build
- .start
- // make sure our server is stopped when jvm is shut down
- Runtime.getRuntime.addShutdownHook(new Thread() {
- override def run(): Unit = {
- server.shutdown()
- server.awaitTermination()
- }
- })
- }
- }
注意:客户端上传的request-header只能在构建server时接触到,在具体服务函数里是无法调用request-header的,但gRPC又一个结构Context可以在两个地方都能调用。所以,我们可以在构建server时把JWT从header搬到Context里。不过,千万注意这个Context的读写必须在同一个线程里。在服务端的Interceptor里我们把JWT从metadata里读出然后写入Context。在需要权限管理的服务函数里再从Context里读取JWT进行验证:
- override def singleResponse(request: PBPOSCommand): Future[PBPOSResponse] = {
- val jwt = AUTH_CTX_KEY.get
- println(s"***********$jwt**************")
- val optUserInfo = authenticator.getUserInfo(jwt)
- val shopid = optUserInfo match {
- case Some(m) => m("shopid")
- case None => "invalid token!"
- }
- FastFuture.successful(PBPOSResponse(msg=s"shopid:$shopid"))
- }
JWT的构建也是一个服务函数:
- val authenticator = new AuthBase()
- .withAlgorithm(JwtAlgorithm.HS256)
- .withSecretKey("OpenSesame")
- .withUserFunc(getValidUser)
- override def getAuthToken(request: PBPOSCredential): Future[PBPOSToken] = {
- getValidUser(request.userid, request.password) match {
- case Some(userinfo) => FastFuture.successful(PBPOSToken(authenticator.issueJwt(userinfo)))
- case None => FastFuture.successful(PBPOSToken("Invalid Token!"))
- }
- }
还需要一个模拟的身份验证服务函数:
- package com.datatech.auth
- object MockUserAuthService {
- type UserInfo = Map[String,Any]
- case class User(username: String, password: String, userInfo: UserInfo)
- val validUsers = Seq(User("johnny", "p4ssw0rd",Map("shopid" -> "1101", "userid" -> "101"))
- ,User("tiger", "secret", Map("shopid" -> "1101" , "userid" -> "102")))
- def getValidUser(userid: String, pswd: String): Option[UserInfo] =
- validUsers.find(user => user.username == userid && user.password == pswd) match {
- case Some(user) => Some(user.userInfo)
- case _ => None
- }
- }
下面是本次示范的源代码:
project/plugins.sbt
- addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.9")
- addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2")
- addSbtPlugin("com.typesafe.sbt" % "sbt-native-packager" % "1.3.15")
- addSbtPlugin("com.thesamet" % "sbt-protoc" % "0.99.21")
- addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2")
- libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.9.0-M6"
build.sbt
- name := "grpc-jwt"
-
- version := "0.1"
-
- version := "0.1"
-
- scalaVersion := "2.12.8"
-
- scalacOptions += "-Ypartial-unification"
-
- val akkaversion = "2.5.23"
-
- libraryDependencies := Seq(
- "com.typesafe.akka" %% "akka-cluster-metrics" % akkaversion,
- "com.typesafe.akka" %% "akka-cluster-sharding" % akkaversion,
- "com.typesafe.akka" %% "akka-persistence" % akkaversion,
- "com.lightbend.akka" %% "akka-stream-alpakka-cassandra" % "1.0.1",
- "org.mongodb.scala" %% "mongo-scala-driver" % "2.6.0",
- "com.lightbend.akka" %% "akka-stream-alpakka-mongodb" % "1.0.1",
- "com.typesafe.akka" %% "akka-persistence-query" % akkaversion,
- "com.typesafe.akka" %% "akka-persistence-cassandra" % "0.97",
- "com.datastax.cassandra" % "cassandra-driver-core" % "3.6.0",
- "com.datastax.cassandra" % "cassandra-driver-extras" % "3.6.0",
- "ch.qos.logback" % "logback-classic" % "1.2.3",
- "io.monix" %% "monix" % "3.0.0-RC2",
- "org.typelevel" %% "cats-core" % "2.0.0-M1",
- "io.grpc" % "grpc-netty" % scalapb.compiler.Version.grpcJavaVersion,
- "io.netty" % "netty-tcnative-boringssl-static" % "2.0.22.Final",
- "com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.compiler.Version.scalapbVersion % "protobuf",
- "com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion,
-
- "com.pauldijou" %% "jwt-core" % "3.0.1",
- "de.heikoseeberger" %% "akka-http-json4s" % "1.22.0",
- "org.json4s" %% "json4s-native" % "3.6.1",
- "com.typesafe.akka" %% "akka-http-spray-json" % "10.1.8",
- "org.json4s" %% "json4s-jackson" % "3.6.7",
- "org.json4s" %% "json4s-ext" % "3.6.7"
-
- )
-
- // (optional) If you need scalapb/scalapb.proto or anything from
- // google/protobuf/*.proto
- //libraryDependencies += "com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.compiler.Version.scalapbVersion % "protobuf"
-
-
- PB.targets in Compile := Seq(
- scalapb.gen() -> (sourceManaged in Compile).value
- )
-
- enablePlugins(JavaAppPackaging)
main/protobuf/posmessages.proto
- syntax = "proto3";
-
- import "google/protobuf/wrappers.proto";
- import "google/protobuf/any.proto";
- import "scalapb/scalapb.proto";
-
- option (scalapb.options) = {
- // use a custom Scala package name
- // package_name: "io.ontherocks.introgrpc.demo"
-
- // don't append file name to package
- flat_package: true
-
- // generate one Scala file for all messages (services still get their own file)
- single_file: true
-
- // add imports to generated file
- // useful when extending traits or using custom types
- // import: "io.ontherocks.hellogrpc.RockingMessage"
-
- // code to put at the top of generated file
- // works only with `single_file: true`
- //preamble: "sealed trait SomeSealedTrait"
- };
-
- package com.datatech.pos.messages;
-
- message PBVchState { //单据状态
- string opr = 1; //收款员
- int64 jseq = 2; //begin journal sequence for read-side replay
- int32 num = 3; //当前单号
- int32 seq = 4; //当前序号
- bool void = 5; //取消模式
- bool refd = 6; //退款模式
- bool susp = 7; //挂单
- bool canc = 8; //废单
- bool due = 9; //当前余额
- string su = 10; //主管编号
- string mbr = 11; //会员号
- int32 mode = 12; //当前操作流程:0=logOff, 1=LogOn, 2=Payment
- }
-
- message PBTxnItem { //交易记录
- string txndate = 1; //交易日期
- string txntime = 2; //录入时间
- string opr = 3; //操作员
- int32 num = 4; //销售单号
- int32 seq = 5; //交易序号
- int32 txntype = 6; //交易类型
- int32 salestype = 7; //销售类型
- int32 qty = 8; //交易数量
- int32 price = 9; //单价(分)
- int32 amount = 10; //码洋(分)
- int32 disc = 11; //折扣率 (%)
- int32 dscamt = 12; //折扣额:负值 net实洋 = amount + dscamt
- string member = 13; //会员卡号
- string code = 14; //编号(商品、卡号...)
- string acct = 15; //账号
- string dpt = 16; //部类
- }
-
- message PBPOSResponse {
- int32 sts = 1;
- string msg = 2;
- PBVchState voucher = 3;
- repeated PBTxnItem txnitems = 4;
-
- }
-
- message PBPOSCommand {
- string commandname = 1;
- string delimitedparams = 2;
- }
-
- message PBPOSCredential {
- string userid = 1;
- string password = 2;
- }
- message PBPOSToken {
- string jwt = 1;
- }
-
- service SendCommand {
- rpc SingleResponse(PBPOSCommand) returns (PBPOSResponse) {};
- rpc GetTxnItems(PBPOSCommand) returns (stream PBTxnItem) {};
- rpc GetAuthToken(PBPOSCredential) returns (PBPOSToken) {};
-
- }
gRPCServer.scala
- package com.datatech.grpc.server
-
- import io.grpc.ServerServiceDefinition
- import io.grpc.netty.NettyServerBuilder
- import io.grpc.ServerInterceptors
- import scala.concurrent._
- import io.grpc.Context
- import io.grpc.Contexts
- import io.grpc.ServerCall
- import io.grpc.ServerCallHandler
- import io.grpc.ServerInterceptor
- import io.grpc.Metadata
- import io.grpc.Metadata.Key.of
- import io.grpc.Context.key
- import io.grpc.ServerCall.Listener
- import akka.actor._
-
-
- abstract class FutureListener[Q](implicit ec: ExecutionContext) extends Listener[Q] {
-
- protected val delegate: Future[Listener[Q]]
-
- private val eventually = delegate.foreach _
-
- override def onComplete(): Unit = eventually { _.onComplete() }
- override def onCancel(): Unit = eventually { _.onCancel() }
- override def onMessage(message: Q): Unit = eventually { _ onMessage message }
- override def onHalfClose(): Unit = eventually { _.onHalfClose() }
- override def onReady(): Unit = eventually { _.onReady() }
-
- }
-
- object Keys {
- val AUTH_META_KEY: Metadata.Key[String] = of("jwt", Metadata.ASCII_STRING_MARSHALLER)
- val AUTH_CTX_KEY: Context.Key[String] = key("jwt")
- }
-
- class AuthorizationInterceptor(implicit ec: ExecutionContext) extends ServerInterceptor {
- override def interceptCall[Q, R](
- call: ServerCall[Q, R],
- headers: Metadata,
- next: ServerCallHandler[Q, R]
- ): Listener[Q] = {
-
- val prevCtx = Context.current
- val jwt = headers.get(Keys.AUTH_META_KEY)
-
- println(s"!!!!!!!!!!! $jwt !!!!!!!!!!")
-
- new FutureListener[Q] {
- protected val delegate = Future {
- val nextCtx = prevCtx withValue (Keys.AUTH_CTX_KEY, jwt)
- Contexts.interceptCall(nextCtx, call, headers, next)
- }
- }
- }
- }
-
- trait gRPCServer {
-
- def runServer(service: ServerServiceDefinition)(implicit actorSys: ActorSystem): Unit = {
- import actorSys.dispatcher
- val server = NettyServerBuilder
- .forPort(50051)
- .addService(ServerInterceptors.intercept(service,
- new AuthorizationInterceptor))
- .build
- .start
- // make sure our server is stopped when jvm is shut down
- Runtime.getRuntime.addShutdownHook(new Thread() {
- override def run(): Unit = {
- server.shutdown()
- server.awaitTermination()
- }
- })
- }
-
- }
POSServices.scala
- package com.datatech.pos.service
- import com.datatech.grpc.server.Keys._
- import akka.http.scaladsl.util.FastFuture
- import com.datatech.pos.messages._
- import com.datatech.grpc.server._
- import com.datatech.auth.MockUserAuthService._
-
- import scala.concurrent.Future
- import com.datatech.auth.AuthBase._
- import pdi.jwt._
- import akka.actor._
- import io.grpc.stub.StreamObserver
-
-
- object POSServices extends gRPCServer {
- type UserInfo = Map[String, Any]
-
- class POSServices extends SendCommandGrpc.SendCommand {
-
- val authenticator = new AuthBase()
- .withAlgorithm(JwtAlgorithm.HS256)
- .withSecretKey("OpenSesame")
- .withUserFunc(getValidUser)
-
- override def getTxnItems(request: PBPOSCommand, responseObserver: StreamObserver[PBTxnItem]): Unit = ???
-
- override def singleResponse(request: PBPOSCommand): Future[PBPOSResponse] = {
- val jwt = AUTH_CTX_KEY.get
- println(s"***********$jwt**************")
- val optUserInfo = authenticator.getUserInfo(jwt)
- val shopid = optUserInfo match {
- case Some(m) => m("shopid")
- case None => "invalid token!"
- }
- FastFuture.successful(PBPOSResponse(msg=s"shopid:$shopid"))
- }
-
- override def getAuthToken(request: PBPOSCredential): Future[PBPOSToken] = {
- getValidUser(request.userid, request.password) match {
- case Some(userinfo) => FastFuture.successful(PBPOSToken(authenticator.issueJwt(userinfo)))
- case None => FastFuture.successful(PBPOSToken("Invalid Token!"))
- }
- }
- }
-
- def main(args: Array[String]) = {
- implicit val system = ActorSystem("grpc-system")
- val svc = SendCommandGrpc.bindService(new POSServices, system.dispatcher)
- runServer(svc)
- }
- }
AuthBase.scala
- package com.datatech.auth
-
- import pdi.jwt._
- import org.json4s.native.Json
- import org.json4s._
- import org.json4s.jackson.JsonMethods._
- import pdi.jwt.algorithms._
- import scala.util._
-
- object AuthBase {
- type UserInfo = Map[String, Any]
- case class AuthBase(
- algorithm: JwtAlgorithm = JwtAlgorithm.HMD5,
- secret: String = "OpenSesame",
- getUserInfo: (String,String) => Option[UserInfo] = null) {
- ctx =>
-
- def withAlgorithm(algo: JwtAlgorithm): AuthBase = ctx.copy(algorithm = algo)
-
- def withSecretKey(key: String): AuthBase = ctx.copy(secret = key)
-
- def withUserFunc(f: (String, String) => Option[UserInfo]): AuthBase = ctx.copy(getUserInfo = f)
-
- def authenticateToken(token: String): Option[String] =
- algorithm match {
- case algo: JwtAsymmetricAlgorithm =>
- Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtAsymmetricAlgorithm]))) match {
- case true => Some(token)
- case _ => None
- }
- case _ =>
- Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtHmacAlgorithm]))) match {
- case true => Some(token)
- case _ => None
- }
- }
-
- def getUserInfo(token: String): Option[UserInfo] = {
- algorithm match {
- case algo: JwtAsymmetricAlgorithm =>
- Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtAsymmetricAlgorithm])) match {
- case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) \ "userinfo").values.asInstanceOf[UserInfo])
- case Failure(err) => None
- }
- case _ =>
- Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtHmacAlgorithm])) match {
- case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) \ "userinfo").values.asInstanceOf[UserInfo])
- case Failure(err) => None
- }
- }
- }
-
- def issueJwt(userinfo: UserInfo): String = {
- val claims = JwtClaim() + Json(DefaultFormats).write(("userinfo", userinfo))
- Jwt.encode(claims, secret, algorithm)
- }
- }
-
- }
POSClient.scala
- package com.datatech.pos.client
-
- import com.datatech.pos.messages.{PBPOSCommand, PBPOSCredential, SendCommandGrpc}
- import io.grpc.stub.StreamObserver
- import io.grpc.netty.{ NegotiationType, NettyChannelBuilder}
- import io.grpc.CallOptions
- import io.grpc.ClientCall
- import io.grpc.ClientInterceptor
- import io.grpc.ForwardingClientCall
- import io.grpc.Metadata
- import io.grpc.Metadata.Key
- import io.grpc.MethodDescriptor
- import io.grpc.ClientInterceptors
-
- object POSClient {
- class AuthClientInterceptor(jwt: String) extends ClientInterceptor {
- def interceptCall[ReqT, RespT](methodDescriptor: MethodDescriptor[ReqT, RespT], callOptions: CallOptions, channel: io.grpc.Channel): ClientCall[ReqT, RespT] =
- new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](channel.newCall(methodDescriptor, callOptions)) {
- override def start(responseListener: ClientCall.Listener[RespT], headers: Metadata): Unit = {
- headers.put(Key.of("jwt", Metadata.ASCII_STRING_MARSHALLER), jwt)
- super.start(responseListener, headers)
- }
- }
- }
-
- def main(args: Array[String]): Unit = {
-
- //build connection channel
- val unsafeChannel = NettyChannelBuilder
- .forAddress("192.168.0.189",50051)
- .negotiationType(NegotiationType.PLAINTEXT)
- .build()
-
-
- val authClient = SendCommandGrpc.blockingStub(unsafeChannel)
- val jwt = authClient.getAuthToken(PBPOSCredential(userid="johnny",password="p4ssw0rd")).jwt
- println(s"got jwt: $jwt")
-
-
- val securedChannel = ClientInterceptors.intercept(unsafeChannel, new AuthClientInterceptor(jwt))
-
- val securedClient = SendCommandGrpc.blockingStub(securedChannel)
-
- val resp = securedClient.singleResponse(PBPOSCommand())
-
- println(s"secured response: $resp")
-
- // wait for async execution
- scala.io.StdIn.readLine()
- }
-
-
- }