Skip to content

Commit

Permalink
update Api Errors adding support for WWW-Authenticate header, update …
Browse files Browse the repository at this point in the history
…rejection and exception default handlers
  • Loading branch information
fupelaqu committed Nov 10, 2023
1 parent 966b10b commit afa96d0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 26 deletions.
26 changes: 26 additions & 0 deletions server/src/main/scala/app/softnetwork/api/server/ApiErrors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.Route
import app.softnetwork.serialization.commonFormats
import org.json4s.Formats
import sttp.model.headers.WWWAuthenticateChallenge
import sttp.model.{HeaderNames, StatusCode, Uri}
import sttp.tapir.EndpointOutput.OneOf
import sttp.tapir.server.PartialServerEndpointWithSecurityOutput
Expand Down Expand Up @@ -70,6 +71,23 @@ object ApiErrors extends SchemaDerivation with TapirJson4s {
}
)(_.left.get.toString())

case class UnauthorizedWithChallenge(scheme: String, realm: String) extends ErrorInfo {
override val message: String = "Unauthorized"
override def toString: String = WWWAuthenticateChallenge(scheme).realm(realm).toString()
}

implicit val unauthorizedWithChallengeCodec
: Codec[String, UnauthorizedWithChallenge, CodecFormat.TextPlain] =
Codec.string.mapDecode(s =>
WWWAuthenticateChallenge.parseSingle(s) match {
case Right(challenge) =>
DecodeResult.Value(
UnauthorizedWithChallenge(challenge.scheme, challenge.realm.getOrElse(""))
)
case Left(_) => DecodeResult.Error(s, new Exception("Cannot parse WWW-Authenticate header"))
}
)(_.toString())

implicit def apiError2Route(apiError: ErrorInfo)(implicit formats: Formats): Route =
apiError match {
case r: BadRequest => complete(HttpResponse(StatusCodes.BadRequest, entity = r))
Expand Down Expand Up @@ -134,13 +152,20 @@ object ApiErrors extends SchemaDerivation with TapirJson4s {
.example(ApiErrors.ErrorMessage("Test error message"))
)

val unauthorizedWithChallengeVariant: EndpointOutput.OneOfVariant[UnauthorizedWithChallenge] =
oneOfVariant(
statusCode(StatusCode.Unauthorized)
.and(header[UnauthorizedWithChallenge](HeaderNames.WwwAuthenticate))
)

val oneOfApiErrors: EndpointOutput.OneOf[ApiErrors.ErrorInfo, ApiErrors.ErrorInfo] =
oneOf[ApiErrors.ErrorInfo](
// returns required http code for different types of ErrorInfo.
// For secured endpoint you need to define
// all cases before defining security logic
forbiddenVariant,
unauthorizedVariant,
unauthorizedWithChallengeVariant,
notFoundVariant,
foundVariant,
badRequestVariant,
Expand Down Expand Up @@ -226,6 +251,7 @@ object ApiErrors extends SchemaDerivation with TapirJson4s {
body.errorOutVariants(
forbiddenVariant,
unauthorizedVariant,
unauthorizedWithChallengeVariant,
notFoundVariant,
foundVariant,
badRequestVariant,
Expand Down
36 changes: 10 additions & 26 deletions server/src/main/scala/app/softnetwork/api/server/ApiRoutes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import akka.http.scaladsl.server.{
Route,
ValidationRejection
}
import akka.http.scaladsl.settings.RoutingSettings
import app.softnetwork.api.server.config.ServerSettings
import org.json4s.Formats
import app.softnetwork.serialization._
Expand All @@ -31,33 +32,9 @@ trait ApiRoutes extends Directives with GrpcServices with DefaultComplete {

def log: Logger

val rejectionHandler: RejectionHandler =
RejectionHandler
.newBuilder()
.handle { case MissingCookieRejection(cookieName) =>
complete(HttpResponse(StatusCodes.BadRequest, entity = s"$cookieName cookie required"))
}
.handle { case AuthorizationFailedRejection =>
complete(StatusCodes.Forbidden)
}
.handle { case ValidationRejection(msg, _) =>
complete(HttpResponse(StatusCodes.InternalServerError, entity = msg))
}
.handleAll[MethodRejection] { methodRejections =>
val names = methodRejections.map(_.supported.name)
complete(
HttpResponse(
StatusCodes.MethodNotAllowed,
entity = s"Supported methods: ${names mkString " or "}!"
)
)
}
.handleNotFound {
complete(HttpResponse(StatusCodes.NotFound, entity = "Not found"))
}
.result()
val rejectionHandler: RejectionHandler = RejectionHandler.default

val exceptionHandler: ExceptionHandler =
lazy val exceptionHandler: ExceptionHandler =
ExceptionHandler { case e: TimeoutException =>
extractUri { uri =>
log.error(
Expand All @@ -67,6 +44,13 @@ trait ApiRoutes extends Directives with GrpcServices with DefaultComplete {
complete(HttpResponse(StatusCodes.InternalServerError, entity = "Timeout"))
}
}
.withFallback(
ExceptionHandler.default(
RoutingSettings(
ServerSettings.config
)
)
)

final def mainRoutes: ActorSystem[_] => Route = system => {
val routes = concat((HealthCheckService :: apiRoutes(system)).map(_.route): _*)
Expand Down

0 comments on commit afa96d0

Please sign in to comment.