原始套接字之用户与内核通信

1220阅读 0评论2013-03-18 spyhjl
分类:LINUX

==================================== header file ==============================
#ifndef __IMP2_H__
#define __IMP2_H__

#define IMP2_U_PID   0
#define IMP2_K_MSG   1
#define IMP2_CLOSE   2

#define NL_IMP2      31

struct packet_info
{
  __u32 src;
  __u32 dest;
  __be16 source_port;
  __be16 dest_port;
};

#endif

==================================== user level src file ===========================
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include "imp2.h"

struct msg_to_kernel
{
    struct nlmsghdr hdr;
};

struct u_packet_info
{
    struct nlmsghdr hdr;
    struct packet_info icmp_info;
};

static int skfd;

static void sig_int(int signo)
{
    struct sockaddr_nl kpeer;
    struct msg_to_kernel message;

    memset(&kpeer, 0, sizeof(kpeer));
    kpeer.nl_family = AF_NETLINK;
    kpeer.nl_pid    = 0;
    kpeer.nl_groups = 0;

    memset(&message, 0, sizeof(message));
    message.hdr.nlmsg_len = NLMSG_LENGTH(0);
    message.hdr.nlmsg_flags = 0;
    message.hdr.nlmsg_type = IMP2_CLOSE;
    message.hdr.nlmsg_pid = getpid();

    sendto(skfd, &message, message.hdr.nlmsg_len, 0, (struct sockaddr *)(&kpeer),
            sizeof(kpeer));

    close(skfd);
    exit(0);
}

int main(void)
{
    struct sockaddr_nl local;
    struct sockaddr_nl kpeer;
    socklen_t kpeerlen;
    struct msg_to_kernel message;
    struct u_packet_info info;
    int rcvlen = 0;
    struct in_addr addr;

    skfd = socket(PF_NETLINK, SOCK_RAW, NL_IMP2);
    if(skfd < 0)
    {
        printf("can not create a netlink socket\n");
        exit(0);
    }

    memset(&local, 0, sizeof(local));
    local.nl_family = AF_NETLINK;
    local.nl_pid = getpid();
    local.nl_groups = 0;
    if(bind(skfd, (struct sockaddr*)&local, sizeof(local)) != 0)
    {
        printf("bind() error\n");
        return -1;
    }

    signal(SIGINT, sig_int);

    memset(&kpeer, 0, sizeof(kpeer));
    kpeer.nl_family = AF_NETLINK;
    kpeer.nl_pid = 0;
    kpeer.nl_groups = 0;

    memset(&message, 0, sizeof(message));
    message.hdr.nlmsg_len = NLMSG_LENGTH(0);
    message.hdr.nlmsg_flags = 0;
    message.hdr.nlmsg_type = IMP2_U_PID;
    message.hdr.nlmsg_pid = local.nl_pid;

    printf("local pid:%d\n", local.nl_pid);

    sendto(skfd, &message, message.hdr.nlmsg_len, 0,
            (struct sockaddr*)&kpeer, sizeof(kpeer));

    while(1)
    {
        kpeerlen = sizeof(struct sockaddr_nl);
        rcvlen = recvfrom(skfd, &info, sizeof(struct u_packet_info),
                0, (struct sockaddr*)&kpeer, &kpeerlen);

        addr.s_addr = info.icmp_info.src;
        printf("src: %s.%u, ", inet_ntoa(addr), (info.icmp_info.source_port));
        addr.s_addr = info.icmp_info.dest;
        printf("dest: %s.%u\n", inet_ntoa(addr), (info.icmp_info.dest_port));
    }

    return 0;
}

==================================== kernel level src  ============================
#ifndef __KERNEL__
#define __KERNEL__
#endif

#ifndef MODULE
#define MODULE
#endif

#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include "imp2.h"

MODULE_LICENSE("GPL");
MODULE_AUTHOR("DRAGON");
MODULE_DESCRIPTION("test");
MODULE_ALIAS("IMP2");


DECLARE_MUTEX(receive_sem);

static struct sock *nlfd;

struct
{
    __u32 pid;
    rwlock_t lock;
}user_proc;

#if LINUX_VERSION_CODE >= 0x02061B
#if 1
static void nl_receive(struct  sk_buff *sk)
{
    struct sk_buff *skb = skb_get(sk);
    if(down_trylock(&receive_sem))
        return;


    if(skb != NULL)
    {
        struct nlmsghdr *nlh = NULL;
        if(skb->len >= sizeof(struct nlmsghdr))
        {

            nlh = (struct nlmsghdr *)skb->data;
            if((nlh->nlmsg_len >= sizeof(struct nlmsghdr))
                    && (skb->len >= nlh->nlmsg_len))
            {
                
                printk("type::%d::%x\n", nlh->nlmsg_type,nlh->nlmsg_type);
                printk("pid::%d::%x\n", nlh->nlmsg_pid,nlh->nlmsg_pid);
                if(nlh->nlmsg_type == IMP2_U_PID)
                {
                    printk("imp2 sured *******\n");
                    write_lock_bh(&user_proc.lock);
                    user_proc.pid = nlh->nlmsg_pid;
                    write_unlock_bh(&user_proc.lock);
                }
                else if(nlh->nlmsg_type == IMP2_CLOSE)
                {
                    printk("imp2 close *******\n");
                    write_lock_bh(&user_proc.lock);
                    if(nlh->nlmsg_pid == user_proc.pid)
                        user_proc.pid = 0;
                    write_unlock_bh(&user_proc.lock);
                }
            }
        }
        kfree_skb(skb);
    }
    printk("*********received out**********\n");
    up(&receive_sem);
}
#endif
#if 0
static void nl_receive(struct  sk_buff *sk)
{
    do
    {
        struct sk_buff *skb = NULL;
        if(down_trylock(&receive_sem))
            return;

        printk("*********rrrrrrrrrrrrr**********\n");

        while((skb = skb_dequeue(&sk->sk->sk_receive_queue)) != NULL)
       // while((skb = skb_get(sk)) != NULL)
        {
            printk("*********received in**********\n");
            {
                struct nlmsghdr *nlh = NULL;

                if(skb->len >= sizeof(struct nlmsghdr))
                {
                    nlh = (struct nlmsghdr *)skb->data;
                    if((nlh->nlmsg_len >= sizeof(struct nlmsghdr))
                            && (skb->len >= nlh->nlmsg_len))
                    {
                        if(nlh->nlmsg_type == IMP2_U_PID)
                        {
                            printk("imp2 sured *******\n");
                            write_lock_bh(&user_proc.lock);
                            user_proc.pid = nlh->nlmsg_pid;
                            write_unlock_bh(&user_proc.lock);
                        }
                        else if(nlh->nlmsg_type == IMP2_CLOSE)
                        {
                            write_lock_bh(&user_proc.lock);
                            if(nlh->nlmsg_pid == user_proc.pid)
                                user_proc.pid = 0;
                            write_unlock_bh(&user_proc.lock);
                        }
                    }
                }
            }
            kfree_skb(skb);
        }
        printk("*********received out**********\n");
        up(&receive_sem);
    }while(nlfd && nlfd->sk_receive_queue.qlen);
}
#endif
#else 
static void nl_receive(struct sock *sk, int len)
{
    /* 通过skb = skb_dequeue(&sk->sk_receive_queue)得到实际数据 */
    do
    {
        struct sk_buff *skb;
        if(down_trylock(&receive_sem))
            return;

        while((skb = skb_dequeue(&sk->sk_receive_queue)) != NULL)
        {
            {
                struct nlmsghdr *nlh = NULL;

                if(skb->len >= sizeof(struct nlmsghdr))
                {
                    nlh = (struct nlmsghdr *)skb->data;
                    if((nlh->nlmsg_len >= sizeof(struct nlmsghdr))
                            && (skb->len >= nlh->nlmsg_len))
                    {
                        if(nlh->nlmsg_type == IMP2_U_PID)
                        {
                            write_lock_bh(&user_proc.lock);
                            user_proc.pid = nlh->nlmsg_pid;
                            write_unlock_bh(&user_proc.lock);
                        }
                        else if(nlh->nlmsg_type == IMP2_CLOSE)
                        {
                            write_lock_bh(&user_proc.lock);
                            if(nlh->nlmsg_pid == user_proc.pid)
                                user_proc.pid = 0;
                            write_unlock_bh(&user_proc.lock);
                        }
                    }
                }
            }
            kfree_skb(skb);
        }
        up(&receive_sem);
    }while(nlfd && nlfd->sk_receive_queue.qlen);
}
#endif

static int send_to_user(struct packet_info *info)
{
    int ret;
    int size;
    //unsigned char *old_tail;
    unsigned int old_tail;
    struct sk_buff *skb;
    struct nlmsghdr *nlh;
    struct packet_info *packet;

    /* NLMSG_SPACE rerturn the total lenth including the part 
     * of nlmsghdr header and real data */
    size = NLMSG_SPACE(sizeof(*info));

    skb = alloc_skb(size, GFP_ATOMIC);
    old_tail = skb->tail;


    nlh = NLMSG_PUT(skb, 0, 0, IMP2_K_MSG, size-sizeof(*nlh));
    packet = NLMSG_DATA(nlh);
    memset(packet, 0, sizeof(struct packet_info));

    packet->src = info->src;
    packet->dest = info->dest;
    packet->source_port = info->source_port;
    packet->dest_port = info->dest_port;

    nlh->nlmsg_len = skb->tail - old_tail;
    NETLINK_CB(skb).dst_group = 0;

    read_lock_bh(&user_proc.lock);
    ret = netlink_unicast(nlfd, skb, user_proc.pid, MSG_DONTWAIT);
    read_unlock_bh(&user_proc.lock);

    return ret;

nlmsg_failure:
    if(skb)
        kfree_skb(skb);
    return -1;
}

static unsigned int get_icmp(unsigned int hook,
        struct sk_buff *pskb,
        const struct net_device *in,
        const struct net_device *out,
        int (*okfn)(struct sk_buff *))
{

#if LINUX_VERSION_CODE >= 0x02061B
    struct iphdr *iph = ip_hdr(pskb);
#else
    struct iphdr *iph = (*pskb)->nh.iph;
#endif
    struct packet_info info;
    struct tcphdr *tcp;

    struct sk_buff *skb = skb_copy(pskb, GFP_ATOMIC);

#if 0
    if(iph->protocol == IPPROTO_ICMP){
        printk("icmp packet in......\n");
        read_lock_bh(&user_proc.lock);
        if(user_proc.pid != 0){
            read_unlock_bh(&user_proc.lock);
            info.src = iph->saddr;
            info.dest = iph->daddr;
            send_to_user(&info);
            printk("send icmp info......\n");
        }else
            read_unlock_bh(&user_proc.lock);
    } 
#endif
    if(iph->protocol == IPPROTO_TCP){
        printk("tcp packet in......\n");

        read_lock_bh(&user_proc.lock);
        if(user_proc.pid != 0){
            read_unlock_bh(&user_proc.lock);
            info.src = iph->saddr;
            info.dest = iph->daddr;

            //tcp = tcp_hdr(skb);
            tcp = (struct tcphdr*)(skb->data + skb->transport_header);

            info.source_port = ntohs(tcp->source);
            info.dest_port = ntohs(tcp->dest);
            printk("src: %d, dst:%d\n", tcp->source, tcp->dest);
            printk("src: %u, dst:%u\n", info.source_port, info.dest_port);

            send_to_user(&info);
            printk("send tcp info......\n");
        }else
            read_unlock_bh(&user_proc.lock);
    }

    return NF_ACCEPT;
}

static struct nf_hook_ops imp2_ops ={
    .hook = get_icmp,
    .pf = PF_INET,
    .hooknum = NF_INET_PRE_ROUTING,
    .priority = NF_IP_PRI_FILTER -1,
};

static int __init init(void)
{
    printk(KERN_ALERT "module init\n");

    rwlock_init(&user_proc.lock);

#if LINUX_VERSION_CODE >= 0x020618
    nlfd = netlink_kernel_create(&init_net, NL_IMP2, 0, 
            nl_receive, NULL, THIS_MODULE);
#elif LINUX_VERSION_CODE >= 0x020616
    nlfd = netlink_kernel_create(NL_IMP2, 0, nl_receive, 
            THIS_MODULE);
#else
    nlfd = netlink_kernel_create(NL_IMP2, nl_receive);
#endif

    if(!nlfd)
    {
        printk("can not create a netlink socket\n");
        return -1;
    }

    return nf_register_hook(&imp2_ops);
}

static void __exit fini(void)
{
    printk(KERN_ALERT "module exit\n");

    if(nlfd)
    {
#if LINUX_VERSION_CODE >= 0x020618
        sock_release(nlfd->sk_socket);
#else
        sock_release(nlfd->socket);
#endif
    }
    nf_unregister_hook(&imp2_ops);
}

module_init(init);
module_exit(fini);
==================================== makefile ================================
#ifneq($(KERNELRELEASE),)
obj-m := imp2_k.o
#else

KERNELDIR := /usr/src/linux/

modules:
    $(MAKE) -C $(KERNELDIR) SUBDIRS=$(PWD) $@
install:
    insmod imp2_k.ko
uninstall:
    rmmod imp2_k.ko
test:
    gcc -g -Wall imp2_u.c -o test
clean:
    rm -rf *.ko *.o *.mod.c *.symvers *.order test
==================================== end ====================================

上一篇:Shell编程(一)基础语法
下一篇:Shell编程(二)练手篇