View on GitHub
File Changes

                      
  object KRequest {
    case class FindNodes[A](requestId: UUID, nodeRecord: NodeRecord[A], targetNodeId: BitVector) extends KRequest[A]
+

                      
+
    case class Ping[A](requestId: UUID, nodeRecord: NodeRecord[A]) extends KRequest[A]
  }

                      
  sealed trait KResponse[A] extends KMessage[A]

                      
  object KResponse {
    case class Nodes[A](requestId: UUID, nodeRecord: NodeRecord[A], nodes: Seq[NodeRecord[A]]) extends KResponse[A]
+

                      
+
    case class Pong[A](requestId: UUID, nodeRecord: NodeRecord[A]) extends KResponse[A]
  }
}
package io.iohk.scalanet.peergroup.kademlia

                      
-
import io.iohk.scalanet.peergroup.kademlia.KMessage.KRequest.FindNodes
-
import io.iohk.scalanet.peergroup.kademlia.KMessage.KResponse.Nodes
+
import io.iohk.scalanet.peergroup.kademlia.KMessage.{KRequest, KResponse}
+
import io.iohk.scalanet.peergroup.kademlia.KMessage.KRequest.{FindNodes, Ping}
+
import io.iohk.scalanet.peergroup.kademlia.KMessage.KResponse.{Nodes, Pong}
import io.iohk.scalanet.peergroup.kademlia.KRouter.NodeRecord
import io.iohk.scalanet.peergroup.{Channel, PeerGroup}
import monix.eval.Task
    * @return the future response
    */
  def findNodes(to: NodeRecord[A], request: FindNodes[A]): Task[Nodes[A]]
+

                      
+
  /**
+
    * Server side PING handler.
+
    * @return An Observable for receiving PING requests.
+
    *         Each element contains a tuple consisting of a PING request
+
    *         with a function for accepting the require PONG response.
+
    */
+
  def ping: Observable[(Ping[A], Pong[A] => Task[Unit])]
+

                      
+
  /**
+
    * Send a PING message to another peer.
+
    * @param to the peer to send the message to
+
    * @param request the PING request
+
    * @return the future response
+
    */
+
  def ping(to: NodeRecord[A], request: Ping[A]): Task[Pong[A]]
}

                      
object KNetwork {
  )(implicit scheduler: Scheduler)
      extends KNetwork[A] {

                      
-
    override def findNodes(to: NodeRecord[A], message: FindNodes[A]): Task[Nodes[A]] = {
+
    override def findNodes: Observable[(FindNodes[A], Nodes[A] => Task[Unit])] = serverTemplate {
+
      case f @ FindNodes(_, _, _) => f
+
    }
+

                      
+
    override def ping: Observable[(Ping[A], Pong[A] => Task[Unit])] = serverTemplate {
+
      case p @ Ping(_, _) => p
+
    }
+

                      
+
    override def findNodes(to: NodeRecord[A], request: FindNodes[A]): Task[Nodes[A]] = {
+
      requestTemplate(to, request, { case n @ Nodes(_, _, _) => n })
+
    }
+

                      
+
    override def ping(to: NodeRecord[A], request: Ping[A]): Task[Pong[A]] = {
+
      requestTemplate(to, request, { case p @ Pong(_, _) => p })
+
    }
+

                      
+
    private def requestTemplate[Request <: KRequest[A], Response <: KResponse[A]](
+
        to: NodeRecord[A],
+
        message: Request,
+
        pf: PartialFunction[KMessage[A], Response]
+
    ): Task[Response] = {
      peerGroup
        .client(to.routingAddress)
        .bracket { clientChannel =>
-
          makeFindNodesRequest(message, clientChannel)
+
          sendRequest(message, clientChannel, pf)
        } { clientChannel =>
          clientChannel.close()
        }
-
    }

                      
-
    override def findNodes: Observable[(FindNodes[A], Nodes[A] => Task[Unit])] = {
-
      peerGroup.server().collectChannelCreated.mapTask { channel: Channel[A, KMessage[A]] =>
-
        channel.in
-
          .collect {
-
            case f @ FindNodes(_, _, _) =>
-
              (f, nodesTask(channel))
-
          }
-
          .headL
-
          .timeout(requestTimeout)
-
          .doOnFinish(closeIfAnError(channel))
-
      }
    }

                      
-
    private def makeFindNodesRequest(message: FindNodes[A], clientChannel: Channel[A, KMessage[A]]): Task[Nodes[A]] = {
+
    private def sendRequest[Request <: KRequest[A], Response <: KResponse[A]](
+
        message: Request,
+
        clientChannel: Channel[A, KMessage[A]],
+
        pf: PartialFunction[KMessage[A], Response]
+
    ): Task[Response] = {
      for {
        _ <- clientChannel.sendMessage(message).timeout(requestTimeout)
-
        nodes <- clientChannel.in
-
          .collect { case n @ Nodes(_, _, _) => n }
+
        response <- clientChannel.in
+
          .collect(pf)
          .headL
          .timeout(requestTimeout)
-
      } yield nodes
+
      } yield response
    }

                      
    private def closeIfAnError(
      maybeError.fold(Task.unit)(_ => channel.close())
    }

                      
-
    private def nodesTask(
+
    private def sendResponse(
        channel: Channel[A, KMessage[A]]
-
    ): Nodes[A] => Task[Unit] = { nodes =>
+
    ): KMessage[A] => Task[Unit] = { message =>
      channel
-
        .sendMessage(nodes)
+
        .sendMessage(message)
        .timeout(requestTimeout)
        .doOnFinish(_ => channel.close())
    }
+

                      
+
    private def serverTemplate[Request <: KRequest[A], Response <: KResponse[A]](
+
        pf: PartialFunction[KMessage[A], Request]
+
    ): Observable[(Request, KMessage[A] => Task[Unit])] = {
+
      peerGroup.server().collectChannelCreated.mapTask { channel: Channel[A, KMessage[A]] =>
+
        channel.in
+
          .collect(pf)
+
          .map(request => (request, sendResponse(channel)))
+
          .headL
+
          .timeout(requestTimeout)
+
          .doOnFinish(closeIfAnError(channel))
+
      }
+
    }
  }
}
import java.util.UUID
import java.util.concurrent.TimeoutException

                      
-
import io.iohk.scalanet.peergroup.kademlia.KMessage.KRequest.FindNodes
-
import io.iohk.scalanet.peergroup.kademlia.KMessage.KResponse.Nodes
+
import io.iohk.scalanet.peergroup.kademlia.KMessage.KRequest.{FindNodes, Ping}
+
import io.iohk.scalanet.peergroup.kademlia.KMessage.KResponse.{Nodes, Pong}
import io.iohk.scalanet.peergroup.{Channel, PeerGroup}
import io.iohk.scalanet.peergroup.kademlia.KNetwork.KNetworkScalanetImpl
import io.iohk.scalanet.peergroup.kademlia.KRouter.NodeRecord
import org.scalatest.FlatSpec
import org.scalatest.Matchers._
import org.scalatest.mockito.MockitoSugar._
-
import org.mockito.Mockito.{verify, when, never}
+
import org.mockito.Mockito.{never, verify, when}

                      
import scala.concurrent.duration._
import monix.execution.Scheduler.Implicits.global
import org.scalatest.concurrent.ScalaFutures._
import io.iohk.scalanet.TaskValues._
import KNetworkSpec._
import io.iohk.scalanet.peergroup.PeerGroup.ServerEvent.ChannelCreated
+
import io.iohk.scalanet.peergroup.kademlia.KMessage.{KRequest, KResponse}
+
import org.scalatest.prop.TableDrivenPropertyChecks._

                      
class KNetworkSpec extends FlatSpec {

                      
  implicit val patienceConfig = PatienceConfig(1 second)

                      
-
  "Server findNodes" should "not close server channels (it is the responsibility of the response handler)" in {
-
    val (network, peerGroup) = createKNetwork
-
    val channel = mock[Channel[String, KMessage[String]]]
-
    when(peerGroup.server()).thenReturn(Observable.eval(ChannelCreated(channel)))
-
    when(channel.in).thenReturn(Observable.eval(findNodes))
-
    when(channel.close()).thenReturn(Task.unit)
-

                      
-
    val (request, _) = network.findNodes.headL.runAsync.futureValue
-

                      
-
    request shouldBe findNodes
-
    verify(channel, never()).close()
-
  }
-

                      
-
  "Server findNodes" should "close server channels when a request does not arrive before a timeout" in {
-
    val (network, peerGroup) = createKNetwork
-
    val channel = mock[Channel[String, KMessage[String]]]
-
    when(peerGroup.server()).thenReturn(Observable.eval(ChannelCreated(channel)))
-
    when(channel.in).thenReturn(Observable.never)
-
    when(channel.close()).thenReturn(Task.unit)
-

                      
-
    val t = network.findNodes.headL.runAsync.failed.futureValue
-

                      
-
    t shouldBe a[TimeoutException]
-
    verify(channel).close()
-
  }
-

                      
-
  "Server findNodes" should "close server channel in the response task" in {
-
    val (network, peerGroup) = createKNetwork
-
    val channel = mock[Channel[String, KMessage[String]]]
-
    when(peerGroup.server()).thenReturn(Observable.eval(ChannelCreated(channel)))
-
    when(channel.in).thenReturn(Observable.eval(findNodes))
-
    when(channel.sendMessage(nodes)).thenReturn(Task.unit)
-
    when(channel.close()).thenReturn(Task.unit)
-

                      
-
    val (_, responseHandler) = network.findNodes.headL.runAsync.futureValue
-
    responseHandler(nodes).evaluated
-

                      
-
    verify(channel).close()
-
  }
-

                      
-
  "Server findNodes" should "close server channel in timed out response task" in {
-
    val (network, peerGroup) = createKNetwork
-
    val channel = mock[Channel[String, KMessage[String]]]
-
    when(peerGroup.server()).thenReturn(Observable.eval(ChannelCreated(channel)))
-
    when(channel.in).thenReturn(Observable.eval(findNodes))
-
    when(channel.sendMessage(nodes)).thenReturn(Task.never)
-
    when(channel.close()).thenReturn(Task.unit)
-

                      
-
    val (_, responseHandler) = network.findNodes.headL.runAsync.futureValue
-
    val t = responseHandler(nodes).failed.evaluated
-

                      
-
    t shouldBe a[TimeoutException]
-
    verify(channel).close()
-
  }
-

                      
-
  "Client findNodes" should "close client channels when requests are successful" in {
-
    val (network, peerGroup) = createKNetwork
-
    val client = mock[Channel[String, KMessage[String]]]
-
    when(peerGroup.client(targetRecord.routingAddress)).thenReturn(Task(client))
-
    when(client.sendMessage(findNodes)).thenReturn(Task.unit)
-
    when(client.in).thenReturn(Observable.eval(nodes))
-
    when(client.close()).thenReturn(Task.unit)
-

                      
-
    val response: Nodes[String] =
-
      network.findNodes(targetRecord, findNodes).evaluated
-

                      
-
    response shouldBe nodes
-
    verify(client).close()
-
  }
-

                      
-
  "Client findNodes" should "pass exception when client call fails" in {
-
    val (network, peerGroup) = createKNetwork
-
    val client = mock[Channel[String, KMessage[String]]]
-
    val exception = new Exception("failed")
-
    when(peerGroup.client(targetRecord.routingAddress))
-
      .thenReturn(Task.raiseError(exception))
-
    when(client.close()).thenReturn(Task.unit)
-

                      
-
    val t: Throwable =
-
      network.findNodes(targetRecord, findNodes).failed.evaluated
-

                      
-
    t shouldBe exception
-
  }
-

                      
-
  "Client findNodes" should "close client channels when sendMessage calls fail" in {
-
    val (network, peerGroup) = createKNetwork
-
    val client = mock[Channel[String, KMessage[String]]]
-
    val exception = new Exception("failed")
-
    when(peerGroup.client(targetRecord.routingAddress)).thenReturn(Task(client))
-
    when(client.sendMessage(findNodes)).thenReturn(Task.raiseError(exception))
-
    when(client.close()).thenReturn(Task.unit)
-

                      
-
    val t: Throwable =
-
      network.findNodes(targetRecord, findNodes).failed.evaluated
-

                      
-
    t shouldBe exception
-
    verify(client).close()
-
  }
-

                      
-
  "Client findNodes" should "close client channels when response fails to arrive" in {
-
    val (network, peerGroup) = createKNetwork
-
    val client = mock[Channel[String, KMessage[String]]]
-
    when(peerGroup.client(targetRecord.routingAddress)).thenReturn(Task(client))
-
    when(client.sendMessage(findNodes)).thenReturn(Task.unit)
-
    when(client.in).thenReturn(Observable.fromTask(Task.never))
-
    when(client.close()).thenReturn(Task.unit)
-

                      
-
    val t: Throwable =
-
      network.findNodes(targetRecord, findNodes).failed.evaluated
-

                      
-
    t shouldBe a[TimeoutException]
-
    verify(client).close()
+
  private val getFindNodesRequest: KNetwork[String] => Task[KRequest[String]] = getActualRequest(_.findNodes)
+
  private val getPingRequest: KNetwork[String] => Task[KRequest[String]] = getActualRequest(_.ping)
+

                      
+
  private val sendFindNodesResponse: Nodes[String] => KNetwork[String] => Task[Unit] = sendResponse(_.findNodes)
+
  private val sendPingResponse: Pong[String] => KNetwork[String] => Task[Unit] = sendResponse(_.ping)
+

                      
+
  private val sendFindNodesRequest: (NodeRecord[String], FindNodes[String]) => KNetwork[String] => Task[Nodes[String]] =
+
    (to, request) => network => network.findNodes(to, request)
+

                      
+
  private val sendPingRequest: (NodeRecord[String], Ping[String]) => KNetwork[String] => Task[Pong[String]] =
+
    (to, request) => network => network.ping(to, request)
+

                      
+
  private val rpcs = Table(
+
    ("Label", "Request", "Response", "Request extractor", "Response application", "Client RPC"),
+
    (
+
      "FIND_NODES",
+
      findNodes,
+
      nodes,
+
      getFindNodesRequest,
+
      sendFindNodesResponse(nodes),
+
      sendFindNodesRequest(targetRecord, findNodes)
+
    ),
+
    ("PING", ping, pong, getPingRequest, sendPingResponse(pong), sendPingRequest(targetRecord, ping))
+
  )
+

                      
+
  forAll(rpcs) { (label, request, response, requestExtractor, responseApplication, clientRpc) =>
+
    s"Server $label" should "not close server channels (it is the responsibility of the response handler)" in {
+
      val (network, peerGroup) = createKNetwork
+
      val channel = mock[Channel[String, KMessage[String]]]
+
      when(peerGroup.server())
+
        .thenReturn(Observable.eval(ChannelCreated(channel)))
+
      when(channel.in).thenReturn(Observable.eval(request))
+
      when(channel.close()).thenReturn(Task.unit)
+

                      
+
      val actualRequest = requestExtractor(network).runAsync.futureValue
+

                      
+
      actualRequest shouldBe request
+
      verify(channel, never()).close()
+
    }
+

                      
+
    s"Server $label" should "close server channels when a request does not arrive before a timeout" in {
+
      val (network, peerGroup) = createKNetwork
+
      val channel = mock[Channel[String, KMessage[String]]]
+
      when(peerGroup.server())
+
        .thenReturn(Observable.eval(ChannelCreated(channel)))
+
      when(channel.in).thenReturn(Observable.never)
+
      when(channel.close()).thenReturn(Task.unit)
+

                      
+
      val t = requestExtractor(network).runAsync.failed.futureValue
+

                      
+
      t shouldBe a[TimeoutException]
+
      verify(channel).close()
+
    }
+

                      
+
    s"Server $label" should "close server channel in the response task" in {
+
      val (network, peerGroup) = createKNetwork
+
      val channel = mock[Channel[String, KMessage[String]]]