Skip to content

Commit 5bff4c5

Browse files
committed
Support for RFC 7766
1 parent ba3e4a4 commit 5bff4c5

36 files changed

+1210
-1129
lines changed

ARSoft.Tools.Net/ARSoft.Tools.Net.csproj

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
<Description>This project contains a complete managed .Net DNS and DNSSEC client, a DNS server and SPF and SenderID validation.</Description>
1414
<PackageProjectUrl>https://github.com/alexreinert/ARSoft.Tools.Net</PackageProjectUrl>
1515
<PackageTags>dns dnssec spf</PackageTags>
16-
<PackageLicenseUrl>https://github.com/alexreinert/ARSoft.Tools.Net/blob/master/LICENSE</PackageLicenseUrl>
16+
<PackageLicenseExpression>Apache-2.0</PackageLicenseExpression>
1717
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
1818
<Copyright>Copyright 2010..2023 Alexander Reinert</Copyright>
19-
<VersionPrefix>3.3.0</VersionPrefix>
19+
<VersionPrefix>3.4.0</VersionPrefix>
2020
</PropertyGroup>
2121

2222
<ItemGroup>

ARSoft.Tools.Net/Dns/DnsClient.cs

+11-23
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,11 @@ public DnsClient(IEnumerable<IPAddress> dnsServers, IClientTransport[] transport
9090

9191
DnsMessage message = new DnsMessage() { IsQuery = true, OperationCode = OperationCode.Query, IsRecursionDesired = true, IsEDnsEnabled = true };
9292

93-
if (options == null)
94-
{
95-
message.IsRecursionDesired = true;
96-
message.IsEDnsEnabled = true;
97-
}
98-
else
99-
{
100-
message.IsRecursionDesired = options.IsRecursionDesired;
101-
message.IsCheckingDisabled = options.IsCheckingDisabled;
102-
message.EDnsOptions = options.EDnsOptions;
103-
}
93+
options ??= DnsQueryOptions.DefaultQueryOptions;
94+
95+
message.IsRecursionDesired = options.IsRecursionDesired;
96+
message.IsCheckingDisabled = options.IsCheckingDisabled;
97+
message.EDnsOptions = options.EDnsOptions;
10498

10599
message.Questions.Add(new DnsQuestion(name, recordType, recordClass));
106100

@@ -120,19 +114,13 @@ public DnsClient(IEnumerable<IPAddress> dnsServers, IClientTransport[] transport
120114
{
121115
_ = name ?? throw new ArgumentNullException(nameof(name), "Name must be provided");
122116

123-
DnsMessage message = new DnsMessage() { IsQuery = true, OperationCode = OperationCode.Query, IsRecursionDesired = true, IsEDnsEnabled = true };
117+
options ??= DnsQueryOptions.DefaultQueryOptions;
124118

125-
if (options == null)
126-
{
127-
message.IsRecursionDesired = true;
128-
message.IsEDnsEnabled = true;
129-
}
130-
else
131-
{
132-
message.IsRecursionDesired = options.IsRecursionDesired;
133-
message.IsCheckingDisabled = options.IsCheckingDisabled;
134-
message.EDnsOptions = options.EDnsOptions;
135-
}
119+
var message = new DnsMessage { IsQuery = true, OperationCode = OperationCode.Query, IsRecursionDesired = true, IsEDnsEnabled = true };
120+
121+
message.IsRecursionDesired = options.IsRecursionDesired;
122+
message.IsCheckingDisabled = options.IsCheckingDisabled;
123+
message.EDnsOptions = options.EDnsOptions;
136124

137125
message.Questions.Add(new DnsQuestion(name, recordType, recordClass));
138126

ARSoft.Tools.Net/Dns/DnsClientBase.cs

+9-146
Original file line numberDiff line numberDiff line change
@@ -86,151 +86,13 @@ internal DnsClientBase(IEnumerable<IPAddress> servers, int queryTimeout, IClient
8686
protected TMessage? SendMessage<TMessage>(TMessage query)
8787
where TMessage : DnsMessageBase, new()
8888
{
89-
SelectTsigKey? tsigKeySelector;
90-
91-
var package = PrepareMessage(query, out tsigKeySelector, out var tsigOriginalMac);
92-
93-
TMessage? response = null;
94-
95-
foreach (var connection in GetConnections(package, query.IsReliableSendingRequested))
96-
{
97-
try
98-
{
99-
var receivedMessage = SendMessage<TMessage>(package, connection, tsigKeySelector, tsigOriginalMac);
100-
101-
if ((receivedMessage != null) && ValidateResponse(query, receivedMessage.Message))
102-
{
103-
if (receivedMessage.Message.ReturnCode == ReturnCode.ServerFailure)
104-
{
105-
response = receivedMessage.Message;
106-
continue;
107-
}
108-
109-
if (!receivedMessage.Message.IsReliableResendingRequested)
110-
return receivedMessage.Message;
111-
112-
var resendTransport = _transports.FirstOrDefault(t => t.SupportsReliableTransfer && t.MaximumAllowedQuerySize <= package.Length && t != connection.Transport);
113-
114-
if (resendTransport != null)
115-
{
116-
using (IClientConnection? resendConnection = resendTransport.Connect(new DnsClientEndpointInfo(false, receivedMessage.ResponderAddress.Address, receivedMessage.LocalAddress.Address), QueryTimeout))
117-
{
118-
if (resendConnection == null)
119-
{
120-
response = receivedMessage.Message;
121-
}
122-
else
123-
{
124-
var resendResponse = SendMessage<TMessage>(package, resendConnection, tsigKeySelector, tsigOriginalMac);
125-
126-
if ((resendResponse != null)
127-
&& ValidateResponse(query, resendResponse.Message)
128-
&& ((resendResponse.Message.ReturnCode != ReturnCode.ServerFailure)))
129-
{
130-
return resendResponse.Message;
131-
}
132-
else
133-
{
134-
resendConnection.MarkFaulty();
135-
response = receivedMessage.Message;
136-
}
137-
}
138-
}
139-
}
140-
}
141-
else
142-
{
143-
connection.MarkFaulty();
144-
}
145-
}
146-
catch (Exception e)
147-
{
148-
Trace.TraceError("Error on dns query: " + e);
149-
connection.MarkFaulty();
150-
}
151-
finally
152-
{
153-
connection.Dispose();
154-
}
155-
}
156-
157-
return response;
158-
}
159-
160-
private IEnumerable<IClientConnection> GetConnections(DnsRawPackage package, bool isReliableTransportRequested)
161-
{
162-
foreach (var transport in _transports)
163-
{
164-
if (transport.SupportsPooledConnections
165-
&& package.Length <= transport.MaximumAllowedQuerySize
166-
&& (!isReliableTransportRequested || transport.SupportsReliableTransfer))
167-
{
168-
foreach (var endpointInfo in _endpointInfos)
169-
{
170-
var connection = transport.GetPooledConnection(endpointInfo);
171-
if (connection != null)
172-
yield return connection;
173-
}
174-
}
175-
}
176-
177-
foreach (var transport in _transports)
178-
{
179-
if (package.Length <= transport.MaximumAllowedQuerySize
180-
&& (!isReliableTransportRequested || transport.SupportsReliableTransfer))
181-
{
182-
foreach (var endpointInfo in _endpointInfos)
183-
{
184-
var connection = transport.Connect(endpointInfo, QueryTimeout);
185-
if (connection != null)
186-
yield return connection;
187-
}
188-
}
189-
}
190-
}
191-
192-
private ReceivedMessage<TMessage>? SendMessage<TMessage>(DnsRawPackage package, IClientConnection connection, SelectTsigKey? tsigKeySelector, byte[]? tsigOriginalMac)
193-
where TMessage : DnsMessageBase, new()
194-
{
195-
if (!connection.Send(package))
196-
return null;
197-
198-
var resultData = connection.Receive();
199-
200-
if (resultData == null)
201-
return null;
202-
203-
var response = DnsMessageBase.Parse<TMessage>(resultData.ToArraySegment(false), tsigKeySelector, tsigOriginalMac);
204-
205-
var isNextMessageWaiting = response.IsNextMessageWaiting(false);
206-
207-
while (isNextMessageWaiting)
208-
{
209-
resultData = connection.Receive();
210-
211-
if (resultData == null)
212-
return null;
213-
214-
var nextResult = DnsMessageBase.Parse<TMessage>(resultData.ToArraySegment(false), tsigKeySelector, tsigOriginalMac);
215-
216-
if (nextResult.ReturnCode == ReturnCode.ServerFailure)
217-
return null;
218-
219-
response.AddSubsequentResponse(nextResult);
220-
isNextMessageWaiting = nextResult.IsNextMessageWaiting(true);
221-
}
222-
223-
return new ReceivedMessage<TMessage>(resultData.RemoteEndpoint, resultData.LocalEndpoint, response);
89+
return SendMessageAsync<TMessage>(query, CancellationToken.None).GetAwaiter().GetResult();
22490
}
22591

22692
protected List<TMessage> SendMessageParallel<TMessage>(TMessage message)
22793
where TMessage : DnsMessageBase, new()
22894
{
229-
var result = SendMessageParallelAsync(message, default);
230-
231-
result.Wait();
232-
233-
return result.Result;
95+
return SendMessageParallelAsync(message, default).GetAwaiter().GetResult();
23496
}
23597

23698
private bool ValidateResponse<TMessage>(TMessage message, TMessage response)
@@ -293,6 +155,8 @@ private DnsRawPackage PrepareMessage<TMessage>(TMessage message, out SelectTsigK
293155

294156
if ((receivedMessage != null) && ValidateResponse(query, receivedMessage.Message))
295157
{
158+
connection.RestartIdleTimeout(receivedMessage.Message.GetEDnsKeepAliveTimeout());
159+
296160
if (receivedMessage.Message.ReturnCode == ReturnCode.ServerFailure)
297161
{
298162
response = receivedMessage.Message;
@@ -320,6 +184,7 @@ private DnsRawPackage PrepareMessage<TMessage>(TMessage message, out SelectTsigK
320184
&& ValidateResponse(query, resendResponse.Message)
321185
&& ((resendResponse.Message.ReturnCode != ReturnCode.ServerFailure)))
322186
{
187+
resendConnection.RestartIdleTimeout(receivedMessage.Message.GetEDnsKeepAliveTimeout());
323188
return resendResponse.Message;
324189
}
325190
else
@@ -360,9 +225,7 @@ private DnsRawPackage PrepareMessage<TMessage>(TMessage message, out SelectTsigK
360225
{
361226
foreach (var endpointInfo in _endpointInfos)
362227
{
363-
var connection = transport.GetPooledConnection(endpointInfo);
364-
if (connection != null)
365-
yield return Task.FromResult<IClientConnection?>(connection);
228+
yield return transport.GetPooledConnectionAsync(endpointInfo, token);
366229
}
367230
}
368231
}
@@ -386,7 +249,7 @@ private DnsRawPackage PrepareMessage<TMessage>(TMessage message, out SelectTsigK
386249
if (!await connection.SendAsync(package, token))
387250
return null;
388251

389-
var resultData = await connection.ReceiveAsync(token);
252+
var resultData = await connection.ReceiveAsync(package.MessageIdentification, token);
390253

391254
if (resultData == null)
392255
return null;
@@ -397,7 +260,7 @@ private DnsRawPackage PrepareMessage<TMessage>(TMessage message, out SelectTsigK
397260

398261
while (isNextMessageWaiting)
399262
{
400-
resultData = await connection.ReceiveAsync(token);
263+
resultData = await connection.ReceiveAsync(package.MessageIdentification, token);
401264

402265
if (resultData == null)
403266
return null;
@@ -458,7 +321,7 @@ private async Task SendMessageParallelAsync<TMessage>(IClientTransport transport
458321
if (token.IsCancellationRequested)
459322
break;
460323

461-
var response = await connection.ReceiveAsync(token);
324+
var response = await connection.ReceiveAsync(package.MessageIdentification, token);
462325

463326
if (response == null)
464327
continue;

ARSoft.Tools.Net/Dns/DnsMessageBase.cs

+4-10
Original file line numberDiff line numberDiff line change
@@ -365,12 +365,7 @@ protected void ParseQuestionSection(IList<byte> data, ref int currentPosition, i
365365

366366
for (var i = 0; i < recordCount; i++)
367367
{
368-
DnsQuestion question = new DnsQuestion(
369-
ParseDomainName(data, ref currentPosition),
370-
(RecordType)ParseUShort(data, ref currentPosition),
371-
(RecordClass)ParseUShort(data, ref currentPosition));
372-
373-
questions.Add(question);
368+
questions.Add(DnsQuestion.Parse(data, ref currentPosition));
374369
}
375370

376371
SetQuestionSection(questions);
@@ -971,7 +966,7 @@ internal static TMessage ParseRfc8427Json<TMessage>(JsonElement json)
971966
msg.IsQuery = !ReadBoolFlag(prop.Value);
972967
break;
973968
case "Opcode":
974-
msg.OperationCodeInternal = (OperationCode)prop.Value.GetUInt16();
969+
msg.OperationCodeInternal = (OperationCode) prop.Value.GetUInt16();
975970
break;
976971
case "AA":
977972
msg.AAFlagInternal = ReadBoolFlag(prop.Value);
@@ -992,7 +987,7 @@ internal static TMessage ParseRfc8427Json<TMessage>(JsonElement json)
992987
msg.CDFlagInternal = ReadBoolFlag(prop.Value);
993988
break;
994989
case "RCODE":
995-
msg.ReturnCode = (ReturnCode)prop.Value.GetUInt16();
990+
msg.ReturnCode = (ReturnCode) prop.Value.GetUInt16();
996991
break;
997992
case "QNAME":
998993
qname = DomainName.Parse(prop.Value.GetString() ?? String.Empty);
@@ -1004,7 +999,7 @@ internal static TMessage ParseRfc8427Json<TMessage>(JsonElement json)
1004999
qtype = RecordTypeHelper.ParseShortString(prop.Value.GetString() ?? String.Empty);
10051000
break;
10061001
case "QCLASS":
1007-
qclass = (RecordClass)prop.Value.GetUInt16();
1002+
qclass = (RecordClass) prop.Value.GetUInt16();
10081003
break;
10091004
case "QCLASSname":
10101005
qclass = RecordClassHelper.ParseShortString(prop.Value.GetString() ?? String.Empty);
@@ -1082,6 +1077,5 @@ private static bool ReadBoolFlag(JsonElement json)
10821077
default: throw new JsonException("Not a valid boolean flag: '" + json.GetRawText() + "'");
10831078
}
10841079
}
1085-
10861080
}
10871081
}
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright and License
1+
#region Copyright and License
22
// Copyright 2010..2023 Alexander Reinert
33
//
44
// This file is part of the ARSoft.Tools.Net - C# DNS client/server and SPF Library (https://github.com/alexreinert/ARSoft.Tools.Net)
@@ -16,16 +16,12 @@
1616
// limitations under the License.
1717
#endregion
1818

19-
namespace ARSoft.Tools.Net.Dns
19+
namespace ARSoft.Tools.Net.Dns;
20+
21+
internal static class DnsMessageBaseExtensions
2022
{
21-
/// <summary>
22-
/// Interface of a pooled connection initiated by a client
23-
/// </summary>
24-
public interface IPoolableClientConnection : IClientConnection
23+
public static TimeSpan? GetEDnsKeepAliveTimeout(this DnsMessageBase message)
2524
{
26-
/// <summary>
27-
/// Returns a value indicating if the connection is still alive
28-
/// </summary>
29-
bool IsAlive { get; }
25+
return message.EDnsOptions?.Options?.OfType<TcpKeepAliveOption>()?.FirstOrDefault()?.Timeout;
3026
}
3127
}

ARSoft.Tools.Net/Dns/DnsQueryOptions.cs

+24
Original file line numberDiff line numberDiff line change
@@ -102,5 +102,29 @@ public bool IsDnsSecOk
102102
}
103103
}
104104
}
105+
106+
internal static DnsQueryOptions DefaultQueryOptions { get; } = new()
107+
{
108+
IsRecursionDesired = true,
109+
EDnsOptions = new(
110+
1232,
111+
new DnssecAlgorithmUnderstoodOption(EnumHelper<DnsSecAlgorithm>.Names.Keys.Where(a => a.IsSupported()).ToArray()),
112+
new DsHashUnderstoodOption(EnumHelper<DnsSecDigestType>.Names.Keys.Where(d => d.IsSupported()).ToArray()),
113+
new Nsec3HashUnderstoodOption(EnumHelper<NSec3HashAlgorithm>.Names.Keys.Where(a => a.IsSupported()).ToArray())
114+
)
115+
};
116+
117+
internal static DnsQueryOptions DefaultDnsSecQueryOptions { get; } = new()
118+
{
119+
IsRecursionDesired = true,
120+
IsCheckingDisabled = true,
121+
EDnsOptions = new(
122+
1232,
123+
new DnssecAlgorithmUnderstoodOption(EnumHelper<DnsSecAlgorithm>.Names.Keys.Where(a => a.IsSupported()).ToArray()),
124+
new DsHashUnderstoodOption(EnumHelper<DnsSecDigestType>.Names.Keys.Where(d => d.IsSupported()).ToArray()),
125+
new Nsec3HashUnderstoodOption(EnumHelper<NSec3HashAlgorithm>.Names.Keys.Where(a => a.IsSupported()).ToArray())
126+
),
127+
IsDnsSecOk = true
128+
};
105129
}
106130
}

0 commit comments

Comments
 (0)