Skip to content

Commit

Permalink
Fix concurrentRate not work in some case
Browse files Browse the repository at this point in the history
  • Loading branch information
821938089 committed Jan 22, 2025
1 parent 5913df1 commit 5caa3b8
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 128 deletions.
135 changes: 135 additions & 0 deletions app/src/main/java/io/legado/app/help/ConcurrentRateLimiter.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package io.legado.app.help

import io.legado.app.data.entities.BaseSource
import io.legado.app.exception.ConcurrentException
import io.legado.app.model.analyzeRule.AnalyzeUrl.ConcurrentRecord
import kotlinx.coroutines.delay

class ConcurrentRateLimiter(val source: BaseSource?) {

companion object {
private val concurrentRecordMap = hashMapOf<String, ConcurrentRecord>()
}

/**
* 开始访问,并发判断
*/
@Throws(ConcurrentException::class)
private fun fetchStart(): ConcurrentRecord? {
source ?: return null
val concurrentRate = source.concurrentRate
if (concurrentRate.isNullOrEmpty() || concurrentRate == "0") {
return null
}
val rateIndex = concurrentRate.indexOf("/")
var fetchRecord = concurrentRecordMap[source.getKey()]
if (fetchRecord == null) {
synchronized(concurrentRecordMap) {
fetchRecord = concurrentRecordMap[source.getKey()]
if (fetchRecord == null) {
fetchRecord = ConcurrentRecord(rateIndex > 0, System.currentTimeMillis(), 1)
concurrentRecordMap[source.getKey()] = fetchRecord
return fetchRecord
}
}
}
val waitTime: Int = synchronized(fetchRecord!!) {
try {
if (!fetchRecord.isConcurrent) {
//并发控制非 次数/毫秒
if (fetchRecord.frequency > 0) {
//已经有访问线程,直接等待
return@synchronized concurrentRate.toInt()
}
//没有线程访问,判断还剩多少时间可以访问
val nextTime = fetchRecord.time + concurrentRate.toInt()
if (System.currentTimeMillis() >= nextTime) {
fetchRecord.time = System.currentTimeMillis()
fetchRecord.frequency = 1
return@synchronized 0
}
return@synchronized (nextTime - System.currentTimeMillis()).toInt()
} else {
//并发控制为 次数/毫秒
val sj = concurrentRate.substring(rateIndex + 1)
val nextTime = fetchRecord.time + sj.toInt()
if (System.currentTimeMillis() >= nextTime) {
//已经过了限制时间,重置开始时间
fetchRecord.time = System.currentTimeMillis()
fetchRecord.frequency = 1
return@synchronized 0
}
val cs = concurrentRate.substring(0, rateIndex)
if (fetchRecord.frequency > cs.toInt()) {
return@synchronized (nextTime - System.currentTimeMillis()).toInt()
} else {
fetchRecord.frequency += 1
return@synchronized 0
}
}
} catch (_: Exception) {
return@synchronized 0
}
}
if (waitTime > 0) {
throw ConcurrentException(
"根据并发率还需等待${waitTime}毫秒才可以访问",
waitTime = waitTime
)
}
return fetchRecord
}

/**
* 访问结束
*/
fun fetchEnd(concurrentRecord: ConcurrentRecord?) {
if (concurrentRecord != null && !concurrentRecord.isConcurrent) {
synchronized(concurrentRecord) {
concurrentRecord.frequency -= 1
}
}
}

/**
* 获取并发记录,若处于并发限制状态下则会等待
*/
suspend fun getConcurrentRecord(): ConcurrentRecord? {
while (true) {
try {
return fetchStart()
} catch (e: ConcurrentException) {
delay(e.waitTime.toLong())
}
}
}

fun getConcurrentRecordBlocking(): ConcurrentRecord? {
while (true) {
try {
return fetchStart()
} catch (e: ConcurrentException) {
Thread.sleep(e.waitTime.toLong())
}
}
}

suspend inline fun <T> withLimit(block: () -> T): T {
val concurrentRecord = getConcurrentRecord()
try {
return block()
} finally {
fetchEnd(concurrentRecord)
}
}

inline fun <T> withLimitBlocking(block: () -> T): T {
val concurrentRecord = getConcurrentRecordBlocking()
try {
return block()
} finally {
fetchEnd(concurrentRecord)
}
}

}
57 changes: 35 additions & 22 deletions app/src/main/java/io/legado/app/help/JsExtensions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import io.legado.app.utils.toStringArray
import io.legado.app.utils.toastOnUi
import kotlinx.coroutines.Dispatchers.IO
import kotlinx.coroutines.async
import kotlinx.coroutines.ensureActive
import kotlinx.coroutines.runBlocking
import okio.use
import org.jsoup.Connection
Expand Down Expand Up @@ -358,13 +359,17 @@ interface JsExtensions : JsEncodeUtils {
val requestHeaders = if (getSource()?.enabledCookieJar == true) {
headers.toMutableMap().apply { put(cookieJarHeader, "1") }
} else headers
val response = Jsoup.connect(urlStr)
.sslSocketFactory(SSLHelper.unsafeSSLSocketFactory)
.ignoreContentType(true)
.followRedirects(false)
.headers(requestHeaders)
.method(Connection.Method.GET)
.execute()
val rateLimiter = ConcurrentRateLimiter(getSource())
val response = rateLimiter.withLimitBlocking {
context.ensureActive()
Jsoup.connect(urlStr)
.sslSocketFactory(SSLHelper.unsafeSSLSocketFactory)
.ignoreContentType(true)
.followRedirects(false)
.headers(requestHeaders)
.method(Connection.Method.GET)
.execute()
}
return response
}

Expand All @@ -375,13 +380,17 @@ interface JsExtensions : JsEncodeUtils {
val requestHeaders = if (getSource()?.enabledCookieJar == true) {
headers.toMutableMap().apply { put(cookieJarHeader, "1") }
} else headers
val response = Jsoup.connect(urlStr)
.sslSocketFactory(SSLHelper.unsafeSSLSocketFactory)
.ignoreContentType(true)
.followRedirects(false)
.headers(requestHeaders)
.method(Connection.Method.HEAD)
.execute()
val rateLimiter = ConcurrentRateLimiter(getSource())
val response = rateLimiter.withLimitBlocking {
context.ensureActive()
Jsoup.connect(urlStr)
.sslSocketFactory(SSLHelper.unsafeSSLSocketFactory)
.ignoreContentType(true)
.followRedirects(false)
.headers(requestHeaders)
.method(Connection.Method.HEAD)
.execute()
}
return response
}

Expand All @@ -392,14 +401,18 @@ interface JsExtensions : JsEncodeUtils {
val requestHeaders = if (getSource()?.enabledCookieJar == true) {
headers.toMutableMap().apply { put(cookieJarHeader, "1") }
} else headers
val response = Jsoup.connect(urlStr)
.sslSocketFactory(SSLHelper.unsafeSSLSocketFactory)
.ignoreContentType(true)
.followRedirects(false)
.requestBody(body)
.headers(requestHeaders)
.method(Connection.Method.POST)
.execute()
val rateLimiter = ConcurrentRateLimiter(getSource())
val response = rateLimiter.withLimitBlocking {
context.ensureActive()
Jsoup.connect(urlStr)
.sslSocketFactory(SSLHelper.unsafeSSLSocketFactory)
.ignoreContentType(true)
.followRedirects(false)
.requestBody(body)
.headers(requestHeaders)
.method(Connection.Method.POST)
.execute()
}
return response
}

Expand Down
110 changes: 4 additions & 106 deletions app/src/main/java/io/legado/app/model/analyzeRule/AnalyzeUrl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import io.legado.app.constant.AppPattern.dataUriRegex
import io.legado.app.data.entities.BaseSource
import io.legado.app.data.entities.Book
import io.legado.app.data.entities.BookChapter
import io.legado.app.exception.ConcurrentException
import io.legado.app.help.CacheManager
import io.legado.app.help.ConcurrentRateLimiter
import io.legado.app.help.JsExtensions
import io.legado.app.help.config.AppConfig
import io.legado.app.help.exoplayer.ExoPlayerHelper
Expand Down Expand Up @@ -46,7 +46,6 @@ import io.legado.app.utils.isJsonArray
import io.legado.app.utils.isJsonObject
import io.legado.app.utils.isXml
import io.legado.app.utils.splitNotBlank
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.OkHttpClient
Expand Down Expand Up @@ -86,7 +85,6 @@ class AnalyzeUrl(
companion object {
val paramPattern: Pattern = Pattern.compile("\\s*,\\s*(?=\\{)")
private val pagePattern = Pattern.compile("<(.*?)>")
private val concurrentRecordMap = hashMapOf<String, ConcurrentRecord>()
}

var ruleUrl = ""
Expand All @@ -110,6 +108,7 @@ class AnalyzeUrl(
private val enabledCookieJar = source?.enabledCookieJar ?: false
private val domain: String
private var webViewDelayTime: Long = 0
private val concurrentRateLimiter = ConcurrentRateLimiter(source)

// 服务器ID
var serverID: Long? = null
Expand Down Expand Up @@ -331,99 +330,6 @@ class AnalyzeUrl(
?: ""
}

/**
* 开始访问,并发判断
*/
@Throws(ConcurrentException::class)
private fun fetchStart(): ConcurrentRecord? {
source ?: return null
val concurrentRate = source.concurrentRate
if (concurrentRate.isNullOrEmpty() || concurrentRate == "0") {
return null
}
val rateIndex = concurrentRate.indexOf("/")
var fetchRecord = concurrentRecordMap[source.getKey()]
if (fetchRecord == null) {
synchronized(concurrentRecordMap) {
fetchRecord = concurrentRecordMap[source.getKey()]
if (fetchRecord == null) {
fetchRecord = ConcurrentRecord(rateIndex > 0, System.currentTimeMillis(), 1)
concurrentRecordMap[source.getKey()] = fetchRecord
return fetchRecord
}
}
}
val waitTime: Int = synchronized(fetchRecord!!) {
try {
if (!fetchRecord.isConcurrent) {
//并发控制非 次数/毫秒
if (fetchRecord.frequency > 0) {
//已经有访问线程,直接等待
return@synchronized concurrentRate.toInt()
}
//没有线程访问,判断还剩多少时间可以访问
val nextTime = fetchRecord.time + concurrentRate.toInt()
if (System.currentTimeMillis() >= nextTime) {
fetchRecord.time = System.currentTimeMillis()
fetchRecord.frequency = 1
return@synchronized 0
}
return@synchronized (nextTime - System.currentTimeMillis()).toInt()
} else {
//并发控制为 次数/毫秒
val sj = concurrentRate.substring(rateIndex + 1)
val nextTime = fetchRecord.time + sj.toInt()
if (System.currentTimeMillis() >= nextTime) {
//已经过了限制时间,重置开始时间
fetchRecord.time = System.currentTimeMillis()
fetchRecord.frequency = 1
return@synchronized 0
}
val cs = concurrentRate.substring(0, rateIndex)
if (fetchRecord.frequency > cs.toInt()) {
return@synchronized (nextTime - System.currentTimeMillis()).toInt()
} else {
fetchRecord.frequency += 1
return@synchronized 0
}
}
} catch (e: Exception) {
return@synchronized 0
}
}
if (waitTime > 0) {
throw ConcurrentException(
"根据并发率还需等待${waitTime}毫秒才可以访问",
waitTime = waitTime
)
}
return fetchRecord
}

/**
* 访问结束
*/
private fun fetchEnd(concurrentRecord: ConcurrentRecord?) {
if (concurrentRecord != null && !concurrentRecord.isConcurrent) {
synchronized(concurrentRecord) {
concurrentRecord.frequency -= 1
}
}
}

/**
* 获取并发记录,若处于并发限制状态下则会等待
*/
private suspend fun getConcurrentRecord(): ConcurrentRecord? {
while (true) {
try {
return fetchStart()
} catch (e: ConcurrentException) {
delay(e.waitTime.toLong())
}
}
}

/**
* 访问网站,返回StrResponse
*/
Expand All @@ -435,8 +341,7 @@ class AnalyzeUrl(
if (type != null) {
return StrResponse(url, HexUtil.encodeHexStr(getByteArrayAwait()))
}
val concurrentRecord = getConcurrentRecord()
try {
concurrentRateLimiter.withLimit {
setCookie()
val strResponse: StrResponse
if (this.useWebView && useWebView) {
Expand Down Expand Up @@ -500,9 +405,6 @@ class AnalyzeUrl(
}
}
return strResponse
} finally {
//saveCookie()
fetchEnd(concurrentRecord)
}
}

Expand All @@ -521,8 +423,7 @@ class AnalyzeUrl(
* 访问网站,返回Response
*/
suspend fun getResponseAwait(): Response {
val concurrentRecord = getConcurrentRecord()
try {
concurrentRateLimiter.withLimit {
setCookie()
val response = getClient().newCallResponse(retry) {
addHeaders(headerMap)
Expand All @@ -545,9 +446,6 @@ class AnalyzeUrl(
}
}
return response
} finally {
//saveCookie()
fetchEnd(concurrentRecord)
}
}

Expand Down

0 comments on commit 5caa3b8

Please sign in to comment.