Rewrite RateLimitInterceptor (#7889)

This commit is contained in:
stevenyomi 2022-08-31 01:17:37 +08:00 committed by GitHub
parent 53f5ea7fe9
commit 532f662b05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 89 deletions

View File

@ -5,6 +5,8 @@ import okhttp3.Interceptor
import okhttp3.OkHttpClient import okhttp3.OkHttpClient
import okhttp3.Response import okhttp3.Response
import java.io.IOException import java.io.IOException
import java.util.ArrayDeque
import java.util.concurrent.Semaphore
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
/** /**
@ -25,54 +27,77 @@ fun OkHttpClient.Builder.rateLimit(
permits: Int, permits: Int,
period: Long = 1, period: Long = 1,
unit: TimeUnit = TimeUnit.SECONDS, unit: TimeUnit = TimeUnit.SECONDS,
) = addInterceptor(RateLimitInterceptor(permits, period, unit)) ) = addInterceptor(RateLimitInterceptor(null, permits, period, unit))
private class RateLimitInterceptor( /** We can probably accept domains or wildcards by comparing with [endsWith], etc. */
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
internal class RateLimitInterceptor(
private val host: String?,
private val permits: Int, private val permits: Int,
period: Long, period: Long,
unit: TimeUnit, unit: TimeUnit,
) : Interceptor { ) : Interceptor {
private val requestQueue = ArrayList<Long>(permits) private val requestQueue = ArrayDeque<Long>(permits)
private val rateLimitMillis = unit.toMillis(period) private val rateLimitMillis = unit.toMillis(period)
private val fairLock = Semaphore(1, true)
override fun intercept(chain: Interceptor.Chain): Response { override fun intercept(chain: Interceptor.Chain): Response {
// Ignore canceled calls, otherwise they would jam the queue val call = chain.call()
if (chain.call().isCanceled()) { if (call.isCanceled()) throw IOException("Canceled")
throw IOException()
val request = chain.request()
when (host) {
null, request.url.host -> {} // need rate limit
else -> return chain.proceed(request)
} }
try {
fairLock.acquire()
} catch (e: InterruptedException) {
throw IOException(e)
}
val requestQueue = this.requestQueue
val timestamp: Long
try {
synchronized(requestQueue) { synchronized(requestQueue) {
val now = SystemClock.elapsedRealtime() while (requestQueue.size >= permits) { // queue is full, remove expired entries
val waitTime = if (requestQueue.size < permits) { val periodStart = SystemClock.elapsedRealtime() - rateLimitMillis
0 var hasRemovedExpired = false
} else { while (requestQueue.isEmpty().not() && requestQueue.first <= periodStart) {
val oldestReq = requestQueue[0] requestQueue.removeFirst()
val newestReq = requestQueue[permits - 1] hasRemovedExpired = true
}
if (newestReq - oldestReq > rateLimitMillis) { if (call.isCanceled()) {
0 throw IOException("Canceled")
} else { } else if (hasRemovedExpired) {
oldestReq + rateLimitMillis - now // Remaining time break
} else try { // wait for the first entry to expire, or notified by cached response
(requestQueue as Object).wait(requestQueue.first - periodStart)
} catch (_: InterruptedException) {
continue
} }
} }
// Final check // add request to queue
if (chain.call().isCanceled()) { timestamp = SystemClock.elapsedRealtime()
throw IOException() requestQueue.addLast(timestamp)
}
} finally {
fairLock.release()
} }
if (requestQueue.size == permits) { val response = chain.proceed(request)
requestQueue.removeAt(0) if (response.networkResponse == null) { // response is cached, remove it from queue
} synchronized(requestQueue) {
if (waitTime > 0) { if (requestQueue.isEmpty() || timestamp < requestQueue.first) return@synchronized
requestQueue.add(now + waitTime) requestQueue.removeFirstOccurrence(timestamp)
Thread.sleep(waitTime) // Sleep inside synchronized to pause queued requests (requestQueue as Object).notifyAll()
} else {
requestQueue.add(now)
} }
} }
return chain.proceed(chain.request()) return response
} }
} }

View File

@ -1,11 +1,7 @@
package eu.kanade.tachiyomi.network.interceptor package eu.kanade.tachiyomi.network.interceptor
import android.os.SystemClock
import okhttp3.HttpUrl import okhttp3.HttpUrl
import okhttp3.Interceptor
import okhttp3.OkHttpClient import okhttp3.OkHttpClient
import okhttp3.Response
import java.io.IOException
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
/** /**
@ -28,58 +24,4 @@ fun OkHttpClient.Builder.rateLimitHost(
permits: Int, permits: Int,
period: Long = 1, period: Long = 1,
unit: TimeUnit = TimeUnit.SECONDS, unit: TimeUnit = TimeUnit.SECONDS,
) = addInterceptor(SpecificHostRateLimitInterceptor(httpUrl, permits, period, unit)) ) = addInterceptor(RateLimitInterceptor(httpUrl.host, permits, period, unit))
class SpecificHostRateLimitInterceptor(
httpUrl: HttpUrl,
private val permits: Int,
period: Long,
unit: TimeUnit,
) : Interceptor {
private val requestQueue = ArrayList<Long>(permits)
private val rateLimitMillis = unit.toMillis(period)
private val host = httpUrl.host
override fun intercept(chain: Interceptor.Chain): Response {
// Ignore canceled calls, otherwise they would jam the queue
if (chain.call().isCanceled()) {
throw IOException()
} else if (chain.request().url.host != host) {
return chain.proceed(chain.request())
}
synchronized(requestQueue) {
val now = SystemClock.elapsedRealtime()
val waitTime = if (requestQueue.size < permits) {
0
} else {
val oldestReq = requestQueue[0]
val newestReq = requestQueue[permits - 1]
if (newestReq - oldestReq > rateLimitMillis) {
0
} else {
oldestReq + rateLimitMillis - now // Remaining time
}
}
// Final check
if (chain.call().isCanceled()) {
throw IOException()
}
if (requestQueue.size == permits) {
requestQueue.removeAt(0)
}
if (waitTime > 0) {
requestQueue.add(now + waitTime)
Thread.sleep(waitTime) // Sleep inside synchronized to pause queued requests
} else {
requestQueue.add(now)
}
}
return chain.proceed(chain.request())
}
}