use std::{ops::RangeBounds, sync::OnceLock};
use http::{header::HeaderName, Request, Response, StatusCode};
use tower::Service;
use tower_http::cors::CorsLayer;
use crate::layers::{
    body_to_bytes_response::BodyToBytesResponse, bytes_to_body_request::BytesToBodyRequest,
    catch_http_codes::CatchHttpCodes, form_urlencoded_request::FormUrlencodedRequest,
    json_request::JsonRequest, json_response::JsonResponse,
};
static PROPAGATOR_HEADERS: OnceLock<Vec<HeaderName>> = OnceLock::new();
pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
    let headers = propagator
        .fields()
        .map(|h| HeaderName::try_from(h).unwrap())
        .collect();
    tracing::debug!(
        ?headers,
        "Headers allowed in CORS requests for trace propagators set"
    );
    PROPAGATOR_HEADERS
        .set(headers)
        .expect(concat!(module_path!(), "::set_propagator was called twice"));
}
pub trait CorsLayerExt {
    #[must_use]
    fn allow_otel_headers<H>(self, headers: H) -> Self
    where
        H: IntoIterator<Item = HeaderName>;
}
impl CorsLayerExt for CorsLayer {
    fn allow_otel_headers<H>(self, headers: H) -> Self
    where
        H: IntoIterator<Item = HeaderName>,
    {
        let base = PROPAGATOR_HEADERS.get().cloned().unwrap_or_default();
        let headers: Vec<_> = headers.into_iter().chain(base).collect();
        self.allow_headers(headers)
    }
}
pub trait ServiceExt<Body>: Sized {
    fn request_bytes_to_body(self) -> BytesToBodyRequest<Self> {
        BytesToBodyRequest::new(self)
    }
    fn response_body_to_bytes(self) -> BodyToBytesResponse<Self> {
        BodyToBytesResponse::new(self)
    }
    fn json_response<T>(self) -> JsonResponse<Self, T> {
        JsonResponse::new(self)
    }
    fn json_request<T>(self) -> JsonRequest<Self, T> {
        JsonRequest::new(self)
    }
    fn form_urlencoded_request<T>(self) -> FormUrlencodedRequest<Self, T> {
        FormUrlencodedRequest::new(self)
    }
    fn catch_http_code<M, ResBody, E>(
        self,
        status_code: StatusCode,
        mapper: M,
    ) -> CatchHttpCodes<Self, M>
    where
        M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
    {
        self.catch_http_codes(status_code..=status_code, mapper)
    }
    fn catch_http_codes<B, M, ResBody, E>(self, bounds: B, mapper: M) -> CatchHttpCodes<Self, M>
    where
        B: RangeBounds<StatusCode>,
        M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
    {
        CatchHttpCodes::new(self, bounds, mapper)
    }
    fn catch_http_errors<M, ResBody, E>(self, mapper: M) -> CatchHttpCodes<Self, M>
    where
        M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
    {
        self.catch_http_codes(
            StatusCode::from_u16(400).unwrap()..StatusCode::from_u16(600).unwrap(),
            mapper,
        )
    }
}
impl<S, B> ServiceExt<B> for S where S: Service<Request<B>> {}