zoukankan      html  css  js  c++  java
  • VC++基于LSP实现数据拦截


    LSP即分层服务提供商,Winsock 作为应用程序的 Windows 的网络套接字工具,可以由称为“分层服务提供商”的机制进行扩展。Winsock LSP 可用于非常广泛的实用用途,包括 Internet 家长控制 (parental control) 和 Web 内容筛选。在以前版本的 Windows XP 中,删除不正确的(也称为“buggy”)LSP 可能会导致注册表中的 Winsock 目录损坏,潜在地导致所有网络连接的丢失。


    请见代码



    #define UNICODE
    #define _UNICODE
    
    #include <Winsock2.h>
    #include <Ws2spi.h>
    #include <Sporder.h>
    #include <Windows.h>
    #include <stdio.h>
    #include <tchar.h>
    
    #include "PhoenixLSP.h"
    
    #include "../common/Debug.h"
    #include "../common/PMacRes.h"
    #include "Acl.h"
    
    #pragma comment(lib, "Ws2_32.lib")
    
    
    CAcl g_Acl;							// 访问列表,用来检查会话的访问权限
    
    WSPUPCALLTABLE g_pUpCallTable;		// 上层函数列表。如果LSP创建了自己的伪句柄,才使用这个函数列表
    WSPPROC_TABLE g_NextProcTable;		// 下层函数列表
    TCHAR	g_szCurrentApp[MAX_PATH];	// 当前调用本DLL的程序的名称
    
    
    BOOL APIENTRY DllMain( HANDLE hModule, 
                           DWORD  ul_reason_for_call, 
                           LPVOID lpReserved
    					 )
    {
    	switch (ul_reason_for_call)
    	{
    	case DLL_PROCESS_ATTACH:
    		{
    			// 取得主模块的名称
    			::GetModuleFileName(NULL, g_szCurrentApp, MAX_PATH);
    		}
    		break;
    	}
    	return TRUE;
    }
    
    int WSPAPI WSPStartup(
      WORD wVersionRequested,
      LPWSPDATA lpWSPData,
      LPWSAPROTOCOL_INFO lpProtocolInfo,
      WSPUPCALLTABLE UpcallTable,
      LPWSPPROC_TABLE lpProcTable
    )
    {
    	ODS1(L"  WSPStartup...  %s \n", g_szCurrentApp);
    	
    	if(lpProtocolInfo->ProtocolChain.ChainLen <= 1)
    	{	
    		return WSAEPROVIDERFAILEDINIT;
    	}
    	
    	// 保存向上调用的函数表指针(这里我们不使用它)
    	g_pUpCallTable = UpcallTable;
    
    	// 枚举协议,找到下层协议的WSAPROTOCOL_INFOW结构	
    	WSAPROTOCOL_INFOW	NextProtocolInfo;
    	int nTotalProtos;
    	LPWSAPROTOCOL_INFOW pProtoInfo = GetProvider(&nTotalProtos);
    	// 下层入口ID	
    	DWORD dwBaseEntryId = lpProtocolInfo->ProtocolChain.ChainEntries[1];
    	for(int i=0; i<nTotalProtos; i++)
    	{
    		if(pProtoInfo[i].dwCatalogEntryId == dwBaseEntryId)
    		{
    			memcpy(&NextProtocolInfo, &pProtoInfo[i], sizeof(NextProtocolInfo));
    			break;
    		}
    	}
    	if(i >= nTotalProtos)
    	{
    		ODS(L" WSPStartup:	Can not find underlying protocol \n");
    		return WSAEPROVIDERFAILEDINIT;
    	}
    
    	// 加载下层协议的DLL
    	int nError;
    	TCHAR szBaseProviderDll[MAX_PATH];
    	int nLen = MAX_PATH;
    	// 取得下层提供程序DLL路径
    	if(::WSCGetProviderPath(&NextProtocolInfo.ProviderId, szBaseProviderDll, &nLen, &nError) == SOCKET_ERROR)
    	{
    		ODS1(L" WSPStartup: WSCGetProviderPath() failed %d \n", nError);
    		return WSAEPROVIDERFAILEDINIT;
    	}
    	if(!::ExpandEnvironmentStrings(szBaseProviderDll, szBaseProviderDll, MAX_PATH))
    	{
    		ODS1(L" WSPStartup:  ExpandEnvironmentStrings() failed %d \n", ::GetLastError());
    		return WSAEPROVIDERFAILEDINIT;
    	}
    	// 加载下层提供程序
    	HMODULE hModule = ::LoadLibrary(szBaseProviderDll);
    	if(hModule == NULL)
    	{
    		ODS1(L" WSPStartup:  LoadLibrary() failed %d \n", ::GetLastError());
    		return WSAEPROVIDERFAILEDINIT;
    	}
    
    	// 导入下层提供程序的WSPStartup函数
    	LPWSPSTARTUP  pfnWSPStartup = NULL;
    	pfnWSPStartup = (LPWSPSTARTUP)::GetProcAddress(hModule, "WSPStartup");
    	if(pfnWSPStartup == NULL)
    	{
    		ODS1(L" WSPStartup:  GetProcAddress() failed %d \n", ::GetLastError());
    		return WSAEPROVIDERFAILEDINIT;
    	}
    
    	// 调用下层提供程序的WSPStartup函数
    	LPWSAPROTOCOL_INFOW pInfo = lpProtocolInfo;
    	if(NextProtocolInfo.ProtocolChain.ChainLen == BASE_PROTOCOL)
    		pInfo = &NextProtocolInfo;
    
    	int nRet = pfnWSPStartup(wVersionRequested, lpWSPData, pInfo, UpcallTable, lpProcTable);
    	if(nRet != ERROR_SUCCESS)
    	{
    		ODS1(L" WSPStartup:  underlying provider's WSPStartup() failed %d \n", nRet);
    		return nRet;
    	}
    
    	// 保存下层提供者的函数表
    	g_NextProcTable = *lpProcTable;
    
    	// 传给上层,截获对以下函数的调用
    	lpProcTable->lpWSPSocket = WSPSocket;
    	lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
    	lpProcTable->lpWSPBind = WSPBind;
    	lpProcTable->lpWSPAccept = WSPAccept;
    	lpProcTable->lpWSPConnect = WSPConnect;
    	lpProcTable->lpWSPSendTo = WSPSendTo;	
    	lpProcTable->lpWSPRecvFrom = WSPRecvFrom; 
    
    	FreeProvider(pProtoInfo);
    	return nRet;
    }
    
    
    
    SOCKET WSPAPI WSPSocket(
    	int			af,                               
    	int			type,                             
    	int			protocol,                         
    	LPWSAPROTOCOL_INFOW lpProtocolInfo,   
    	GROUP		g,                              
    	DWORD		dwFlags,                        
    	LPINT		lpErrno
    )
    {
    	// 首先调用下层函数创建套节字
    	SOCKET	s = g_NextProcTable.lpWSPSocket(af, type, protocol, lpProtocolInfo, g, dwFlags, lpErrno);
    	if(s == INVALID_SOCKET)
    		return s;
    
    	// 调用CAcl类的CheckSocket函数,设置会话属性
    	if (af == FROM_PROTOCOL_INFO)
    		af = lpProtocolInfo->iAddressFamily;
    	if (type == FROM_PROTOCOL_INFO)
    		type = lpProtocolInfo->iSocketType;
    	if (protocol == FROM_PROTOCOL_INFO)
    		protocol = lpProtocolInfo->iProtocol;
    
    	g_Acl.CheckSocket(s, af, type, protocol);
    
    	return s;
    }
    
    int WSPAPI WSPCloseSocket(
    	SOCKET		s,
    	LPINT		lpErrno
    )
    {
    	// 调用CAcl类的CheckCloseSocket函数,删除对应的会话
    	g_Acl.CheckCloseSocket(s);
    	return g_NextProcTable.lpWSPCloseSocket(s, lpErrno);
    }
    
    int WSPAPI WSPBind(SOCKET s, const struct sockaddr* name, int namelen, LPINT lpErrno)
    {
    	// 调用CAcl类的CheckBind函数,设置会话属性
    	g_Acl.CheckBind(s, name);
    	return g_NextProcTable.lpWSPBind(s, name, namelen, lpErrno);
    }
    
    int WSPAPI WSPConnect(
    	SOCKET			s,
    	const struct	sockaddr FAR * name,
    	int				namelen,
    	LPWSABUF		lpCallerData,
    	LPWSABUF		lpCalleeData,
    	LPQOS			lpSQOS,
    	LPQOS			lpGQOS,
    	LPINT			lpErrno
    )
    {
    	ODS1(L" WSPConnect...	%s", g_szCurrentApp);
    
    	// 检查是否允许连接到远程主机
    	if(g_Acl.CheckConnect(s, name) != PF_PASS)
    	{
    		*lpErrno = WSAECONNREFUSED;
    		ODS1(L" WSPConnect deny a query %s \n", g_szCurrentApp);
    		return SOCKET_ERROR;
    	} 
    
    	return g_NextProcTable.lpWSPConnect(s, name, namelen, lpCallerData, lpCalleeData, lpSQOS, lpGQOS, lpErrno);
    }
    
    SOCKET WSPAPI WSPAccept(
    	SOCKET			s,
    	struct sockaddr FAR *addr,
    	LPINT			addrlen,
    	LPCONDITIONPROC	lpfnCondition,
    	DWORD			dwCallbackData,
    	LPINT			lpErrno
    )
    {
    	ODS1(L"  PhoenixLSP:  WSPAccept  %s \n", g_szCurrentApp);
    
    	// 首先调用下层函数接收到来的连接
    	SOCKET	sNew	= g_NextProcTable.lpWSPAccept(s, addr, addrlen, lpfnCondition, dwCallbackData, lpErrno);
    	
    	// 检查是否允许,如果不允许,关闭新接收的连接
    	if (sNew != INVALID_SOCKET && g_Acl.CheckAccept(s, sNew, addr) != PF_PASS)
    	{
    		int iError;
    		g_NextProcTable.lpWSPCloseSocket(sNew, &iError);
    		*lpErrno = WSAECONNREFUSED;
    		return SOCKET_ERROR;
    	}
    
    	return sNew;
    }
    
    
    int WSPAPI WSPSendTo(
    	SOCKET			s,
    	LPWSABUF		lpBuffers,
    	DWORD			dwBufferCount,
    	LPDWORD			lpNumberOfBytesSent,
    	DWORD			dwFlags,
    	const struct sockaddr FAR * lpTo,
    	int				iTolen,
    	LPWSAOVERLAPPED	lpOverlapped,
    	LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
    	LPWSATHREADID	lpThreadId,
    	LPINT			lpErrno
    )
    {
    	ODS1(L" query send to... %s \n", g_szCurrentApp);
    	
    	// 检查是否允许发送数据
    	if (g_Acl.CheckSendTo(s, lpTo) != PF_PASS)
    	{
    		int		iError;
    		g_NextProcTable.lpWSPShutdown(s, SD_BOTH, &iError);
    		*lpErrno = WSAECONNABORTED;
    
    		ODS1(L" WSPSendTo deny query %s \n", g_szCurrentApp);
    
    		return SOCKET_ERROR;
    	}
    
    	// 调用下层发送函数
    	return g_NextProcTable.lpWSPSendTo(s, lpBuffers, dwBufferCount, 
    							lpNumberOfBytesSent, dwFlags, lpTo, iTolen, 
    								lpOverlapped, lpCompletionRoutine, lpThreadId, lpErrno);
    }
    
    int WSPAPI WSPRecvFrom (
    	SOCKET			s,
    	LPWSABUF		lpBuffers,
    	DWORD			dwBufferCount,
    	LPDWORD			lpNumberOfBytesRecvd,
    	LPDWORD			lpFlags,
    	struct sockaddr FAR * lpFrom,
    	LPINT			lpFromlen,
    	LPWSAOVERLAPPED lpOverlapped,
    	LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
    	LPWSATHREADID	lpThreadId,
    	LPINT			lpErrno
    )
    {
    	ODS1(L"  PhoenixLSP:  WSPRecvFrom %s \n", g_szCurrentApp);
    	
    	// 首先检查是否允许接收数据
    	if(g_Acl.CheckRecvFrom(s, lpFrom) != PF_PASS)
    	{
    		int		iError;
    		g_NextProcTable.lpWSPShutdown(s, SD_BOTH, &iError);
    		*lpErrno = WSAECONNABORTED;
    
    		ODS1(L" WSPRecvFrom deny query %s \n", g_szCurrentApp);
    		return SOCKET_ERROR;
    	}
    	
    	// 调用下层接收函数
    	return g_NextProcTable.lpWSPRecvFrom(s, lpBuffers, dwBufferCount, lpNumberOfBytesRecvd, 
    		lpFlags, lpFrom, lpFromlen, lpOverlapped, lpCompletionRoutine, lpThreadId, lpErrno);
    }
    
    
    
    LPWSAPROTOCOL_INFOW GetProvider(LPINT lpnTotalProtocols)
    {
    	DWORD dwSize = 0;
    	int nError;
    	LPWSAPROTOCOL_INFOW pProtoInfo = NULL;
    	
    	// 取得需要的长度
    	if(::WSCEnumProtocols(NULL, pProtoInfo, &dwSize, &nError) == SOCKET_ERROR)
    	{
    		if(nError != WSAENOBUFS)
    			return NULL;
    	}
    	
    	pProtoInfo = (LPWSAPROTOCOL_INFOW)::GlobalAlloc(GPTR, dwSize);
    	*lpnTotalProtocols = ::WSCEnumProtocols(NULL, pProtoInfo, &dwSize, &nError);
    	return pProtoInfo;
    }
    
    void FreeProvider(LPWSAPROTOCOL_INFOW pProtoInfo)
    {
    	::GlobalFree(pProtoInfo);
    }
    
    /*
    
    int WSPAPI WSPStartup(
      WORD wVersionRequested,
      LPWSPDATA lpWSPData,
      LPWSAPROTOCOL_INFO lpProtocolInfo,
      WSPUPCALLTABLE UpcallTable,
      LPWSPPROC_TABLE lpProcTable
    )
    {
    	ODS1(L"  PhoenixLSP:  WSPStartup  %s \n", g_szCurrentApp);
    
    	ODS1(L" %s", lpProtocolInfo->szProtocol);
    	
    	if(lpProtocolInfo->ProtocolChain.ChainLen <= 1)
    	{	
    		::OutputDebugString(L" Chain len <= 1 \n");
    		return WSAEPROVIDERFAILEDINIT;
    	}
    	
    	g_pUpCallTable = UpcallTable;
    	int nTotalProtos;
    	LPWSAPROTOCOL_INFOW pProtoInfo = GetProvider(&nTotalProtos);
    
    	
    	// 找到下层协议	
    	WSAPROTOCOL_INFOW	NextProtocolInfo;
    	// 下层入口ID	
    	DWORD dwBaseEntryId = lpProtocolInfo->ProtocolChain.ChainEntries[1];
    	for(int i=0; i<nTotalProtos; i++)
    	{
    		if(pProtoInfo[i].dwCatalogEntryId == dwBaseEntryId)
    		{
    			memcpy(&NextProtocolInfo, &pProtoInfo[i], sizeof(NextProtocolInfo));
    			break;
    		}
    	}
    	if(i >= nTotalProtos)
    	{
    		::OutputDebugString(L" Can not find next protocol <= 1 \n");
    		return WSAEPROVIDERFAILEDINIT;
    	}
    
    	// 加载下层协议的DLL
    	int nError;
    	TCHAR szBaseProviderDll[MAX_PATH];
    	int nLen = MAX_PATH;
    	if(::WSCGetProviderPath(&NextProtocolInfo.ProviderId, szBaseProviderDll, &nLen, &nError) == SOCKET_ERROR)
    	{
    		::OutputDebugString(L" WSCGetProviderPath() failed \n");
    		return WSAEPROVIDERFAILEDINIT;
    	}
    	if(!::ExpandEnvironmentStrings(szBaseProviderDll, szBaseProviderDll, MAX_PATH))
    	{
    		::OutputDebugString(L" ExpandEnvironmentStrings() failed \n");
    		return WSAEPROVIDERFAILEDINIT;
    	}
    	HMODULE hModule = ::LoadLibrary(szBaseProviderDll);
    	if(hModule == NULL)
    	{
    		::OutputDebugString(L" LoadLibrary() failed \n");
    		return WSAEPROVIDERFAILEDINIT;
    	}
    
    	// 调用下层协议的WSPStartup函数
    	LPWSPSTARTUP  proWSPStartup = NULL;
    	proWSPStartup = (LPWSPSTARTUP)::GetProcAddress(hModule, "WSPStartup");
    	if(proWSPStartup == NULL)
    	{
    		::OutputDebugString(L" GetProcAddress() failed \n");
    		return WSAEPROVIDERFAILEDINIT;
    	}
    
    	LPWSAPROTOCOL_INFOW pInfo = lpProtocolInfo;
    	if(NextProtocolInfo.ProtocolChain.ChainLen == BASE_PROTOCOL)
    		pInfo = &NextProtocolInfo;
    
    	int nRet = proWSPStartup(wVersionRequested, lpWSPData, pInfo, UpcallTable, lpProcTable);
    	if(nRet != ERROR_SUCCESS)
    	{
    		ODS1(L" next layer's WSPStartup() failed %d \n ", nRet);
    		return nRet;
    	}
    
    	// 保存下层协议的函数表
    	g_NextProcTable = *lpProcTable;
    
    	// 传给上层
    	lpProcTable->lpWSPSocket = WSPSocket;
    	lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
    
    	lpProcTable->lpWSPBind = WSPBind;
    
    	// tcp
    	lpProcTable->lpWSPAccept = WSPAccept;
    	lpProcTable->lpWSPConnect = WSPConnect;
    
    	// udp raw
    	lpProcTable->lpWSPSendTo = WSPSendTo;	
    	lpProcTable->lpWSPRecvFrom = WSPRecvFrom; 
    
    
    
    	FreeProvider(pProtoInfo);
    	return nRet;
    }
    
    */
    
    
    
    
    ////////////////////////////////////////////////////////////////////////////////
    
    /*	lpProcTable->lpWSPSend = WSPSend;
    	lpProcTable->lpWSPRecv = WSPRecv;
    */
    /*
    
      int WSPAPI WSPSend(
    	SOCKET			s,
    	LPWSABUF		lpBuffers,
    	DWORD			dwBufferCount,
    	LPDWORD			lpNumberOfBytesSent,
    	DWORD			dwFlags,
    	LPWSAOVERLAPPED	lpOverlapped,
    	LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
    	LPWSATHREADID	lpThreadId,
    	LPINT			lpErrno
    )
    {
    //	ODS1(L"  PhoenixLSP:  WSPSend  %s \n", g_szCurrentApp);
    
    	// ?? 多个Buf如何处理
    	if (g_Acl.CheckSend(s, lpBuffers[0].buf, *lpNumberOfBytesSent) != PF_PASS)
    	{
    		int		iError;
    		g_NextProcTable.lpWSPShutdown(s, SD_BOTH, &iError);
    		*lpErrno = WSAECONNABORTED;
    
    		ODS(L" deny a send ");
    		return SOCKET_ERROR;
    	}
    
    	return g_NextProcTable.lpWSPSend(s, lpBuffers, dwBufferCount, 
    				lpNumberOfBytesSent, dwFlags, lpOverlapped, lpCompletionRoutine, lpThreadId, lpErrno);
    }
    
    int WSPAPI WSPRecv(
    	SOCKET			s,
    	LPWSABUF		lpBuffers,
    	DWORD			dwBufferCount,
    	LPDWORD			lpNumberOfBytesRecvd,
    	LPDWORD			lpFlags,
    	LPWSAOVERLAPPED	lpOverlapped,
    	LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
    	LPWSATHREADID	lpThreadId,
    	LPINT			lpErrno
    )
    {
    	ODS1(L"  PhoenixLSP:  WSPRecv  %s \n", g_szCurrentApp);
    
    	if(lpOverlapped != NULL)
    	{
    		if(g_Acl.CheckRecv(s, NULL, 0) != PF_PASS)
    		{
    			int		iError;
    			g_NextProcTable.lpWSPShutdown(s, SD_BOTH, &iError);
    			*lpErrno = WSAECONNABORTED;
    
    			ODS(L"deny a recv");
    			return SOCKET_ERROR;
    		}
    		ODS(L" overlappped ");
    	}
    
    	int	iRet = g_NextProcTable.lpWSPRecv(s, lpBuffers, dwBufferCount, lpNumberOfBytesRecvd, lpFlags, lpOverlapped
    				, lpCompletionRoutine, lpThreadId, lpErrno);
    
    	if(iRet != SOCKET_ERROR && lpOverlapped == NULL)
    	{
    		if(g_Acl.CheckRecv(s, lpBuffers[0].buf, *lpNumberOfBytesRecvd) != PF_PASS)
    		{
    			int		iError;
    			g_NextProcTable.lpWSPShutdown(s, SD_BOTH, &iError);
    			*lpErrno = WSAECONNABORTED;
    
    			ODS(L"deny a recv");
    			return SOCKET_ERROR;
    		}
    	}
    	return iRet;
    }
    
      */
    


  • 相关阅读:
    SQL Server 自定义函数(Function)——参数默认值
    SQL Server返回插入数据的ID和受影响的行数
    SQL Server去重和判断是否为数字——OBJECT_ID的使用
    SQL Server插入数据和删除数据
    SQL Server语句创建数据库和表——并设置主外键关系
    SQL Server 数据分页查询
    C# ref、out、params与值类型参数修饰符
    C#字符串的方法
    C#数组的声明
    tcpdump的使用
  • 原文地址:https://www.cnblogs.com/new0801/p/6177674.html
Copyright © 2011-2022 走看看