作者:肖波
个人博客:http://blog.csdn.net/eaglet ; http://www.cnblogs.com/eaglet
2007/7 南京
版本
CodeSmith 4.0
netTiers 2.0.1
背景
最近在项目中使用CodeSmith + netTiers 生成数据访问层DAL,感觉效果很好,减少了大量的简单重复劳动。
不过在使用过程中发现CodeSmith提供的方法不能完全满足项目需要,主要体现在两个方面:
1、 Data.DataRepository.TableProvider.GetPaged方法无法输入带参数的条件,调用前必须进行SQL 拼接,
这样可能导致SQL 注入攻击。
2、 DataRepository.Provider.ExecuteDataSet 无法分页查询
为解决以上问题,我做了如下代码对生成的DAL进行了补充。这些代码可以在DAL外部使用,也可以修改netTiers
模板,内置到DAL中。
/// <summary>
/// 带参数的条件查询子句异常
/// </summary>
public class ParaWhereStringException : Exception
{
public ParaWhereStringException(String message)
: base(message)
{
}
}
/// <summary>
/// 带参数的条件查询子句
/// </summary>
public class ParaWhereString
{
enum T_STATE
{
Idle = 0,
At = 1,
Str = 2,
}
T_STATE m_State;
int m_LastPos;
int m_CurPos;
String m_WhereString;
List<String> m_Words = new List<string>();
private void Clear()
{
m_State = T_STATE.Idle;
m_LastPos = 0;
m_CurPos = 0;
m_WhereString = "";
m_Words = new List<string>();
}
private void ChangeState(T_STATE curState)
{
m_State = curState;
NewWord();
}
private void EndWord()
{
m_Words.Add(m_WhereString.Substring(m_LastPos, m_WhereString.Length - m_LastPos));
}
private void NewWord()
{
m_Words.Add(m_WhereString.Substring(m_LastPos, m_CurPos - m_LastPos));
m_LastPos = m_CurPos;
}
private void StateMachine(char ch)
{
switch (m_State)
{
case T_STATE.Idle:
if (ch == '@')
{
ChangeState(T_STATE.At);
}
else if (ch == '\'')
{
ChangeState(T_STATE.Str);
}
break;
case T_STATE.At:
if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_')
{
break;
}
if (ch >= '0' && ch <= '9' && m_CurPos - m_LastPos > 1)
{
break;
}
if (ch == '\'')
{
ChangeState(T_STATE.Str);
}
else
{
ChangeState(T_STATE.Idle);
}
break;
case T_STATE.Str:
if (ch == '\'')
{
m_CurPos++;
if (m_WhereString[m_CurPos] == '\'')
{
break;
}
else
{
ChangeState(T_STATE.Idle);
}
}
break;
}
if (m_CurPos == m_WhereString.Length - 1)
{
//无论任何状态,只要到了最后一个字符,结束状态机
EndWord();
return;
}
}
private void SplitWhereString(String whereString)
{
System.Diagnostics.Debug.Assert(whereString != null);
m_State = T_STATE.Idle;
m_LastPos = 0;
m_CurPos = 0;
while (m_CurPos < whereString.Length)
{
StateMachine(whereString[m_CurPos]);
m_CurPos++;
}
}
private String GetParaValue(String paraName, object value)
{
if ((value is int) || (value is uint) ||
(value is short) || (value is ushort) ||
(value is sbyte) || (value is byte) ||
(value is long) || (value is ulong) ||
(value is float) || (value is double)
)
{
return value.ToString();
}
if ((value is string) || (value is char))
{
return "'" + value.ToString().Replace("'", "''") + "'";
}
if (value is DateTime)
{
DateTime d = (DateTime)value;
return "'" + d.ToString("yyyy-MM-dd HH:mm:ss") + "'";
}
if (value == DBNull.Value)
{
return "NULL";
}
throw new ParaWhereStringException(String.Format("invalid type of para={0}!",
paraName));
}
/// <summary>
/// 根据参数获取条件子句
/// </summary>
/// <param name="whereString">
/// 带参数的条件子句,如
/// "Price>@MinPrice and Price < @MaxPrice"
/// </param>
/// <param name="parameters">参数列表</param>
/// <returns>获取实际的条件子句,如 "Price > 10 and Price < 100"</returns>
public String GetWhereString(String whereString, List<SqlParameter> parameters)
{
if (parameters == null)
{
return whereString;
}
Clear();
m_WhereString = whereString;
SplitWhereString(whereString);
Hashtable table = new Hashtable();
foreach (SqlParameter para in parameters)
{
if (para.Value == null)
{
table['@' + para.ParameterName.ToLower()] = DBNull.Value;
}
else
{
table['@' + para.ParameterName.ToLower()] = para.Value;
}
}
StringBuilder whereStr = new StringBuilder();
foreach (String str in m_Words)
{
if (str.Length > 0)
{
if (str[0] == '@')
{
object value = table[str.ToLower().Trim()];
if (value == null)
{
throw new ParaWhereStringException(String.Format("para={0} does not in parameters!",
str));
}
whereStr.Append(GetParaValue(str, value));
continue;
}
}
whereStr.Append(str);
}
return whereStr.ToString();
}
}
/// <summary>
/// 数据存储扩展
/// </summary>
public class DataRepositoryEx
{
/// <summary>
/// 获取分页的查询结果,查询语句必须是
/// Select 形式的,不能处理存储过程
/// </summary>
/// <param name="fields">where 子句前面的部分,不能有top关键字 如 “Price,ReleaseTime, RecName as Address”</param>
/// <param name="tableName">要查询的表名</param>
/// <param name="condition">带参数的 where子句,不包括where关键字 如 “Price > @MinPrice and Price < @MaxPrice”</param>
/// <param name="parameters">where子句的参数</param>
/// <param name="orderBy">order by 子句部分, 如果有Group by 也可以写在这里 如“order by ReleaseTime ASC”</param>
/// <param name="pageNo">页面号,从0开始编号</param>
/// <param name="pageLength">页面长度,即每页面记录数</param>
/// <param name="count">输出查询结果的总数</param>
/// <returns>以数据表形式返回查询结果集</returns>
static public DataTable SelectPaged(String fields, String tableName,
String condition, List<SqlParameter> parameters, String orderBy, int pageNo, int pageLength, out int count)
{
System.Diagnostics.Debug.Assert(pageNo >= 0);
System.Diagnostics.Debug.Assert(pageLength > 0);
ParaWhereString paraWhereStr = new ParaWhereString();
String sqlCond = paraWhereStr.GetWhereString(condition, parameters);
String sql;
if (condition == null)
{
condition = "";
}
if (condition == "")
{
sql = String.Format("select count(*) cnt from {0}", tableName);
}
else
{
sql = String.Format("select count(*) cnt from {0} where {1}", tableName, sqlCond);
}
DataSet ds = DataRepository.Provider.ExecuteDataSet(CommandType.Text, sql);
count = (int)ds.Tables[0].Rows[0]["cnt"];
int upperBound = (pageNo + 1) * pageLength;
int lowerBound = pageNo * pageLength;
if (condition == "")
{
sql = String.Format("select top {0} {1} from {2} ", upperBound, fields, tableName);
}
else
{
sql = String.Format("select top {0} {1} from {2} where {3} ", upperBound, fields, tableName, sqlCond);
}
if (orderBy != "" && orderBy != null)
{
sql += orderBy;
}
ds = DataRepository.Provider.ExecuteDataSet(CommandType.Text, sql);
if (ds.Tables[0].Rows.Count <= lowerBound)
{
ds.Tables[0].Clear();
}
else
{
for (int i = 0; i < lowerBound; i++)
{
ds.Tables[0].Rows.RemoveAt(0);
}
}
return ds.Tables[0];
}
}
/// 带参数的条件查询子句异常
/// </summary>
public class ParaWhereStringException : Exception
{
public ParaWhereStringException(String message)
: base(message)
{
}
}
/// <summary>
/// 带参数的条件查询子句
/// </summary>
public class ParaWhereString
{
enum T_STATE
{
Idle = 0,
At = 1,
Str = 2,
}
T_STATE m_State;
int m_LastPos;
int m_CurPos;
String m_WhereString;
List<String> m_Words = new List<string>();
private void Clear()
{
m_State = T_STATE.Idle;
m_LastPos = 0;
m_CurPos = 0;
m_WhereString = "";
m_Words = new List<string>();
}
private void ChangeState(T_STATE curState)
{
m_State = curState;
NewWord();
}
private void EndWord()
{
m_Words.Add(m_WhereString.Substring(m_LastPos, m_WhereString.Length - m_LastPos));
}
private void NewWord()
{
m_Words.Add(m_WhereString.Substring(m_LastPos, m_CurPos - m_LastPos));
m_LastPos = m_CurPos;
}
private void StateMachine(char ch)
{
switch (m_State)
{
case T_STATE.Idle:
if (ch == '@')
{
ChangeState(T_STATE.At);
}
else if (ch == '\'')
{
ChangeState(T_STATE.Str);
}
break;
case T_STATE.At:
if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_')
{
break;
}
if (ch >= '0' && ch <= '9' && m_CurPos - m_LastPos > 1)
{
break;
}
if (ch == '\'')
{
ChangeState(T_STATE.Str);
}
else
{
ChangeState(T_STATE.Idle);
}
break;
case T_STATE.Str:
if (ch == '\'')
{
m_CurPos++;
if (m_WhereString[m_CurPos] == '\'')
{
break;
}
else
{
ChangeState(T_STATE.Idle);
}
}
break;
}
if (m_CurPos == m_WhereString.Length - 1)
{
//无论任何状态,只要到了最后一个字符,结束状态机
EndWord();
return;
}
}
private void SplitWhereString(String whereString)
{
System.Diagnostics.Debug.Assert(whereString != null);
m_State = T_STATE.Idle;
m_LastPos = 0;
m_CurPos = 0;
while (m_CurPos < whereString.Length)
{
StateMachine(whereString[m_CurPos]);
m_CurPos++;
}
}
private String GetParaValue(String paraName, object value)
{
if ((value is int) || (value is uint) ||
(value is short) || (value is ushort) ||
(value is sbyte) || (value is byte) ||
(value is long) || (value is ulong) ||
(value is float) || (value is double)
)
{
return value.ToString();
}
if ((value is string) || (value is char))
{
return "'" + value.ToString().Replace("'", "''") + "'";
}
if (value is DateTime)
{
DateTime d = (DateTime)value;
return "'" + d.ToString("yyyy-MM-dd HH:mm:ss") + "'";
}
if (value == DBNull.Value)
{
return "NULL";
}
throw new ParaWhereStringException(String.Format("invalid type of para={0}!",
paraName));
}
/// <summary>
/// 根据参数获取条件子句
/// </summary>
/// <param name="whereString">
/// 带参数的条件子句,如
/// "Price>@MinPrice and Price < @MaxPrice"
/// </param>
/// <param name="parameters">参数列表</param>
/// <returns>获取实际的条件子句,如 "Price > 10 and Price < 100"</returns>
public String GetWhereString(String whereString, List<SqlParameter> parameters)
{
if (parameters == null)
{
return whereString;
}
Clear();
m_WhereString = whereString;
SplitWhereString(whereString);
Hashtable table = new Hashtable();
foreach (SqlParameter para in parameters)
{
if (para.Value == null)
{
table['@' + para.ParameterName.ToLower()] = DBNull.Value;
}
else
{
table['@' + para.ParameterName.ToLower()] = para.Value;
}
}
StringBuilder whereStr = new StringBuilder();
foreach (String str in m_Words)
{
if (str.Length > 0)
{
if (str[0] == '@')
{
object value = table[str.ToLower().Trim()];
if (value == null)
{
throw new ParaWhereStringException(String.Format("para={0} does not in parameters!",
str));
}
whereStr.Append(GetParaValue(str, value));
continue;
}
}
whereStr.Append(str);
}
return whereStr.ToString();
}
}
/// <summary>
/// 数据存储扩展
/// </summary>
public class DataRepositoryEx
{
/// <summary>
/// 获取分页的查询结果,查询语句必须是
/// Select 形式的,不能处理存储过程
/// </summary>
/// <param name="fields">where 子句前面的部分,不能有top关键字 如 “Price,ReleaseTime, RecName as Address”</param>
/// <param name="tableName">要查询的表名</param>
/// <param name="condition">带参数的 where子句,不包括where关键字 如 “Price > @MinPrice and Price < @MaxPrice”</param>
/// <param name="parameters">where子句的参数</param>
/// <param name="orderBy">order by 子句部分, 如果有Group by 也可以写在这里 如“order by ReleaseTime ASC”</param>
/// <param name="pageNo">页面号,从0开始编号</param>
/// <param name="pageLength">页面长度,即每页面记录数</param>
/// <param name="count">输出查询结果的总数</param>
/// <returns>以数据表形式返回查询结果集</returns>
static public DataTable SelectPaged(String fields, String tableName,
String condition, List<SqlParameter> parameters, String orderBy, int pageNo, int pageLength, out int count)
{
System.Diagnostics.Debug.Assert(pageNo >= 0);
System.Diagnostics.Debug.Assert(pageLength > 0);
ParaWhereString paraWhereStr = new ParaWhereString();
String sqlCond = paraWhereStr.GetWhereString(condition, parameters);
String sql;
if (condition == null)
{
condition = "";
}
if (condition == "")
{
sql = String.Format("select count(*) cnt from {0}", tableName);
}
else
{
sql = String.Format("select count(*) cnt from {0} where {1}", tableName, sqlCond);
}
DataSet ds = DataRepository.Provider.ExecuteDataSet(CommandType.Text, sql);
count = (int)ds.Tables[0].Rows[0]["cnt"];
int upperBound = (pageNo + 1) * pageLength;
int lowerBound = pageNo * pageLength;
if (condition == "")
{
sql = String.Format("select top {0} {1} from {2} ", upperBound, fields, tableName);
}
else
{
sql = String.Format("select top {0} {1} from {2} where {3} ", upperBound, fields, tableName, sqlCond);
}
if (orderBy != "" && orderBy != null)
{
sql += orderBy;
}
ds = DataRepository.Provider.ExecuteDataSet(CommandType.Text, sql);
if (ds.Tables[0].Rows.Count <= lowerBound)
{
ds.Tables[0].Clear();
}
else
{
for (int i = 0; i < lowerBound; i++)
{
ds.Tables[0].Rows.RemoveAt(0);
}
}
return ds.Tables[0];
}
}
ParaWhereString 类用于将带参数的条件子句转换为不带参数的条件子句,供GetPaged,GetAll两个方法使用。这个类是一个通用的类,也可以用于
其他应用中获取带参数的条件子句的最终转换后的条件子句。
DataRepositoryEx 类提供分查询的方法。
ParaWhereString 调用示例
ParaWhereString paraWhereString = new ParaWhereString();
string whereString = "price>@minPrice and price <= @maxPrice and str like '%adb''@aaa dsafj'";
List<SqlParameter> paras = new List<SqlParameter>();
paras.Add(new SqlParameter("minPrice", 100));
paras.Add(new SqlParameter("MaxPrice", 1000));
String sql = paraWhereString.GetWhereString(whereString, paras);
Console.WriteLine(sql);
输出结果:string whereString = "price>@minPrice and price <= @maxPrice and str like '%adb''@aaa dsafj'";
List<SqlParameter> paras = new List<SqlParameter>();
paras.Add(new SqlParameter("minPrice", 100));
paras.Add(new SqlParameter("MaxPrice", 1000));
String sql = paraWhereString.GetWhereString(whereString, paras);
Console.WriteLine(sql);
price>100 and price <= 1000 and str like '%adb''@aaa dsafj'
DataRepositoryEx 调用示例
用于测试的表结构
use Test
GO
Create Table Test
(
id int identity (1,1) not null,
a int
)
GO
Create Table Test
(
id int identity (1,1) not null,
a int
)
向表Test中插入若干条连续的记录
查询分页数据示例
int count;
List<SqlParameter> paras = new List<SqlParameter>();
paras.Add(new SqlParameter("min", 3));
paras.Add(new SqlParameter("max", 30));
DataTable table = SecUser.Cert.BLL.DataRepositoryEx.SelectPaged("id, a", "test..test", "id >= @min and id < @max",
paras, "order by id DESC", 0, 10, out count);
Response.Write(String.Format("Count={0}", count));
foreach(DataRow row in table.Rows)
{
Response.Write(String.Format("</p>{0}", row["id"]));
}
查询结果:
Count=27
29
28
27
26
25
24
23
22
21
20